diff options
| author | Robin Voetter <robin@voetter.nl> | 2024-02-05 09:24:49 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-02-05 09:24:49 +0100 |
| commit | 7634a115c50ef66edbdd5644c4ba310eb31e6343 (patch) | |
| tree | b8be56f0db16691e2939e87bac1222ba2c9fd4a8 /src/codegen/spirv.zig | |
| parent | aebf20cc9a0469a778d6276d3797525660746e91 (diff) | |
| parent | 25111061504a652bfed45b26252349f363b109af (diff) | |
| download | zig-7634a115c50ef66edbdd5644c4ba310eb31e6343.tar.gz zig-7634a115c50ef66edbdd5644c4ba310eb31e6343.zip | |
Merge pull request #18580 from Snektron/spirv-more-vectors
spirv: more vector operations
Diffstat (limited to 'src/codegen/spirv.zig')
| -rw-r--r-- | src/codegen/spirv.zig | 1324 |
1 files changed, 891 insertions, 433 deletions
diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 6c058308df..a499f3d8ed 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -373,8 +373,9 @@ const DeclGen = struct { /// For `composite_integer` this is 0 (TODO) backing_bits: u16, - /// Whether the type is a vector. - is_vector: bool, + /// Null if this type is a scalar, or the length + /// of the vector otherwise. + vector_len: ?u32, /// Whether the inner type is signed. Only relevant for integers. signedness: std.builtin.Signedness, @@ -597,32 +598,37 @@ const DeclGen = struct { return self.backingIntBits(ty) == null; } - fn arithmeticTypeInfo(self: *DeclGen, ty: Type) !ArithmeticTypeInfo { + fn arithmeticTypeInfo(self: *DeclGen, ty: Type) ArithmeticTypeInfo { const mod = self.module; const target = self.getTarget(); - return switch (ty.zigTypeTag(mod)) { + var scalar_ty = ty.scalarType(mod); + if (scalar_ty.zigTypeTag(mod) == .Enum) { + scalar_ty = scalar_ty.intTagType(mod); + } + const vector_len = if (ty.isVector(mod)) ty.vectorLen(mod) else null; + return switch (scalar_ty.zigTypeTag(mod)) { .Bool => ArithmeticTypeInfo{ .bits = 1, // Doesn't matter for this class. .backing_bits = self.backingIntBits(1).?, - .is_vector = false, + .vector_len = vector_len, .signedness = .unsigned, // Technically, but doesn't matter for this class. .class = .bool, }, .Float => ArithmeticTypeInfo{ - .bits = ty.floatBits(target), - .backing_bits = ty.floatBits(target), // TODO: F80? - .is_vector = false, + .bits = scalar_ty.floatBits(target), + .backing_bits = scalar_ty.floatBits(target), // TODO: F80? + .vector_len = vector_len, .signedness = .signed, // Technically, but doesn't matter for this class. .class = .float, }, .Int => blk: { - const int_info = ty.intInfo(mod); + const int_info = scalar_ty.intInfo(mod); // TODO: Maybe it's useful to also return this value. const maybe_backing_bits = self.backingIntBits(int_info.bits); break :blk ArithmeticTypeInfo{ .bits = int_info.bits, .backing_bits = maybe_backing_bits orelse 0, - .is_vector = false, + .vector_len = vector_len, .signedness = int_info.signedness, .class = if (maybe_backing_bits) |backing_bits| if (backing_bits == int_info.bits) @@ -633,22 +639,9 @@ const DeclGen = struct { .composite_integer, }; }, - .Enum => return self.arithmeticTypeInfo(ty.intTagType(mod)), - // As of yet, there is no vector support in the self-hosted compiler. - .Vector => blk: { - const child_type = ty.childType(mod); - const child_ty_info = try self.arithmeticTypeInfo(child_type); - break :blk ArithmeticTypeInfo{ - .bits = child_ty_info.bits, - .backing_bits = child_ty_info.backing_bits, - .is_vector = true, - .signedness = child_ty_info.signedness, - .class = child_ty_info.class, - }; - }, - // TODO: For which types is this the case? - // else => self.todo("implement arithmeticTypeInfo for {}", .{ty.fmt(self.module)}), - else => unreachable, + .Enum => unreachable, + .Vector => unreachable, + else => unreachable, // Unhandled arithmetic type }; } @@ -685,6 +678,18 @@ const DeclGen = struct { } } + /// Emits a float constant + fn constFloat(self: *DeclGen, ty_ref: CacheRef, value: f128) !IdRef { + const ty = self.spv.cache.lookup(ty_ref).float_type; + return switch (ty.bits) { + 16 => try self.spv.resolveId(.{ .float = .{ .ty = ty_ref, .value = .{ .float16 = @floatCast(value) } } }), + 32 => try self.spv.resolveId(.{ .float = .{ .ty = ty_ref, .value = .{ .float32 = @floatCast(value) } } }), + 64 => try self.spv.resolveId(.{ .float = .{ .ty = ty_ref, .value = .{ .float64 = @floatCast(value) } } }), + 80, 128 => unreachable, // TODO + else => unreachable, + }; + } + /// Construct a struct at runtime. /// ty must be a struct type. /// Constituents should be in `indirect` representation (as the elements of a struct should be). @@ -1760,6 +1765,92 @@ const DeclGen = struct { return union_layout; } + /// This structure is used as helper for element-wise operations. It is intended + /// to be used with both vectors and single elements. + const WipElementWise = struct { + dg: *DeclGen, + result_ty: Type, + /// Always in direct representation. + result_ty_ref: CacheRef, + scalar_ty: Type, + /// Always in direct representation. + scalar_ty_ref: CacheRef, + scalar_ty_id: IdRef, + /// True if the input is actually a vector type. + is_vector: bool, + /// The element-wise operation should fill these results before calling finalize(). + /// These should all be in **direct** representation! `finalize()` will convert + /// them to indirect if required. + results: []IdRef, + + fn deinit(wip: *WipElementWise) void { + wip.dg.gpa.free(wip.results); + } + + /// Utility function to extract the element at a particular index in an + /// input vector. This type is expected to be a vector if `wip.is_vector`, and + /// a scalar otherwise. + fn elementAt(wip: WipElementWise, ty: Type, value: IdRef, index: usize) !IdRef { + const mod = wip.dg.module; + if (wip.is_vector) { + assert(ty.isVector(mod)); + return try wip.dg.extractField(ty.childType(mod), value, @intCast(index)); + } else { + assert(!ty.isVector(mod)); + assert(index == 0); + return value; + } + } + + /// Turns the results of this WipElementWise into a result. This can either + /// be a vector or single element, depending on `result_ty`. + /// After calling this function, this WIP is no longer usable. + /// Results is in `direct` representation. + fn finalize(wip: *WipElementWise) !IdRef { + if (wip.is_vector) { + // Convert all the constituents to indirect, as required for the array. + for (wip.results) |*result| { + result.* = try wip.dg.convertToIndirect(wip.scalar_ty, result.*); + } + return try wip.dg.constructArray(wip.result_ty, wip.results); + } else { + return wip.results[0]; + } + } + + /// Allocate a result id at a particular index, and return it. + fn allocId(wip: *WipElementWise, index: usize) IdRef { + assert(wip.is_vector or index == 0); + wip.results[index] = wip.dg.spv.allocId(); + return wip.results[index]; + } + }; + + /// Create a new element-wise operation. + fn elementWise(self: *DeclGen, result_ty: Type) !WipElementWise { + const mod = self.module; + // For now, this operation also reasons in terms of `.direct` representation. + const result_ty_ref = try self.resolveType(result_ty, .direct); + const is_vector = result_ty.isVector(mod); + const num_results = if (is_vector) result_ty.vectorLen(mod) else 1; + const results = try self.gpa.alloc(IdRef, num_results); + for (results) |*result| result.* = undefined; + + const scalar_ty = result_ty.scalarType(mod); + const scalar_ty_ref = try self.resolveType(scalar_ty, .direct); + + return .{ + .dg = self, + .result_ty = result_ty, + .result_ty_ref = result_ty_ref, + .scalar_ty = scalar_ty, + .scalar_ty_ref = scalar_ty_ref, + .scalar_ty_id = self.typeId(scalar_ty_ref), + .is_vector = is_vector, + .results = results, + }; + } + /// The SPIR-V backend is not yet advanced enough to support the std testing infrastructure. /// In order to be able to run tests, we "temporarily" lower test kernels into separate entry- /// points. The test executor will then be able to invoke these to run the tests. @@ -2081,25 +2172,31 @@ const DeclGen = struct { const air_tags = self.air.instructions.items(.tag); const maybe_result_id: ?IdRef = switch (air_tags[@intFromEnum(inst)]) { // zig fmt: off - .add, .add_wrap => try self.airArithOp(inst, .OpFAdd, .OpIAdd, .OpIAdd, true), - .sub, .sub_wrap => try self.airArithOp(inst, .OpFSub, .OpISub, .OpISub, true), - .mul, .mul_wrap => try self.airArithOp(inst, .OpFMul, .OpIMul, .OpIMul, true), + .add, .add_wrap, .add_optimized => try self.airArithOp(inst, .OpFAdd, .OpIAdd, .OpIAdd), + .sub, .sub_wrap, .sub_optimized => try self.airArithOp(inst, .OpFSub, .OpISub, .OpISub), + .mul, .mul_wrap, .mul_optimized => try self.airArithOp(inst, .OpFMul, .OpIMul, .OpIMul), + + .abs => try self.airAbs(inst), .div_float, .div_float_optimized, // TODO: Check that this is the right operation. .div_trunc, .div_trunc_optimized, - => try self.airArithOp(inst, .OpFDiv, .OpSDiv, .OpUDiv, false), + => try self.airArithOp(inst, .OpFDiv, .OpSDiv, .OpUDiv), // TODO: Check if this is the right operation - // TODO: Make airArithOp for rem not emit a mask for the LHS. .rem, .rem_optimized, - => try self.airArithOp(inst, .OpFRem, .OpSRem, .OpSRem, false), + => try self.airArithOp(inst, .OpFRem, .OpSRem, .OpSRem), .add_with_overflow => try self.airAddSubOverflow(inst, .OpIAdd, .OpULessThan, .OpSLessThan), .sub_with_overflow => try self.airAddSubOverflow(inst, .OpISub, .OpUGreaterThan, .OpSGreaterThan), + .shl_with_overflow => try self.airShlOverflow(inst), + .mul_add => try self.airMulAdd(inst), + + .splat => try self.airSplat(inst), + .reduce, .reduce_optimized => try self.airReduce(inst), .shuffle => try self.airShuffle(inst), .ptr_add => try self.airPtrAdd(inst), @@ -2111,7 +2208,8 @@ const DeclGen = struct { .bool_and => try self.airBinOpSimple(inst, .OpLogicalAnd), .bool_or => try self.airBinOpSimple(inst, .OpLogicalOr), - .shl => try self.airShift(inst, .OpShiftLeftLogical), + .shl, .shl_exact => try self.airShift(inst, .OpShiftLeftLogical, .OpShiftLeftLogical), + .shr, .shr_exact => try self.airShift(inst, .OpShiftRightLogical, .OpShiftRightArithmetic), .min => try self.airMinMax(inst, .lt), .max => try self.airMinMax(inst, .gt), @@ -2121,6 +2219,7 @@ const DeclGen = struct { .int_from_ptr => try self.airIntFromPtr(inst), .float_from_int => try self.airFloatFromInt(inst), .int_from_float => try self.airIntFromFloat(inst), + .int_from_bool => try self.airIntFromBool(inst), .fpext, .fptrunc => try self.airFloatCast(inst), .not => try self.airNot(inst), @@ -2137,6 +2236,8 @@ const DeclGen = struct { .ptr_elem_val => try self.airPtrElemVal(inst), .array_elem_val => try self.airArrayElemVal(inst), + .vector_store_elem => return self.airVectorStoreElem(inst), + .set_union_tag => return self.airSetUnionTag(inst), .get_union_tag => try self.airGetUnionTag(inst), .union_init => try self.airUnionInit(inst), @@ -2189,13 +2290,16 @@ const DeclGen = struct { .wrap_errunion_err => try self.airWrapErrUnionErr(inst), .wrap_errunion_payload => try self.airWrapErrUnionPayload(inst), - .is_null => try self.airIsNull(inst, .is_null), - .is_non_null => try self.airIsNull(inst, .is_non_null), - .is_err => try self.airIsErr(inst, .is_err), - .is_non_err => try self.airIsErr(inst, .is_non_err), + .is_null => try self.airIsNull(inst, false, .is_null), + .is_non_null => try self.airIsNull(inst, false, .is_non_null), + .is_null_ptr => try self.airIsNull(inst, true, .is_null), + .is_non_null_ptr => try self.airIsNull(inst, true, .is_non_null), + .is_err => try self.airIsErr(inst, .is_err), + .is_non_err => try self.airIsErr(inst, .is_non_err), - .optional_payload => try self.airUnwrapOptional(inst), - .wrap_optional => try self.airWrapOptional(inst), + .optional_payload => try self.airUnwrapOptional(inst), + .optional_payload_ptr => try self.airUnwrapOptionalPtr(inst), + .wrap_optional => try self.airWrapOptional(inst), .assembly => try self.airAssembly(inst), @@ -2213,34 +2317,17 @@ const DeclGen = struct { } fn binOpSimple(self: *DeclGen, ty: Type, lhs_id: IdRef, rhs_id: IdRef, comptime opcode: Opcode) !IdRef { - const mod = self.module; - - if (ty.isVector(mod)) { - const child_ty = ty.childType(mod); - const vector_len = ty.vectorLen(mod); - - const constituents = try self.gpa.alloc(IdRef, vector_len); - defer self.gpa.free(constituents); - - for (constituents, 0..) |*constituent, i| { - const lhs_index_id = try self.extractField(child_ty, lhs_id, @intCast(i)); - const rhs_index_id = try self.extractField(child_ty, rhs_id, @intCast(i)); - const result_id = try self.binOpSimple(child_ty, lhs_index_id, rhs_index_id, opcode); - constituent.* = try self.convertToIndirect(child_ty, result_id); - } - - return try self.constructArray(ty, constituents); + var wip = try self.elementWise(ty); + defer wip.deinit(); + for (0..wip.results.len) |i| { + try self.func.body.emit(self.spv.gpa, opcode, .{ + .id_result_type = wip.scalar_ty_id, + .id_result = wip.allocId(i), + .operand_1 = try wip.elementAt(ty, lhs_id, i), + .operand_2 = try wip.elementAt(ty, rhs_id, i), + }); } - - const result_id = self.spv.allocId(); - const result_type_id = try self.resolveTypeId(ty); - try self.func.body.emit(self.spv.gpa, opcode, .{ - .id_result_type = result_type_id, - .id_result = result_id, - .operand_1 = lhs_id, - .operand_2 = rhs_id, - }); - return result_id; + return try wip.finalize(); } fn airBinOpSimple(self: *DeclGen, inst: Air.Inst.Index, comptime opcode: Opcode) !?IdRef { @@ -2254,29 +2341,59 @@ const DeclGen = struct { return try self.binOpSimple(ty, lhs_id, rhs_id, opcode); } - fn airShift(self: *DeclGen, inst: Air.Inst.Index, comptime opcode: Opcode) !?IdRef { + fn airShift(self: *DeclGen, inst: Air.Inst.Index, comptime unsigned: Opcode, comptime signed: Opcode) !?IdRef { if (self.liveness.isUnused(inst)) return null; + const mod = self.module; const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; const lhs_id = try self.resolve(bin_op.lhs); const rhs_id = try self.resolve(bin_op.rhs); - const result_type_id = try self.resolveTypeId(self.typeOfIndex(inst)); - // the shift and the base must be the same type in SPIR-V, but in Zig the shift is a smaller int. - const shift_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ - .id_result_type = result_type_id, - .id_result = shift_id, - .unsigned_value = rhs_id, - }); + const result_ty = self.typeOfIndex(inst); + const shift_ty = self.typeOf(bin_op.rhs); + const scalar_shift_ty_ref = try self.resolveType(shift_ty.scalarType(mod), .direct); - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, opcode, .{ - .id_result_type = result_type_id, - .id_result = result_id, - .base = lhs_id, - .shift = shift_id, - }); - return result_id; + const info = self.arithmeticTypeInfo(result_ty); + switch (info.class) { + .composite_integer => return self.todo("shift ops for composite integers", .{}), + .integer, .strange_integer => {}, + .float, .bool => unreachable, + } + + var wip = try self.elementWise(result_ty); + defer wip.deinit(); + for (wip.results, 0..) |*result_id, i| { + const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i); + const rhs_elem_id = try wip.elementAt(shift_ty, rhs_id, i); + + // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that, + // so just manually upcast it if required. + const shift_id = if (scalar_shift_ty_ref != wip.scalar_ty_ref) blk: { + const shift_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ + .id_result_type = wip.scalar_ty_id, + .id_result = shift_id, + .unsigned_value = rhs_elem_id, + }); + break :blk shift_id; + } else rhs_elem_id; + + const value_id = self.spv.allocId(); + const args = .{ + .id_result_type = wip.scalar_ty_id, + .id_result = value_id, + .base = lhs_elem_id, + .shift = shift_id, + }; + + if (result_ty.isSignedInt(mod)) { + try self.func.body.emit(self.spv.gpa, signed, args); + } else { + try self.func.body.emit(self.spv.gpa, unsigned, args); + } + + result_id.* = try self.normalize(wip.scalar_ty_ref, value_id, info); + } + return try wip.finalize(); } fn airMinMax(self: *DeclGen, inst: Air.Inst.Index, op: std.math.CompareOperator) !?IdRef { @@ -2286,88 +2403,102 @@ const DeclGen = struct { const lhs_id = try self.resolve(bin_op.lhs); const rhs_id = try self.resolve(bin_op.rhs); const result_ty = self.typeOfIndex(inst); - const result_ty_ref = try self.resolveType(result_ty, .direct); - - const info = try self.arithmeticTypeInfo(result_ty); - // TODO: Use fmin for OpenCL - const cmp_id = try self.cmp(op, Type.bool, result_ty, lhs_id, rhs_id); - const selection_id = switch (info.class) { - .float => blk: { - // cmp uses OpFOrd. When we have 0 [<>] nan this returns false, - // but we want it to pick lhs. Therefore we also have to check if - // rhs is nan. We don't need to care about the result when both - // are nan. - const rhs_is_nan_id = self.spv.allocId(); - const bool_ty_ref = try self.resolveType(Type.bool, .direct); - try self.func.body.emit(self.spv.gpa, .OpIsNan, .{ - .id_result_type = self.typeId(bool_ty_ref), - .id_result = rhs_is_nan_id, - .x = rhs_id, - }); - const float_cmp_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{ - .id_result_type = self.typeId(bool_ty_ref), - .id_result = float_cmp_id, - .operand_1 = cmp_id, - .operand_2 = rhs_is_nan_id, - }); - break :blk float_cmp_id; - }, - else => cmp_id, - }; - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpSelect, .{ - .id_result_type = self.typeId(result_ty_ref), - .id_result = result_id, - .condition = selection_id, - .object_1 = lhs_id, - .object_2 = rhs_id, - }); - return result_id; - } + return try self.minMax(result_ty, op, lhs_id, rhs_id); + } + + fn minMax(self: *DeclGen, result_ty: Type, op: std.math.CompareOperator, lhs_id: IdRef, rhs_id: IdRef) !IdRef { + const info = self.arithmeticTypeInfo(result_ty); + + var wip = try self.elementWise(result_ty); + defer wip.deinit(); + for (wip.results, 0..) |*result_id, i| { + const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i); + const rhs_elem_id = try wip.elementAt(result_ty, rhs_id, i); + + // TODO: Use fmin for OpenCL + const cmp_id = try self.cmp(op, Type.bool, wip.scalar_ty, lhs_elem_id, rhs_elem_id); + const selection_id = switch (info.class) { + .float => blk: { + // cmp uses OpFOrd. When we have 0 [<>] nan this returns false, + // but we want it to pick lhs. Therefore we also have to check if + // rhs is nan. We don't need to care about the result when both + // are nan. + const rhs_is_nan_id = self.spv.allocId(); + const bool_ty_ref = try self.resolveType(Type.bool, .direct); + try self.func.body.emit(self.spv.gpa, .OpIsNan, .{ + .id_result_type = self.typeId(bool_ty_ref), + .id_result = rhs_is_nan_id, + .x = rhs_elem_id, + }); + const float_cmp_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{ + .id_result_type = self.typeId(bool_ty_ref), + .id_result = float_cmp_id, + .operand_1 = cmp_id, + .operand_2 = rhs_is_nan_id, + }); + break :blk float_cmp_id; + }, + else => cmp_id, + }; - /// This function canonicalizes a "strange" integer value: - /// For unsigned integers, the value is masked so that only the relevant bits can contain - /// non-zeros. - /// For signed integers, the value is also sign extended. - fn normalizeInt(self: *DeclGen, ty_ref: CacheRef, value_id: IdRef, info: ArithmeticTypeInfo) !IdRef { - assert(info.class != .composite_integer); // TODO - if (info.bits == info.backing_bits) { - return value_id; + result_id.* = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpSelect, .{ + .id_result_type = wip.scalar_ty_id, + .id_result = result_id.*, + .condition = selection_id, + .object_1 = lhs_elem_id, + .object_2 = rhs_elem_id, + }); } - - switch (info.signedness) { - .unsigned => { - const mask_value = if (info.bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1; - const result_id = self.spv.allocId(); - const mask_id = try self.constInt(ty_ref, mask_value); - try self.func.body.emit(self.spv.gpa, .OpBitwiseAnd, .{ - .id_result_type = self.typeId(ty_ref), - .id_result = result_id, - .operand_1 = value_id, - .operand_2 = mask_id, - }); - return result_id; - }, - .signed => { - // Shift left and right so that we can copy the sight bit that way. - const shift_amt_id = try self.constInt(ty_ref, info.backing_bits - info.bits); - const left_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{ - .id_result_type = self.typeId(ty_ref), - .id_result = left_id, - .base = value_id, - .shift = shift_amt_id, - }); - const right_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{ - .id_result_type = self.typeId(ty_ref), - .id_result = right_id, - .base = left_id, - .shift = shift_amt_id, - }); - return right_id; + return wip.finalize(); + } + + /// This function normalizes values to a canonical representation + /// after some arithmetic operation. This mostly consists of wrapping + /// behavior for strange integers: + /// - Unsigned integers are bitwise masked with a mask that only passes + /// the valid bits through. + /// - Signed integers are also sign extended if they are negative. + /// All other values are returned unmodified (this makes strange integer + /// wrapping easier to use in generic operations). + fn normalize(self: *DeclGen, ty_ref: CacheRef, value_id: IdRef, info: ArithmeticTypeInfo) !IdRef { + switch (info.class) { + .integer, .bool, .float => return value_id, + .composite_integer => unreachable, // TODO + .strange_integer => switch (info.signedness) { + .unsigned => { + const mask_value = if (info.bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1; + const result_id = self.spv.allocId(); + const mask_id = try self.constInt(ty_ref, mask_value); + try self.func.body.emit(self.spv.gpa, .OpBitwiseAnd, .{ + .id_result_type = self.typeId(ty_ref), + .id_result = result_id, + .operand_1 = value_id, + .operand_2 = mask_id, + }); + return result_id; + }, + .signed => { + // Shift left and right so that we can copy the sight bit that way. + const shift_amt_id = try self.constInt(ty_ref, info.backing_bits - info.bits); + const left_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{ + .id_result_type = self.typeId(ty_ref), + .id_result = left_id, + .base = value_id, + .shift = shift_amt_id, + }); + const right_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{ + .id_result_type = self.typeId(ty_ref), + .id_result = right_id, + .base = left_id, + .shift = shift_amt_id, + }); + return right_id; + }, }, } } @@ -2378,8 +2509,6 @@ const DeclGen = struct { comptime fop: Opcode, comptime sop: Opcode, comptime uop: Opcode, - /// true if this operation holds under modular arithmetic. - comptime modular: bool, ) !?IdRef { if (self.liveness.isUnused(inst)) return null; @@ -2393,60 +2522,27 @@ const DeclGen = struct { assert(self.typeOf(bin_op.lhs).eql(ty, self.module)); assert(self.typeOf(bin_op.rhs).eql(ty, self.module)); - return try self.arithOp(ty, lhs_id, rhs_id, fop, sop, uop, modular); + return try self.arithOp(ty, lhs_id, rhs_id, fop, sop, uop); } fn arithOp( self: *DeclGen, ty: Type, - lhs_id_: IdRef, - rhs_id_: IdRef, + lhs_id: IdRef, + rhs_id: IdRef, comptime fop: Opcode, comptime sop: Opcode, comptime uop: Opcode, - /// true if this operation holds under modular arithmetic. - comptime modular: bool, ) !IdRef { - var rhs_id = rhs_id_; - var lhs_id = lhs_id_; - - const mod = self.module; - const result_ty_ref = try self.resolveType(ty, .direct); - - if (ty.isVector(mod)) { - const child_ty = ty.childType(mod); - const vector_len = ty.vectorLen(mod); - const constituents = try self.gpa.alloc(IdRef, vector_len); - defer self.gpa.free(constituents); - - for (constituents, 0..) |*constituent, i| { - const lhs_index_id = try self.extractField(child_ty, lhs_id, @intCast(i)); - const rhs_index_id = try self.extractField(child_ty, rhs_id, @intCast(i)); - constituent.* = try self.arithOp(child_ty, lhs_index_id, rhs_index_id, fop, sop, uop, modular); - } - - return self.constructArray(ty, constituents); - } - // Binary operations are generally applicable to both scalar and vector operations // in SPIR-V, but int and float versions of operations require different opcodes. - const info = try self.arithmeticTypeInfo(ty); + const info = self.arithmeticTypeInfo(ty); const opcode_index: usize = switch (info.class) { .composite_integer => { return self.todo("binary operations for composite integers", .{}); }, - .strange_integer => blk: { - if (!modular) { - lhs_id = try self.normalizeInt(result_ty_ref, lhs_id, info); - rhs_id = try self.normalizeInt(result_ty_ref, rhs_id, info); - } - break :blk switch (info.signedness) { - .signed => @as(usize, 1), - .unsigned => @as(usize, 2), - }; - }, - .integer => switch (info.signedness) { + .integer, .strange_integer => switch (info.signedness) { .signed => @as(usize, 1), .unsigned => @as(usize, 2), }, @@ -2454,24 +2550,91 @@ const DeclGen = struct { .bool => unreachable, }; - const result_id = self.spv.allocId(); - const operands = .{ - .id_result_type = self.typeId(result_ty_ref), - .id_result = result_id, - .operand_1 = lhs_id, - .operand_2 = rhs_id, - }; + var wip = try self.elementWise(ty); + defer wip.deinit(); + for (wip.results, 0..) |*result_id, i| { + const lhs_elem_id = try wip.elementAt(ty, lhs_id, i); + const rhs_elem_id = try wip.elementAt(ty, rhs_id, i); + + const value_id = self.spv.allocId(); + const operands = .{ + .id_result_type = wip.scalar_ty_id, + .id_result = value_id, + .operand_1 = lhs_elem_id, + .operand_2 = rhs_elem_id, + }; - switch (opcode_index) { - 0 => try self.func.body.emit(self.spv.gpa, fop, operands), - 1 => try self.func.body.emit(self.spv.gpa, sop, operands), - 2 => try self.func.body.emit(self.spv.gpa, uop, operands), - else => unreachable, + switch (opcode_index) { + 0 => try self.func.body.emit(self.spv.gpa, fop, operands), + 1 => try self.func.body.emit(self.spv.gpa, sop, operands), + 2 => try self.func.body.emit(self.spv.gpa, uop, operands), + else => unreachable, + } + + // TODO: Trap on overflow? Probably going to be annoying. + // TODO: Look into SPV_KHR_no_integer_wrap_decoration which provides NoSignedWrap/NoUnsignedWrap. + result_id.* = try self.normalize(wip.scalar_ty_ref, value_id, info); } - // TODO: Trap on overflow? Probably going to be annoying. - // TODO: Look into SPV_KHR_no_integer_wrap_decoration which provides NoSignedWrap/NoUnsignedWrap. - return result_id; + return try wip.finalize(); + } + + fn airAbs(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + if (self.liveness.isUnused(inst)) return null; + + const mod = self.module; + const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; + const operand_id = try self.resolve(ty_op.operand); + // Note: operand_ty may be signed, while ty is always unsigned! + const operand_ty = self.typeOf(ty_op.operand); + const ty = self.typeOfIndex(inst); + const info = self.arithmeticTypeInfo(ty); + const operand_scalar_ty = operand_ty.scalarType(mod); + const operand_scalar_ty_ref = try self.resolveType(operand_scalar_ty, .direct); + + var wip = try self.elementWise(ty); + defer wip.deinit(); + + const zero_id = switch (info.class) { + .float => try self.constFloat(operand_scalar_ty_ref, 0), + .integer, .strange_integer => try self.constInt(operand_scalar_ty_ref, 0), + .composite_integer => unreachable, // TODO + .bool => unreachable, + }; + for (wip.results, 0..) |*result_id, i| { + const elem_id = try wip.elementAt(operand_ty, operand_id, i); + // Idk why spir-v doesn't have a dedicated abs() instruction in the base + // instruction set. For now we're just going to negate and check to avoid + // importing the extinst. + // TODO: Make this a call to compiler rt / ext inst + const neg_id = self.spv.allocId(); + const args = .{ + .id_result_type = self.typeId(operand_scalar_ty_ref), + .id_result = neg_id, + .operand_1 = zero_id, + .operand_2 = elem_id, + }; + switch (info.class) { + .float => try self.func.body.emit(self.spv.gpa, .OpFSub, args), + .integer, .strange_integer => try self.func.body.emit(self.spv.gpa, .OpISub, args), + .composite_integer => unreachable, // TODO + .bool => unreachable, + } + const neg_norm_id = try self.normalize(wip.scalar_ty_ref, neg_id, info); + + const gt_zero_id = try self.cmp(.gt, Type.bool, operand_scalar_ty, elem_id, zero_id); + const abs_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpSelect, .{ + .id_result_type = self.typeId(operand_scalar_ty_ref), + .id_result = abs_id, + .condition = gt_zero_id, + .object_1 = elem_id, + .object_2 = neg_norm_id, + }); + // For Shader, we may need to cast from signed to unsigned here. + result_id.* = try self.bitCast(wip.scalar_ty, operand_scalar_ty, abs_id); + } + return try wip.finalize(); } fn airAddSubOverflow( @@ -2488,140 +2651,344 @@ const DeclGen = struct { const lhs = try self.resolve(extra.lhs); const rhs = try self.resolve(extra.rhs); - const operand_ty = self.typeOf(extra.lhs); const result_ty = self.typeOfIndex(inst); + const operand_ty = self.typeOf(extra.lhs); + const ov_ty = result_ty.structFieldType(1, self.module); + + const bool_ty_ref = try self.resolveType(Type.bool, .direct); - const info = try self.arithmeticTypeInfo(operand_ty); + const info = self.arithmeticTypeInfo(operand_ty); switch (info.class) { .composite_integer => return self.todo("overflow ops for composite integers", .{}), - .strange_integer => return self.todo("overflow ops for strange integers", .{}), - .integer => {}, + .strange_integer, .integer => {}, .float, .bool => unreachable, } - // The operand type must be the same as the result type in SPIR-V, which - // is the same as in Zig. - const operand_ty_ref = try self.resolveType(operand_ty, .direct); - const operand_ty_id = self.typeId(operand_ty_ref); + var wip_result = try self.elementWise(operand_ty); + defer wip_result.deinit(); + var wip_ov = try self.elementWise(ov_ty); + defer wip_ov.deinit(); + for (wip_result.results, wip_ov.results, 0..) |*result_id, *ov_id, i| { + const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i); + const rhs_elem_id = try wip_result.elementAt(operand_ty, rhs, i); + + // Normalize both so that we can properly check for overflow + const value_id = self.spv.allocId(); + + try self.func.body.emit(self.spv.gpa, add, .{ + .id_result_type = wip_result.scalar_ty_id, + .id_result = value_id, + .operand_1 = lhs_elem_id, + .operand_2 = rhs_elem_id, + }); - const bool_ty_ref = try self.resolveType(Type.bool, .direct); + // Normalize the result so that the comparisons go well + result_id.* = try self.normalize(wip_result.scalar_ty_ref, value_id, info); + + const overflowed_id = switch (info.signedness) { + .unsigned => blk: { + // Overflow happened if the result is smaller than either of the operands. It doesn't matter which. + // For subtraction the conditions need to be swapped. + const overflowed_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, ucmp, .{ + .id_result_type = self.typeId(bool_ty_ref), + .id_result = overflowed_id, + .operand_1 = result_id.*, + .operand_2 = lhs_elem_id, + }); + break :blk overflowed_id; + }, + .signed => blk: { + // lhs - rhs + // For addition, overflow happened if: + // - rhs is negative and value > lhs + // - rhs is positive and value < lhs + // This can be shortened to: + // (rhs < 0 and value > lhs) or (rhs >= 0 and value <= lhs) + // = (rhs < 0) == (value > lhs) + // = (rhs < 0) == (lhs < value) + // Note that signed overflow is also wrapping in spir-v. + // For subtraction, overflow happened if: + // - rhs is negative and value < lhs + // - rhs is positive and value > lhs + // This can be shortened to: + // (rhs < 0 and value < lhs) or (rhs >= 0 and value >= lhs) + // = (rhs < 0) == (value < lhs) + // = (rhs < 0) == (lhs > value) + + const rhs_lt_zero_id = self.spv.allocId(); + const zero_id = try self.constInt(wip_result.scalar_ty_ref, 0); + try self.func.body.emit(self.spv.gpa, .OpSLessThan, .{ + .id_result_type = self.typeId(bool_ty_ref), + .id_result = rhs_lt_zero_id, + .operand_1 = rhs_elem_id, + .operand_2 = zero_id, + }); + + const value_gt_lhs_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, scmp, .{ + .id_result_type = self.typeId(bool_ty_ref), + .id_result = value_gt_lhs_id, + .operand_1 = lhs_elem_id, + .operand_2 = result_id.*, + }); + + const overflowed_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpLogicalEqual, .{ + .id_result_type = self.typeId(bool_ty_ref), + .id_result = overflowed_id, + .operand_1 = rhs_lt_zero_id, + .operand_2 = value_gt_lhs_id, + }); + break :blk overflowed_id; + }, + }; + + ov_id.* = try self.intFromBool(wip_ov.scalar_ty_ref, overflowed_id); + } + + return try self.constructStruct( + result_ty, + &.{ operand_ty, ov_ty }, + &.{ try wip_result.finalize(), try wip_ov.finalize() }, + ); + } + + fn airShlOverflow(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + if (self.liveness.isUnused(inst)) return null; + const mod = self.module; + const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; + const extra = self.air.extraData(Air.Bin, ty_pl.payload).data; + const lhs = try self.resolve(extra.lhs); + const rhs = try self.resolve(extra.rhs); + + const result_ty = self.typeOfIndex(inst); + const operand_ty = self.typeOf(extra.lhs); + const shift_ty = self.typeOf(extra.rhs); + const scalar_shift_ty_ref = try self.resolveType(shift_ty.scalarType(mod), .direct); const ov_ty = result_ty.structFieldType(1, self.module); - // Note: result is stored in a struct, so indirect representation. - const ov_ty_ref = try self.resolveType(ov_ty, .indirect); - - // TODO: Operations other than addition. - const value_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, add, .{ - .id_result_type = operand_ty_id, - .id_result = value_id, - .operand_1 = lhs, - .operand_2 = rhs, - }); - const overflowed_id = switch (info.signedness) { - .unsigned => blk: { - // Overflow happened if the result is smaller than either of the operands. It doesn't matter which. - // For subtraction the conditions need to be swapped. - const overflowed_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, ucmp, .{ - .id_result_type = self.typeId(bool_ty_ref), - .id_result = overflowed_id, - .operand_1 = value_id, - .operand_2 = lhs, - }); - break :blk overflowed_id; - }, - .signed => blk: { - // lhs - rhs - // For addition, overflow happened if: - // - rhs is negative and value > lhs - // - rhs is positive and value < lhs - // This can be shortened to: - // (rhs < 0 and value > lhs) or (rhs >= 0 and value <= lhs) - // = (rhs < 0) == (value > lhs) - // = (rhs < 0) == (lhs < value) - // Note that signed overflow is also wrapping in spir-v. - // For subtraction, overflow happened if: - // - rhs is negative and value < lhs - // - rhs is positive and value > lhs - // This can be shortened to: - // (rhs < 0 and value < lhs) or (rhs >= 0 and value >= lhs) - // = (rhs < 0) == (value < lhs) - // = (rhs < 0) == (lhs > value) - - const rhs_lt_zero_id = self.spv.allocId(); - const zero_id = try self.constInt(operand_ty_ref, 0); - try self.func.body.emit(self.spv.gpa, .OpSLessThan, .{ - .id_result_type = self.typeId(bool_ty_ref), - .id_result = rhs_lt_zero_id, - .operand_1 = rhs, - .operand_2 = zero_id, - }); + const bool_ty_ref = try self.resolveType(Type.bool, .direct); - const value_gt_lhs_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, scmp, .{ - .id_result_type = self.typeId(bool_ty_ref), - .id_result = value_gt_lhs_id, - .operand_1 = lhs, - .operand_2 = value_id, - }); + const info = self.arithmeticTypeInfo(operand_ty); + switch (info.class) { + .composite_integer => return self.todo("overflow shift for composite integers", .{}), + .integer, .strange_integer => {}, + .float, .bool => unreachable, + } - const overflowed_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLogicalEqual, .{ - .id_result_type = self.typeId(bool_ty_ref), - .id_result = overflowed_id, - .operand_1 = rhs_lt_zero_id, - .operand_2 = value_gt_lhs_id, + var wip_result = try self.elementWise(operand_ty); + defer wip_result.deinit(); + var wip_ov = try self.elementWise(ov_ty); + defer wip_ov.deinit(); + for (wip_result.results, wip_ov.results, 0..) |*result_id, *ov_id, i| { + const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i); + const rhs_elem_id = try wip_result.elementAt(shift_ty, rhs, i); + + // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that, + // so just manually upcast it if required. + const shift_id = if (scalar_shift_ty_ref != wip_result.scalar_ty_ref) blk: { + const shift_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ + .id_result_type = wip_result.scalar_ty_id, + .id_result = shift_id, + .unsigned_value = rhs_elem_id, }); - break :blk overflowed_id; - }, - }; + break :blk shift_id; + } else rhs_elem_id; + + const value_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{ + .id_result_type = wip_result.scalar_ty_id, + .id_result = value_id, + .base = lhs_elem_id, + .shift = shift_id, + }); + result_id.* = try self.normalize(wip_result.scalar_ty_ref, value_id, info); + + const right_shift_id = self.spv.allocId(); + switch (info.signedness) { + .signed => { + try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{ + .id_result_type = wip_result.scalar_ty_id, + .id_result = right_shift_id, + .base = result_id.*, + .shift = shift_id, + }); + }, + .unsigned => { + try self.func.body.emit(self.spv.gpa, .OpShiftRightLogical, .{ + .id_result_type = wip_result.scalar_ty_id, + .id_result = right_shift_id, + .base = result_id.*, + .shift = shift_id, + }); + }, + } + + const overflowed_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{ + .id_result_type = self.typeId(bool_ty_ref), + .id_result = overflowed_id, + .operand_1 = lhs_elem_id, + .operand_2 = right_shift_id, + }); + + ov_id.* = try self.intFromBool(wip_ov.scalar_ty_ref, overflowed_id); + } - // Construct the struct that Zig wants as result. - // The value should already be the correct type. - const ov_id = try self.intFromBool(ov_ty_ref, overflowed_id); return try self.constructStruct( result_ty, &.{ operand_ty, ov_ty }, - &.{ value_id, ov_id }, + &.{ try wip_result.finalize(), try wip_ov.finalize() }, ); } + fn airMulAdd(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + if (self.liveness.isUnused(inst)) return null; + + const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op; + const extra = self.air.extraData(Air.Bin, pl_op.payload).data; + + const mulend1 = try self.resolve(extra.lhs); + const mulend2 = try self.resolve(extra.rhs); + const addend = try self.resolve(pl_op.operand); + + const ty = self.typeOfIndex(inst); + + const info = self.arithmeticTypeInfo(ty); + assert(info.class == .float); // .mul_add is only emitted for floats + + var wip = try self.elementWise(ty); + defer wip.deinit(); + for (0..wip.results.len) |i| { + const mul_result = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpFMul, .{ + .id_result_type = wip.scalar_ty_id, + .id_result = mul_result, + .operand_1 = try wip.elementAt(ty, mulend1, i), + .operand_2 = try wip.elementAt(ty, mulend2, i), + }); + + try self.func.body.emit(self.spv.gpa, .OpFAdd, .{ + .id_result_type = wip.scalar_ty_id, + .id_result = wip.allocId(i), + .operand_1 = mul_result, + .operand_2 = try wip.elementAt(ty, addend, i), + }); + } + return try wip.finalize(); + } + + fn airSplat(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + if (self.liveness.isUnused(inst)) return null; + const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; + const operand_id = try self.resolve(ty_op.operand); + const result_ty = self.typeOfIndex(inst); + var wip = try self.elementWise(result_ty); + defer wip.deinit(); + for (wip.results) |*result_id| { + result_id.* = operand_id; + } + return try wip.finalize(); + } + + fn airReduce(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + if (self.liveness.isUnused(inst)) return null; + const mod = self.module; + const reduce = self.air.instructions.items(.data)[@intFromEnum(inst)].reduce; + const operand = try self.resolve(reduce.operand); + const operand_ty = self.typeOf(reduce.operand); + const scalar_ty = operand_ty.scalarType(mod); + const scalar_ty_ref = try self.resolveType(scalar_ty, .direct); + const scalar_ty_id = self.typeId(scalar_ty_ref); + + const info = self.arithmeticTypeInfo(operand_ty); + + var result_id = try self.extractField(scalar_ty, operand, 0); + const len = operand_ty.vectorLen(mod); + + switch (reduce.operation) { + .Min, .Max => |op| { + const cmp_op: std.math.CompareOperator = if (op == .Max) .gt else .lt; + for (1..len) |i| { + const lhs = result_id; + const rhs = try self.extractField(scalar_ty, operand, @intCast(i)); + result_id = try self.minMax(scalar_ty, cmp_op, lhs, rhs); + } + + return result_id; + }, + else => {}, + } + + const opcode: Opcode = switch (info.class) { + .bool => switch (reduce.operation) { + .And => .OpLogicalAnd, + .Or => .OpLogicalOr, + .Xor => .OpLogicalNotEqual, + else => unreachable, + }, + .strange_integer, .integer => switch (reduce.operation) { + .And => .OpBitwiseAnd, + .Or => .OpBitwiseOr, + .Xor => .OpBitwiseXor, + .Add => .OpIAdd, + .Mul => .OpIMul, + else => unreachable, + }, + .float => switch (reduce.operation) { + .Add => .OpFAdd, + .Mul => .OpFMul, + else => unreachable, + }, + .composite_integer => unreachable, // TODO + }; + + for (1..len) |i| { + const lhs = result_id; + const rhs = try self.extractField(scalar_ty, operand, @intCast(i)); + result_id = self.spv.allocId(); + + try self.func.body.emitRaw(self.spv.gpa, opcode, 4); + self.func.body.writeOperand(spec.IdResultType, scalar_ty_id); + self.func.body.writeOperand(spec.IdResult, result_id); + self.func.body.writeOperand(spec.IdResultType, lhs); + self.func.body.writeOperand(spec.IdResultType, rhs); + } + + return result_id; + } + fn airShuffle(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const mod = self.module; if (self.liveness.isUnused(inst)) return null; - const ty = self.typeOfIndex(inst); const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; const extra = self.air.extraData(Air.Shuffle, ty_pl.payload).data; const a = try self.resolve(extra.a); const b = try self.resolve(extra.b); const mask = Value.fromInterned(extra.mask); - const mask_len = extra.mask_len; - const a_len = self.typeOf(extra.a).vectorLen(mod); - const result_id = self.spv.allocId(); - const result_type_id = try self.resolveTypeId(ty); - // Similar to LLVM, SPIR-V uses indices larger than the length of the first vector - // to index into the second vector. - try self.func.body.emitRaw(self.spv.gpa, .OpVectorShuffle, 4 + mask_len); - self.func.body.writeOperand(spec.IdResultType, result_type_id); - self.func.body.writeOperand(spec.IdResult, result_id); - self.func.body.writeOperand(spec.IdRef, a); - self.func.body.writeOperand(spec.IdRef, b); + const ty = self.typeOfIndex(inst); - var i: usize = 0; - while (i < mask_len) : (i += 1) { + var wip = try self.elementWise(ty); + defer wip.deinit(); + for (wip.results, 0..) |*result_id, i| { const elem = try mask.elemValue(mod, i); if (elem.isUndef(mod)) { - self.func.body.writeOperand(spec.LiteralInteger, 0xFFFF_FFFF); + result_id.* = try self.spv.constUndef(wip.scalar_ty_ref); + continue; + } + + const index = elem.toSignedInt(mod); + if (index >= 0) { + result_id.* = try self.extractField(wip.scalar_ty, a, @intCast(index)); } else { - const int = elem.toSignedInt(mod); - const unsigned = if (int >= 0) @as(u32, @intCast(int)) else @as(u32, @intCast(~int + a_len)); - self.func.body.writeOperand(spec.LiteralInteger, unsigned); + result_id.* = try self.extractField(wip.scalar_ty, b, @intCast(~index)); } } - return result_id; + return try wip.finalize(); } fn indicesToIds(self: *DeclGen, indices: []const u32) ![]IdRef { @@ -2828,26 +3195,21 @@ const DeclGen = struct { return result_id; }, .Vector => { - const child_ty = ty.childType(mod); - const vector_len = ty.vectorLen(mod); - - const constituents = try self.gpa.alloc(IdRef, vector_len); - defer self.gpa.free(constituents); - - for (constituents, 0..) |*constituent, i| { - const lhs_index_id = try self.extractField(child_ty, cmp_lhs_id, @intCast(i)); - const rhs_index_id = try self.extractField(child_ty, cmp_rhs_id, @intCast(i)); - const result_id = try self.cmp(op, Type.bool, child_ty, lhs_index_id, rhs_index_id); - constituent.* = try self.convertToIndirect(Type.bool, result_id); + var wip = try self.elementWise(result_ty); + defer wip.deinit(); + const scalar_ty = ty.scalarType(mod); + for (wip.results, 0..) |*result_id, i| { + const lhs_elem_id = try wip.elementAt(ty, lhs_id, i); + const rhs_elem_id = try wip.elementAt(ty, rhs_id, i); + result_id.* = try self.cmp(op, Type.bool, scalar_ty, lhs_elem_id, rhs_elem_id); } - - return try self.constructArray(result_ty, constituents); + return wip.finalize(); }, else => unreachable, }; const opcode: Opcode = opcode: { - const info = try self.arithmeticTypeInfo(op_ty); + const info = self.arithmeticTypeInfo(op_ty); const signedness = switch (info.class) { .composite_integer => { return self.todo("binary operations for composite integers", .{}); @@ -2865,14 +3227,7 @@ const DeclGen = struct { .neq => .OpLogicalNotEqual, else => unreachable, }, - .strange_integer => sign: { - const op_ty_ref = try self.resolveType(op_ty, .direct); - // Mask operands before performing comparison. - cmp_lhs_id = try self.normalizeInt(op_ty_ref, cmp_lhs_id, info); - cmp_rhs_id = try self.normalizeInt(op_ty_ref, cmp_rhs_id, info); - break :sign info.signedness; - }, - .integer => info.signedness, + .integer, .strange_integer => info.signedness, }; break :opcode switch (signedness) { @@ -2942,50 +3297,64 @@ const DeclGen = struct { const mod = self.module; const src_ty_ref = try self.resolveType(src_ty, .direct); const dst_ty_ref = try self.resolveType(dst_ty, .direct); - if (src_ty_ref == dst_ty_ref) { - return src_id; - } + const src_key = self.spv.cache.lookup(src_ty_ref); + const dst_key = self.spv.cache.lookup(dst_ty_ref); - // TODO: Some more cases are missing here - // See fn bitCast in llvm.zig + const result_id = blk: { + if (src_ty_ref == dst_ty_ref) { + break :blk src_id; + } - if (src_ty.zigTypeTag(mod) == .Int and dst_ty.isPtrAtRuntime(mod)) { - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpConvertUToPtr, .{ - .id_result_type = self.typeId(dst_ty_ref), - .id_result = result_id, - .integer_value = src_id, - }); - return result_id; - } + // TODO: Some more cases are missing here + // See fn bitCast in llvm.zig - // We can only use OpBitcast for specific conversions: between numerical types, and - // between pointers. If the resolved spir-v types fall into this category then emit OpBitcast, - // otherwise use a temporary and perform a pointer cast. - const src_key = self.spv.cache.lookup(src_ty_ref); - const dst_key = self.spv.cache.lookup(dst_ty_ref); + if (src_ty.zigTypeTag(mod) == .Int and dst_ty.isPtrAtRuntime(mod)) { + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpConvertUToPtr, .{ + .id_result_type = self.typeId(dst_ty_ref), + .id_result = result_id, + .integer_value = src_id, + }); + break :blk result_id; + } - if ((src_key.isNumericalType() and dst_key.isNumericalType()) or (src_key == .ptr_type and dst_key == .ptr_type)) { - const result_id = self.spv.allocId(); + // We can only use OpBitcast for specific conversions: between numerical types, and + // between pointers. If the resolved spir-v types fall into this category then emit OpBitcast, + // otherwise use a temporary and perform a pointer cast. + if ((src_key.isNumericalType() and dst_key.isNumericalType()) or (src_key == .ptr_type and dst_key == .ptr_type)) { + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpBitcast, .{ + .id_result_type = self.typeId(dst_ty_ref), + .id_result = result_id, + .operand = src_id, + }); + + break :blk result_id; + } + + const dst_ptr_ty_ref = try self.ptrType(dst_ty, .Function); + + const tmp_id = try self.alloc(src_ty, .{ .storage_class = .Function }); + try self.store(src_ty, tmp_id, src_id, .{}); + const casted_ptr_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpBitcast, .{ - .id_result_type = self.typeId(dst_ty_ref), - .id_result = result_id, - .operand = src_id, + .id_result_type = self.typeId(dst_ptr_ty_ref), + .id_result = casted_ptr_id, + .operand = tmp_id, }); - return result_id; - } + break :blk try self.load(dst_ty, casted_ptr_id, .{}); + }; - const dst_ptr_ty_ref = try self.ptrType(dst_ty, .Function); + // Because strange integers use sign-extended representation, we may need to normalize + // the result here. + // TODO: This detail could cause stuff like @as(*const i1, @ptrCast(&@as(u1, 1))) to break + // should we change the representation of strange integers? + if (dst_ty.zigTypeTag(mod) == .Int) { + const info = self.arithmeticTypeInfo(dst_ty); + return try self.normalize(dst_ty_ref, result_id, info); + } - const tmp_id = try self.alloc(src_ty, .{ .storage_class = .Function }); - try self.store(src_ty, tmp_id, src_id, .{}); - const casted_ptr_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpBitcast, .{ - .id_result_type = self.typeId(dst_ptr_ty_ref), - .id_result = casted_ptr_id, - .operand = tmp_id, - }); - return try self.load(dst_ty, casted_ptr_id, .{}); + return result_id; } fn airBitCast(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { @@ -3004,34 +3373,43 @@ const DeclGen = struct { const operand_id = try self.resolve(ty_op.operand); const src_ty = self.typeOf(ty_op.operand); const dst_ty = self.typeOfIndex(inst); - const src_ty_ref = try self.resolveType(src_ty, .direct); - const dst_ty_ref = try self.resolveType(dst_ty, .direct); - - const src_info = try self.arithmeticTypeInfo(src_ty); - const dst_info = try self.arithmeticTypeInfo(dst_ty); - // While intcast promises that the value already fits, the upper bits of a - // strange integer may contain garbage. Therefore, mask/sign extend it before. - const src_id = try self.normalizeInt(src_ty_ref, operand_id, src_info); + const src_info = self.arithmeticTypeInfo(src_ty); + const dst_info = self.arithmeticTypeInfo(dst_ty); if (src_info.backing_bits == dst_info.backing_bits) { - return src_id; + return operand_id; } - const result_id = self.spv.allocId(); - switch (dst_info.signedness) { - .signed => try self.func.body.emit(self.spv.gpa, .OpSConvert, .{ - .id_result_type = self.typeId(dst_ty_ref), - .id_result = result_id, - .signed_value = src_id, - }), - .unsigned => try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ - .id_result_type = self.typeId(dst_ty_ref), - .id_result = result_id, - .unsigned_value = src_id, - }), + var wip = try self.elementWise(dst_ty); + defer wip.deinit(); + for (wip.results, 0..) |*result_id, i| { + const elem_id = try wip.elementAt(src_ty, operand_id, i); + const value_id = self.spv.allocId(); + switch (dst_info.signedness) { + .signed => try self.func.body.emit(self.spv.gpa, .OpSConvert, .{ + .id_result_type = wip.scalar_ty_id, + .id_result = value_id, + .signed_value = elem_id, + }), + .unsigned => try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ + .id_result_type = wip.scalar_ty_id, + .id_result = value_id, + .unsigned_value = elem_id, + }), + } + + // Make sure to normalize the result if shrinking. + // Because strange ints are sign extended in their backing + // type, we don't need to normalize when growing the type. The + // representation is already the same. + if (dst_info.bits < src_info.bits) { + result_id.* = try self.normalize(wip.scalar_ty_ref, value_id, dst_info); + } else { + result_id.* = value_id; + } } - return result_id; + return try wip.finalize(); } fn intFromPtr(self: *DeclGen, operand_id: IdRef) !IdRef { @@ -3059,7 +3437,7 @@ const DeclGen = struct { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; const operand_ty = self.typeOf(ty_op.operand); const operand_id = try self.resolve(ty_op.operand); - const operand_info = try self.arithmeticTypeInfo(operand_ty); + const operand_info = self.arithmeticTypeInfo(operand_ty); const dest_ty = self.typeOfIndex(inst); const dest_ty_id = try self.resolveTypeId(dest_ty); @@ -3085,7 +3463,7 @@ const DeclGen = struct { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; const operand_id = try self.resolve(ty_op.operand); const dest_ty = self.typeOfIndex(inst); - const dest_info = try self.arithmeticTypeInfo(dest_ty); + const dest_info = self.arithmeticTypeInfo(dest_ty); const dest_ty_id = try self.resolveTypeId(dest_ty); const result_id = self.spv.allocId(); @@ -3104,6 +3482,22 @@ const DeclGen = struct { return result_id; } + fn airIntFromBool(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + if (self.liveness.isUnused(inst)) return null; + + const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op; + const operand_id = try self.resolve(un_op); + const result_ty = self.typeOfIndex(inst); + + var wip = try self.elementWise(result_ty); + defer wip.deinit(); + for (wip.results, 0..) |*result_id, i| { + const elem_id = try wip.elementAt(Type.bool, operand_id, i); + result_id.* = try self.intFromBool(wip.scalar_ty_ref, elem_id); + } + return try wip.finalize(); + } + fn airFloatCast(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { if (self.liveness.isUnused(inst)) return null; @@ -3126,31 +3520,31 @@ const DeclGen = struct { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; const operand_id = try self.resolve(ty_op.operand); const result_ty = self.typeOfIndex(inst); - const result_ty_id = try self.resolveTypeId(result_ty); - const info = try self.arithmeticTypeInfo(result_ty); + const info = self.arithmeticTypeInfo(result_ty); - const result_id = self.spv.allocId(); - switch (info.class) { - .bool => { - try self.func.body.emit(self.spv.gpa, .OpLogicalNot, .{ - .id_result_type = result_ty_id, - .id_result = result_id, - .operand = operand_id, - }); - }, - .float => unreachable, - .composite_integer => unreachable, // TODO - .strange_integer, .integer => { - // Note: strange integer bits will be masked before operations that do not hold under modulo. - try self.func.body.emit(self.spv.gpa, .OpNot, .{ - .id_result_type = result_ty_id, - .id_result = result_id, - .operand = operand_id, - }); - }, + var wip = try self.elementWise(result_ty); + defer wip.deinit(); + + for (0..wip.results.len) |i| { + const args = .{ + .id_result_type = wip.scalar_ty_id, + .id_result = wip.allocId(i), + .operand = try wip.elementAt(result_ty, operand_id, i), + }; + switch (info.class) { + .bool => { + try self.func.body.emit(self.spv.gpa, .OpLogicalNot, args); + }, + .float => unreachable, + .composite_integer => unreachable, // TODO + .strange_integer, .integer => { + // Note: strange integer bits will be masked before operations that do not hold under modulo. + try self.func.body.emit(self.spv.gpa, .OpNot, args); + }, + } } - return result_id; + return try wip.finalize(); } fn airArrayToSlice(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { @@ -3213,7 +3607,6 @@ const DeclGen = struct { const elements: []const Air.Inst.Ref = @ptrCast(self.air.extra[ty_pl.payload..][0..len]); switch (result_ty.zigTypeTag(mod)) { - .Vector => unreachable, // TODO .Struct => { if (mod.typeToPackedStruct(result_ty)) |struct_type| { _ = struct_type; @@ -3261,7 +3654,7 @@ const DeclGen = struct { constituents[0..index], ); }, - .Array => { + .Vector, .Array => { const array_info = result_ty.arrayInfo(mod); const n_elems: usize = @intCast(result_ty.arrayLenIncludingSentinel(mod)); const elem_ids = try self.gpa.alloc(IdRef, n_elems); @@ -3433,6 +3826,28 @@ const DeclGen = struct { return try self.load(elem_ty, elem_ptr_id, .{ .is_volatile = ptr_ty.isVolatilePtr(mod) }); } + fn airVectorStoreElem(self: *DeclGen, inst: Air.Inst.Index) !void { + const mod = self.module; + const data = self.air.instructions.items(.data)[@intFromEnum(inst)].vector_store_elem; + const extra = self.air.extraData(Air.Bin, data.payload).data; + + const vector_ptr_ty = self.typeOf(data.vector_ptr); + const vector_ty = vector_ptr_ty.childType(mod); + const scalar_ty = vector_ty.scalarType(mod); + + const storage_class = spvStorageClass(vector_ptr_ty.ptrAddressSpace(mod)); + const scalar_ptr_ty_ref = try self.ptrType(scalar_ty, storage_class); + + const vector_ptr = try self.resolve(data.vector_ptr); + const index = try self.resolve(extra.lhs); + const operand = try self.resolve(extra.rhs); + + const elem_ptr_id = try self.accessChainId(scalar_ptr_ty_ref, vector_ptr, &.{index}); + try self.store(scalar_ty, elem_ptr_id, operand, .{ + .is_volatile = vector_ptr_ty.isVolatilePtr(mod), + }); + } + fn airSetUnionTag(self: *DeclGen, inst: Air.Inst.Index) !void { const mod = self.module; const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; @@ -4424,20 +4839,24 @@ const DeclGen = struct { return try self.constructStruct(err_union_ty, &types, &members); } - fn airIsNull(self: *DeclGen, inst: Air.Inst.Index, pred: enum { is_null, is_non_null }) !?IdRef { + fn airIsNull(self: *DeclGen, inst: Air.Inst.Index, is_pointer: bool, pred: enum { is_null, is_non_null }) !?IdRef { if (self.liveness.isUnused(inst)) return null; const mod = self.module; const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op; const operand_id = try self.resolve(un_op); - const optional_ty = self.typeOf(un_op); - + const operand_ty = self.typeOf(un_op); + const optional_ty = if (is_pointer) operand_ty.childType(mod) else operand_ty; const payload_ty = optional_ty.optionalChild(mod); const bool_ty_ref = try self.resolveType(Type.bool, .direct); if (optional_ty.optionalReprIsPayload(mod)) { // Pointer payload represents nullability: pointer or slice. + const loaded_id = if (is_pointer) + try self.load(optional_ty, operand_id, .{}) + else + operand_id; const ptr_ty = if (payload_ty.isSlice(mod)) payload_ty.slicePtrFieldType(mod) @@ -4445,9 +4864,9 @@ const DeclGen = struct { payload_ty; const ptr_id = if (payload_ty.isSlice(mod)) - try self.extractField(ptr_ty, operand_id, 0) + try self.extractField(ptr_ty, loaded_id, 0) else - operand_id; + loaded_id; const payload_ty_ref = try self.resolveType(ptr_ty, .direct); const null_id = try self.spv.constNull(payload_ty_ref); @@ -4458,13 +4877,26 @@ const DeclGen = struct { return try self.cmp(op, Type.bool, ptr_ty, ptr_id, null_id); } - const is_non_null_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod)) - try self.extractField(Type.bool, operand_id, 1) - else - // Optional representation is bool indicating whether the optional is set - // Optionals with no payload are represented as an (indirect) bool, so convert - // it back to the direct bool here. - try self.convertToDirect(Type.bool, operand_id); + const is_non_null_id = blk: { + if (is_pointer) { + if (payload_ty.hasRuntimeBitsIgnoreComptime(mod)) { + const storage_class = spvStorageClass(operand_ty.ptrAddressSpace(mod)); + const bool_ptr_ty = try self.ptrType(Type.bool, storage_class); + const tag_ptr_id = try self.accessChain(bool_ptr_ty, operand_id, &.{1}); + break :blk try self.load(Type.bool, tag_ptr_id, .{}); + } + + break :blk try self.load(Type.bool, operand_id, .{}); + } + + break :blk if (payload_ty.hasRuntimeBitsIgnoreComptime(mod)) + try self.extractField(Type.bool, operand_id, 1) + else + // Optional representation is bool indicating whether the optional is set + // Optionals with no payload are represented as an (indirect) bool, so convert + // it back to the direct bool here. + try self.convertToDirect(Type.bool, operand_id); + }; return switch (pred) { .is_null => blk: { @@ -4535,6 +4967,32 @@ const DeclGen = struct { return try self.extractField(payload_ty, operand_id, 0); } + fn airUnwrapOptionalPtr(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + if (self.liveness.isUnused(inst)) return null; + + const mod = self.module; + const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; + const operand_id = try self.resolve(ty_op.operand); + const operand_ty = self.typeOf(ty_op.operand); + const optional_ty = operand_ty.childType(mod); + const payload_ty = optional_ty.optionalChild(mod); + const result_ty = self.typeOfIndex(inst); + const result_ty_ref = try self.resolveType(result_ty, .direct); + + if (!payload_ty.hasRuntimeBitsIgnoreComptime(mod)) { + // There is no payload, but we still need to return a valid pointer. + // We can just return anything here, so just return a pointer to the operand. + return try self.bitCast(result_ty, operand_ty, operand_id); + } + + if (optional_ty.optionalReprIsPayload(mod)) { + // They are the same value. + return try self.bitCast(result_ty, operand_ty, operand_id); + } + + return try self.accessChain(result_ty_ref, operand_id, &.{0}); + } + fn airWrapOptional(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { if (self.liveness.isUnused(inst)) return null; |
