aboutsummaryrefslogtreecommitdiff
path: root/src/arch/wasm/CodeGen.zig
diff options
context:
space:
mode:
authorJakub Konka <kubkon@jakubkonka.com>2022-05-07 23:30:08 +0200
committerGitHub <noreply@github.com>2022-05-07 23:30:08 +0200
commitf161d3875ad341971c97384587e2e6c2b50bc09c (patch)
treeae6e441e067fa881982f568de54b6c9e53f2e777 /src/arch/wasm/CodeGen.zig
parente8c85450feac961a147d37394920081214ba9396 (diff)
parenta11097958271562fe7e64716356c07d9996fad5f (diff)
downloadzig-f161d3875ad341971c97384587e2e6c2b50bc09c.tar.gz
zig-f161d3875ad341971c97384587e2e6c2b50bc09c.zip
Merge pull request #11605 from Luukdegram/wasm-mul-overflow
stage2: wasm - Improve `@mulWithOverflow` implementation
Diffstat (limited to 'src/arch/wasm/CodeGen.zig')
-rw-r--r--src/arch/wasm/CodeGen.zig302
1 files changed, 176 insertions, 126 deletions
diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig
index 5171dfb460..9318b4ecca 100644
--- a/src/arch/wasm/CodeGen.zig
+++ b/src/arch/wasm/CodeGen.zig
@@ -1424,7 +1424,7 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
.add_with_overflow => self.airBinOpOverflow(inst, .add),
.sub_with_overflow => self.airBinOpOverflow(inst, .sub),
.shl_with_overflow => self.airBinOpOverflow(inst, .shl),
- .mul_with_overflow => self.airBinOpOverflow(inst, .mul),
+ .mul_with_overflow => self.airMulWithOverflow(inst),
.clz => self.airClz(inst),
.ctz => self.airCtz(inst),
@@ -1822,7 +1822,7 @@ fn store(self: *Self, lhs: WValue, rhs: WValue, ty: Type, offset: u32) InnerErro
const opcode = buildOpcode(.{
.valtype1 = valtype,
- .width = abi_size * 8, // use bitsize instead of byte size
+ .width = abi_size * 8,
.op = .store,
});
@@ -1852,21 +1852,13 @@ fn airLoad(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
fn load(self: *Self, operand: WValue, ty: Type, offset: u32) InnerError!WValue {
// load local's value from memory by its stack position
try self.emitWValue(operand);
- // Build the opcode with the right bitsize
- const signedness: std.builtin.Signedness = if (ty.isUnsignedInt() or
- ty.zigTypeTag() == .ErrorSet or
- ty.zigTypeTag() == .Bool)
- .unsigned
- else
- .signed;
const abi_size = @intCast(u8, ty.abiSize(self.target));
-
const opcode = buildOpcode(.{
.valtype1 = typeToValtype(ty, self.target),
- .width = abi_size * 8, // use bitsize instead of byte size
+ .width = abi_size * 8,
.op = .load,
- .signedness = signedness,
+ .signedness = .unsigned,
});
try self.addMemArg(
@@ -1935,7 +1927,14 @@ fn airWrapBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
const lhs = try self.resolveInst(bin_op.lhs);
const rhs = try self.resolveInst(bin_op.rhs);
- return self.wrapBinOp(lhs, rhs, self.air.typeOf(bin_op.lhs), op);
+ const ty = self.air.typeOf(bin_op.lhs);
+ if (ty.zigTypeTag() == .Vector) {
+ return self.fail("TODO: Implement wrapping arithmetic for vectors", .{});
+ } else if (ty.abiSize(self.target) > 8) {
+ return self.fail("TODO: Implement wrapping arithmetic for bitsize > 64", .{});
+ }
+
+ return self.wrapBinOp(lhs, rhs, ty, op);
}
fn wrapBinOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WValue {
@@ -1948,38 +1947,28 @@ fn wrapBinOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError
.signedness = if (ty.isSignedInt()) .signed else .unsigned,
});
try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
-
- const int_info = ty.intInfo(self.target);
- const bitsize = int_info.bits;
- const is_signed = int_info.signedness == .signed;
- // if target type bitsize is x < 32 and 32 > x < 64, we perform
- // result & ((1<<N)-1) where N = bitsize or bitsize -1 incase of signed.
- if (bitsize != 32 and bitsize < 64) {
- // first check if we can use a single instruction,
- // wasm provides those if the integers are signed and 8/16-bit.
- // For arbitrary integer sizes, we use the algorithm mentioned above.
- if (is_signed and bitsize == 8) {
- try self.addTag(.i32_extend8_s);
- } else if (is_signed and bitsize == 16) {
- try self.addTag(.i32_extend16_s);
- } else {
- const result = (@as(u64, 1) << @intCast(u6, bitsize - @boolToInt(is_signed))) - 1;
- if (bitsize < 32) {
- try self.addImm32(@bitCast(i32, @intCast(u32, result)));
- try self.addTag(.i32_and);
- } else {
- try self.addImm64(result);
- try self.addTag(.i64_and);
- }
- }
- } else if (int_info.bits > 64) {
- return self.fail("TODO wasm: Integer wrapping for bitsizes larger than 64", .{});
- }
-
- // save the result in a temporary
const bin_local = try self.allocLocal(ty);
try self.addLabel(.local_set, bin_local.local);
- return bin_local;
+ return self.wrapOperand(bin_local, ty);
+}
+
+/// Wraps an operand based on a given type's bitsize.
+/// Asserts `Type` is <= 64bits.
+fn wrapOperand(self: *Self, operand: WValue, ty: Type) InnerError!WValue {
+ assert(ty.abiSize(self.target) <= 8);
+ const result_local = try self.allocLocal(ty);
+ const bitsize = ty.intInfo(self.target).bits;
+ const result = @intCast(u64, (@as(u65, 1) << @intCast(u7, bitsize)) - 1);
+ try self.emitWValue(operand);
+ if (bitsize <= 32) {
+ try self.addImm32(@bitCast(i32, @intCast(u32, result)));
+ try self.addTag(.i32_and);
+ } else {
+ try self.addImm64(result);
+ try self.addTag(.i64_and);
+ }
+ try self.addLabel(.local_set, result_local.local);
+ return result_local;
}
fn lowerParentPtr(self: *Self, ptr_val: Value, ptr_child_ty: Type) InnerError!WValue {
@@ -2098,6 +2087,22 @@ fn lowerDeclRefValue(self: *Self, tv: TypedValue, decl_index: Module.Decl.Index)
} else return WValue{ .memory = target_sym_index };
}
+/// Converts a signed integer to its 2's complement form and returns
+/// an unsigned integer instead.
+/// Asserts bitsize <= 64
+fn toTwosComplement(value: anytype, bits: u7) std.meta.Int(.unsigned, @typeInfo(@TypeOf(value)).Int.bits) {
+ const T = @TypeOf(value);
+ comptime assert(@typeInfo(T) == .Int);
+ comptime assert(@typeInfo(T).Int.signedness == .signed);
+ assert(bits <= 64);
+ const WantedT = std.meta.Int(.unsigned, @typeInfo(T).Int.bits);
+ if (value >= 0) return @bitCast(WantedT, value);
+ const max_value = @intCast(u64, (@as(u65, 1) << bits) - 1);
+ const flipped = (~-value) + 1;
+ const result = @bitCast(WantedT, flipped) & max_value;
+ return @intCast(WantedT, result);
+}
+
fn lowerConstant(self: *Self, val: Value, ty: Type) InnerError!WValue {
if (val.isUndefDeep()) return self.emitUndefined(ty);
if (val.castTag(.decl_ref)) |decl_ref| {
@@ -2114,10 +2119,12 @@ fn lowerConstant(self: *Self, val: Value, ty: Type) InnerError!WValue {
switch (ty.zigTypeTag()) {
.Int => {
const int_info = ty.intInfo(self.target);
- // write constant
switch (int_info.signedness) {
.signed => switch (int_info.bits) {
- 0...32 => return WValue{ .imm32 = @bitCast(u32, @intCast(i32, val.toSignedInt())) },
+ 0...32 => return WValue{ .imm32 = @intCast(u32, toTwosComplement(
+ val.toSignedInt(),
+ @intCast(u6, int_info.bits),
+ )) },
33...64 => return WValue{ .imm64 = @bitCast(u64, val.toSignedInt()) },
else => unreachable,
},
@@ -2832,30 +2839,38 @@ fn airIntcast(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const ty = self.air.getRefType(ty_op.ty);
const operand = try self.resolveInst(ty_op.operand);
- const ref_ty = self.air.typeOf(ty_op.operand);
- const ref_info = ref_ty.intInfo(self.target);
- const wanted_info = ty.intInfo(self.target);
+ const operand_ty = self.air.typeOf(ty_op.operand);
+ if (ty.abiSize(self.target) > 8 or operand_ty.abiSize(self.target) > 8) {
+ return self.fail("todo Wasm intcast for bitsize > 64", .{});
+ }
- const op_bits = toWasmBits(ref_info.bits) orelse
- return self.fail("TODO: Wasm intcast integer types of bitsize: {d}", .{ref_info.bits});
- const wanted_bits = toWasmBits(wanted_info.bits) orelse
- return self.fail("TODO: Wasm intcast integer types of bitsize: {d}", .{wanted_info.bits});
+ return self.intcast(operand, operand_ty, ty);
+}
+
+/// Upcasts or downcasts an integer based on the given and wanted types,
+/// and stores the result in a new operand.
+/// Asserts type's bitsize <= 64
+fn intcast(self: *Self, operand: WValue, given: Type, wanted: Type) InnerError!WValue {
+ const given_info = given.intInfo(self.target);
+ const wanted_info = wanted.intInfo(self.target);
+ assert(given_info.bits <= 64);
+ assert(wanted_info.bits <= 64);
- // hot path
+ const op_bits = toWasmBits(given_info.bits).?;
+ const wanted_bits = toWasmBits(wanted_info.bits).?;
if (op_bits == wanted_bits) return operand;
+ try self.emitWValue(operand);
if (op_bits > 32 and wanted_bits == 32) {
- try self.emitWValue(operand);
try self.addTag(.i32_wrap_i64);
} else if (op_bits == 32 and wanted_bits > 32) {
- try self.emitWValue(operand);
- try self.addTag(switch (ref_info.signedness) {
+ try self.addTag(switch (wanted_info.signedness) {
.signed => .i64_extend_i32_s,
.unsigned => .i64_extend_i32_u,
});
} else unreachable;
- const result = try self.allocLocal(ty);
+ const result = try self.allocLocal(wanted);
try self.addLabel(.local_set, result.local);
return result;
}
@@ -3072,63 +3087,17 @@ fn airSlicePtr(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
}
fn airTrunc(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
- if (self.liveness.isUnused(inst)) return WValue.none;
+ if (self.liveness.isUnused(inst)) return WValue{ .none = {} };
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const operand = try self.resolveInst(ty_op.operand);
- const op_ty = self.air.typeOf(ty_op.operand);
- const int_info = self.air.getRefType(ty_op.ty).intInfo(self.target);
+ const wanted_ty = self.air.getRefType(ty_op.ty);
+ const int_info = wanted_ty.intInfo(self.target);
const wanted_bits = int_info.bits;
- const result = try self.allocLocal(self.air.getRefType(ty_op.ty));
- const op_bits = op_ty.intInfo(self.target).bits;
- const wasm_bits = toWasmBits(wanted_bits) orelse
+ _ = toWasmBits(wanted_bits) orelse {
return self.fail("TODO: Implement wasm integer truncation for integer bitsize: {d}", .{wanted_bits});
-
- // Use wasm's instruction to wrap from 64bit to 32bit integer when possible
- if (op_bits == 64 and wanted_bits == 32) {
- try self.emitWValue(operand);
- try self.addTag(.i32_wrap_i64);
- try self.addLabel(.local_set, result.local);
- return result;
- }
-
- // Any other truncation must be done manually
- if (int_info.signedness == .unsigned) {
- const mask = (@as(u65, 1) << @intCast(u7, wanted_bits)) - 1;
- try self.emitWValue(operand);
- switch (wasm_bits) {
- 32 => {
- try self.addImm32(@bitCast(i32, @intCast(u32, mask)));
- try self.addTag(.i32_and);
- },
- 64 => {
- try self.addImm64(@intCast(u64, mask));
- try self.addTag(.i64_and);
- },
- else => unreachable,
- }
- } else {
- const shift_bits = wasm_bits - wanted_bits;
- try self.emitWValue(operand);
- switch (wasm_bits) {
- 32 => {
- try self.addImm32(@bitCast(i16, shift_bits));
- try self.addTag(.i32_shl);
- try self.addImm32(@bitCast(i16, shift_bits));
- try self.addTag(.i32_shr_s);
- },
- 64 => {
- try self.addImm64(shift_bits);
- try self.addTag(.i64_shl);
- try self.addImm64(shift_bits);
- try self.addTag(.i64_shr_s);
- },
- else => unreachable,
- }
- }
-
- try self.addLabel(.local_set, result.local);
- return result;
+ };
+ return self.wrapOperand(operand, wanted_ty);
}
fn airBoolToInt(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
@@ -3418,7 +3387,8 @@ fn airFloatToInt(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const result = try self.allocLocal(dest_ty);
try self.addLabel(.local_set, result.local);
- return result;
+
+ return self.wrapOperand(result, dest_ty);
}
fn airIntToFloat(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
@@ -3922,6 +3892,10 @@ fn airBinOpOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue
const rhs = try self.resolveInst(extra.rhs);
const lhs_ty = self.air.typeOf(extra.lhs);
+ if (lhs_ty.zigTypeTag() == .Vector) {
+ return self.fail("TODO: Implement overflow arithmetic for vectors", .{});
+ }
+
// We store the bit if it's overflowed or not in this. As it's zero-initialized
// we only need to update it if an overflow (or underflow) occured.
const overflow_bit = try self.allocLocal(Type.initTag(.u1));
@@ -4008,24 +3982,100 @@ fn airBinOpOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue
}
try self.addLabel(.local_set, tmp_val.local);
break :blk tmp_val;
- } else if (op == .mul) blk: {
- const bin_op = try self.wrapBinOp(lhs, rhs, lhs_ty, op);
- try self.startBlock(.block, wasm.block_empty);
- // check if 0. true => Break out of block as cannot over -or underflow.
- try self.emitWValue(lhs);
- switch (wasm_bits) {
- 32 => try self.addTag(.i32_eqz),
- 64 => try self.addTag(.i64_eqz),
- else => unreachable,
+ } else try self.wrapBinOp(lhs, rhs, lhs_ty, op);
+
+ const result_ptr = try self.allocStack(self.air.typeOfIndex(inst));
+ try self.store(result_ptr, bin_op, lhs_ty, 0);
+ const offset = @intCast(u32, lhs_ty.abiSize(self.target));
+ try self.store(result_ptr, overflow_bit, Type.initTag(.u1), offset);
+
+ return result_ptr;
+}
+
+fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
+ const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
+ const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;
+ const lhs = try self.resolveInst(extra.lhs);
+ const rhs = try self.resolveInst(extra.rhs);
+ const lhs_ty = self.air.typeOf(extra.lhs);
+
+ if (lhs_ty.zigTypeTag() == .Vector) {
+ return self.fail("TODO: Implement overflow arithmetic for vectors", .{});
+ }
+
+ // We store the bit if it's overflowed or not in this. As it's zero-initialized
+ // we only need to update it if an overflow (or underflow) occured.
+ const overflow_bit = try self.allocLocal(Type.initTag(.u1));
+ const int_info = lhs_ty.intInfo(self.target);
+ const wasm_bits = toWasmBits(int_info.bits) orelse {
+ return self.fail("TODO: Implement overflow arithmetic for integer bitsize: {d}", .{int_info.bits});
+ };
+
+ if (wasm_bits == 64) {
+ return self.fail("TODO: Implement `@mulWithOverflow` for integer bitsize: {d}", .{int_info.bits});
+ }
+
+ const zero = switch (wasm_bits) {
+ 32 => WValue{ .imm32 = 0 },
+ 64 => WValue{ .imm64 = 0 },
+ else => unreachable,
+ };
+
+ // for 32 bit integers we upcast it to a 64bit integer
+ const bin_op = if (int_info.bits == 32) blk: {
+ const new_ty = if (int_info.signedness == .signed) Type.i64 else Type.u64;
+ const lhs_upcast = try self.intcast(lhs, lhs_ty, new_ty);
+ const rhs_upcast = try self.intcast(rhs, lhs_ty, new_ty);
+ const bin_op = try self.binOp(lhs_upcast, rhs_upcast, new_ty, .mul);
+ if (int_info.signedness == .unsigned) {
+ const shr = try self.binOp(bin_op, .{ .imm64 = int_info.bits }, new_ty, .shr);
+ const wrap = try self.intcast(shr, new_ty, lhs_ty);
+ const cmp_res = try self.cmp(wrap, zero, lhs_ty, .neq);
+ try self.emitWValue(cmp_res);
+ try self.addLabel(.local_set, overflow_bit.local);
+ break :blk try self.intcast(bin_op, new_ty, lhs_ty);
+ } else {
+ const down_cast = try self.intcast(bin_op, new_ty, lhs_ty);
+ const shr = try self.binOp(down_cast, .{ .imm32 = int_info.bits - 1 }, lhs_ty, .shr);
+
+ const shr_res = try self.binOp(bin_op, .{ .imm64 = int_info.bits }, new_ty, .shr);
+ const down_shr_res = try self.intcast(shr_res, new_ty, lhs_ty);
+ const cmp_res = try self.cmp(down_shr_res, shr, lhs_ty, .neq);
+ try self.emitWValue(cmp_res);
+ try self.addLabel(.local_set, overflow_bit.local);
+ break :blk down_cast;
}
- try self.addLabel(.br_if, 0);
- const div = try self.binOp(bin_op, lhs, lhs_ty, .div);
- const cmp_res = try self.cmp(div, rhs, lhs_ty, .neq);
- try self.emitWValue(cmp_res);
+ } else if (int_info.signedness == .signed) blk: {
+ const shift_imm = if (wasm_bits == 32)
+ WValue{ .imm32 = wasm_bits - int_info.bits }
+ else
+ WValue{ .imm64 = wasm_bits - int_info.bits };
+
+ const lhs_shl = try self.binOp(lhs, shift_imm, lhs_ty, .shl);
+ const lhs_shr = try self.binOp(lhs_shl, shift_imm, lhs_ty, .shr);
+ const rhs_shl = try self.binOp(rhs, shift_imm, lhs_ty, .shl);
+ const rhs_shr = try self.binOp(rhs_shl, shift_imm, lhs_ty, .shr);
+
+ const bin_op = try self.binOp(lhs_shr, rhs_shr, lhs_ty, .mul);
+ const shl = try self.binOp(bin_op, shift_imm, lhs_ty, .shl);
+ const shr = try self.binOp(shl, shift_imm, lhs_ty, .shr);
+
+ const cmp_op = try self.cmp(shr, bin_op, lhs_ty, .neq);
+ try self.emitWValue(cmp_op);
try self.addLabel(.local_set, overflow_bit.local);
- try self.endBlock();
- break :blk bin_op;
- } else try self.wrapBinOp(lhs, rhs, lhs_ty, op);
+ break :blk try self.wrapOperand(bin_op, lhs_ty);
+ } else blk: {
+ const bin_op = try self.binOp(lhs, rhs, lhs_ty, .mul);
+ const shift_imm = if (wasm_bits == 32)
+ WValue{ .imm32 = int_info.bits }
+ else
+ WValue{ .imm64 = int_info.bits };
+ const shr = try self.binOp(bin_op, shift_imm, lhs_ty, .shr);
+ const cmp_op = try self.cmp(shr, zero, lhs_ty, .neq);
+ try self.emitWValue(cmp_op);
+ try self.addLabel(.local_set, overflow_bit.local);
+ break :blk try self.wrapOperand(bin_op, lhs_ty);
+ };
const result_ptr = try self.allocStack(self.air.typeOfIndex(inst));
try self.store(result_ptr, bin_op, lhs_ty, 0);