diff options
Diffstat (limited to 'src/codegen/spirv.zig')
| -rw-r--r-- | src/codegen/spirv.zig | 60 |
1 files changed, 39 insertions, 21 deletions
diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 372a11c912..a119d93901 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -408,7 +408,7 @@ pub const DeclGen = struct { switch (repr) { .indirect => { const int_ty_ref = try self.intType(.unsigned, 1); - return self.spv.constInt(int_ty_ref, @intFromBool(value)); + return self.constInt(int_ty_ref, @intFromBool(value)); }, .direct => { const bool_ty_ref = try self.resolveType(Type.bool, .direct); @@ -417,6 +417,25 @@ pub const DeclGen = struct { } } + /// Emits an integer constant. + /// This function, unlike SpvModule.constInt, takes care to bitcast + /// the value to an unsigned int first for Kernels. + fn constInt(self: *DeclGen, ty_ref: CacheRef, value: anytype) !IdRef { + if (value < 0) { + const ty = self.spv.cache.lookup(ty_ref).int_type; + // Manually truncate the value so that the resulting value + // fits within the unsigned type. + const bits: u64 = @bitCast(@as(i64, @intCast(value))); + const truncated_bits = if (ty.bits == 64) + bits + else + bits & (@as(u64, 1) << @intCast(ty.bits)) - 1; + return try self.spv.constInt(ty_ref, truncated_bits); + } else { + return try self.spv.constInt(ty_ref, value); + } + } + /// Construct a struct at runtime. /// result_ty_ref must be a struct type. fn constructStruct(self: *DeclGen, result_ty_ref: CacheRef, constituents: []const IdRef) !IdRef { @@ -434,7 +453,7 @@ pub const DeclGen = struct { const member_types = spv_composite_ty.member_types; for (constituents, member_types, 0..) |constitent_id, member_ty_ref, index| { - const index_id = try self.spv.constInt(index_ty_ref, index); + const index_id = try self.constInt(index_ty_ref, index); const ptr_member_ty_ref = try self.spv.ptrType(member_ty_ref, .Generic); const ptr_id = try self.accessChain(ptr_member_ty_ref, ptr_composite_id, &.{index_id}); try self.func.body.emit(self.spv.gpa, .OpStore, .{ @@ -469,7 +488,7 @@ pub const DeclGen = struct { const ptr_elem_ty_ref = try self.spv.ptrType(elem_ty_ref, .Generic); for (constituents, 0..) |constitent_id, index| { - const index_id = try self.spv.constInt(index_ty_ref, index); + const index_id = try self.constInt(index_ty_ref, index); const ptr_id = try self.accessChain(ptr_elem_ty_ref, ptr_composite_id, &.{index_id}); try self.func.body.emit(self.spv.gpa, .OpStore, .{ .pointer = ptr_id, @@ -580,17 +599,14 @@ pub const DeclGen = struct { .generic_poison, => unreachable, // non-runtime values - .false, .true => switch (repr) { - .direct => return try self.spv.constBool(result_ty_ref, val.toBool()), - .indirect => return try self.spv.constInt(result_ty_ref, @intFromBool(val.toBool())), - }, + .false, .true => return try self.constBool(val.toBool(), repr), }, .int => { if (ty.isSignedInt(mod)) { - return try self.spv.constInt(result_ty_ref, val.toSignedInt(mod)); + return try self.constInt(result_ty_ref, val.toSignedInt(mod)); } else { - return try self.spv.constInt(result_ty_ref, val.toUnsignedInt(mod)); + return try self.constInt(result_ty_ref, val.toUnsignedInt(mod)); } }, .float => return switch (ty.floatBits(target)) { @@ -602,7 +618,7 @@ pub const DeclGen = struct { }, .err => |err| { const value = try mod.getErrorValue(err.name); - return try self.spv.constInt(result_ty_ref, value); + return try self.constInt(result_ty_ref, value); }, .error_union => |error_union| { // TODO: Error unions may be constructed with constant instructions if the payload type @@ -716,7 +732,7 @@ pub const DeclGen = struct { // TODO: This is really space inefficient, perhaps there is a better // way to do it? for (bytes, 0..) |byte, i| { - constituents[i] = try self.spv.constInt(elem_ty_ref, byte); + constituents[i] = try self.constInt(elem_ty_ref, byte); } }, .elems => |elems| { @@ -794,7 +810,7 @@ pub const DeclGen = struct { const index_ty_ref = try self.intType(.unsigned, 32); if (layout.tag_size != 0) { - const index_id = try self.spv.constInt(index_ty_ref, @as(u32, @intCast(layout.tag_index))); + const index_id = try self.constInt(index_ty_ref, @as(u32, @intCast(layout.tag_index))); const tag_ty = ty.unionTagTypeSafety(mod).?; const tag_ty_ref = try self.resolveType(tag_ty, .indirect); const tag_ptr_ty_ref = try self.spv.ptrType(tag_ty_ref, .Function); @@ -807,7 +823,7 @@ pub const DeclGen = struct { } if (layout.active_field_size != 0) { - const index_id = try self.spv.constInt(index_ty_ref, @as(u32, @intCast(layout.active_field_index))); + const index_id = try self.constInt(index_ty_ref, @as(u32, @intCast(layout.active_field_index))); const active_field_ty_ref = try self.resolveType(layout.active_field_ty, .indirect); const active_field_ptr_ty_ref = try self.spv.ptrType(active_field_ty_ref, .Function); const ptr_id = try self.accessChain(active_field_ptr_ty_ref, var_id, &.{index_id}); @@ -870,7 +886,9 @@ pub const DeclGen = struct { // An array of largestSupportedIntBits. return self.todo("Implement {s} composite int type of {} bits", .{ @tagName(signedness), bits }); }; - return self.spv.intType(signedness, backing_bits); + // Kernel only supports unsigned ints. + // TODO: Only do this with Kernels + return self.spv.intType(.unsigned, backing_bits); } /// Create an integer type that represents 'usize'. @@ -1568,8 +1586,8 @@ pub const DeclGen = struct { } fn intFromBool(self: *DeclGen, result_ty_ref: CacheRef, condition_id: IdRef) !IdRef { - const zero_id = try self.spv.constInt(result_ty_ref, 0); - const one_id = try self.spv.constInt(result_ty_ref, 1); + const zero_id = try self.constInt(result_ty_ref, 0); + const one_id = try self.constInt(result_ty_ref, 1); const result_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpSelect, .{ .id_result_type = self.typeId(result_ty_ref), @@ -1589,7 +1607,7 @@ pub const DeclGen = struct { .Bool => blk: { const direct_bool_ty_ref = try self.resolveType(ty, .direct); const indirect_bool_ty_ref = try self.resolveType(ty, .indirect); - const zero_id = try self.spv.constInt(indirect_bool_ty_ref, 0); + const zero_id = try self.constInt(indirect_bool_ty_ref, 0); const result_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{ .id_result_type = self.typeId(direct_bool_ty_ref), @@ -1832,7 +1850,7 @@ pub const DeclGen = struct { fn maskStrangeInt(self: *DeclGen, ty_ref: CacheRef, value_id: IdRef, bits: u16) !IdRef { const mask_value = if (bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(bits))) - 1; const result_id = self.spv.allocId(); - const mask_id = try self.spv.constInt(ty_ref, mask_value); + const mask_id = try self.constInt(ty_ref, mask_value); try self.func.body.emit(self.spv.gpa, .OpBitwiseAnd, .{ .id_result_type = self.typeId(ty_ref), .id_result = result_id, @@ -1971,7 +1989,7 @@ pub const DeclGen = struct { // Note that signed overflow is also wrapping in spir-v. const rhs_lt_zero_id = self.spv.allocId(); - const zero_id = try self.spv.constInt(operand_ty_ref, 0); + const zero_id = try self.constInt(operand_ty_ref, 0); try self.func.body.emit(self.spv.gpa, .OpSLessThan, .{ .id_result_type = self.typeId(bool_ty_ref), .id_result = rhs_lt_zero_id, @@ -2540,7 +2558,7 @@ pub const DeclGen = struct { .Packed => unreachable, // TODO else => { const field_index_ty_ref = try self.intType(.unsigned, 32); - const field_index_id = try self.spv.constInt(field_index_ty_ref, field_index); + const field_index_id = try self.constInt(field_index_ty_ref, field_index); const result_ty_ref = try self.resolveType(result_ptr_ty, .direct); return try self.accessChain(result_ty_ref, object_ptr, &.{field_index_id}); }, @@ -2822,7 +2840,7 @@ pub const DeclGen = struct { else err_union_id; - const zero_id = try self.spv.constInt(err_ty_ref, 0); + const zero_id = try self.constInt(err_ty_ref, 0); const is_err_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{ .id_result_type = self.typeId(bool_ty_ref), |
