diff options
| author | Luuk de Gram <luuk@degram.dev> | 2022-05-30 21:07:03 +0200 |
|---|---|---|
| committer | Luuk de Gram <luuk@degram.dev> | 2022-06-24 08:12:17 +0200 |
| commit | 5ebaf49ebb65f3b201e56ec7d89b7d91f189409e (patch) | |
| tree | 7dff0022ba12267782afe07a9236ef1e9b2f2756 | |
| parent | 241180216f92c86f833910d73e1c86f54ae940e4 (diff) | |
| download | zig-5ebaf49ebb65f3b201e56ec7d89b7d91f189409e.tar.gz zig-5ebaf49ebb65f3b201e56ec7d89b7d91f189409e.zip | |
wasm: Implement basic f16 support
This implements binary operations and comparisons
for floats with bitsize 16. It does this by calling into
compiler-rt to first extend the float to 32 bits, perform the operation,
and then finally truncate back to 16 bits. When loading and storing the f16,
we do this as an unsigned 16bit integer.
| -rw-r--r-- | src/arch/wasm/CodeGen.zig | 124 |
1 files changed, 105 insertions, 19 deletions
diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index 2f82f1f694..462393a150 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -215,8 +215,7 @@ fn buildOpcode(args: OpcodeBuildArguments) wasm.Opcode { 16 => switch (args.valtype1.?) { .i32 => if (args.signedness.? == .signed) return .i32_load16_s else return .i32_load16_u, .i64 => if (args.signedness.? == .signed) return .i64_load16_s else return .i64_load16_u, - .f32 => return .f32_load, - .f64 => unreachable, + .f32, .f64 => unreachable, }, 32 => switch (args.valtype1.?) { .i64 => if (args.signedness.? == .signed) return .i64_load32_s else return .i64_load32_u, @@ -246,8 +245,7 @@ fn buildOpcode(args: OpcodeBuildArguments) wasm.Opcode { 16 => switch (args.valtype1.?) { .i32 => return .i32_store16, .i64 => return .i64_store16, - .f32 => return .f32_store, - .f64 => unreachable, + .f32, .f64 => unreachable, }, 32 => switch (args.valtype1.?) { .i64 => return .i64_store32, @@ -725,7 +723,8 @@ fn typeToValtype(ty: Type, target: std.Target) wasm.Valtype { return switch (ty.zigTypeTag()) { .Float => blk: { const bits = ty.floatBits(target); - if (bits == 16 or bits == 32) break :blk wasm.Valtype.f32; + if (bits == 16) return wasm.Valtype.i32; // stored/loaded as u16 + if (bits == 32) break :blk wasm.Valtype.f32; if (bits == 64) break :blk wasm.Valtype.f64; if (bits == 128) break :blk wasm.Valtype.i64; return wasm.Valtype.i32; // represented as pointer to stack @@ -2013,6 +2012,10 @@ fn binOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WVa } } + if (ty.isAnyFloat() and ty.floatBits(self.target) == 16) { + return self.binOpFloat16(lhs, rhs, op); + } + const opcode: wasm.Opcode = buildOpcode(.{ .op = op, .valtype1 = typeToValtype(ty, self.target), @@ -2029,6 +2032,20 @@ fn binOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WVa return bin_local; } +fn binOpFloat16(self: *Self, lhs: WValue, rhs: WValue, op: Op) InnerError!WValue { + const ext_lhs = try self.fpext(lhs, Type.f16, Type.f32); + const ext_rhs = try self.fpext(rhs, Type.f16, Type.f32); + + const opcode: wasm.Opcode = buildOpcode(.{ .op = op, .valtype1 = .f32, .signedness = .unsigned }); + try self.emitWValue(ext_lhs); + try self.emitWValue(ext_rhs); + try self.addTag(Mir.Inst.Tag.fromOpcode(opcode)); + + // re-use temporary local + try self.addLabel(.local_set, ext_lhs.local); + return self.fptrunc(ext_lhs, Type.f32, Type.f16); +} + fn binOpBigInt(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WValue { if (ty.intInfo(self.target).bits > 128) { return self.fail("TODO: Implement binary operation for big integer", .{}); @@ -2310,8 +2327,9 @@ fn lowerConstant(self: *Self, val: Value, ty: Type) InnerError!WValue { }, .Bool => return WValue{ .imm32 = @intCast(u32, val.toUnsignedInt(target)) }, .Float => switch (ty.floatBits(self.target)) { - 0...32 => return WValue{ .float32 = val.toFloat(f32) }, - 33...64 => return WValue{ .float64 = val.toFloat(f64) }, + 16 => return WValue{ .imm32 = @bitCast(u16, val.toFloat(f16)) }, + 32 => return WValue{ .float32 = val.toFloat(f32) }, + 64 => return WValue{ .float64 = val.toFloat(f64) }, else => unreachable, }, .Pointer => switch (val.tag()) { @@ -2389,8 +2407,9 @@ fn emitUndefined(self: *Self, ty: Type) InnerError!WValue { else => unreachable, }, .Float => switch (ty.floatBits(self.target)) { - 0...32 => return WValue{ .float32 = @bitCast(f32, @as(u32, 0xaaaaaaaa)) }, - 33...64 => return WValue{ .float64 = @bitCast(f64, @as(u64, 0xaaaaaaaaaaaaaaaa)) }, + 16 => return WValue{ .imm32 = 0xaaaaaaaa }, + 32 => return WValue{ .float32 = @bitCast(f32, @as(u32, 0xaaaaaaaa)) }, + 64 => return WValue{ .float64 = @bitCast(f64, @as(u64, 0xaaaaaaaaaaaaaaaa)) }, else => unreachable, }, .Pointer => switch (self.arch()) { @@ -2562,6 +2581,8 @@ fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOper } } else if (isByRef(ty, self.target)) { return self.cmpBigInt(lhs, rhs, ty, op); + } else if (ty.isAnyFloat() and ty.floatBits(self.target) == 16) { + return self.cmpFloat16(lhs, rhs, op); } // ensure that when we compare pointers, we emit @@ -2595,6 +2616,31 @@ fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOper return cmp_tmp; } +fn cmpFloat16(self: *Self, lhs: WValue, rhs: WValue, op: std.math.CompareOperator) InnerError!WValue { + const ext_lhs = try self.fpext(lhs, Type.f16, Type.f32); + const ext_rhs = try self.fpext(rhs, Type.f16, Type.f32); + + const opcode: wasm.Opcode = buildOpcode(.{ + .op = switch (op) { + .lt => .lt, + .lte => .le, + .eq => .eq, + .neq => .ne, + .gte => .ge, + .gt => .gt, + }, + .valtype1 = .f32, + .signedness = .unsigned, + }); + try self.emitWValue(ext_lhs); + try self.emitWValue(ext_rhs); + try self.addTag(Mir.Inst.Tag.fromOpcode(opcode)); + + const result = try self.allocLocal(Type.initTag(.i32)); // bool is always i32 + try self.addLabel(.local_set, result.local); + return result; +} + fn airCmpVector(self: *Self, inst: Air.Inst.Index) InnerError!WValue { _ = inst; return self.fail("TODO implement airCmpVector for wasm", .{}); @@ -3934,19 +3980,44 @@ fn airFpext(self: *Self, inst: Air.Inst.Index) InnerError!WValue { const ty_op = self.air.instructions.items(.data)[inst].ty_op; const dest_ty = self.air.typeOfIndex(inst); - const dest_bits = dest_ty.floatBits(self.target); - const src_bits = self.air.typeOf(ty_op.operand).floatBits(self.target); const operand = try self.resolveInst(ty_op.operand); - if (dest_bits == 64 and src_bits == 32) { - const result = try self.allocLocal(dest_ty); + return self.fpext(operand, self.air.typeOf(ty_op.operand), dest_ty); +} + +fn fpext(self: *Self, operand: WValue, given: Type, wanted: Type) InnerError!WValue { + const given_bits = given.floatBits(self.target); + const wanted_bits = wanted.floatBits(self.target); + + if (wanted_bits == 64 and given_bits == 32) { + const result = try self.allocLocal(wanted); try self.emitWValue(operand); try self.addTag(.f64_promote_f32); try self.addLabel(.local_set, result.local); return result; + } else if (given_bits == 16) { + // call __extendhfsf2(f16) f32 + const f32_result = try self.callIntrinsic( + "__extendhfsf2", + &.{Type.f16}, + Type.f32, + &.{operand}, + ); + + if (wanted_bits == 32) { + return f32_result; + } + if (wanted_bits == 64) { + const result = try self.allocLocal(wanted); + try self.emitWValue(f32_result); + try self.addTag(.f64_promote_f32); + try self.addLabel(.local_set, result.local); + return result; + } + return self.fail("TODO: Implement 'fpext' for floats with bitsize: {d}", .{wanted_bits}); } else { // TODO: Emit a call to compiler-rt to extend the float. e.g. __extendhfsf2 - return self.fail("TODO: Implement 'fpext' for floats with bitsize: {d}", .{dest_bits}); + return self.fail("TODO: Implement 'fpext' for floats with bitsize: {d}", .{wanted_bits}); } } @@ -3955,19 +4026,34 @@ fn airFptrunc(self: *Self, inst: Air.Inst.Index) InnerError!WValue { const ty_op = self.air.instructions.items(.data)[inst].ty_op; const dest_ty = self.air.typeOfIndex(inst); - const dest_bits = dest_ty.floatBits(self.target); - const src_bits = self.air.typeOf(ty_op.operand).floatBits(self.target); const operand = try self.resolveInst(ty_op.operand); + return self.fptrunc(operand, self.air.typeOf(ty_op.operand), dest_ty); +} - if (dest_bits == 32 and src_bits == 64) { - const result = try self.allocLocal(dest_ty); +fn fptrunc(self: *Self, operand: WValue, given: Type, wanted: Type) InnerError!WValue { + const given_bits = given.floatBits(self.target); + const wanted_bits = wanted.floatBits(self.target); + + if (wanted_bits == 32 and given_bits == 64) { + const result = try self.allocLocal(wanted); try self.emitWValue(operand); try self.addTag(.f32_demote_f64); try self.addLabel(.local_set, result.local); return result; + } else if (wanted_bits == 16) { + const op: WValue = if (given_bits == 64) blk: { + const tmp = try self.allocLocal(Type.f32); + try self.emitWValue(operand); + try self.addTag(.f32_demote_f64); + try self.addLabel(.local_set, tmp.local); + break :blk tmp; + } else operand; + + // call __truncsfhf2(f32) f16 + return self.callIntrinsic("__truncsfhf2", &.{Type.f32}, Type.f16, &.{op}); } else { // TODO: Emit a call to compiler-rt to trunc the float. e.g. __truncdfhf2 - return self.fail("TODO: Implement 'fptrunc' for floats with bitsize: {d}", .{dest_bits}); + return self.fail("TODO: Implement 'fptrunc' for floats with bitsize: {d}", .{wanted_bits}); } } |
