aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorLuuk de Gram <luuk@degram.dev>2022-05-30 21:07:03 +0200
committerLuuk de Gram <luuk@degram.dev>2022-06-24 08:12:17 +0200
commit5ebaf49ebb65f3b201e56ec7d89b7d91f189409e (patch)
tree7dff0022ba12267782afe07a9236ef1e9b2f2756 /src
parent241180216f92c86f833910d73e1c86f54ae940e4 (diff)
downloadzig-5ebaf49ebb65f3b201e56ec7d89b7d91f189409e.tar.gz
zig-5ebaf49ebb65f3b201e56ec7d89b7d91f189409e.zip
wasm: Implement basic f16 support
This implements binary operations and comparisons for floats with bitsize 16. It does this by calling into compiler-rt to first extend the float to 32 bits, perform the operation, and then finally truncate back to 16 bits. When loading and storing the f16, we do this as an unsigned 16bit integer.
Diffstat (limited to 'src')
-rw-r--r--src/arch/wasm/CodeGen.zig124
1 files changed, 105 insertions, 19 deletions
diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig
index 2f82f1f694..462393a150 100644
--- a/src/arch/wasm/CodeGen.zig
+++ b/src/arch/wasm/CodeGen.zig
@@ -215,8 +215,7 @@ fn buildOpcode(args: OpcodeBuildArguments) wasm.Opcode {
16 => switch (args.valtype1.?) {
.i32 => if (args.signedness.? == .signed) return .i32_load16_s else return .i32_load16_u,
.i64 => if (args.signedness.? == .signed) return .i64_load16_s else return .i64_load16_u,
- .f32 => return .f32_load,
- .f64 => unreachable,
+ .f32, .f64 => unreachable,
},
32 => switch (args.valtype1.?) {
.i64 => if (args.signedness.? == .signed) return .i64_load32_s else return .i64_load32_u,
@@ -246,8 +245,7 @@ fn buildOpcode(args: OpcodeBuildArguments) wasm.Opcode {
16 => switch (args.valtype1.?) {
.i32 => return .i32_store16,
.i64 => return .i64_store16,
- .f32 => return .f32_store,
- .f64 => unreachable,
+ .f32, .f64 => unreachable,
},
32 => switch (args.valtype1.?) {
.i64 => return .i64_store32,
@@ -725,7 +723,8 @@ fn typeToValtype(ty: Type, target: std.Target) wasm.Valtype {
return switch (ty.zigTypeTag()) {
.Float => blk: {
const bits = ty.floatBits(target);
- if (bits == 16 or bits == 32) break :blk wasm.Valtype.f32;
+ if (bits == 16) return wasm.Valtype.i32; // stored/loaded as u16
+ if (bits == 32) break :blk wasm.Valtype.f32;
if (bits == 64) break :blk wasm.Valtype.f64;
if (bits == 128) break :blk wasm.Valtype.i64;
return wasm.Valtype.i32; // represented as pointer to stack
@@ -2013,6 +2012,10 @@ fn binOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WVa
}
}
+ if (ty.isAnyFloat() and ty.floatBits(self.target) == 16) {
+ return self.binOpFloat16(lhs, rhs, op);
+ }
+
const opcode: wasm.Opcode = buildOpcode(.{
.op = op,
.valtype1 = typeToValtype(ty, self.target),
@@ -2029,6 +2032,20 @@ fn binOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WVa
return bin_local;
}
+fn binOpFloat16(self: *Self, lhs: WValue, rhs: WValue, op: Op) InnerError!WValue {
+ const ext_lhs = try self.fpext(lhs, Type.f16, Type.f32);
+ const ext_rhs = try self.fpext(rhs, Type.f16, Type.f32);
+
+ const opcode: wasm.Opcode = buildOpcode(.{ .op = op, .valtype1 = .f32, .signedness = .unsigned });
+ try self.emitWValue(ext_lhs);
+ try self.emitWValue(ext_rhs);
+ try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
+
+ // re-use temporary local
+ try self.addLabel(.local_set, ext_lhs.local);
+ return self.fptrunc(ext_lhs, Type.f32, Type.f16);
+}
+
fn binOpBigInt(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WValue {
if (ty.intInfo(self.target).bits > 128) {
return self.fail("TODO: Implement binary operation for big integer", .{});
@@ -2310,8 +2327,9 @@ fn lowerConstant(self: *Self, val: Value, ty: Type) InnerError!WValue {
},
.Bool => return WValue{ .imm32 = @intCast(u32, val.toUnsignedInt(target)) },
.Float => switch (ty.floatBits(self.target)) {
- 0...32 => return WValue{ .float32 = val.toFloat(f32) },
- 33...64 => return WValue{ .float64 = val.toFloat(f64) },
+ 16 => return WValue{ .imm32 = @bitCast(u16, val.toFloat(f16)) },
+ 32 => return WValue{ .float32 = val.toFloat(f32) },
+ 64 => return WValue{ .float64 = val.toFloat(f64) },
else => unreachable,
},
.Pointer => switch (val.tag()) {
@@ -2389,8 +2407,9 @@ fn emitUndefined(self: *Self, ty: Type) InnerError!WValue {
else => unreachable,
},
.Float => switch (ty.floatBits(self.target)) {
- 0...32 => return WValue{ .float32 = @bitCast(f32, @as(u32, 0xaaaaaaaa)) },
- 33...64 => return WValue{ .float64 = @bitCast(f64, @as(u64, 0xaaaaaaaaaaaaaaaa)) },
+ 16 => return WValue{ .imm32 = 0xaaaaaaaa },
+ 32 => return WValue{ .float32 = @bitCast(f32, @as(u32, 0xaaaaaaaa)) },
+ 64 => return WValue{ .float64 = @bitCast(f64, @as(u64, 0xaaaaaaaaaaaaaaaa)) },
else => unreachable,
},
.Pointer => switch (self.arch()) {
@@ -2562,6 +2581,8 @@ fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOper
}
} else if (isByRef(ty, self.target)) {
return self.cmpBigInt(lhs, rhs, ty, op);
+ } else if (ty.isAnyFloat() and ty.floatBits(self.target) == 16) {
+ return self.cmpFloat16(lhs, rhs, op);
}
// ensure that when we compare pointers, we emit
@@ -2595,6 +2616,31 @@ fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOper
return cmp_tmp;
}
+fn cmpFloat16(self: *Self, lhs: WValue, rhs: WValue, op: std.math.CompareOperator) InnerError!WValue {
+ const ext_lhs = try self.fpext(lhs, Type.f16, Type.f32);
+ const ext_rhs = try self.fpext(rhs, Type.f16, Type.f32);
+
+ const opcode: wasm.Opcode = buildOpcode(.{
+ .op = switch (op) {
+ .lt => .lt,
+ .lte => .le,
+ .eq => .eq,
+ .neq => .ne,
+ .gte => .ge,
+ .gt => .gt,
+ },
+ .valtype1 = .f32,
+ .signedness = .unsigned,
+ });
+ try self.emitWValue(ext_lhs);
+ try self.emitWValue(ext_rhs);
+ try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
+
+ const result = try self.allocLocal(Type.initTag(.i32)); // bool is always i32
+ try self.addLabel(.local_set, result.local);
+ return result;
+}
+
fn airCmpVector(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
_ = inst;
return self.fail("TODO implement airCmpVector for wasm", .{});
@@ -3934,19 +3980,44 @@ fn airFpext(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const dest_ty = self.air.typeOfIndex(inst);
- const dest_bits = dest_ty.floatBits(self.target);
- const src_bits = self.air.typeOf(ty_op.operand).floatBits(self.target);
const operand = try self.resolveInst(ty_op.operand);
- if (dest_bits == 64 and src_bits == 32) {
- const result = try self.allocLocal(dest_ty);
+ return self.fpext(operand, self.air.typeOf(ty_op.operand), dest_ty);
+}
+
+fn fpext(self: *Self, operand: WValue, given: Type, wanted: Type) InnerError!WValue {
+ const given_bits = given.floatBits(self.target);
+ const wanted_bits = wanted.floatBits(self.target);
+
+ if (wanted_bits == 64 and given_bits == 32) {
+ const result = try self.allocLocal(wanted);
try self.emitWValue(operand);
try self.addTag(.f64_promote_f32);
try self.addLabel(.local_set, result.local);
return result;
+ } else if (given_bits == 16) {
+ // call __extendhfsf2(f16) f32
+ const f32_result = try self.callIntrinsic(
+ "__extendhfsf2",
+ &.{Type.f16},
+ Type.f32,
+ &.{operand},
+ );
+
+ if (wanted_bits == 32) {
+ return f32_result;
+ }
+ if (wanted_bits == 64) {
+ const result = try self.allocLocal(wanted);
+ try self.emitWValue(f32_result);
+ try self.addTag(.f64_promote_f32);
+ try self.addLabel(.local_set, result.local);
+ return result;
+ }
+ return self.fail("TODO: Implement 'fpext' for floats with bitsize: {d}", .{wanted_bits});
} else {
// TODO: Emit a call to compiler-rt to extend the float. e.g. __extendhfsf2
- return self.fail("TODO: Implement 'fpext' for floats with bitsize: {d}", .{dest_bits});
+ return self.fail("TODO: Implement 'fpext' for floats with bitsize: {d}", .{wanted_bits});
}
}
@@ -3955,19 +4026,34 @@ fn airFptrunc(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const dest_ty = self.air.typeOfIndex(inst);
- const dest_bits = dest_ty.floatBits(self.target);
- const src_bits = self.air.typeOf(ty_op.operand).floatBits(self.target);
const operand = try self.resolveInst(ty_op.operand);
+ return self.fptrunc(operand, self.air.typeOf(ty_op.operand), dest_ty);
+}
- if (dest_bits == 32 and src_bits == 64) {
- const result = try self.allocLocal(dest_ty);
+fn fptrunc(self: *Self, operand: WValue, given: Type, wanted: Type) InnerError!WValue {
+ const given_bits = given.floatBits(self.target);
+ const wanted_bits = wanted.floatBits(self.target);
+
+ if (wanted_bits == 32 and given_bits == 64) {
+ const result = try self.allocLocal(wanted);
try self.emitWValue(operand);
try self.addTag(.f32_demote_f64);
try self.addLabel(.local_set, result.local);
return result;
+ } else if (wanted_bits == 16) {
+ const op: WValue = if (given_bits == 64) blk: {
+ const tmp = try self.allocLocal(Type.f32);
+ try self.emitWValue(operand);
+ try self.addTag(.f32_demote_f64);
+ try self.addLabel(.local_set, tmp.local);
+ break :blk tmp;
+ } else operand;
+
+ // call __truncsfhf2(f32) f16
+ return self.callIntrinsic("__truncsfhf2", &.{Type.f32}, Type.f16, &.{op});
} else {
// TODO: Emit a call to compiler-rt to trunc the float. e.g. __truncdfhf2
- return self.fail("TODO: Implement 'fptrunc' for floats with bitsize: {d}", .{dest_bits});
+ return self.fail("TODO: Implement 'fptrunc' for floats with bitsize: {d}", .{wanted_bits});
}
}