diff options
| author | Jakub Konka <kubkon@jakubkonka.com> | 2022-05-07 23:30:08 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-05-07 23:30:08 +0200 |
| commit | f161d3875ad341971c97384587e2e6c2b50bc09c (patch) | |
| tree | ae6e441e067fa881982f568de54b6c9e53f2e777 /src/arch/wasm/CodeGen.zig | |
| parent | e8c85450feac961a147d37394920081214ba9396 (diff) | |
| parent | a11097958271562fe7e64716356c07d9996fad5f (diff) | |
| download | zig-f161d3875ad341971c97384587e2e6c2b50bc09c.tar.gz zig-f161d3875ad341971c97384587e2e6c2b50bc09c.zip | |
Merge pull request #11605 from Luukdegram/wasm-mul-overflow
stage2: wasm - Improve `@mulWithOverflow` implementation
Diffstat (limited to 'src/arch/wasm/CodeGen.zig')
| -rw-r--r-- | src/arch/wasm/CodeGen.zig | 302 |
1 files changed, 176 insertions, 126 deletions
diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index 5171dfb460..9318b4ecca 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -1424,7 +1424,7 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue { .add_with_overflow => self.airBinOpOverflow(inst, .add), .sub_with_overflow => self.airBinOpOverflow(inst, .sub), .shl_with_overflow => self.airBinOpOverflow(inst, .shl), - .mul_with_overflow => self.airBinOpOverflow(inst, .mul), + .mul_with_overflow => self.airMulWithOverflow(inst), .clz => self.airClz(inst), .ctz => self.airCtz(inst), @@ -1822,7 +1822,7 @@ fn store(self: *Self, lhs: WValue, rhs: WValue, ty: Type, offset: u32) InnerErro const opcode = buildOpcode(.{ .valtype1 = valtype, - .width = abi_size * 8, // use bitsize instead of byte size + .width = abi_size * 8, .op = .store, }); @@ -1852,21 +1852,13 @@ fn airLoad(self: *Self, inst: Air.Inst.Index) InnerError!WValue { fn load(self: *Self, operand: WValue, ty: Type, offset: u32) InnerError!WValue { // load local's value from memory by its stack position try self.emitWValue(operand); - // Build the opcode with the right bitsize - const signedness: std.builtin.Signedness = if (ty.isUnsignedInt() or - ty.zigTypeTag() == .ErrorSet or - ty.zigTypeTag() == .Bool) - .unsigned - else - .signed; const abi_size = @intCast(u8, ty.abiSize(self.target)); - const opcode = buildOpcode(.{ .valtype1 = typeToValtype(ty, self.target), - .width = abi_size * 8, // use bitsize instead of byte size + .width = abi_size * 8, .op = .load, - .signedness = signedness, + .signedness = .unsigned, }); try self.addMemArg( @@ -1935,7 +1927,14 @@ fn airWrapBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue { const lhs = try self.resolveInst(bin_op.lhs); const rhs = try self.resolveInst(bin_op.rhs); - return self.wrapBinOp(lhs, rhs, self.air.typeOf(bin_op.lhs), op); + const ty = self.air.typeOf(bin_op.lhs); + if (ty.zigTypeTag() == .Vector) { + return self.fail("TODO: Implement wrapping arithmetic for vectors", .{}); + } else if (ty.abiSize(self.target) > 8) { + return self.fail("TODO: Implement wrapping arithmetic for bitsize > 64", .{}); + } + + return self.wrapBinOp(lhs, rhs, ty, op); } fn wrapBinOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WValue { @@ -1948,38 +1947,28 @@ fn wrapBinOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError .signedness = if (ty.isSignedInt()) .signed else .unsigned, }); try self.addTag(Mir.Inst.Tag.fromOpcode(opcode)); - - const int_info = ty.intInfo(self.target); - const bitsize = int_info.bits; - const is_signed = int_info.signedness == .signed; - // if target type bitsize is x < 32 and 32 > x < 64, we perform - // result & ((1<<N)-1) where N = bitsize or bitsize -1 incase of signed. - if (bitsize != 32 and bitsize < 64) { - // first check if we can use a single instruction, - // wasm provides those if the integers are signed and 8/16-bit. - // For arbitrary integer sizes, we use the algorithm mentioned above. - if (is_signed and bitsize == 8) { - try self.addTag(.i32_extend8_s); - } else if (is_signed and bitsize == 16) { - try self.addTag(.i32_extend16_s); - } else { - const result = (@as(u64, 1) << @intCast(u6, bitsize - @boolToInt(is_signed))) - 1; - if (bitsize < 32) { - try self.addImm32(@bitCast(i32, @intCast(u32, result))); - try self.addTag(.i32_and); - } else { - try self.addImm64(result); - try self.addTag(.i64_and); - } - } - } else if (int_info.bits > 64) { - return self.fail("TODO wasm: Integer wrapping for bitsizes larger than 64", .{}); - } - - // save the result in a temporary const bin_local = try self.allocLocal(ty); try self.addLabel(.local_set, bin_local.local); - return bin_local; + return self.wrapOperand(bin_local, ty); +} + +/// Wraps an operand based on a given type's bitsize. +/// Asserts `Type` is <= 64bits. +fn wrapOperand(self: *Self, operand: WValue, ty: Type) InnerError!WValue { + assert(ty.abiSize(self.target) <= 8); + const result_local = try self.allocLocal(ty); + const bitsize = ty.intInfo(self.target).bits; + const result = @intCast(u64, (@as(u65, 1) << @intCast(u7, bitsize)) - 1); + try self.emitWValue(operand); + if (bitsize <= 32) { + try self.addImm32(@bitCast(i32, @intCast(u32, result))); + try self.addTag(.i32_and); + } else { + try self.addImm64(result); + try self.addTag(.i64_and); + } + try self.addLabel(.local_set, result_local.local); + return result_local; } fn lowerParentPtr(self: *Self, ptr_val: Value, ptr_child_ty: Type) InnerError!WValue { @@ -2098,6 +2087,22 @@ fn lowerDeclRefValue(self: *Self, tv: TypedValue, decl_index: Module.Decl.Index) } else return WValue{ .memory = target_sym_index }; } +/// Converts a signed integer to its 2's complement form and returns +/// an unsigned integer instead. +/// Asserts bitsize <= 64 +fn toTwosComplement(value: anytype, bits: u7) std.meta.Int(.unsigned, @typeInfo(@TypeOf(value)).Int.bits) { + const T = @TypeOf(value); + comptime assert(@typeInfo(T) == .Int); + comptime assert(@typeInfo(T).Int.signedness == .signed); + assert(bits <= 64); + const WantedT = std.meta.Int(.unsigned, @typeInfo(T).Int.bits); + if (value >= 0) return @bitCast(WantedT, value); + const max_value = @intCast(u64, (@as(u65, 1) << bits) - 1); + const flipped = (~-value) + 1; + const result = @bitCast(WantedT, flipped) & max_value; + return @intCast(WantedT, result); +} + fn lowerConstant(self: *Self, val: Value, ty: Type) InnerError!WValue { if (val.isUndefDeep()) return self.emitUndefined(ty); if (val.castTag(.decl_ref)) |decl_ref| { @@ -2114,10 +2119,12 @@ fn lowerConstant(self: *Self, val: Value, ty: Type) InnerError!WValue { switch (ty.zigTypeTag()) { .Int => { const int_info = ty.intInfo(self.target); - // write constant switch (int_info.signedness) { .signed => switch (int_info.bits) { - 0...32 => return WValue{ .imm32 = @bitCast(u32, @intCast(i32, val.toSignedInt())) }, + 0...32 => return WValue{ .imm32 = @intCast(u32, toTwosComplement( + val.toSignedInt(), + @intCast(u6, int_info.bits), + )) }, 33...64 => return WValue{ .imm64 = @bitCast(u64, val.toSignedInt()) }, else => unreachable, }, @@ -2832,30 +2839,38 @@ fn airIntcast(self: *Self, inst: Air.Inst.Index) InnerError!WValue { const ty_op = self.air.instructions.items(.data)[inst].ty_op; const ty = self.air.getRefType(ty_op.ty); const operand = try self.resolveInst(ty_op.operand); - const ref_ty = self.air.typeOf(ty_op.operand); - const ref_info = ref_ty.intInfo(self.target); - const wanted_info = ty.intInfo(self.target); + const operand_ty = self.air.typeOf(ty_op.operand); + if (ty.abiSize(self.target) > 8 or operand_ty.abiSize(self.target) > 8) { + return self.fail("todo Wasm intcast for bitsize > 64", .{}); + } - const op_bits = toWasmBits(ref_info.bits) orelse - return self.fail("TODO: Wasm intcast integer types of bitsize: {d}", .{ref_info.bits}); - const wanted_bits = toWasmBits(wanted_info.bits) orelse - return self.fail("TODO: Wasm intcast integer types of bitsize: {d}", .{wanted_info.bits}); + return self.intcast(operand, operand_ty, ty); +} + +/// Upcasts or downcasts an integer based on the given and wanted types, +/// and stores the result in a new operand. +/// Asserts type's bitsize <= 64 +fn intcast(self: *Self, operand: WValue, given: Type, wanted: Type) InnerError!WValue { + const given_info = given.intInfo(self.target); + const wanted_info = wanted.intInfo(self.target); + assert(given_info.bits <= 64); + assert(wanted_info.bits <= 64); - // hot path + const op_bits = toWasmBits(given_info.bits).?; + const wanted_bits = toWasmBits(wanted_info.bits).?; if (op_bits == wanted_bits) return operand; + try self.emitWValue(operand); if (op_bits > 32 and wanted_bits == 32) { - try self.emitWValue(operand); try self.addTag(.i32_wrap_i64); } else if (op_bits == 32 and wanted_bits > 32) { - try self.emitWValue(operand); - try self.addTag(switch (ref_info.signedness) { + try self.addTag(switch (wanted_info.signedness) { .signed => .i64_extend_i32_s, .unsigned => .i64_extend_i32_u, }); } else unreachable; - const result = try self.allocLocal(ty); + const result = try self.allocLocal(wanted); try self.addLabel(.local_set, result.local); return result; } @@ -3072,63 +3087,17 @@ fn airSlicePtr(self: *Self, inst: Air.Inst.Index) InnerError!WValue { } fn airTrunc(self: *Self, inst: Air.Inst.Index) InnerError!WValue { - if (self.liveness.isUnused(inst)) return WValue.none; + if (self.liveness.isUnused(inst)) return WValue{ .none = {} }; const ty_op = self.air.instructions.items(.data)[inst].ty_op; const operand = try self.resolveInst(ty_op.operand); - const op_ty = self.air.typeOf(ty_op.operand); - const int_info = self.air.getRefType(ty_op.ty).intInfo(self.target); + const wanted_ty = self.air.getRefType(ty_op.ty); + const int_info = wanted_ty.intInfo(self.target); const wanted_bits = int_info.bits; - const result = try self.allocLocal(self.air.getRefType(ty_op.ty)); - const op_bits = op_ty.intInfo(self.target).bits; - const wasm_bits = toWasmBits(wanted_bits) orelse + _ = toWasmBits(wanted_bits) orelse { return self.fail("TODO: Implement wasm integer truncation for integer bitsize: {d}", .{wanted_bits}); - - // Use wasm's instruction to wrap from 64bit to 32bit integer when possible - if (op_bits == 64 and wanted_bits == 32) { - try self.emitWValue(operand); - try self.addTag(.i32_wrap_i64); - try self.addLabel(.local_set, result.local); - return result; - } - - // Any other truncation must be done manually - if (int_info.signedness == .unsigned) { - const mask = (@as(u65, 1) << @intCast(u7, wanted_bits)) - 1; - try self.emitWValue(operand); - switch (wasm_bits) { - 32 => { - try self.addImm32(@bitCast(i32, @intCast(u32, mask))); - try self.addTag(.i32_and); - }, - 64 => { - try self.addImm64(@intCast(u64, mask)); - try self.addTag(.i64_and); - }, - else => unreachable, - } - } else { - const shift_bits = wasm_bits - wanted_bits; - try self.emitWValue(operand); - switch (wasm_bits) { - 32 => { - try self.addImm32(@bitCast(i16, shift_bits)); - try self.addTag(.i32_shl); - try self.addImm32(@bitCast(i16, shift_bits)); - try self.addTag(.i32_shr_s); - }, - 64 => { - try self.addImm64(shift_bits); - try self.addTag(.i64_shl); - try self.addImm64(shift_bits); - try self.addTag(.i64_shr_s); - }, - else => unreachable, - } - } - - try self.addLabel(.local_set, result.local); - return result; + }; + return self.wrapOperand(operand, wanted_ty); } fn airBoolToInt(self: *Self, inst: Air.Inst.Index) InnerError!WValue { @@ -3418,7 +3387,8 @@ fn airFloatToInt(self: *Self, inst: Air.Inst.Index) InnerError!WValue { const result = try self.allocLocal(dest_ty); try self.addLabel(.local_set, result.local); - return result; + + return self.wrapOperand(result, dest_ty); } fn airIntToFloat(self: *Self, inst: Air.Inst.Index) InnerError!WValue { @@ -3922,6 +3892,10 @@ fn airBinOpOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue const rhs = try self.resolveInst(extra.rhs); const lhs_ty = self.air.typeOf(extra.lhs); + if (lhs_ty.zigTypeTag() == .Vector) { + return self.fail("TODO: Implement overflow arithmetic for vectors", .{}); + } + // We store the bit if it's overflowed or not in this. As it's zero-initialized // we only need to update it if an overflow (or underflow) occured. const overflow_bit = try self.allocLocal(Type.initTag(.u1)); @@ -4008,24 +3982,100 @@ fn airBinOpOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue } try self.addLabel(.local_set, tmp_val.local); break :blk tmp_val; - } else if (op == .mul) blk: { - const bin_op = try self.wrapBinOp(lhs, rhs, lhs_ty, op); - try self.startBlock(.block, wasm.block_empty); - // check if 0. true => Break out of block as cannot over -or underflow. - try self.emitWValue(lhs); - switch (wasm_bits) { - 32 => try self.addTag(.i32_eqz), - 64 => try self.addTag(.i64_eqz), - else => unreachable, + } else try self.wrapBinOp(lhs, rhs, lhs_ty, op); + + const result_ptr = try self.allocStack(self.air.typeOfIndex(inst)); + try self.store(result_ptr, bin_op, lhs_ty, 0); + const offset = @intCast(u32, lhs_ty.abiSize(self.target)); + try self.store(result_ptr, overflow_bit, Type.initTag(.u1), offset); + + return result_ptr; +} + +fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) InnerError!WValue { + const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; + const extra = self.air.extraData(Air.Bin, ty_pl.payload).data; + const lhs = try self.resolveInst(extra.lhs); + const rhs = try self.resolveInst(extra.rhs); + const lhs_ty = self.air.typeOf(extra.lhs); + + if (lhs_ty.zigTypeTag() == .Vector) { + return self.fail("TODO: Implement overflow arithmetic for vectors", .{}); + } + + // We store the bit if it's overflowed or not in this. As it's zero-initialized + // we only need to update it if an overflow (or underflow) occured. + const overflow_bit = try self.allocLocal(Type.initTag(.u1)); + const int_info = lhs_ty.intInfo(self.target); + const wasm_bits = toWasmBits(int_info.bits) orelse { + return self.fail("TODO: Implement overflow arithmetic for integer bitsize: {d}", .{int_info.bits}); + }; + + if (wasm_bits == 64) { + return self.fail("TODO: Implement `@mulWithOverflow` for integer bitsize: {d}", .{int_info.bits}); + } + + const zero = switch (wasm_bits) { + 32 => WValue{ .imm32 = 0 }, + 64 => WValue{ .imm64 = 0 }, + else => unreachable, + }; + + // for 32 bit integers we upcast it to a 64bit integer + const bin_op = if (int_info.bits == 32) blk: { + const new_ty = if (int_info.signedness == .signed) Type.i64 else Type.u64; + const lhs_upcast = try self.intcast(lhs, lhs_ty, new_ty); + const rhs_upcast = try self.intcast(rhs, lhs_ty, new_ty); + const bin_op = try self.binOp(lhs_upcast, rhs_upcast, new_ty, .mul); + if (int_info.signedness == .unsigned) { + const shr = try self.binOp(bin_op, .{ .imm64 = int_info.bits }, new_ty, .shr); + const wrap = try self.intcast(shr, new_ty, lhs_ty); + const cmp_res = try self.cmp(wrap, zero, lhs_ty, .neq); + try self.emitWValue(cmp_res); + try self.addLabel(.local_set, overflow_bit.local); + break :blk try self.intcast(bin_op, new_ty, lhs_ty); + } else { + const down_cast = try self.intcast(bin_op, new_ty, lhs_ty); + const shr = try self.binOp(down_cast, .{ .imm32 = int_info.bits - 1 }, lhs_ty, .shr); + + const shr_res = try self.binOp(bin_op, .{ .imm64 = int_info.bits }, new_ty, .shr); + const down_shr_res = try self.intcast(shr_res, new_ty, lhs_ty); + const cmp_res = try self.cmp(down_shr_res, shr, lhs_ty, .neq); + try self.emitWValue(cmp_res); + try self.addLabel(.local_set, overflow_bit.local); + break :blk down_cast; } - try self.addLabel(.br_if, 0); - const div = try self.binOp(bin_op, lhs, lhs_ty, .div); - const cmp_res = try self.cmp(div, rhs, lhs_ty, .neq); - try self.emitWValue(cmp_res); + } else if (int_info.signedness == .signed) blk: { + const shift_imm = if (wasm_bits == 32) + WValue{ .imm32 = wasm_bits - int_info.bits } + else + WValue{ .imm64 = wasm_bits - int_info.bits }; + + const lhs_shl = try self.binOp(lhs, shift_imm, lhs_ty, .shl); + const lhs_shr = try self.binOp(lhs_shl, shift_imm, lhs_ty, .shr); + const rhs_shl = try self.binOp(rhs, shift_imm, lhs_ty, .shl); + const rhs_shr = try self.binOp(rhs_shl, shift_imm, lhs_ty, .shr); + + const bin_op = try self.binOp(lhs_shr, rhs_shr, lhs_ty, .mul); + const shl = try self.binOp(bin_op, shift_imm, lhs_ty, .shl); + const shr = try self.binOp(shl, shift_imm, lhs_ty, .shr); + + const cmp_op = try self.cmp(shr, bin_op, lhs_ty, .neq); + try self.emitWValue(cmp_op); try self.addLabel(.local_set, overflow_bit.local); - try self.endBlock(); - break :blk bin_op; - } else try self.wrapBinOp(lhs, rhs, lhs_ty, op); + break :blk try self.wrapOperand(bin_op, lhs_ty); + } else blk: { + const bin_op = try self.binOp(lhs, rhs, lhs_ty, .mul); + const shift_imm = if (wasm_bits == 32) + WValue{ .imm32 = int_info.bits } + else + WValue{ .imm64 = int_info.bits }; + const shr = try self.binOp(bin_op, shift_imm, lhs_ty, .shr); + const cmp_op = try self.cmp(shr, zero, lhs_ty, .neq); + try self.emitWValue(cmp_op); + try self.addLabel(.local_set, overflow_bit.local); + break :blk try self.wrapOperand(bin_op, lhs_ty); + }; const result_ptr = try self.allocStack(self.air.typeOfIndex(inst)); try self.store(result_ptr, bin_op, lhs_ty, 0); |
