aboutsummaryrefslogtreecommitdiff
path: root/src/codegen/spirv.zig
diff options
context:
space:
mode:
authorRobin Voetter <robin@voetter.nl>2024-02-05 09:24:49 +0100
committerGitHub <noreply@github.com>2024-02-05 09:24:49 +0100
commit7634a115c50ef66edbdd5644c4ba310eb31e6343 (patch)
treeb8be56f0db16691e2939e87bac1222ba2c9fd4a8 /src/codegen/spirv.zig
parentaebf20cc9a0469a778d6276d3797525660746e91 (diff)
parent25111061504a652bfed45b26252349f363b109af (diff)
downloadzig-7634a115c50ef66edbdd5644c4ba310eb31e6343.tar.gz
zig-7634a115c50ef66edbdd5644c4ba310eb31e6343.zip
Merge pull request #18580 from Snektron/spirv-more-vectors
spirv: more vector operations
Diffstat (limited to 'src/codegen/spirv.zig')
-rw-r--r--src/codegen/spirv.zig1324
1 files changed, 891 insertions, 433 deletions
diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig
index 6c058308df..a499f3d8ed 100644
--- a/src/codegen/spirv.zig
+++ b/src/codegen/spirv.zig
@@ -373,8 +373,9 @@ const DeclGen = struct {
/// For `composite_integer` this is 0 (TODO)
backing_bits: u16,
- /// Whether the type is a vector.
- is_vector: bool,
+ /// Null if this type is a scalar, or the length
+ /// of the vector otherwise.
+ vector_len: ?u32,
/// Whether the inner type is signed. Only relevant for integers.
signedness: std.builtin.Signedness,
@@ -597,32 +598,37 @@ const DeclGen = struct {
return self.backingIntBits(ty) == null;
}
- fn arithmeticTypeInfo(self: *DeclGen, ty: Type) !ArithmeticTypeInfo {
+ fn arithmeticTypeInfo(self: *DeclGen, ty: Type) ArithmeticTypeInfo {
const mod = self.module;
const target = self.getTarget();
- return switch (ty.zigTypeTag(mod)) {
+ var scalar_ty = ty.scalarType(mod);
+ if (scalar_ty.zigTypeTag(mod) == .Enum) {
+ scalar_ty = scalar_ty.intTagType(mod);
+ }
+ const vector_len = if (ty.isVector(mod)) ty.vectorLen(mod) else null;
+ return switch (scalar_ty.zigTypeTag(mod)) {
.Bool => ArithmeticTypeInfo{
.bits = 1, // Doesn't matter for this class.
.backing_bits = self.backingIntBits(1).?,
- .is_vector = false,
+ .vector_len = vector_len,
.signedness = .unsigned, // Technically, but doesn't matter for this class.
.class = .bool,
},
.Float => ArithmeticTypeInfo{
- .bits = ty.floatBits(target),
- .backing_bits = ty.floatBits(target), // TODO: F80?
- .is_vector = false,
+ .bits = scalar_ty.floatBits(target),
+ .backing_bits = scalar_ty.floatBits(target), // TODO: F80?
+ .vector_len = vector_len,
.signedness = .signed, // Technically, but doesn't matter for this class.
.class = .float,
},
.Int => blk: {
- const int_info = ty.intInfo(mod);
+ const int_info = scalar_ty.intInfo(mod);
// TODO: Maybe it's useful to also return this value.
const maybe_backing_bits = self.backingIntBits(int_info.bits);
break :blk ArithmeticTypeInfo{
.bits = int_info.bits,
.backing_bits = maybe_backing_bits orelse 0,
- .is_vector = false,
+ .vector_len = vector_len,
.signedness = int_info.signedness,
.class = if (maybe_backing_bits) |backing_bits|
if (backing_bits == int_info.bits)
@@ -633,22 +639,9 @@ const DeclGen = struct {
.composite_integer,
};
},
- .Enum => return self.arithmeticTypeInfo(ty.intTagType(mod)),
- // As of yet, there is no vector support in the self-hosted compiler.
- .Vector => blk: {
- const child_type = ty.childType(mod);
- const child_ty_info = try self.arithmeticTypeInfo(child_type);
- break :blk ArithmeticTypeInfo{
- .bits = child_ty_info.bits,
- .backing_bits = child_ty_info.backing_bits,
- .is_vector = true,
- .signedness = child_ty_info.signedness,
- .class = child_ty_info.class,
- };
- },
- // TODO: For which types is this the case?
- // else => self.todo("implement arithmeticTypeInfo for {}", .{ty.fmt(self.module)}),
- else => unreachable,
+ .Enum => unreachable,
+ .Vector => unreachable,
+ else => unreachable, // Unhandled arithmetic type
};
}
@@ -685,6 +678,18 @@ const DeclGen = struct {
}
}
+ /// Emits a float constant
+ fn constFloat(self: *DeclGen, ty_ref: CacheRef, value: f128) !IdRef {
+ const ty = self.spv.cache.lookup(ty_ref).float_type;
+ return switch (ty.bits) {
+ 16 => try self.spv.resolveId(.{ .float = .{ .ty = ty_ref, .value = .{ .float16 = @floatCast(value) } } }),
+ 32 => try self.spv.resolveId(.{ .float = .{ .ty = ty_ref, .value = .{ .float32 = @floatCast(value) } } }),
+ 64 => try self.spv.resolveId(.{ .float = .{ .ty = ty_ref, .value = .{ .float64 = @floatCast(value) } } }),
+ 80, 128 => unreachable, // TODO
+ else => unreachable,
+ };
+ }
+
/// Construct a struct at runtime.
/// ty must be a struct type.
/// Constituents should be in `indirect` representation (as the elements of a struct should be).
@@ -1760,6 +1765,92 @@ const DeclGen = struct {
return union_layout;
}
+ /// This structure is used as helper for element-wise operations. It is intended
+ /// to be used with both vectors and single elements.
+ const WipElementWise = struct {
+ dg: *DeclGen,
+ result_ty: Type,
+ /// Always in direct representation.
+ result_ty_ref: CacheRef,
+ scalar_ty: Type,
+ /// Always in direct representation.
+ scalar_ty_ref: CacheRef,
+ scalar_ty_id: IdRef,
+ /// True if the input is actually a vector type.
+ is_vector: bool,
+ /// The element-wise operation should fill these results before calling finalize().
+ /// These should all be in **direct** representation! `finalize()` will convert
+ /// them to indirect if required.
+ results: []IdRef,
+
+ fn deinit(wip: *WipElementWise) void {
+ wip.dg.gpa.free(wip.results);
+ }
+
+ /// Utility function to extract the element at a particular index in an
+ /// input vector. This type is expected to be a vector if `wip.is_vector`, and
+ /// a scalar otherwise.
+ fn elementAt(wip: WipElementWise, ty: Type, value: IdRef, index: usize) !IdRef {
+ const mod = wip.dg.module;
+ if (wip.is_vector) {
+ assert(ty.isVector(mod));
+ return try wip.dg.extractField(ty.childType(mod), value, @intCast(index));
+ } else {
+ assert(!ty.isVector(mod));
+ assert(index == 0);
+ return value;
+ }
+ }
+
+ /// Turns the results of this WipElementWise into a result. This can either
+ /// be a vector or single element, depending on `result_ty`.
+ /// After calling this function, this WIP is no longer usable.
+ /// Results is in `direct` representation.
+ fn finalize(wip: *WipElementWise) !IdRef {
+ if (wip.is_vector) {
+ // Convert all the constituents to indirect, as required for the array.
+ for (wip.results) |*result| {
+ result.* = try wip.dg.convertToIndirect(wip.scalar_ty, result.*);
+ }
+ return try wip.dg.constructArray(wip.result_ty, wip.results);
+ } else {
+ return wip.results[0];
+ }
+ }
+
+ /// Allocate a result id at a particular index, and return it.
+ fn allocId(wip: *WipElementWise, index: usize) IdRef {
+ assert(wip.is_vector or index == 0);
+ wip.results[index] = wip.dg.spv.allocId();
+ return wip.results[index];
+ }
+ };
+
+ /// Create a new element-wise operation.
+ fn elementWise(self: *DeclGen, result_ty: Type) !WipElementWise {
+ const mod = self.module;
+ // For now, this operation also reasons in terms of `.direct` representation.
+ const result_ty_ref = try self.resolveType(result_ty, .direct);
+ const is_vector = result_ty.isVector(mod);
+ const num_results = if (is_vector) result_ty.vectorLen(mod) else 1;
+ const results = try self.gpa.alloc(IdRef, num_results);
+ for (results) |*result| result.* = undefined;
+
+ const scalar_ty = result_ty.scalarType(mod);
+ const scalar_ty_ref = try self.resolveType(scalar_ty, .direct);
+
+ return .{
+ .dg = self,
+ .result_ty = result_ty,
+ .result_ty_ref = result_ty_ref,
+ .scalar_ty = scalar_ty,
+ .scalar_ty_ref = scalar_ty_ref,
+ .scalar_ty_id = self.typeId(scalar_ty_ref),
+ .is_vector = is_vector,
+ .results = results,
+ };
+ }
+
/// The SPIR-V backend is not yet advanced enough to support the std testing infrastructure.
/// In order to be able to run tests, we "temporarily" lower test kernels into separate entry-
/// points. The test executor will then be able to invoke these to run the tests.
@@ -2081,25 +2172,31 @@ const DeclGen = struct {
const air_tags = self.air.instructions.items(.tag);
const maybe_result_id: ?IdRef = switch (air_tags[@intFromEnum(inst)]) {
// zig fmt: off
- .add, .add_wrap => try self.airArithOp(inst, .OpFAdd, .OpIAdd, .OpIAdd, true),
- .sub, .sub_wrap => try self.airArithOp(inst, .OpFSub, .OpISub, .OpISub, true),
- .mul, .mul_wrap => try self.airArithOp(inst, .OpFMul, .OpIMul, .OpIMul, true),
+ .add, .add_wrap, .add_optimized => try self.airArithOp(inst, .OpFAdd, .OpIAdd, .OpIAdd),
+ .sub, .sub_wrap, .sub_optimized => try self.airArithOp(inst, .OpFSub, .OpISub, .OpISub),
+ .mul, .mul_wrap, .mul_optimized => try self.airArithOp(inst, .OpFMul, .OpIMul, .OpIMul),
+
+ .abs => try self.airAbs(inst),
.div_float,
.div_float_optimized,
// TODO: Check that this is the right operation.
.div_trunc,
.div_trunc_optimized,
- => try self.airArithOp(inst, .OpFDiv, .OpSDiv, .OpUDiv, false),
+ => try self.airArithOp(inst, .OpFDiv, .OpSDiv, .OpUDiv),
// TODO: Check if this is the right operation
- // TODO: Make airArithOp for rem not emit a mask for the LHS.
.rem,
.rem_optimized,
- => try self.airArithOp(inst, .OpFRem, .OpSRem, .OpSRem, false),
+ => try self.airArithOp(inst, .OpFRem, .OpSRem, .OpSRem),
.add_with_overflow => try self.airAddSubOverflow(inst, .OpIAdd, .OpULessThan, .OpSLessThan),
.sub_with_overflow => try self.airAddSubOverflow(inst, .OpISub, .OpUGreaterThan, .OpSGreaterThan),
+ .shl_with_overflow => try self.airShlOverflow(inst),
+ .mul_add => try self.airMulAdd(inst),
+
+ .splat => try self.airSplat(inst),
+ .reduce, .reduce_optimized => try self.airReduce(inst),
.shuffle => try self.airShuffle(inst),
.ptr_add => try self.airPtrAdd(inst),
@@ -2111,7 +2208,8 @@ const DeclGen = struct {
.bool_and => try self.airBinOpSimple(inst, .OpLogicalAnd),
.bool_or => try self.airBinOpSimple(inst, .OpLogicalOr),
- .shl => try self.airShift(inst, .OpShiftLeftLogical),
+ .shl, .shl_exact => try self.airShift(inst, .OpShiftLeftLogical, .OpShiftLeftLogical),
+ .shr, .shr_exact => try self.airShift(inst, .OpShiftRightLogical, .OpShiftRightArithmetic),
.min => try self.airMinMax(inst, .lt),
.max => try self.airMinMax(inst, .gt),
@@ -2121,6 +2219,7 @@ const DeclGen = struct {
.int_from_ptr => try self.airIntFromPtr(inst),
.float_from_int => try self.airFloatFromInt(inst),
.int_from_float => try self.airIntFromFloat(inst),
+ .int_from_bool => try self.airIntFromBool(inst),
.fpext, .fptrunc => try self.airFloatCast(inst),
.not => try self.airNot(inst),
@@ -2137,6 +2236,8 @@ const DeclGen = struct {
.ptr_elem_val => try self.airPtrElemVal(inst),
.array_elem_val => try self.airArrayElemVal(inst),
+ .vector_store_elem => return self.airVectorStoreElem(inst),
+
.set_union_tag => return self.airSetUnionTag(inst),
.get_union_tag => try self.airGetUnionTag(inst),
.union_init => try self.airUnionInit(inst),
@@ -2189,13 +2290,16 @@ const DeclGen = struct {
.wrap_errunion_err => try self.airWrapErrUnionErr(inst),
.wrap_errunion_payload => try self.airWrapErrUnionPayload(inst),
- .is_null => try self.airIsNull(inst, .is_null),
- .is_non_null => try self.airIsNull(inst, .is_non_null),
- .is_err => try self.airIsErr(inst, .is_err),
- .is_non_err => try self.airIsErr(inst, .is_non_err),
+ .is_null => try self.airIsNull(inst, false, .is_null),
+ .is_non_null => try self.airIsNull(inst, false, .is_non_null),
+ .is_null_ptr => try self.airIsNull(inst, true, .is_null),
+ .is_non_null_ptr => try self.airIsNull(inst, true, .is_non_null),
+ .is_err => try self.airIsErr(inst, .is_err),
+ .is_non_err => try self.airIsErr(inst, .is_non_err),
- .optional_payload => try self.airUnwrapOptional(inst),
- .wrap_optional => try self.airWrapOptional(inst),
+ .optional_payload => try self.airUnwrapOptional(inst),
+ .optional_payload_ptr => try self.airUnwrapOptionalPtr(inst),
+ .wrap_optional => try self.airWrapOptional(inst),
.assembly => try self.airAssembly(inst),
@@ -2213,34 +2317,17 @@ const DeclGen = struct {
}
fn binOpSimple(self: *DeclGen, ty: Type, lhs_id: IdRef, rhs_id: IdRef, comptime opcode: Opcode) !IdRef {
- const mod = self.module;
-
- if (ty.isVector(mod)) {
- const child_ty = ty.childType(mod);
- const vector_len = ty.vectorLen(mod);
-
- const constituents = try self.gpa.alloc(IdRef, vector_len);
- defer self.gpa.free(constituents);
-
- for (constituents, 0..) |*constituent, i| {
- const lhs_index_id = try self.extractField(child_ty, lhs_id, @intCast(i));
- const rhs_index_id = try self.extractField(child_ty, rhs_id, @intCast(i));
- const result_id = try self.binOpSimple(child_ty, lhs_index_id, rhs_index_id, opcode);
- constituent.* = try self.convertToIndirect(child_ty, result_id);
- }
-
- return try self.constructArray(ty, constituents);
+ var wip = try self.elementWise(ty);
+ defer wip.deinit();
+ for (0..wip.results.len) |i| {
+ try self.func.body.emit(self.spv.gpa, opcode, .{
+ .id_result_type = wip.scalar_ty_id,
+ .id_result = wip.allocId(i),
+ .operand_1 = try wip.elementAt(ty, lhs_id, i),
+ .operand_2 = try wip.elementAt(ty, rhs_id, i),
+ });
}
-
- const result_id = self.spv.allocId();
- const result_type_id = try self.resolveTypeId(ty);
- try self.func.body.emit(self.spv.gpa, opcode, .{
- .id_result_type = result_type_id,
- .id_result = result_id,
- .operand_1 = lhs_id,
- .operand_2 = rhs_id,
- });
- return result_id;
+ return try wip.finalize();
}
fn airBinOpSimple(self: *DeclGen, inst: Air.Inst.Index, comptime opcode: Opcode) !?IdRef {
@@ -2254,29 +2341,59 @@ const DeclGen = struct {
return try self.binOpSimple(ty, lhs_id, rhs_id, opcode);
}
- fn airShift(self: *DeclGen, inst: Air.Inst.Index, comptime opcode: Opcode) !?IdRef {
+ fn airShift(self: *DeclGen, inst: Air.Inst.Index, comptime unsigned: Opcode, comptime signed: Opcode) !?IdRef {
if (self.liveness.isUnused(inst)) return null;
+ const mod = self.module;
const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
const lhs_id = try self.resolve(bin_op.lhs);
const rhs_id = try self.resolve(bin_op.rhs);
- const result_type_id = try self.resolveTypeId(self.typeOfIndex(inst));
- // the shift and the base must be the same type in SPIR-V, but in Zig the shift is a smaller int.
- const shift_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
- .id_result_type = result_type_id,
- .id_result = shift_id,
- .unsigned_value = rhs_id,
- });
+ const result_ty = self.typeOfIndex(inst);
+ const shift_ty = self.typeOf(bin_op.rhs);
+ const scalar_shift_ty_ref = try self.resolveType(shift_ty.scalarType(mod), .direct);
- const result_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, opcode, .{
- .id_result_type = result_type_id,
- .id_result = result_id,
- .base = lhs_id,
- .shift = shift_id,
- });
- return result_id;
+ const info = self.arithmeticTypeInfo(result_ty);
+ switch (info.class) {
+ .composite_integer => return self.todo("shift ops for composite integers", .{}),
+ .integer, .strange_integer => {},
+ .float, .bool => unreachable,
+ }
+
+ var wip = try self.elementWise(result_ty);
+ defer wip.deinit();
+ for (wip.results, 0..) |*result_id, i| {
+ const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i);
+ const rhs_elem_id = try wip.elementAt(shift_ty, rhs_id, i);
+
+ // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that,
+ // so just manually upcast it if required.
+ const shift_id = if (scalar_shift_ty_ref != wip.scalar_ty_ref) blk: {
+ const shift_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
+ .id_result_type = wip.scalar_ty_id,
+ .id_result = shift_id,
+ .unsigned_value = rhs_elem_id,
+ });
+ break :blk shift_id;
+ } else rhs_elem_id;
+
+ const value_id = self.spv.allocId();
+ const args = .{
+ .id_result_type = wip.scalar_ty_id,
+ .id_result = value_id,
+ .base = lhs_elem_id,
+ .shift = shift_id,
+ };
+
+ if (result_ty.isSignedInt(mod)) {
+ try self.func.body.emit(self.spv.gpa, signed, args);
+ } else {
+ try self.func.body.emit(self.spv.gpa, unsigned, args);
+ }
+
+ result_id.* = try self.normalize(wip.scalar_ty_ref, value_id, info);
+ }
+ return try wip.finalize();
}
fn airMinMax(self: *DeclGen, inst: Air.Inst.Index, op: std.math.CompareOperator) !?IdRef {
@@ -2286,88 +2403,102 @@ const DeclGen = struct {
const lhs_id = try self.resolve(bin_op.lhs);
const rhs_id = try self.resolve(bin_op.rhs);
const result_ty = self.typeOfIndex(inst);
- const result_ty_ref = try self.resolveType(result_ty, .direct);
-
- const info = try self.arithmeticTypeInfo(result_ty);
- // TODO: Use fmin for OpenCL
- const cmp_id = try self.cmp(op, Type.bool, result_ty, lhs_id, rhs_id);
- const selection_id = switch (info.class) {
- .float => blk: {
- // cmp uses OpFOrd. When we have 0 [<>] nan this returns false,
- // but we want it to pick lhs. Therefore we also have to check if
- // rhs is nan. We don't need to care about the result when both
- // are nan.
- const rhs_is_nan_id = self.spv.allocId();
- const bool_ty_ref = try self.resolveType(Type.bool, .direct);
- try self.func.body.emit(self.spv.gpa, .OpIsNan, .{
- .id_result_type = self.typeId(bool_ty_ref),
- .id_result = rhs_is_nan_id,
- .x = rhs_id,
- });
- const float_cmp_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{
- .id_result_type = self.typeId(bool_ty_ref),
- .id_result = float_cmp_id,
- .operand_1 = cmp_id,
- .operand_2 = rhs_is_nan_id,
- });
- break :blk float_cmp_id;
- },
- else => cmp_id,
- };
- const result_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, .OpSelect, .{
- .id_result_type = self.typeId(result_ty_ref),
- .id_result = result_id,
- .condition = selection_id,
- .object_1 = lhs_id,
- .object_2 = rhs_id,
- });
- return result_id;
- }
+ return try self.minMax(result_ty, op, lhs_id, rhs_id);
+ }
+
+ fn minMax(self: *DeclGen, result_ty: Type, op: std.math.CompareOperator, lhs_id: IdRef, rhs_id: IdRef) !IdRef {
+ const info = self.arithmeticTypeInfo(result_ty);
+
+ var wip = try self.elementWise(result_ty);
+ defer wip.deinit();
+ for (wip.results, 0..) |*result_id, i| {
+ const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i);
+ const rhs_elem_id = try wip.elementAt(result_ty, rhs_id, i);
+
+ // TODO: Use fmin for OpenCL
+ const cmp_id = try self.cmp(op, Type.bool, wip.scalar_ty, lhs_elem_id, rhs_elem_id);
+ const selection_id = switch (info.class) {
+ .float => blk: {
+ // cmp uses OpFOrd. When we have 0 [<>] nan this returns false,
+ // but we want it to pick lhs. Therefore we also have to check if
+ // rhs is nan. We don't need to care about the result when both
+ // are nan.
+ const rhs_is_nan_id = self.spv.allocId();
+ const bool_ty_ref = try self.resolveType(Type.bool, .direct);
+ try self.func.body.emit(self.spv.gpa, .OpIsNan, .{
+ .id_result_type = self.typeId(bool_ty_ref),
+ .id_result = rhs_is_nan_id,
+ .x = rhs_elem_id,
+ });
+ const float_cmp_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{
+ .id_result_type = self.typeId(bool_ty_ref),
+ .id_result = float_cmp_id,
+ .operand_1 = cmp_id,
+ .operand_2 = rhs_is_nan_id,
+ });
+ break :blk float_cmp_id;
+ },
+ else => cmp_id,
+ };
- /// This function canonicalizes a "strange" integer value:
- /// For unsigned integers, the value is masked so that only the relevant bits can contain
- /// non-zeros.
- /// For signed integers, the value is also sign extended.
- fn normalizeInt(self: *DeclGen, ty_ref: CacheRef, value_id: IdRef, info: ArithmeticTypeInfo) !IdRef {
- assert(info.class != .composite_integer); // TODO
- if (info.bits == info.backing_bits) {
- return value_id;
+ result_id.* = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpSelect, .{
+ .id_result_type = wip.scalar_ty_id,
+ .id_result = result_id.*,
+ .condition = selection_id,
+ .object_1 = lhs_elem_id,
+ .object_2 = rhs_elem_id,
+ });
}
-
- switch (info.signedness) {
- .unsigned => {
- const mask_value = if (info.bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1;
- const result_id = self.spv.allocId();
- const mask_id = try self.constInt(ty_ref, mask_value);
- try self.func.body.emit(self.spv.gpa, .OpBitwiseAnd, .{
- .id_result_type = self.typeId(ty_ref),
- .id_result = result_id,
- .operand_1 = value_id,
- .operand_2 = mask_id,
- });
- return result_id;
- },
- .signed => {
- // Shift left and right so that we can copy the sight bit that way.
- const shift_amt_id = try self.constInt(ty_ref, info.backing_bits - info.bits);
- const left_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{
- .id_result_type = self.typeId(ty_ref),
- .id_result = left_id,
- .base = value_id,
- .shift = shift_amt_id,
- });
- const right_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{
- .id_result_type = self.typeId(ty_ref),
- .id_result = right_id,
- .base = left_id,
- .shift = shift_amt_id,
- });
- return right_id;
+ return wip.finalize();
+ }
+
+ /// This function normalizes values to a canonical representation
+ /// after some arithmetic operation. This mostly consists of wrapping
+ /// behavior for strange integers:
+ /// - Unsigned integers are bitwise masked with a mask that only passes
+ /// the valid bits through.
+ /// - Signed integers are also sign extended if they are negative.
+ /// All other values are returned unmodified (this makes strange integer
+ /// wrapping easier to use in generic operations).
+ fn normalize(self: *DeclGen, ty_ref: CacheRef, value_id: IdRef, info: ArithmeticTypeInfo) !IdRef {
+ switch (info.class) {
+ .integer, .bool, .float => return value_id,
+ .composite_integer => unreachable, // TODO
+ .strange_integer => switch (info.signedness) {
+ .unsigned => {
+ const mask_value = if (info.bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1;
+ const result_id = self.spv.allocId();
+ const mask_id = try self.constInt(ty_ref, mask_value);
+ try self.func.body.emit(self.spv.gpa, .OpBitwiseAnd, .{
+ .id_result_type = self.typeId(ty_ref),
+ .id_result = result_id,
+ .operand_1 = value_id,
+ .operand_2 = mask_id,
+ });
+ return result_id;
+ },
+ .signed => {
+ // Shift left and right so that we can copy the sight bit that way.
+ const shift_amt_id = try self.constInt(ty_ref, info.backing_bits - info.bits);
+ const left_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{
+ .id_result_type = self.typeId(ty_ref),
+ .id_result = left_id,
+ .base = value_id,
+ .shift = shift_amt_id,
+ });
+ const right_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{
+ .id_result_type = self.typeId(ty_ref),
+ .id_result = right_id,
+ .base = left_id,
+ .shift = shift_amt_id,
+ });
+ return right_id;
+ },
},
}
}
@@ -2378,8 +2509,6 @@ const DeclGen = struct {
comptime fop: Opcode,
comptime sop: Opcode,
comptime uop: Opcode,
- /// true if this operation holds under modular arithmetic.
- comptime modular: bool,
) !?IdRef {
if (self.liveness.isUnused(inst)) return null;
@@ -2393,60 +2522,27 @@ const DeclGen = struct {
assert(self.typeOf(bin_op.lhs).eql(ty, self.module));
assert(self.typeOf(bin_op.rhs).eql(ty, self.module));
- return try self.arithOp(ty, lhs_id, rhs_id, fop, sop, uop, modular);
+ return try self.arithOp(ty, lhs_id, rhs_id, fop, sop, uop);
}
fn arithOp(
self: *DeclGen,
ty: Type,
- lhs_id_: IdRef,
- rhs_id_: IdRef,
+ lhs_id: IdRef,
+ rhs_id: IdRef,
comptime fop: Opcode,
comptime sop: Opcode,
comptime uop: Opcode,
- /// true if this operation holds under modular arithmetic.
- comptime modular: bool,
) !IdRef {
- var rhs_id = rhs_id_;
- var lhs_id = lhs_id_;
-
- const mod = self.module;
- const result_ty_ref = try self.resolveType(ty, .direct);
-
- if (ty.isVector(mod)) {
- const child_ty = ty.childType(mod);
- const vector_len = ty.vectorLen(mod);
- const constituents = try self.gpa.alloc(IdRef, vector_len);
- defer self.gpa.free(constituents);
-
- for (constituents, 0..) |*constituent, i| {
- const lhs_index_id = try self.extractField(child_ty, lhs_id, @intCast(i));
- const rhs_index_id = try self.extractField(child_ty, rhs_id, @intCast(i));
- constituent.* = try self.arithOp(child_ty, lhs_index_id, rhs_index_id, fop, sop, uop, modular);
- }
-
- return self.constructArray(ty, constituents);
- }
-
// Binary operations are generally applicable to both scalar and vector operations
// in SPIR-V, but int and float versions of operations require different opcodes.
- const info = try self.arithmeticTypeInfo(ty);
+ const info = self.arithmeticTypeInfo(ty);
const opcode_index: usize = switch (info.class) {
.composite_integer => {
return self.todo("binary operations for composite integers", .{});
},
- .strange_integer => blk: {
- if (!modular) {
- lhs_id = try self.normalizeInt(result_ty_ref, lhs_id, info);
- rhs_id = try self.normalizeInt(result_ty_ref, rhs_id, info);
- }
- break :blk switch (info.signedness) {
- .signed => @as(usize, 1),
- .unsigned => @as(usize, 2),
- };
- },
- .integer => switch (info.signedness) {
+ .integer, .strange_integer => switch (info.signedness) {
.signed => @as(usize, 1),
.unsigned => @as(usize, 2),
},
@@ -2454,24 +2550,91 @@ const DeclGen = struct {
.bool => unreachable,
};
- const result_id = self.spv.allocId();
- const operands = .{
- .id_result_type = self.typeId(result_ty_ref),
- .id_result = result_id,
- .operand_1 = lhs_id,
- .operand_2 = rhs_id,
- };
+ var wip = try self.elementWise(ty);
+ defer wip.deinit();
+ for (wip.results, 0..) |*result_id, i| {
+ const lhs_elem_id = try wip.elementAt(ty, lhs_id, i);
+ const rhs_elem_id = try wip.elementAt(ty, rhs_id, i);
+
+ const value_id = self.spv.allocId();
+ const operands = .{
+ .id_result_type = wip.scalar_ty_id,
+ .id_result = value_id,
+ .operand_1 = lhs_elem_id,
+ .operand_2 = rhs_elem_id,
+ };
- switch (opcode_index) {
- 0 => try self.func.body.emit(self.spv.gpa, fop, operands),
- 1 => try self.func.body.emit(self.spv.gpa, sop, operands),
- 2 => try self.func.body.emit(self.spv.gpa, uop, operands),
- else => unreachable,
+ switch (opcode_index) {
+ 0 => try self.func.body.emit(self.spv.gpa, fop, operands),
+ 1 => try self.func.body.emit(self.spv.gpa, sop, operands),
+ 2 => try self.func.body.emit(self.spv.gpa, uop, operands),
+ else => unreachable,
+ }
+
+ // TODO: Trap on overflow? Probably going to be annoying.
+ // TODO: Look into SPV_KHR_no_integer_wrap_decoration which provides NoSignedWrap/NoUnsignedWrap.
+ result_id.* = try self.normalize(wip.scalar_ty_ref, value_id, info);
}
- // TODO: Trap on overflow? Probably going to be annoying.
- // TODO: Look into SPV_KHR_no_integer_wrap_decoration which provides NoSignedWrap/NoUnsignedWrap.
- return result_id;
+ return try wip.finalize();
+ }
+
+ fn airAbs(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+ if (self.liveness.isUnused(inst)) return null;
+
+ const mod = self.module;
+ const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
+ const operand_id = try self.resolve(ty_op.operand);
+ // Note: operand_ty may be signed, while ty is always unsigned!
+ const operand_ty = self.typeOf(ty_op.operand);
+ const ty = self.typeOfIndex(inst);
+ const info = self.arithmeticTypeInfo(ty);
+ const operand_scalar_ty = operand_ty.scalarType(mod);
+ const operand_scalar_ty_ref = try self.resolveType(operand_scalar_ty, .direct);
+
+ var wip = try self.elementWise(ty);
+ defer wip.deinit();
+
+ const zero_id = switch (info.class) {
+ .float => try self.constFloat(operand_scalar_ty_ref, 0),
+ .integer, .strange_integer => try self.constInt(operand_scalar_ty_ref, 0),
+ .composite_integer => unreachable, // TODO
+ .bool => unreachable,
+ };
+ for (wip.results, 0..) |*result_id, i| {
+ const elem_id = try wip.elementAt(operand_ty, operand_id, i);
+ // Idk why spir-v doesn't have a dedicated abs() instruction in the base
+ // instruction set. For now we're just going to negate and check to avoid
+ // importing the extinst.
+ // TODO: Make this a call to compiler rt / ext inst
+ const neg_id = self.spv.allocId();
+ const args = .{
+ .id_result_type = self.typeId(operand_scalar_ty_ref),
+ .id_result = neg_id,
+ .operand_1 = zero_id,
+ .operand_2 = elem_id,
+ };
+ switch (info.class) {
+ .float => try self.func.body.emit(self.spv.gpa, .OpFSub, args),
+ .integer, .strange_integer => try self.func.body.emit(self.spv.gpa, .OpISub, args),
+ .composite_integer => unreachable, // TODO
+ .bool => unreachable,
+ }
+ const neg_norm_id = try self.normalize(wip.scalar_ty_ref, neg_id, info);
+
+ const gt_zero_id = try self.cmp(.gt, Type.bool, operand_scalar_ty, elem_id, zero_id);
+ const abs_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpSelect, .{
+ .id_result_type = self.typeId(operand_scalar_ty_ref),
+ .id_result = abs_id,
+ .condition = gt_zero_id,
+ .object_1 = elem_id,
+ .object_2 = neg_norm_id,
+ });
+ // For Shader, we may need to cast from signed to unsigned here.
+ result_id.* = try self.bitCast(wip.scalar_ty, operand_scalar_ty, abs_id);
+ }
+ return try wip.finalize();
}
fn airAddSubOverflow(
@@ -2488,140 +2651,344 @@ const DeclGen = struct {
const lhs = try self.resolve(extra.lhs);
const rhs = try self.resolve(extra.rhs);
- const operand_ty = self.typeOf(extra.lhs);
const result_ty = self.typeOfIndex(inst);
+ const operand_ty = self.typeOf(extra.lhs);
+ const ov_ty = result_ty.structFieldType(1, self.module);
+
+ const bool_ty_ref = try self.resolveType(Type.bool, .direct);
- const info = try self.arithmeticTypeInfo(operand_ty);
+ const info = self.arithmeticTypeInfo(operand_ty);
switch (info.class) {
.composite_integer => return self.todo("overflow ops for composite integers", .{}),
- .strange_integer => return self.todo("overflow ops for strange integers", .{}),
- .integer => {},
+ .strange_integer, .integer => {},
.float, .bool => unreachable,
}
- // The operand type must be the same as the result type in SPIR-V, which
- // is the same as in Zig.
- const operand_ty_ref = try self.resolveType(operand_ty, .direct);
- const operand_ty_id = self.typeId(operand_ty_ref);
+ var wip_result = try self.elementWise(operand_ty);
+ defer wip_result.deinit();
+ var wip_ov = try self.elementWise(ov_ty);
+ defer wip_ov.deinit();
+ for (wip_result.results, wip_ov.results, 0..) |*result_id, *ov_id, i| {
+ const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i);
+ const rhs_elem_id = try wip_result.elementAt(operand_ty, rhs, i);
+
+ // Normalize both so that we can properly check for overflow
+ const value_id = self.spv.allocId();
+
+ try self.func.body.emit(self.spv.gpa, add, .{
+ .id_result_type = wip_result.scalar_ty_id,
+ .id_result = value_id,
+ .operand_1 = lhs_elem_id,
+ .operand_2 = rhs_elem_id,
+ });
- const bool_ty_ref = try self.resolveType(Type.bool, .direct);
+ // Normalize the result so that the comparisons go well
+ result_id.* = try self.normalize(wip_result.scalar_ty_ref, value_id, info);
+
+ const overflowed_id = switch (info.signedness) {
+ .unsigned => blk: {
+ // Overflow happened if the result is smaller than either of the operands. It doesn't matter which.
+ // For subtraction the conditions need to be swapped.
+ const overflowed_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, ucmp, .{
+ .id_result_type = self.typeId(bool_ty_ref),
+ .id_result = overflowed_id,
+ .operand_1 = result_id.*,
+ .operand_2 = lhs_elem_id,
+ });
+ break :blk overflowed_id;
+ },
+ .signed => blk: {
+ // lhs - rhs
+ // For addition, overflow happened if:
+ // - rhs is negative and value > lhs
+ // - rhs is positive and value < lhs
+ // This can be shortened to:
+ // (rhs < 0 and value > lhs) or (rhs >= 0 and value <= lhs)
+ // = (rhs < 0) == (value > lhs)
+ // = (rhs < 0) == (lhs < value)
+ // Note that signed overflow is also wrapping in spir-v.
+ // For subtraction, overflow happened if:
+ // - rhs is negative and value < lhs
+ // - rhs is positive and value > lhs
+ // This can be shortened to:
+ // (rhs < 0 and value < lhs) or (rhs >= 0 and value >= lhs)
+ // = (rhs < 0) == (value < lhs)
+ // = (rhs < 0) == (lhs > value)
+
+ const rhs_lt_zero_id = self.spv.allocId();
+ const zero_id = try self.constInt(wip_result.scalar_ty_ref, 0);
+ try self.func.body.emit(self.spv.gpa, .OpSLessThan, .{
+ .id_result_type = self.typeId(bool_ty_ref),
+ .id_result = rhs_lt_zero_id,
+ .operand_1 = rhs_elem_id,
+ .operand_2 = zero_id,
+ });
+
+ const value_gt_lhs_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, scmp, .{
+ .id_result_type = self.typeId(bool_ty_ref),
+ .id_result = value_gt_lhs_id,
+ .operand_1 = lhs_elem_id,
+ .operand_2 = result_id.*,
+ });
+
+ const overflowed_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpLogicalEqual, .{
+ .id_result_type = self.typeId(bool_ty_ref),
+ .id_result = overflowed_id,
+ .operand_1 = rhs_lt_zero_id,
+ .operand_2 = value_gt_lhs_id,
+ });
+ break :blk overflowed_id;
+ },
+ };
+
+ ov_id.* = try self.intFromBool(wip_ov.scalar_ty_ref, overflowed_id);
+ }
+
+ return try self.constructStruct(
+ result_ty,
+ &.{ operand_ty, ov_ty },
+ &.{ try wip_result.finalize(), try wip_ov.finalize() },
+ );
+ }
+
+ fn airShlOverflow(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+ if (self.liveness.isUnused(inst)) return null;
+ const mod = self.module;
+ const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
+ const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;
+ const lhs = try self.resolve(extra.lhs);
+ const rhs = try self.resolve(extra.rhs);
+
+ const result_ty = self.typeOfIndex(inst);
+ const operand_ty = self.typeOf(extra.lhs);
+ const shift_ty = self.typeOf(extra.rhs);
+ const scalar_shift_ty_ref = try self.resolveType(shift_ty.scalarType(mod), .direct);
const ov_ty = result_ty.structFieldType(1, self.module);
- // Note: result is stored in a struct, so indirect representation.
- const ov_ty_ref = try self.resolveType(ov_ty, .indirect);
-
- // TODO: Operations other than addition.
- const value_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, add, .{
- .id_result_type = operand_ty_id,
- .id_result = value_id,
- .operand_1 = lhs,
- .operand_2 = rhs,
- });
- const overflowed_id = switch (info.signedness) {
- .unsigned => blk: {
- // Overflow happened if the result is smaller than either of the operands. It doesn't matter which.
- // For subtraction the conditions need to be swapped.
- const overflowed_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, ucmp, .{
- .id_result_type = self.typeId(bool_ty_ref),
- .id_result = overflowed_id,
- .operand_1 = value_id,
- .operand_2 = lhs,
- });
- break :blk overflowed_id;
- },
- .signed => blk: {
- // lhs - rhs
- // For addition, overflow happened if:
- // - rhs is negative and value > lhs
- // - rhs is positive and value < lhs
- // This can be shortened to:
- // (rhs < 0 and value > lhs) or (rhs >= 0 and value <= lhs)
- // = (rhs < 0) == (value > lhs)
- // = (rhs < 0) == (lhs < value)
- // Note that signed overflow is also wrapping in spir-v.
- // For subtraction, overflow happened if:
- // - rhs is negative and value < lhs
- // - rhs is positive and value > lhs
- // This can be shortened to:
- // (rhs < 0 and value < lhs) or (rhs >= 0 and value >= lhs)
- // = (rhs < 0) == (value < lhs)
- // = (rhs < 0) == (lhs > value)
-
- const rhs_lt_zero_id = self.spv.allocId();
- const zero_id = try self.constInt(operand_ty_ref, 0);
- try self.func.body.emit(self.spv.gpa, .OpSLessThan, .{
- .id_result_type = self.typeId(bool_ty_ref),
- .id_result = rhs_lt_zero_id,
- .operand_1 = rhs,
- .operand_2 = zero_id,
- });
+ const bool_ty_ref = try self.resolveType(Type.bool, .direct);
- const value_gt_lhs_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, scmp, .{
- .id_result_type = self.typeId(bool_ty_ref),
- .id_result = value_gt_lhs_id,
- .operand_1 = lhs,
- .operand_2 = value_id,
- });
+ const info = self.arithmeticTypeInfo(operand_ty);
+ switch (info.class) {
+ .composite_integer => return self.todo("overflow shift for composite integers", .{}),
+ .integer, .strange_integer => {},
+ .float, .bool => unreachable,
+ }
- const overflowed_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, .OpLogicalEqual, .{
- .id_result_type = self.typeId(bool_ty_ref),
- .id_result = overflowed_id,
- .operand_1 = rhs_lt_zero_id,
- .operand_2 = value_gt_lhs_id,
+ var wip_result = try self.elementWise(operand_ty);
+ defer wip_result.deinit();
+ var wip_ov = try self.elementWise(ov_ty);
+ defer wip_ov.deinit();
+ for (wip_result.results, wip_ov.results, 0..) |*result_id, *ov_id, i| {
+ const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i);
+ const rhs_elem_id = try wip_result.elementAt(shift_ty, rhs, i);
+
+ // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that,
+ // so just manually upcast it if required.
+ const shift_id = if (scalar_shift_ty_ref != wip_result.scalar_ty_ref) blk: {
+ const shift_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
+ .id_result_type = wip_result.scalar_ty_id,
+ .id_result = shift_id,
+ .unsigned_value = rhs_elem_id,
});
- break :blk overflowed_id;
- },
- };
+ break :blk shift_id;
+ } else rhs_elem_id;
+
+ const value_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{
+ .id_result_type = wip_result.scalar_ty_id,
+ .id_result = value_id,
+ .base = lhs_elem_id,
+ .shift = shift_id,
+ });
+ result_id.* = try self.normalize(wip_result.scalar_ty_ref, value_id, info);
+
+ const right_shift_id = self.spv.allocId();
+ switch (info.signedness) {
+ .signed => {
+ try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{
+ .id_result_type = wip_result.scalar_ty_id,
+ .id_result = right_shift_id,
+ .base = result_id.*,
+ .shift = shift_id,
+ });
+ },
+ .unsigned => {
+ try self.func.body.emit(self.spv.gpa, .OpShiftRightLogical, .{
+ .id_result_type = wip_result.scalar_ty_id,
+ .id_result = right_shift_id,
+ .base = result_id.*,
+ .shift = shift_id,
+ });
+ },
+ }
+
+ const overflowed_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{
+ .id_result_type = self.typeId(bool_ty_ref),
+ .id_result = overflowed_id,
+ .operand_1 = lhs_elem_id,
+ .operand_2 = right_shift_id,
+ });
+
+ ov_id.* = try self.intFromBool(wip_ov.scalar_ty_ref, overflowed_id);
+ }
- // Construct the struct that Zig wants as result.
- // The value should already be the correct type.
- const ov_id = try self.intFromBool(ov_ty_ref, overflowed_id);
return try self.constructStruct(
result_ty,
&.{ operand_ty, ov_ty },
- &.{ value_id, ov_id },
+ &.{ try wip_result.finalize(), try wip_ov.finalize() },
);
}
+ fn airMulAdd(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+ if (self.liveness.isUnused(inst)) return null;
+
+ const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op;
+ const extra = self.air.extraData(Air.Bin, pl_op.payload).data;
+
+ const mulend1 = try self.resolve(extra.lhs);
+ const mulend2 = try self.resolve(extra.rhs);
+ const addend = try self.resolve(pl_op.operand);
+
+ const ty = self.typeOfIndex(inst);
+
+ const info = self.arithmeticTypeInfo(ty);
+ assert(info.class == .float); // .mul_add is only emitted for floats
+
+ var wip = try self.elementWise(ty);
+ defer wip.deinit();
+ for (0..wip.results.len) |i| {
+ const mul_result = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpFMul, .{
+ .id_result_type = wip.scalar_ty_id,
+ .id_result = mul_result,
+ .operand_1 = try wip.elementAt(ty, mulend1, i),
+ .operand_2 = try wip.elementAt(ty, mulend2, i),
+ });
+
+ try self.func.body.emit(self.spv.gpa, .OpFAdd, .{
+ .id_result_type = wip.scalar_ty_id,
+ .id_result = wip.allocId(i),
+ .operand_1 = mul_result,
+ .operand_2 = try wip.elementAt(ty, addend, i),
+ });
+ }
+ return try wip.finalize();
+ }
+
+ fn airSplat(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+ if (self.liveness.isUnused(inst)) return null;
+ const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
+ const operand_id = try self.resolve(ty_op.operand);
+ const result_ty = self.typeOfIndex(inst);
+ var wip = try self.elementWise(result_ty);
+ defer wip.deinit();
+ for (wip.results) |*result_id| {
+ result_id.* = operand_id;
+ }
+ return try wip.finalize();
+ }
+
+ fn airReduce(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+ if (self.liveness.isUnused(inst)) return null;
+ const mod = self.module;
+ const reduce = self.air.instructions.items(.data)[@intFromEnum(inst)].reduce;
+ const operand = try self.resolve(reduce.operand);
+ const operand_ty = self.typeOf(reduce.operand);
+ const scalar_ty = operand_ty.scalarType(mod);
+ const scalar_ty_ref = try self.resolveType(scalar_ty, .direct);
+ const scalar_ty_id = self.typeId(scalar_ty_ref);
+
+ const info = self.arithmeticTypeInfo(operand_ty);
+
+ var result_id = try self.extractField(scalar_ty, operand, 0);
+ const len = operand_ty.vectorLen(mod);
+
+ switch (reduce.operation) {
+ .Min, .Max => |op| {
+ const cmp_op: std.math.CompareOperator = if (op == .Max) .gt else .lt;
+ for (1..len) |i| {
+ const lhs = result_id;
+ const rhs = try self.extractField(scalar_ty, operand, @intCast(i));
+ result_id = try self.minMax(scalar_ty, cmp_op, lhs, rhs);
+ }
+
+ return result_id;
+ },
+ else => {},
+ }
+
+ const opcode: Opcode = switch (info.class) {
+ .bool => switch (reduce.operation) {
+ .And => .OpLogicalAnd,
+ .Or => .OpLogicalOr,
+ .Xor => .OpLogicalNotEqual,
+ else => unreachable,
+ },
+ .strange_integer, .integer => switch (reduce.operation) {
+ .And => .OpBitwiseAnd,
+ .Or => .OpBitwiseOr,
+ .Xor => .OpBitwiseXor,
+ .Add => .OpIAdd,
+ .Mul => .OpIMul,
+ else => unreachable,
+ },
+ .float => switch (reduce.operation) {
+ .Add => .OpFAdd,
+ .Mul => .OpFMul,
+ else => unreachable,
+ },
+ .composite_integer => unreachable, // TODO
+ };
+
+ for (1..len) |i| {
+ const lhs = result_id;
+ const rhs = try self.extractField(scalar_ty, operand, @intCast(i));
+ result_id = self.spv.allocId();
+
+ try self.func.body.emitRaw(self.spv.gpa, opcode, 4);
+ self.func.body.writeOperand(spec.IdResultType, scalar_ty_id);
+ self.func.body.writeOperand(spec.IdResult, result_id);
+ self.func.body.writeOperand(spec.IdResultType, lhs);
+ self.func.body.writeOperand(spec.IdResultType, rhs);
+ }
+
+ return result_id;
+ }
+
fn airShuffle(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
const mod = self.module;
if (self.liveness.isUnused(inst)) return null;
- const ty = self.typeOfIndex(inst);
const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
const extra = self.air.extraData(Air.Shuffle, ty_pl.payload).data;
const a = try self.resolve(extra.a);
const b = try self.resolve(extra.b);
const mask = Value.fromInterned(extra.mask);
- const mask_len = extra.mask_len;
- const a_len = self.typeOf(extra.a).vectorLen(mod);
- const result_id = self.spv.allocId();
- const result_type_id = try self.resolveTypeId(ty);
- // Similar to LLVM, SPIR-V uses indices larger than the length of the first vector
- // to index into the second vector.
- try self.func.body.emitRaw(self.spv.gpa, .OpVectorShuffle, 4 + mask_len);
- self.func.body.writeOperand(spec.IdResultType, result_type_id);
- self.func.body.writeOperand(spec.IdResult, result_id);
- self.func.body.writeOperand(spec.IdRef, a);
- self.func.body.writeOperand(spec.IdRef, b);
+ const ty = self.typeOfIndex(inst);
- var i: usize = 0;
- while (i < mask_len) : (i += 1) {
+ var wip = try self.elementWise(ty);
+ defer wip.deinit();
+ for (wip.results, 0..) |*result_id, i| {
const elem = try mask.elemValue(mod, i);
if (elem.isUndef(mod)) {
- self.func.body.writeOperand(spec.LiteralInteger, 0xFFFF_FFFF);
+ result_id.* = try self.spv.constUndef(wip.scalar_ty_ref);
+ continue;
+ }
+
+ const index = elem.toSignedInt(mod);
+ if (index >= 0) {
+ result_id.* = try self.extractField(wip.scalar_ty, a, @intCast(index));
} else {
- const int = elem.toSignedInt(mod);
- const unsigned = if (int >= 0) @as(u32, @intCast(int)) else @as(u32, @intCast(~int + a_len));
- self.func.body.writeOperand(spec.LiteralInteger, unsigned);
+ result_id.* = try self.extractField(wip.scalar_ty, b, @intCast(~index));
}
}
- return result_id;
+ return try wip.finalize();
}
fn indicesToIds(self: *DeclGen, indices: []const u32) ![]IdRef {
@@ -2828,26 +3195,21 @@ const DeclGen = struct {
return result_id;
},
.Vector => {
- const child_ty = ty.childType(mod);
- const vector_len = ty.vectorLen(mod);
-
- const constituents = try self.gpa.alloc(IdRef, vector_len);
- defer self.gpa.free(constituents);
-
- for (constituents, 0..) |*constituent, i| {
- const lhs_index_id = try self.extractField(child_ty, cmp_lhs_id, @intCast(i));
- const rhs_index_id = try self.extractField(child_ty, cmp_rhs_id, @intCast(i));
- const result_id = try self.cmp(op, Type.bool, child_ty, lhs_index_id, rhs_index_id);
- constituent.* = try self.convertToIndirect(Type.bool, result_id);
+ var wip = try self.elementWise(result_ty);
+ defer wip.deinit();
+ const scalar_ty = ty.scalarType(mod);
+ for (wip.results, 0..) |*result_id, i| {
+ const lhs_elem_id = try wip.elementAt(ty, lhs_id, i);
+ const rhs_elem_id = try wip.elementAt(ty, rhs_id, i);
+ result_id.* = try self.cmp(op, Type.bool, scalar_ty, lhs_elem_id, rhs_elem_id);
}
-
- return try self.constructArray(result_ty, constituents);
+ return wip.finalize();
},
else => unreachable,
};
const opcode: Opcode = opcode: {
- const info = try self.arithmeticTypeInfo(op_ty);
+ const info = self.arithmeticTypeInfo(op_ty);
const signedness = switch (info.class) {
.composite_integer => {
return self.todo("binary operations for composite integers", .{});
@@ -2865,14 +3227,7 @@ const DeclGen = struct {
.neq => .OpLogicalNotEqual,
else => unreachable,
},
- .strange_integer => sign: {
- const op_ty_ref = try self.resolveType(op_ty, .direct);
- // Mask operands before performing comparison.
- cmp_lhs_id = try self.normalizeInt(op_ty_ref, cmp_lhs_id, info);
- cmp_rhs_id = try self.normalizeInt(op_ty_ref, cmp_rhs_id, info);
- break :sign info.signedness;
- },
- .integer => info.signedness,
+ .integer, .strange_integer => info.signedness,
};
break :opcode switch (signedness) {
@@ -2942,50 +3297,64 @@ const DeclGen = struct {
const mod = self.module;
const src_ty_ref = try self.resolveType(src_ty, .direct);
const dst_ty_ref = try self.resolveType(dst_ty, .direct);
- if (src_ty_ref == dst_ty_ref) {
- return src_id;
- }
+ const src_key = self.spv.cache.lookup(src_ty_ref);
+ const dst_key = self.spv.cache.lookup(dst_ty_ref);
- // TODO: Some more cases are missing here
- // See fn bitCast in llvm.zig
+ const result_id = blk: {
+ if (src_ty_ref == dst_ty_ref) {
+ break :blk src_id;
+ }
- if (src_ty.zigTypeTag(mod) == .Int and dst_ty.isPtrAtRuntime(mod)) {
- const result_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, .OpConvertUToPtr, .{
- .id_result_type = self.typeId(dst_ty_ref),
- .id_result = result_id,
- .integer_value = src_id,
- });
- return result_id;
- }
+ // TODO: Some more cases are missing here
+ // See fn bitCast in llvm.zig
- // We can only use OpBitcast for specific conversions: between numerical types, and
- // between pointers. If the resolved spir-v types fall into this category then emit OpBitcast,
- // otherwise use a temporary and perform a pointer cast.
- const src_key = self.spv.cache.lookup(src_ty_ref);
- const dst_key = self.spv.cache.lookup(dst_ty_ref);
+ if (src_ty.zigTypeTag(mod) == .Int and dst_ty.isPtrAtRuntime(mod)) {
+ const result_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpConvertUToPtr, .{
+ .id_result_type = self.typeId(dst_ty_ref),
+ .id_result = result_id,
+ .integer_value = src_id,
+ });
+ break :blk result_id;
+ }
- if ((src_key.isNumericalType() and dst_key.isNumericalType()) or (src_key == .ptr_type and dst_key == .ptr_type)) {
- const result_id = self.spv.allocId();
+ // We can only use OpBitcast for specific conversions: between numerical types, and
+ // between pointers. If the resolved spir-v types fall into this category then emit OpBitcast,
+ // otherwise use a temporary and perform a pointer cast.
+ if ((src_key.isNumericalType() and dst_key.isNumericalType()) or (src_key == .ptr_type and dst_key == .ptr_type)) {
+ const result_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
+ .id_result_type = self.typeId(dst_ty_ref),
+ .id_result = result_id,
+ .operand = src_id,
+ });
+
+ break :blk result_id;
+ }
+
+ const dst_ptr_ty_ref = try self.ptrType(dst_ty, .Function);
+
+ const tmp_id = try self.alloc(src_ty, .{ .storage_class = .Function });
+ try self.store(src_ty, tmp_id, src_id, .{});
+ const casted_ptr_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
- .id_result_type = self.typeId(dst_ty_ref),
- .id_result = result_id,
- .operand = src_id,
+ .id_result_type = self.typeId(dst_ptr_ty_ref),
+ .id_result = casted_ptr_id,
+ .operand = tmp_id,
});
- return result_id;
- }
+ break :blk try self.load(dst_ty, casted_ptr_id, .{});
+ };
- const dst_ptr_ty_ref = try self.ptrType(dst_ty, .Function);
+ // Because strange integers use sign-extended representation, we may need to normalize
+ // the result here.
+ // TODO: This detail could cause stuff like @as(*const i1, @ptrCast(&@as(u1, 1))) to break
+ // should we change the representation of strange integers?
+ if (dst_ty.zigTypeTag(mod) == .Int) {
+ const info = self.arithmeticTypeInfo(dst_ty);
+ return try self.normalize(dst_ty_ref, result_id, info);
+ }
- const tmp_id = try self.alloc(src_ty, .{ .storage_class = .Function });
- try self.store(src_ty, tmp_id, src_id, .{});
- const casted_ptr_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
- .id_result_type = self.typeId(dst_ptr_ty_ref),
- .id_result = casted_ptr_id,
- .operand = tmp_id,
- });
- return try self.load(dst_ty, casted_ptr_id, .{});
+ return result_id;
}
fn airBitCast(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -3004,34 +3373,43 @@ const DeclGen = struct {
const operand_id = try self.resolve(ty_op.operand);
const src_ty = self.typeOf(ty_op.operand);
const dst_ty = self.typeOfIndex(inst);
- const src_ty_ref = try self.resolveType(src_ty, .direct);
- const dst_ty_ref = try self.resolveType(dst_ty, .direct);
-
- const src_info = try self.arithmeticTypeInfo(src_ty);
- const dst_info = try self.arithmeticTypeInfo(dst_ty);
- // While intcast promises that the value already fits, the upper bits of a
- // strange integer may contain garbage. Therefore, mask/sign extend it before.
- const src_id = try self.normalizeInt(src_ty_ref, operand_id, src_info);
+ const src_info = self.arithmeticTypeInfo(src_ty);
+ const dst_info = self.arithmeticTypeInfo(dst_ty);
if (src_info.backing_bits == dst_info.backing_bits) {
- return src_id;
+ return operand_id;
}
- const result_id = self.spv.allocId();
- switch (dst_info.signedness) {
- .signed => try self.func.body.emit(self.spv.gpa, .OpSConvert, .{
- .id_result_type = self.typeId(dst_ty_ref),
- .id_result = result_id,
- .signed_value = src_id,
- }),
- .unsigned => try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
- .id_result_type = self.typeId(dst_ty_ref),
- .id_result = result_id,
- .unsigned_value = src_id,
- }),
+ var wip = try self.elementWise(dst_ty);
+ defer wip.deinit();
+ for (wip.results, 0..) |*result_id, i| {
+ const elem_id = try wip.elementAt(src_ty, operand_id, i);
+ const value_id = self.spv.allocId();
+ switch (dst_info.signedness) {
+ .signed => try self.func.body.emit(self.spv.gpa, .OpSConvert, .{
+ .id_result_type = wip.scalar_ty_id,
+ .id_result = value_id,
+ .signed_value = elem_id,
+ }),
+ .unsigned => try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
+ .id_result_type = wip.scalar_ty_id,
+ .id_result = value_id,
+ .unsigned_value = elem_id,
+ }),
+ }
+
+ // Make sure to normalize the result if shrinking.
+ // Because strange ints are sign extended in their backing
+ // type, we don't need to normalize when growing the type. The
+ // representation is already the same.
+ if (dst_info.bits < src_info.bits) {
+ result_id.* = try self.normalize(wip.scalar_ty_ref, value_id, dst_info);
+ } else {
+ result_id.* = value_id;
+ }
}
- return result_id;
+ return try wip.finalize();
}
fn intFromPtr(self: *DeclGen, operand_id: IdRef) !IdRef {
@@ -3059,7 +3437,7 @@ const DeclGen = struct {
const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
const operand_ty = self.typeOf(ty_op.operand);
const operand_id = try self.resolve(ty_op.operand);
- const operand_info = try self.arithmeticTypeInfo(operand_ty);
+ const operand_info = self.arithmeticTypeInfo(operand_ty);
const dest_ty = self.typeOfIndex(inst);
const dest_ty_id = try self.resolveTypeId(dest_ty);
@@ -3085,7 +3463,7 @@ const DeclGen = struct {
const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
const operand_id = try self.resolve(ty_op.operand);
const dest_ty = self.typeOfIndex(inst);
- const dest_info = try self.arithmeticTypeInfo(dest_ty);
+ const dest_info = self.arithmeticTypeInfo(dest_ty);
const dest_ty_id = try self.resolveTypeId(dest_ty);
const result_id = self.spv.allocId();
@@ -3104,6 +3482,22 @@ const DeclGen = struct {
return result_id;
}
+ fn airIntFromBool(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+ if (self.liveness.isUnused(inst)) return null;
+
+ const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op;
+ const operand_id = try self.resolve(un_op);
+ const result_ty = self.typeOfIndex(inst);
+
+ var wip = try self.elementWise(result_ty);
+ defer wip.deinit();
+ for (wip.results, 0..) |*result_id, i| {
+ const elem_id = try wip.elementAt(Type.bool, operand_id, i);
+ result_id.* = try self.intFromBool(wip.scalar_ty_ref, elem_id);
+ }
+ return try wip.finalize();
+ }
+
fn airFloatCast(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
if (self.liveness.isUnused(inst)) return null;
@@ -3126,31 +3520,31 @@ const DeclGen = struct {
const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
const operand_id = try self.resolve(ty_op.operand);
const result_ty = self.typeOfIndex(inst);
- const result_ty_id = try self.resolveTypeId(result_ty);
- const info = try self.arithmeticTypeInfo(result_ty);
+ const info = self.arithmeticTypeInfo(result_ty);
- const result_id = self.spv.allocId();
- switch (info.class) {
- .bool => {
- try self.func.body.emit(self.spv.gpa, .OpLogicalNot, .{
- .id_result_type = result_ty_id,
- .id_result = result_id,
- .operand = operand_id,
- });
- },
- .float => unreachable,
- .composite_integer => unreachable, // TODO
- .strange_integer, .integer => {
- // Note: strange integer bits will be masked before operations that do not hold under modulo.
- try self.func.body.emit(self.spv.gpa, .OpNot, .{
- .id_result_type = result_ty_id,
- .id_result = result_id,
- .operand = operand_id,
- });
- },
+ var wip = try self.elementWise(result_ty);
+ defer wip.deinit();
+
+ for (0..wip.results.len) |i| {
+ const args = .{
+ .id_result_type = wip.scalar_ty_id,
+ .id_result = wip.allocId(i),
+ .operand = try wip.elementAt(result_ty, operand_id, i),
+ };
+ switch (info.class) {
+ .bool => {
+ try self.func.body.emit(self.spv.gpa, .OpLogicalNot, args);
+ },
+ .float => unreachable,
+ .composite_integer => unreachable, // TODO
+ .strange_integer, .integer => {
+ // Note: strange integer bits will be masked before operations that do not hold under modulo.
+ try self.func.body.emit(self.spv.gpa, .OpNot, args);
+ },
+ }
}
- return result_id;
+ return try wip.finalize();
}
fn airArrayToSlice(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -3213,7 +3607,6 @@ const DeclGen = struct {
const elements: []const Air.Inst.Ref = @ptrCast(self.air.extra[ty_pl.payload..][0..len]);
switch (result_ty.zigTypeTag(mod)) {
- .Vector => unreachable, // TODO
.Struct => {
if (mod.typeToPackedStruct(result_ty)) |struct_type| {
_ = struct_type;
@@ -3261,7 +3654,7 @@ const DeclGen = struct {
constituents[0..index],
);
},
- .Array => {
+ .Vector, .Array => {
const array_info = result_ty.arrayInfo(mod);
const n_elems: usize = @intCast(result_ty.arrayLenIncludingSentinel(mod));
const elem_ids = try self.gpa.alloc(IdRef, n_elems);
@@ -3433,6 +3826,28 @@ const DeclGen = struct {
return try self.load(elem_ty, elem_ptr_id, .{ .is_volatile = ptr_ty.isVolatilePtr(mod) });
}
+ fn airVectorStoreElem(self: *DeclGen, inst: Air.Inst.Index) !void {
+ const mod = self.module;
+ const data = self.air.instructions.items(.data)[@intFromEnum(inst)].vector_store_elem;
+ const extra = self.air.extraData(Air.Bin, data.payload).data;
+
+ const vector_ptr_ty = self.typeOf(data.vector_ptr);
+ const vector_ty = vector_ptr_ty.childType(mod);
+ const scalar_ty = vector_ty.scalarType(mod);
+
+ const storage_class = spvStorageClass(vector_ptr_ty.ptrAddressSpace(mod));
+ const scalar_ptr_ty_ref = try self.ptrType(scalar_ty, storage_class);
+
+ const vector_ptr = try self.resolve(data.vector_ptr);
+ const index = try self.resolve(extra.lhs);
+ const operand = try self.resolve(extra.rhs);
+
+ const elem_ptr_id = try self.accessChainId(scalar_ptr_ty_ref, vector_ptr, &.{index});
+ try self.store(scalar_ty, elem_ptr_id, operand, .{
+ .is_volatile = vector_ptr_ty.isVolatilePtr(mod),
+ });
+ }
+
fn airSetUnionTag(self: *DeclGen, inst: Air.Inst.Index) !void {
const mod = self.module;
const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
@@ -4424,20 +4839,24 @@ const DeclGen = struct {
return try self.constructStruct(err_union_ty, &types, &members);
}
- fn airIsNull(self: *DeclGen, inst: Air.Inst.Index, pred: enum { is_null, is_non_null }) !?IdRef {
+ fn airIsNull(self: *DeclGen, inst: Air.Inst.Index, is_pointer: bool, pred: enum { is_null, is_non_null }) !?IdRef {
if (self.liveness.isUnused(inst)) return null;
const mod = self.module;
const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op;
const operand_id = try self.resolve(un_op);
- const optional_ty = self.typeOf(un_op);
-
+ const operand_ty = self.typeOf(un_op);
+ const optional_ty = if (is_pointer) operand_ty.childType(mod) else operand_ty;
const payload_ty = optional_ty.optionalChild(mod);
const bool_ty_ref = try self.resolveType(Type.bool, .direct);
if (optional_ty.optionalReprIsPayload(mod)) {
// Pointer payload represents nullability: pointer or slice.
+ const loaded_id = if (is_pointer)
+ try self.load(optional_ty, operand_id, .{})
+ else
+ operand_id;
const ptr_ty = if (payload_ty.isSlice(mod))
payload_ty.slicePtrFieldType(mod)
@@ -4445,9 +4864,9 @@ const DeclGen = struct {
payload_ty;
const ptr_id = if (payload_ty.isSlice(mod))
- try self.extractField(ptr_ty, operand_id, 0)
+ try self.extractField(ptr_ty, loaded_id, 0)
else
- operand_id;
+ loaded_id;
const payload_ty_ref = try self.resolveType(ptr_ty, .direct);
const null_id = try self.spv.constNull(payload_ty_ref);
@@ -4458,13 +4877,26 @@ const DeclGen = struct {
return try self.cmp(op, Type.bool, ptr_ty, ptr_id, null_id);
}
- const is_non_null_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod))
- try self.extractField(Type.bool, operand_id, 1)
- else
- // Optional representation is bool indicating whether the optional is set
- // Optionals with no payload are represented as an (indirect) bool, so convert
- // it back to the direct bool here.
- try self.convertToDirect(Type.bool, operand_id);
+ const is_non_null_id = blk: {
+ if (is_pointer) {
+ if (payload_ty.hasRuntimeBitsIgnoreComptime(mod)) {
+ const storage_class = spvStorageClass(operand_ty.ptrAddressSpace(mod));
+ const bool_ptr_ty = try self.ptrType(Type.bool, storage_class);
+ const tag_ptr_id = try self.accessChain(bool_ptr_ty, operand_id, &.{1});
+ break :blk try self.load(Type.bool, tag_ptr_id, .{});
+ }
+
+ break :blk try self.load(Type.bool, operand_id, .{});
+ }
+
+ break :blk if (payload_ty.hasRuntimeBitsIgnoreComptime(mod))
+ try self.extractField(Type.bool, operand_id, 1)
+ else
+ // Optional representation is bool indicating whether the optional is set
+ // Optionals with no payload are represented as an (indirect) bool, so convert
+ // it back to the direct bool here.
+ try self.convertToDirect(Type.bool, operand_id);
+ };
return switch (pred) {
.is_null => blk: {
@@ -4535,6 +4967,32 @@ const DeclGen = struct {
return try self.extractField(payload_ty, operand_id, 0);
}
+ fn airUnwrapOptionalPtr(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+ if (self.liveness.isUnused(inst)) return null;
+
+ const mod = self.module;
+ const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
+ const operand_id = try self.resolve(ty_op.operand);
+ const operand_ty = self.typeOf(ty_op.operand);
+ const optional_ty = operand_ty.childType(mod);
+ const payload_ty = optional_ty.optionalChild(mod);
+ const result_ty = self.typeOfIndex(inst);
+ const result_ty_ref = try self.resolveType(result_ty, .direct);
+
+ if (!payload_ty.hasRuntimeBitsIgnoreComptime(mod)) {
+ // There is no payload, but we still need to return a valid pointer.
+ // We can just return anything here, so just return a pointer to the operand.
+ return try self.bitCast(result_ty, operand_ty, operand_id);
+ }
+
+ if (optional_ty.optionalReprIsPayload(mod)) {
+ // They are the same value.
+ return try self.bitCast(result_ty, operand_ty, operand_id);
+ }
+
+ return try self.accessChain(result_ty_ref, operand_id, &.{0});
+ }
+
fn airWrapOptional(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
if (self.liveness.isUnused(inst)) return null;