diff options
| author | Jacob Young <jacobly0@users.noreply.github.com> | 2023-05-07 05:01:37 -0400 |
|---|---|---|
| committer | Jacob Young <jacobly0@users.noreply.github.com> | 2023-05-08 07:36:20 -0400 |
| commit | ea957c4cff77f045108863cb5552b3511cb455c1 (patch) | |
| tree | a40ee9152c2c8b830e32ee81f0d7df39d2a3681a | |
| parent | 5c5da179fb930c9d8be9366a851eb4a36f4044f1 (diff) | |
| download | zig-ea957c4cff77f045108863cb5552b3511cb455c1.tar.gz zig-ea957c4cff77f045108863cb5552b3511cb455c1.zip | |
x86_64: implement `@sqrt` for `f16` scalars and vectors
| -rw-r--r-- | src/arch/x86_64/CodeGen.zig | 156 | ||||
| -rw-r--r-- | src/arch/x86_64/encodings.zig | 4 | ||||
| -rw-r--r-- | test/behavior/floatop.zig | 1 |
3 files changed, 109 insertions, 52 deletions
diff --git a/src/arch/x86_64/CodeGen.zig b/src/arch/x86_64/CodeGen.zig index 19878bae17..6337ad23f5 100644 --- a/src/arch/x86_64/CodeGen.zig +++ b/src/arch/x86_64/CodeGen.zig @@ -4531,59 +4531,117 @@ fn airSqrt(self: *Self, inst: Air.Inst.Index) !void { const dst_lock = self.register_manager.lockReg(dst_reg); defer if (dst_lock) |lock| self.register_manager.unlockReg(lock); - const tag = if (@as(?Mir.Inst.Tag, switch (ty.zigTypeTag()) { - .Float => switch (ty.childType().floatBits(self.target.*)) { - 32 => if (self.hasFeature(.avx)) .vsqrtss else .sqrtss, - 64 => if (self.hasFeature(.avx)) .vsqrtsd else .sqrtsd, - 16, 80, 128 => null, - else => unreachable, - }, - .Vector => switch (ty.childType().zigTypeTag()) { - .Float => switch (ty.childType().floatBits(self.target.*)) { - 32 => switch (ty.vectorLen()) { - 1 => if (self.hasFeature(.avx)) .vsqrtss else .sqrtss, - 2...4 => if (self.hasFeature(.avx)) .vsqrtps else .sqrtps, - 5...8 => if (self.hasFeature(.avx)) .vsqrtps else null, - else => null, - }, - 64 => switch (ty.vectorLen()) { - 1 => if (self.hasFeature(.avx)) .vsqrtsd else .sqrtsd, - 2 => if (self.hasFeature(.avx)) .vsqrtpd else .sqrtpd, - 3...4 => if (self.hasFeature(.avx)) .vsqrtpd else null, - else => null, + const result: MCValue = result: { + const tag = if (@as(?Mir.Inst.Tag, switch (ty.zigTypeTag()) { + .Float => switch (ty.floatBits(self.target.*)) { + 16 => if (self.hasFeature(.f16c)) { + const mat_src_reg = if (src_mcv.isRegister()) + src_mcv.getReg().? + else + try self.copyToTmpRegister(ty, src_mcv); + try self.asmRegisterRegister(.vcvtph2ps, dst_reg, mat_src_reg.to128()); + try self.asmRegisterRegisterRegister(.vsqrtss, dst_reg, dst_reg, dst_reg); + try self.asmRegisterRegisterImmediate( + .vcvtps2ph, + dst_reg, + dst_reg, + Immediate.u(0b1_00), + ); + break :result dst_mcv; + } else null, + 32 => if (self.hasFeature(.avx)) .vsqrtss else .sqrtss, + 64 => if (self.hasFeature(.avx)) .vsqrtsd else .sqrtsd, + 80, 128 => null, + else => unreachable, + }, + .Vector => switch (ty.childType().zigTypeTag()) { + .Float => switch (ty.childType().floatBits(self.target.*)) { + 16 => if (self.hasFeature(.f16c)) switch (ty.vectorLen()) { + 1 => { + const mat_src_reg = if (src_mcv.isRegister()) + src_mcv.getReg().? + else + try self.copyToTmpRegister(ty, src_mcv); + try self.asmRegisterRegister(.vcvtph2ps, dst_reg, mat_src_reg.to128()); + try self.asmRegisterRegisterRegister(.vsqrtss, dst_reg, dst_reg, dst_reg); + try self.asmRegisterRegisterImmediate( + .vcvtps2ph, + dst_reg, + dst_reg, + Immediate.u(0b1_00), + ); + break :result dst_mcv; + }, + 2...8 => { + const wide_reg = registerAlias(dst_reg, abi_size * 2); + if (src_mcv.isRegister()) try self.asmRegisterRegister( + .vcvtph2ps, + wide_reg, + src_mcv.getReg().?.to128(), + ) else try self.asmRegisterMemory( + .vcvtph2ps, + wide_reg, + src_mcv.mem(Memory.PtrSize.fromSize( + @intCast(u32, @divExact(wide_reg.bitSize(), 16)), + )), + ); + try self.asmRegisterRegister(.vsqrtps, wide_reg, wide_reg); + try self.asmRegisterRegisterImmediate( + .vcvtps2ph, + dst_reg, + wide_reg, + Immediate.u(0b1_00), + ); + break :result dst_mcv; + }, + else => null, + } else null, + 32 => switch (ty.vectorLen()) { + 1 => if (self.hasFeature(.avx)) .vsqrtss else .sqrtss, + 2...4 => if (self.hasFeature(.avx)) .vsqrtps else .sqrtps, + 5...8 => if (self.hasFeature(.avx)) .vsqrtps else null, + else => null, + }, + 64 => switch (ty.vectorLen()) { + 1 => if (self.hasFeature(.avx)) .vsqrtsd else .sqrtsd, + 2 => if (self.hasFeature(.avx)) .vsqrtpd else .sqrtpd, + 3...4 => if (self.hasFeature(.avx)) .vsqrtpd else null, + else => null, + }, + 80, 128 => null, + else => unreachable, }, - 16, 80, 128 => null, else => unreachable, }, else => unreachable, - }, - else => unreachable, - })) |tag| tag else return self.fail("TODO implement airSqrt for {}", .{ - ty.fmt(self.bin_file.options.module.?), - }); - switch (tag) { - .vsqrtss, .vsqrtsd => if (src_mcv.isRegister()) try self.asmRegisterRegisterRegister( - tag, - dst_reg, - dst_reg, - registerAlias(src_mcv.getReg().?, abi_size), - ) else try self.asmRegisterRegisterMemory( - tag, - dst_reg, - dst_reg, - src_mcv.mem(Memory.PtrSize.fromSize(abi_size)), - ), - else => if (src_mcv.isRegister()) try self.asmRegisterRegister( - tag, - dst_reg, - registerAlias(src_mcv.getReg().?, abi_size), - ) else try self.asmRegisterMemory( - tag, - dst_reg, - src_mcv.mem(Memory.PtrSize.fromSize(abi_size)), - ), - } - return self.finishAir(inst, dst_mcv, .{ un_op, .none, .none }); + })) |tag| tag else return self.fail("TODO implement airSqrt for {}", .{ + ty.fmt(self.bin_file.options.module.?), + }); + switch (tag) { + .vsqrtss, .vsqrtsd => if (src_mcv.isRegister()) try self.asmRegisterRegisterRegister( + tag, + dst_reg, + dst_reg, + registerAlias(src_mcv.getReg().?, abi_size), + ) else try self.asmRegisterRegisterMemory( + tag, + dst_reg, + dst_reg, + src_mcv.mem(Memory.PtrSize.fromSize(abi_size)), + ), + else => if (src_mcv.isRegister()) try self.asmRegisterRegister( + tag, + dst_reg, + registerAlias(src_mcv.getReg().?, abi_size), + ) else try self.asmRegisterMemory( + tag, + dst_reg, + src_mcv.mem(Memory.PtrSize.fromSize(abi_size)), + ), + } + break :result dst_mcv; + }; + return self.finishAir(inst, result, .{ un_op, .none, .none }); } fn airUnaryMath(self: *Self, inst: Air.Inst.Index) !void { diff --git a/src/arch/x86_64/encodings.zig b/src/arch/x86_64/encodings.zig index 49ebc344fd..78bda4fc76 100644 --- a/src/arch/x86_64/encodings.zig +++ b/src/arch/x86_64/encodings.zig @@ -1047,9 +1047,9 @@ pub const table = [_]Entry{ .{ .vsqrtps, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x0f, 0x51 }, 0, .vex_128_wig, .avx }, .{ .vsqrtps, .rm, &.{ .ymm, .ymm_m256 }, &.{ 0x0f, 0x51 }, 0, .vex_256_wig, .avx }, - .{ .vsqrtsd, .rvm, &.{ .xmm, .xmm, .xmm_m64 }, &.{ 0xf2, 0x0f }, 0, .vex_lig_wig, .avx }, + .{ .vsqrtsd, .rvm, &.{ .xmm, .xmm, .xmm_m64 }, &.{ 0xf2, 0x0f, 0x51 }, 0, .vex_lig_wig, .avx }, - .{ .vsqrtss, .rvm, &.{ .xmm, .xmm, .xmm_m32 }, &.{ 0xf3, 0x0f }, 0, .vex_lig_wig, .avx }, + .{ .vsqrtss, .rvm, &.{ .xmm, .xmm, .xmm_m32 }, &.{ 0xf3, 0x0f, 0x51 }, 0, .vex_lig_wig, .avx }, // F16C .{ .vcvtph2ps, .rm, &.{ .xmm, .xmm_m64 }, &.{ 0x66, 0x0f, 0x38, 0x13 }, 0, .vex_128_w0, .f16c }, diff --git a/test/behavior/floatop.zig b/test/behavior/floatop.zig index ec24407d9f..3f407061f4 100644 --- a/test/behavior/floatop.zig +++ b/test/behavior/floatop.zig @@ -135,7 +135,6 @@ fn testSqrt() !void { test "@sqrt with vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO |
