aboutsummaryrefslogtreecommitdiff
path: root/src/codegen/spirv.zig
diff options
context:
space:
mode:
authorRobin Voetter <robin@voetter.nl>2023-10-22 15:35:00 +0200
committerGitHub <noreply@github.com>2023-10-22 15:35:00 +0200
commitb822e841cda0adabe3fec260ff51c18508f7ee32 (patch)
treefe1bdf51d000cddcf4b42a20f3c69111c16651d8 /src/codegen/spirv.zig
parent0c99ba1eab63865592bb084feb271cd4e4b0357e (diff)
parent6281ad91dfc0d799bfabced68009dfb4971545d7 (diff)
downloadzig-b822e841cda0adabe3fec260ff51c18508f7ee32.tar.gz
zig-b822e841cda0adabe3fec260ff51c18508f7ee32.zip
Merge pull request #17657 from Snektron/spirv-recursive-ptrs
spirv: recursive pointers
Diffstat (limited to 'src/codegen/spirv.zig')
-rw-r--r--src/codegen/spirv.zig601
1 files changed, 292 insertions, 309 deletions
diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig
index e7e6ebdef3..0b2fc2d037 100644
--- a/src/codegen/spirv.zig
+++ b/src/codegen/spirv.zig
@@ -209,6 +209,10 @@ const DeclGen = struct {
/// See Object.type_map
type_map: *TypeMap,
+ /// Child types of pointers that are currently in progress of being resolved. If a pointer
+ /// is already in this map, its recursive.
+ wip_pointers: std.AutoHashMapUnmanaged(struct { InternPool.Index, StorageClass }, CacheRef) = .{},
+
/// We need to keep track of result ids for block labels, as well as the 'incoming'
/// blocks for a block.
blocks: BlockMap = .{},
@@ -295,6 +299,7 @@ const DeclGen = struct {
pub fn deinit(self: *DeclGen) void {
self.args.deinit(self.gpa);
self.inst_results.deinit(self.gpa);
+ self.wip_pointers.deinit(self.gpa);
self.blocks.deinit(self.gpa);
self.func.deinit(self.gpa);
self.base_line_stack.deinit(self.gpa);
@@ -358,8 +363,7 @@ const DeclGen = struct {
const mod = self.module;
const ty = mod.intern_pool.typeOf(val).toType();
- const ty_ref = try self.resolveType(ty, .indirect);
- const ptr_ty_ref = try self.spv.ptrType(ty_ref, storage_class);
+ const ptr_ty_ref = try self.ptrType(ty, storage_class);
const var_id = self.spv.declPtr(spv_decl_index).result_id;
@@ -582,66 +586,41 @@ const DeclGen = struct {
}
/// Construct a struct at runtime.
- /// result_ty_ref must be a struct type.
+ /// ty must be a struct type.
/// Constituents should be in `indirect` representation (as the elements of a struct should be).
/// Result is in `direct` representation.
- fn constructStruct(self: *DeclGen, result_ty_ref: CacheRef, constituents: []const IdRef) !IdRef {
+ fn constructStruct(self: *DeclGen, ty: Type, types: []const Type, constituents: []const IdRef) !IdRef {
+ assert(types.len == constituents.len);
// The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which'
// operands are not constant.
// See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349
// For now, just initialize the struct by setting the fields manually...
// TODO: Make this OpCompositeConstruct when we can
- const ptr_ty_ref = try self.spv.ptrType(result_ty_ref, .Function);
- const ptr_composite_id = self.spv.allocId();
- try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
- .id_result_type = self.typeId(ptr_ty_ref),
- .id_result = ptr_composite_id,
- .storage_class = .Function,
- });
-
- const spv_composite_ty = self.spv.cache.lookup(result_ty_ref).struct_type;
- const member_types = spv_composite_ty.member_types;
-
- for (constituents, member_types, 0..) |constitent_id, member_ty_ref, index| {
- const ptr_member_ty_ref = try self.spv.ptrType(member_ty_ref, .Function);
+ const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function });
+ for (constituents, types, 0..) |constitent_id, member_ty, index| {
+ const ptr_member_ty_ref = try self.ptrType(member_ty, .Function);
const ptr_id = try self.accessChain(ptr_member_ty_ref, ptr_composite_id, &.{@as(u32, @intCast(index))});
try self.func.body.emit(self.spv.gpa, .OpStore, .{
.pointer = ptr_id,
.object = constitent_id,
});
}
- const result_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, .OpLoad, .{
- .id_result_type = self.typeId(result_ty_ref),
- .id_result = result_id,
- .pointer = ptr_composite_id,
- });
- return result_id;
+ return try self.load(ty, ptr_composite_id, .{});
}
/// Construct an array at runtime.
- /// result_ty_ref must be an array type.
+ /// ty must be an array type.
/// Constituents should be in `indirect` representation (as the elements of an array should be).
/// Result is in `direct` representation.
- fn constructArray(self: *DeclGen, result_ty_ref: CacheRef, constituents: []const IdRef) !IdRef {
+ fn constructArray(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef {
// The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which'
// operands are not constant.
// See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349
// For now, just initialize the struct by setting the fields manually...
// TODO: Make this OpCompositeConstruct when we can
- // TODO: Make this Function storage type
- const ptr_ty_ref = try self.spv.ptrType(result_ty_ref, .Function);
- const ptr_composite_id = self.spv.allocId();
- try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
- .id_result_type = self.typeId(ptr_ty_ref),
- .id_result = ptr_composite_id,
- .storage_class = .Function,
- });
-
- const spv_composite_ty = self.spv.cache.lookup(result_ty_ref).array_type;
- const elem_ty_ref = spv_composite_ty.element_type;
- const ptr_elem_ty_ref = try self.spv.ptrType(elem_ty_ref, .Function);
-
+ const mod = self.module;
+ const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function });
+ const ptr_elem_ty_ref = try self.ptrType(ty.elemType2(mod), .Function);
for (constituents, 0..) |constitent_id, index| {
const ptr_id = try self.accessChain(ptr_elem_ty_ref, ptr_composite_id, &.{@as(u32, @intCast(index))});
try self.func.body.emit(self.spv.gpa, .OpStore, .{
@@ -649,13 +628,8 @@ const DeclGen = struct {
.object = constitent_id,
});
}
- const result_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, .OpLoad, .{
- .id_result_type = self.typeId(result_ty_ref),
- .id_result = result_id,
- .pointer = ptr_composite_id,
- });
- return result_id;
+
+ return try self.load(ty, ptr_composite_id, .{});
}
/// This function generates a load for a constant in direct (ie, non-memory) representation.
@@ -766,15 +740,18 @@ const DeclGen = struct {
}.toValue();
var constituents: [2]IdRef = undefined;
+ var types: [2]Type = undefined;
if (eu_layout.error_first) {
constituents[0] = try self.constant(err_ty, err_val, .indirect);
constituents[1] = try self.constant(payload_ty, payload_val, .indirect);
+ types = .{ err_ty, payload_ty };
} else {
constituents[0] = try self.constant(payload_ty, payload_val, .indirect);
constituents[1] = try self.constant(err_ty, err_val, .indirect);
+ types = .{ payload_ty, err_ty };
}
- return try self.constructStruct(result_ty_ref, &constituents);
+ return try self.constructStruct(ty, &types, &constituents);
},
.enum_tag => {
const int_val = try val.intFromEnum(ty, mod);
@@ -792,7 +769,11 @@ const DeclGen = struct {
}
const len_id = try self.constant(Type.usize, ptr.len.toValue(), .indirect);
- return try self.constructStruct(result_ty_ref, &.{ ptr_id, len_id });
+ return try self.constructStruct(
+ ty,
+ &.{ ptr_ty, Type.usize },
+ &.{ ptr_id, len_id },
+ );
},
.opt => {
const payload_ty = ty.optionalChild(mod);
@@ -819,7 +800,11 @@ const DeclGen = struct {
else
try self.spv.constUndef(try self.resolveType(payload_ty, .indirect));
- return try self.constructStruct(result_ty_ref, &.{ payload_id, has_pl_id });
+ return try self.constructStruct(
+ ty,
+ &.{ payload_ty, Type.bool },
+ &.{ payload_id, has_pl_id },
+ );
},
.aggregate => |aggregate| switch (ip.indexToKey(ty.ip_index)) {
inline .array_type, .vector_type => |array_type, tag| {
@@ -857,7 +842,7 @@ const DeclGen = struct {
else => {},
}
- return try self.constructArray(result_ty_ref, constituents);
+ return try self.constructArray(ty, constituents);
},
.struct_type => {
const struct_type = mod.typeToStruct(ty).?;
@@ -865,6 +850,9 @@ const DeclGen = struct {
return self.todo("packed struct constants", .{});
}
+ var types = std.ArrayList(Type).init(self.gpa);
+ defer types.deinit();
+
var constituents = std.ArrayList(IdRef).init(self.gpa);
defer constituents.deinit();
@@ -880,22 +868,23 @@ const DeclGen = struct {
const field_val = try val.fieldValue(mod, field_index);
const field_id = try self.constant(field_ty, field_val, .indirect);
+ try types.append(field_ty);
try constituents.append(field_id);
}
- return try self.constructStruct(result_ty_ref, constituents.items);
+ return try self.constructStruct(ty, types.items, constituents.items);
},
.anon_struct_type => unreachable, // TODO
else => unreachable,
},
.un => |un| {
const active_field = ty.unionTagFieldIndex(un.tag.toValue(), mod).?;
- const layout = self.unionLayout(ty, active_field);
- const payload = if (layout.active_field_size != 0)
- try self.constant(layout.active_field_ty, un.val.toValue(), .indirect)
+ const union_obj = mod.typeToUnion(ty).?;
+ const field_ty = union_obj.field_types.get(ip)[active_field].toType();
+ const payload = if (field_ty.hasRuntimeBitsIgnoreComptime(mod))
+ try self.constant(field_ty, un.val.toValue(), .direct)
else
null;
-
return try self.unionInit(ty, active_field, payload);
},
.memoized_call => unreachable,
@@ -934,8 +923,7 @@ const DeclGen = struct {
// TODO: Can we consolidate this in ptrElemPtr?
const elem_ty = parent_ptr_ty.elemType2(mod); // use elemType() so that we get T for *[N]T.
- const elem_ty_ref = try self.resolveType(elem_ty, .direct);
- const elem_ptr_ty_ref = try self.spv.ptrType(elem_ty_ref, spvStorageClass(parent_ptr_ty.ptrAddressSpace(mod)));
+ const elem_ptr_ty_ref = try self.ptrType(elem_ty, spvStorageClass(parent_ptr_ty.ptrAddressSpace(mod)));
if (elem_ptr_ty_ref == result_ty_ref) {
return elem_ptr_id;
@@ -997,8 +985,7 @@ const DeclGen = struct {
};
const decl_id = try self.resolveAnonDecl(decl_val, actual_storage_class);
- const decl_ty_ref = try self.resolveType(decl_ty, .indirect);
- const decl_ptr_ty_ref = try self.spv.ptrType(decl_ty_ref, final_storage_class);
+ const decl_ptr_ty_ref = try self.ptrType(decl_ty, final_storage_class);
const ptr_id = switch (final_storage_class) {
.Generic => blk: {
@@ -1054,8 +1041,7 @@ const DeclGen = struct {
const final_storage_class = spvStorageClass(decl.@"addrspace");
- const decl_ty_ref = try self.resolveType(decl.ty, .indirect);
- const decl_ptr_ty_ref = try self.spv.ptrType(decl_ty_ref, final_storage_class);
+ const decl_ptr_ty_ref = try self.ptrType(decl.ty, final_storage_class);
const ptr_id = switch (final_storage_class) {
.Generic => blk: {
@@ -1123,29 +1109,52 @@ const DeclGen = struct {
return try self.intType(.unsigned, self.getTarget().ptrBitWidth());
}
- /// Generate a union type, optionally with a known field. If the tag alignment is greater
- /// than that of the payload, a regular union (non-packed, with both tag and payload), will
- /// be generated as follows:
- /// If the active field is known:
+ fn ptrType(self: *DeclGen, child_ty: Type, storage_class: StorageClass) !CacheRef {
+ const key = .{ child_ty.toIntern(), storage_class };
+ const entry = try self.wip_pointers.getOrPut(self.gpa, key);
+ if (entry.found_existing) {
+ const fwd_ref = entry.value_ptr.*;
+ try self.spv.cache.recursive_ptrs.put(self.spv.gpa, fwd_ref, {});
+ return fwd_ref;
+ }
+
+ const fwd_ref = try self.spv.resolve(.{ .fwd_ptr_type = .{
+ .zig_child_type = child_ty.toIntern(),
+ .storage_class = storage_class,
+ } });
+ entry.value_ptr.* = fwd_ref;
+
+ const child_ty_ref = try self.resolveType(child_ty, .indirect);
+ _ = try self.spv.resolve(.{ .ptr_type = .{
+ .storage_class = storage_class,
+ .child_type = child_ty_ref,
+ .fwd = fwd_ref,
+ } });
+
+ assert(self.wip_pointers.remove(key));
+
+ return fwd_ref;
+ }
+
+ /// Generate a union type. Union types are always generated with the
+ /// most aligned field active. If the tag alignment is greater
+ /// than that of the payload, a regular union (non-packed, with both tag and
+ /// payload), will be generated as follows:
/// struct {
/// tag: TagType,
- /// payload: ActivePayloadType,
- /// payload_padding: [payload_size - @sizeOf(ActivePayloadType)]u8,
+ /// payload: MostAlignedFieldType,
+ /// payload_padding: [payload_size - @sizeOf(MostAlignedFieldType)]u8,
/// padding: [padding_size]u8,
/// }
/// If the payload alignment is greater than that of the tag:
/// struct {
- /// payload: ActivePayloadType,
- /// payload_padding: [payload_size - @sizeOf(ActivePayloadType)]u8,
+ /// payload: MostAlignedFieldType,
+ /// payload_padding: [payload_size - @sizeOf(MostAlignedFieldType)]u8,
/// tag: TagType,
/// padding: [padding_size]u8,
/// }
- /// If the active payload is unknown, it will default back to the most aligned field. This is
- /// to make sure that the overal struct has the correct alignment in spir-v.
/// If any of the fields' size is 0, it will be omitted.
- /// NOTE: When the active field is set to something other than the most aligned field, the
- /// resulting struct will be *underaligned*.
- fn resolveUnionType(self: *DeclGen, ty: Type, maybe_active_field: ?usize) !CacheRef {
+ fn resolveUnionType(self: *DeclGen, ty: Type) !CacheRef {
const mod = self.module;
const ip = &mod.intern_pool;
const union_obj = mod.typeToUnion(ty).?;
@@ -1154,17 +1163,13 @@ const DeclGen = struct {
return self.todo("packed union types", .{});
}
- const layout = self.unionLayout(ty, maybe_active_field);
-
- if (layout.payload_size == 0) {
+ const layout = self.unionLayout(ty);
+ if (!layout.has_payload) {
// No payload, so represent this as just the tag type.
return try self.resolveType(union_obj.enum_tag_ty.toType(), .indirect);
}
- // TODO: We need to add the active field to the key, somehow.
- if (maybe_active_field == null) {
- if (self.type_map.get(ty.toIntern())) |info| return info.ty_ref;
- }
+ if (self.type_map.get(ty.toIntern())) |info| return info.ty_ref;
var member_types: [4]CacheRef = undefined;
var member_names: [4]CacheString = undefined;
@@ -1177,10 +1182,10 @@ const DeclGen = struct {
member_names[layout.tag_index] = try self.spv.resolveString("(tag)");
}
- if (layout.active_field_size != 0) {
- const active_payload_ty_ref = try self.resolveType(layout.active_field_ty, .indirect);
- member_types[layout.active_field_index] = active_payload_ty_ref;
- member_names[layout.active_field_index] = try self.spv.resolveString("(payload)");
+ if (layout.payload_size != 0) {
+ const payload_ty_ref = try self.resolveType(layout.payload_ty, .indirect);
+ member_types[layout.payload_index] = payload_ty_ref;
+ member_names[layout.payload_index] = try self.spv.resolveString("(payload)");
}
if (layout.payload_padding_size != 0) {
@@ -1201,9 +1206,7 @@ const DeclGen = struct {
.member_names = member_names[0..layout.total_fields],
} });
- if (maybe_active_field == null) {
- try self.type_map.put(self.gpa, ty.toIntern(), .{ .ty_ref = ty_ref });
- }
+ try self.type_map.put(self.gpa, ty.toIntern(), .{ .ty_ref = ty_ref });
return ty_ref;
}
@@ -1351,12 +1354,12 @@ const DeclGen = struct {
.Pointer => {
const ptr_info = ty.ptrInfo(mod);
+ // Note: Don't cache this pointer type, it would mess up the recursive pointer functionality
+ // in ptrType()!
+
const storage_class = spvStorageClass(ptr_info.flags.address_space);
- const child_ty_ref = try self.resolveType(ptr_info.child.toType(), .indirect);
- const ptr_ty_ref = try self.spv.resolve(.{ .ptr_type = .{
- .storage_class = storage_class,
- .child_type = child_ty_ref,
- } });
+ const ptr_ty_ref = try self.ptrType(ptr_info.child.toType(), storage_class);
+
if (ptr_info.flags.size != .Slice) {
return ptr_ty_ref;
}
@@ -1471,7 +1474,7 @@ const DeclGen = struct {
try self.type_map.put(self.gpa, ty.toIntern(), .{ .ty_ref = ty_ref });
return ty_ref;
},
- .Union => return try self.resolveUnionType(ty, null),
+ .Union => return try self.resolveUnionType(ty),
.ErrorSet => return try self.intType(.unsigned, 16),
.ErrorUnion => {
const payload_ty = ty.errorUnionPayload(mod);
@@ -1585,14 +1588,16 @@ const DeclGen = struct {
}
const UnionLayout = struct {
- active_field: u32,
- active_field_ty: Type,
- payload_size: u32,
-
+ /// If false, this union is represented
+ /// by only an integer of the tag type.
+ has_payload: bool,
tag_size: u32,
tag_index: u32,
- active_field_size: u32,
- active_field_index: u32,
+ /// Note: This is the size of the payload type itself, NOT the size of the ENTIRE payload.
+ /// Use `has_payload` instead!!
+ payload_ty: Type,
+ payload_size: u32,
+ payload_index: u32,
payload_padding_size: u32,
payload_padding_index: u32,
padding_size: u32,
@@ -1600,23 +1605,19 @@ const DeclGen = struct {
total_fields: u32,
};
- fn unionLayout(self: *DeclGen, ty: Type, maybe_active_field: ?usize) UnionLayout {
+ fn unionLayout(self: *DeclGen, ty: Type) UnionLayout {
const mod = self.module;
const ip = &mod.intern_pool;
const layout = ty.unionGetLayout(self.module);
const union_obj = mod.typeToUnion(ty).?;
- const active_field = maybe_active_field orelse layout.most_aligned_field;
- const active_field_ty = union_obj.field_types.get(ip)[active_field].toType();
-
var union_layout = UnionLayout{
- .active_field = @intCast(active_field),
- .active_field_ty = active_field_ty,
- .payload_size = @intCast(layout.payload_size),
+ .has_payload = layout.payload_size != 0,
.tag_size = @intCast(layout.tag_size),
.tag_index = undefined,
- .active_field_size = undefined,
- .active_field_index = undefined,
+ .payload_ty = undefined,
+ .payload_size = undefined,
+ .payload_index = undefined,
.payload_padding_size = undefined,
.payload_padding_index = undefined,
.padding_size = @intCast(layout.padding),
@@ -1624,11 +1625,16 @@ const DeclGen = struct {
.total_fields = undefined,
};
- union_layout.active_field_size = if (active_field_ty.hasRuntimeBitsIgnoreComptime(mod))
- @intCast(active_field_ty.abiSize(mod))
- else
- 0;
- union_layout.payload_padding_size = @intCast(layout.payload_size - union_layout.active_field_size);
+ if (union_layout.has_payload) {
+ const most_aligned_field = layout.most_aligned_field;
+ const most_aligned_field_ty = union_obj.field_types.get(ip)[most_aligned_field].toType();
+ union_layout.payload_ty = most_aligned_field_ty;
+ union_layout.payload_size = @intCast(most_aligned_field_ty.abiSize(mod));
+ } else {
+ union_layout.payload_size = 0;
+ }
+
+ union_layout.payload_padding_size = @intCast(layout.payload_size - union_layout.payload_size);
const tag_first = layout.tag_align.compare(.gte, layout.payload_align);
var field_index: u32 = 0;
@@ -1638,8 +1644,8 @@ const DeclGen = struct {
field_index += 1;
}
- if (union_layout.active_field_size != 0) {
- union_layout.active_field_index = field_index;
+ if (union_layout.payload_size != 0) {
+ union_layout.payload_index = field_index;
field_index += 1;
}
@@ -1683,7 +1689,7 @@ const DeclGen = struct {
/// the name of an error in the text executor.
fn generateTestEntryPoint(self: *DeclGen, name: []const u8, spv_test_decl_index: SpvModule.Decl.Index) !void {
const anyerror_ty_ref = try self.resolveType(Type.anyerror, .direct);
- const ptr_anyerror_ty_ref = try self.spv.ptrType(anyerror_ty_ref, .CrossWorkgroup);
+ const ptr_anyerror_ty_ref = try self.ptrType(Type.anyerror, .CrossWorkgroup);
const void_ty_ref = try self.resolveType(Type.void, .direct);
const kernel_proto_ty_ref = try self.spv.resolve(.{ .function_type = .{
@@ -1718,6 +1724,7 @@ const DeclGen = struct {
.id_result = error_id,
.function = test_id,
});
+ // Note: Convert to direct not required.
try section.emit(self.spv.gpa, .OpStore, .{
.pointer = p_error_id,
.object = error_id,
@@ -1822,8 +1829,7 @@ const DeclGen = struct {
else => final_storage_class,
};
- const ty_ref = try self.resolveType(decl.ty, .indirect);
- const ptr_ty_ref = try self.spv.ptrType(ty_ref, actual_storage_class);
+ const ptr_ty_ref = try self.ptrType(decl.ty, actual_storage_class);
const begin = self.spv.beginGlobal();
try self.spv.globals.section.emit(self.spv.gpa, .OpVariable, .{
@@ -1928,11 +1934,15 @@ const DeclGen = struct {
return try self.convertToDirect(result_ty, result_id);
}
- fn load(self: *DeclGen, value_ty: Type, ptr_id: IdRef, is_volatile: bool) !IdRef {
+ const MemoryOptions = struct {
+ is_volatile: bool = false,
+ };
+
+ fn load(self: *DeclGen, value_ty: Type, ptr_id: IdRef, options: MemoryOptions) !IdRef {
const indirect_value_ty_ref = try self.resolveType(value_ty, .indirect);
const result_id = self.spv.allocId();
const access = spec.MemoryAccess.Extended{
- .Volatile = is_volatile,
+ .Volatile = options.is_volatile,
};
try self.func.body.emit(self.spv.gpa, .OpLoad, .{
.id_result_type = self.typeId(indirect_value_ty_ref),
@@ -1943,10 +1953,10 @@ const DeclGen = struct {
return try self.convertToDirect(value_ty, result_id);
}
- fn store(self: *DeclGen, value_ty: Type, ptr_id: IdRef, value_id: IdRef, is_volatile: bool) !void {
+ fn store(self: *DeclGen, value_ty: Type, ptr_id: IdRef, value_id: IdRef, options: MemoryOptions) !void {
const indirect_value_id = try self.convertToIndirect(value_ty, value_id);
const access = spec.MemoryAccess.Extended{
- .Volatile = is_volatile,
+ .Volatile = options.is_volatile,
};
try self.func.body.emit(self.spv.gpa, .OpStore, .{
.pointer = ptr_id,
@@ -2118,9 +2128,7 @@ const DeclGen = struct {
constituent.* = try self.convertToIndirect(child_ty, result_id);
}
- const result_ty = try self.resolveType(child_ty, .indirect);
- const result_ty_ref = try self.spv.arrayType(vector_len, result_ty);
- return try self.constructArray(result_ty_ref, constituents);
+ return try self.constructArray(ty, constituents);
}
const result_id = self.spv.allocId();
@@ -2181,7 +2189,7 @@ const DeclGen = struct {
const info = try self.arithmeticTypeInfo(result_ty);
// TODO: Use fmin for OpenCL
- const cmp_id = try self.cmp(op, result_ty, lhs_id, rhs_id);
+ const cmp_id = try self.cmp(op, Type.bool, result_ty, lhs_id, rhs_id);
const selection_id = switch (info.class) {
.float => blk: {
// cmp uses OpFOrd. When we have 0 [<>] nan this returns false,
@@ -2316,7 +2324,7 @@ const DeclGen = struct {
constituent.* = try self.arithOp(child_ty, lhs_index_id, rhs_index_id, fop, sop, uop, modular);
}
- return self.constructArray(result_ty_ref, constituents);
+ return self.constructArray(ty, constituents);
}
// Binary operations are generally applicable to both scalar and vector operations
@@ -2472,11 +2480,11 @@ const DeclGen = struct {
// Construct the struct that Zig wants as result.
// The value should already be the correct type.
const ov_id = try self.intFromBool(ov_ty_ref, overflowed_id);
- const result_ty_ref = try self.resolveType(result_ty, .direct);
- return try self.constructStruct(result_ty_ref, &.{
- value_id,
- ov_id,
- });
+ return try self.constructStruct(
+ result_ty,
+ &.{ operand_ty, ov_ty },
+ &.{ value_id, ov_id },
+ );
}
fn airShuffle(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -2634,6 +2642,7 @@ const DeclGen = struct {
fn cmp(
self: *DeclGen,
op: std.math.CompareOperator,
+ result_ty: Type,
ty: Type,
lhs_id: IdRef,
rhs_id: IdRef,
@@ -2674,7 +2683,7 @@ const DeclGen = struct {
if (ty.optionalReprIsPayload(mod)) {
assert(payload_ty.hasRuntimeBitsIgnoreComptime(mod));
assert(!payload_ty.isSlice(mod));
- return self.cmp(op, payload_ty, lhs_id, rhs_id);
+ return self.cmp(op, Type.bool, payload_ty, lhs_id, rhs_id);
}
const lhs_valid_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod))
@@ -2687,7 +2696,7 @@ const DeclGen = struct {
else
try self.convertToDirect(Type.bool, rhs_id);
- const valid_cmp_id = try self.cmp(op, Type.bool, lhs_valid_id, rhs_valid_id);
+ const valid_cmp_id = try self.cmp(op, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id);
if (!payload_ty.hasRuntimeBitsIgnoreComptime(mod)) {
return valid_cmp_id;
}
@@ -2698,7 +2707,7 @@ const DeclGen = struct {
const lhs_pl_id = try self.extractField(payload_ty, lhs_id, 0);
const rhs_pl_id = try self.extractField(payload_ty, rhs_id, 0);
- const pl_cmp_id = try self.cmp(op, payload_ty, lhs_pl_id, rhs_pl_id);
+ const pl_cmp_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id);
// op == .eq => lhs_valid == rhs_valid && lhs_pl == rhs_pl
// op == .neq => lhs_valid != rhs_valid || lhs_pl != rhs_pl
@@ -2720,7 +2729,6 @@ const DeclGen = struct {
.Vector => {
const child_ty = ty.childType(mod);
const vector_len = ty.vectorLen(mod);
- const bool_ty_ref_indirect = try self.resolveType(Type.bool, .indirect);
var constituents = try self.gpa.alloc(IdRef, vector_len);
defer self.gpa.free(constituents);
@@ -2728,12 +2736,11 @@ const DeclGen = struct {
for (constituents, 0..) |*constituent, i| {
const lhs_index_id = try self.extractField(child_ty, cmp_lhs_id, @intCast(i));
const rhs_index_id = try self.extractField(child_ty, cmp_rhs_id, @intCast(i));
- const result_id = try self.cmp(op, child_ty, lhs_index_id, rhs_index_id);
+ const result_id = try self.cmp(op, Type.bool, child_ty, lhs_index_id, rhs_index_id);
constituent.* = try self.convertToIndirect(Type.bool, result_id);
}
- const result_ty_ref = try self.spv.arrayType(vector_len, bool_ty_ref_indirect);
- return try self.constructArray(result_ty_ref, constituents);
+ return try self.constructArray(result_ty, constituents);
},
else => unreachable,
};
@@ -2806,8 +2813,9 @@ const DeclGen = struct {
const lhs_id = try self.resolve(bin_op.lhs);
const rhs_id = try self.resolve(bin_op.rhs);
const ty = self.typeOf(bin_op.lhs);
+ const result_ty = self.typeOfIndex(inst);
- return try self.cmp(op, ty, lhs_id, rhs_id);
+ return try self.cmp(op, result_ty, ty, lhs_id, rhs_id);
}
fn airVectorCmp(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -2819,8 +2827,9 @@ const DeclGen = struct {
const rhs_id = try self.resolve(vec_cmp.rhs);
const op = vec_cmp.compareOperator();
const ty = self.typeOf(vec_cmp.lhs);
+ const result_ty = self.typeOfIndex(inst);
- return try self.cmp(op, ty, lhs_id, rhs_id);
+ return try self.cmp(op, result_ty, ty, lhs_id, rhs_id);
}
fn bitCast(
@@ -2865,23 +2874,17 @@ const DeclGen = struct {
return result_id;
}
- const src_ptr_ty_ref = try self.spv.ptrType(src_ty_ref, .Function);
- const dst_ptr_ty_ref = try self.spv.ptrType(dst_ty_ref, .Function);
+ const dst_ptr_ty_ref = try self.ptrType(dst_ty, .Function);
- const tmp_id = self.spv.allocId();
- try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
- .id_result_type = self.typeId(src_ptr_ty_ref),
- .id_result = tmp_id,
- .storage_class = .Function,
- });
- try self.store(src_ty, tmp_id, src_id, false);
+ const tmp_id = try self.alloc(src_ty, .{ .storage_class = .Function });
+ try self.store(src_ty, tmp_id, src_id, .{});
const casted_ptr_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
.id_result_type = self.typeId(dst_ptr_ty_ref),
.id_result = casted_ptr_id,
.operand = tmp_id,
});
- return try self.load(dst_ty, casted_ptr_id, false);
+ return try self.load(dst_ty, casted_ptr_id, .{});
}
fn airBitCast(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -3060,7 +3063,6 @@ const DeclGen = struct {
const elem_ptr_ty = slice_ty.slicePtrFieldType(mod);
const elem_ptr_ty_ref = try self.resolveType(elem_ptr_ty, .direct);
- const slice_ty_ref = try self.resolveType(slice_ty, .direct);
const size_ty_ref = try self.sizeType();
const array_ptr_id = try self.resolve(ty_op.operand);
@@ -3073,7 +3075,11 @@ const DeclGen = struct {
// Convert the pointer-to-array to a pointer to the first element.
try self.accessChain(elem_ptr_ty_ref, array_ptr_id, &.{0});
- return try self.constructStruct(slice_ty_ref, &.{ elem_ptr_id, len_id });
+ return try self.constructStruct(
+ slice_ty,
+ &.{ elem_ptr_ty, Type.usize },
+ &.{ elem_ptr_id, len_id },
+ );
}
fn airSlice(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -3083,13 +3089,16 @@ const DeclGen = struct {
const bin_op = self.air.extraData(Air.Bin, ty_pl.payload).data;
const ptr_id = try self.resolve(bin_op.lhs);
const len_id = try self.resolve(bin_op.rhs);
+ const ptr_ty = self.typeOf(bin_op.lhs);
const slice_ty = self.typeOfIndex(inst);
- const slice_ty_ref = try self.resolveType(slice_ty, .direct);
- return try self.constructStruct(slice_ty_ref, &.{
- ptr_id, // Note: Type should not need to be converted to direct.
- len_id, // Note: Type should not need to be converted to direct.
- });
+ // Note: Types should not need to be converted to direct, these types
+ // dont need to be converted.
+ return try self.constructStruct(
+ slice_ty,
+ &.{ ptr_ty, Type.usize },
+ &.{ ptr_id, len_id },
+ );
}
fn airAggregateInit(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -3099,7 +3108,6 @@ const DeclGen = struct {
const ip = &mod.intern_pool;
const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
const result_ty = self.typeOfIndex(inst);
- const result_ty_ref = try self.resolveType(result_ty, .direct);
const len: usize = @intCast(result_ty.arrayLen(mod));
const elements: []const Air.Inst.Ref = @ptrCast(self.air.extra[ty_pl.payload..][0..len]);
@@ -3111,6 +3119,8 @@ const DeclGen = struct {
unreachable; // TODO
}
+ const types = try self.gpa.alloc(Type, elements.len);
+ defer self.gpa.free(types);
const constituents = try self.gpa.alloc(IdRef, elements.len);
defer self.gpa.free(constituents);
var index: usize = 0;
@@ -3122,6 +3132,7 @@ const DeclGen = struct {
assert(field_ty.toType().hasRuntimeBits(mod));
const id = try self.resolve(element);
+ types[index] = field_ty.toType();
constituents[index] = try self.convertToIndirect(field_ty.toType(), id);
index += 1;
}
@@ -3135,6 +3146,7 @@ const DeclGen = struct {
assert(field_ty.hasRuntimeBitsIgnoreComptime(mod));
const id = try self.resolve(element);
+ types[index] = field_ty;
constituents[index] = try self.convertToIndirect(field_ty, id);
index += 1;
}
@@ -3142,7 +3154,11 @@ const DeclGen = struct {
else => unreachable,
}
- return try self.constructStruct(result_ty_ref, constituents[0..index]);
+ return try self.constructStruct(
+ result_ty,
+ types[0..index],
+ constituents[0..index],
+ );
},
.Array => {
const array_info = result_ty.arrayInfo(mod);
@@ -3159,7 +3175,7 @@ const DeclGen = struct {
elem_ids[n_elems - 1] = try self.constant(array_info.elem_type, sentinel_val, .indirect);
}
- return try self.constructArray(result_ty_ref, elem_ids);
+ return try self.constructArray(result_ty, elem_ids);
},
else => unreachable,
}
@@ -3244,15 +3260,14 @@ const DeclGen = struct {
const slice_ptr = try self.extractField(ptr_ty, slice_id, 0);
const elem_ptr = try self.ptrAccessChain(ptr_ty_ref, slice_ptr, index_id, &.{});
- return try self.load(slice_ty.childType(mod), elem_ptr, slice_ty.isVolatilePtr(mod));
+ return try self.load(slice_ty.childType(mod), elem_ptr, .{ .is_volatile = slice_ty.isVolatilePtr(mod) });
}
fn ptrElemPtr(self: *DeclGen, ptr_ty: Type, ptr_id: IdRef, index_id: IdRef) !IdRef {
const mod = self.module;
// Construct new pointer type for the resulting pointer
const elem_ty = ptr_ty.elemType2(mod); // use elemType() so that we get T for *[N]T.
- const elem_ty_ref = try self.resolveType(elem_ty, .direct);
- const elem_ptr_ty_ref = try self.spv.ptrType(elem_ty_ref, spvStorageClass(ptr_ty.ptrAddressSpace(mod)));
+ const elem_ptr_ty_ref = try self.ptrType(elem_ty, spvStorageClass(ptr_ty.ptrAddressSpace(mod)));
if (ptr_ty.isSinglePointer(mod)) {
// Pointer-to-array. In this case, the resulting pointer is not of the same type
// as the ptr_ty (we want a *T, not a *[N]T), and hence we need to use accessChain.
@@ -3288,9 +3303,7 @@ const DeclGen = struct {
const mod = self.module;
const bin_op = self.air.instructions.items(.data)[inst].bin_op;
const array_ty = self.typeOf(bin_op.lhs);
- const array_ty_ref = try self.resolveType(array_ty, .direct);
const elem_ty = array_ty.childType(mod);
- const elem_ty_ref = try self.resolveType(elem_ty, .indirect);
const array_id = try self.resolve(bin_op.lhs);
const index_id = try self.resolve(bin_op.rhs);
@@ -3298,22 +3311,12 @@ const DeclGen = struct {
// For now, just generate a temporary and use that.
// TODO: This backend probably also should use isByRef from llvm...
- const array_ptr_ty_ref = try self.spv.ptrType(array_ty_ref, .Function);
- const elem_ptr_ty_ref = try self.spv.ptrType(elem_ty_ref, .Function);
-
- const tmp_id = self.spv.allocId();
- try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
- .id_result_type = self.typeId(array_ptr_ty_ref),
- .id_result = tmp_id,
- .storage_class = .Function,
- });
- try self.func.body.emit(self.spv.gpa, .OpStore, .{
- .pointer = tmp_id,
- .object = array_id,
- });
+ const elem_ptr_ty_ref = try self.ptrType(elem_ty, .Function);
+ const tmp_id = try self.alloc(array_ty, .{ .storage_class = .Function });
+ try self.store(array_ty, tmp_id, array_id, .{});
const elem_ptr_id = try self.accessChainId(elem_ptr_ty_ref, tmp_id, &.{index_id});
- return try self.load(elem_ty, elem_ptr_id, false);
+ return try self.load(elem_ty, elem_ptr_id, .{});
}
fn airPtrElemVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -3326,7 +3329,7 @@ const DeclGen = struct {
const ptr_id = try self.resolve(bin_op.lhs);
const index_id = try self.resolve(bin_op.rhs);
const elem_ptr_id = try self.ptrElemPtr(ptr_ty, ptr_id, index_id);
- return try self.load(elem_ty, elem_ptr_id, ptr_ty.isVolatilePtr(mod));
+ return try self.load(elem_ty, elem_ptr_id, .{ .is_volatile = ptr_ty.isVolatilePtr(mod) });
}
fn airSetUnionTag(self: *DeclGen, inst: Air.Inst.Index) !void {
@@ -3334,22 +3337,21 @@ const DeclGen = struct {
const bin_op = self.air.instructions.items(.data)[inst].bin_op;
const un_ptr_ty = self.typeOf(bin_op.lhs);
const un_ty = un_ptr_ty.childType(mod);
- const layout = self.unionLayout(un_ty, null);
+ const layout = self.unionLayout(un_ty);
if (layout.tag_size == 0) return;
const tag_ty = un_ty.unionTagTypeSafety(mod).?;
- const tag_ty_ref = try self.resolveType(tag_ty, .indirect);
- const tag_ptr_ty_ref = try self.spv.ptrType(tag_ty_ref, spvStorageClass(un_ptr_ty.ptrAddressSpace(mod)));
+ const tag_ptr_ty_ref = try self.ptrType(tag_ty, spvStorageClass(un_ptr_ty.ptrAddressSpace(mod)));
const union_ptr_id = try self.resolve(bin_op.lhs);
const new_tag_id = try self.resolve(bin_op.rhs);
- if (layout.payload_size == 0) {
- try self.store(tag_ty, union_ptr_id, new_tag_id, un_ptr_ty.isVolatilePtr(mod));
+ if (!layout.has_payload) {
+ try self.store(tag_ty, union_ptr_id, new_tag_id, .{ .is_volatile = un_ptr_ty.isVolatilePtr(mod) });
} else {
const ptr_id = try self.accessChain(tag_ptr_ty_ref, union_ptr_id, &.{layout.tag_index});
- try self.store(tag_ty, ptr_id, new_tag_id, un_ptr_ty.isVolatilePtr(mod));
+ try self.store(tag_ty, ptr_id, new_tag_id, .{ .is_volatile = un_ptr_ty.isVolatilePtr(mod) });
}
}
@@ -3360,11 +3362,11 @@ const DeclGen = struct {
const un_ty = self.typeOf(ty_op.operand);
const mod = self.module;
- const layout = self.unionLayout(un_ty, null);
+ const layout = self.unionLayout(un_ty);
if (layout.tag_size == 0) return null;
const union_handle = try self.resolve(ty_op.operand);
- if (layout.payload_size == 0) return union_handle;
+ if (!layout.has_payload) return union_handle;
const tag_ty = un_ty.unionTagTypeSafety(mod).?;
return try self.extractField(tag_ty, union_handle, layout.tag_index);
@@ -3377,8 +3379,8 @@ const DeclGen = struct {
payload: ?IdRef,
) !IdRef {
// To initialize a union, generate a temporary variable with the
- // type that has the right field active, then pointer-cast and store
- // the active field, and finally load and return the entire union.
+ // union type, then get the field pointer and pointer-cast it to the
+ // right type to store it. Finally load the entire union.
const mod = self.module;
const ip = &mod.intern_pool;
@@ -3389,7 +3391,7 @@ const DeclGen = struct {
}
const maybe_tag_ty = ty.unionTagTypeSafety(mod);
- const layout = self.unionLayout(ty, active_field);
+ const layout = self.unionLayout(ty);
const tag_int = if (layout.tag_size != 0) blk: {
const tag_ty = maybe_tag_ty.?;
@@ -3400,42 +3402,34 @@ const DeclGen = struct {
break :blk tag_int_val.toUnsignedInt(mod);
} else 0;
- if (layout.payload_size == 0) {
+ if (!layout.has_payload) {
const tag_ty_ref = try self.resolveType(maybe_tag_ty.?, .direct);
return try self.constInt(tag_ty_ref, tag_int);
}
- const un_active_ty_ref = try self.resolveUnionType(ty, active_field);
- const un_active_ptr_ty_ref = try self.spv.ptrType(un_active_ty_ref, .Function);
- const un_general_ty_ref = try self.resolveType(ty, .direct);
- const un_general_ptr_ty_ref = try self.spv.ptrType(un_general_ty_ref, .Function);
-
- const tmp_id = self.spv.allocId();
- try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
- .id_result_type = self.typeId(un_active_ptr_ty_ref),
- .id_result = tmp_id,
- .storage_class = .Function,
- });
+ const tmp_id = try self.alloc(ty, .{ .storage_class = .Function });
if (layout.tag_size != 0) {
const tag_ty_ref = try self.resolveType(maybe_tag_ty.?, .direct);
- const tag_ptr_ty_ref = try self.spv.ptrType(tag_ty_ref, .Function);
+ const tag_ptr_ty_ref = try self.ptrType(maybe_tag_ty.?, .Function);
const ptr_id = try self.accessChain(tag_ptr_ty_ref, tmp_id, &.{@as(u32, @intCast(layout.tag_index))});
const tag_id = try self.constInt(tag_ty_ref, tag_int);
- try self.func.body.emit(self.spv.gpa, .OpStore, .{
- .pointer = ptr_id,
- .object = tag_id,
- });
+ try self.store(maybe_tag_ty.?, ptr_id, tag_id, .{});
}
- if (layout.active_field_size != 0) {
- const active_field_ty_ref = try self.resolveType(layout.active_field_ty, .indirect);
- const active_field_ptr_ty_ref = try self.spv.ptrType(active_field_ty_ref, .Function);
- const ptr_id = try self.accessChain(active_field_ptr_ty_ref, tmp_id, &.{@as(u32, @intCast(layout.active_field_index))});
- try self.func.body.emit(self.spv.gpa, .OpStore, .{
- .pointer = ptr_id,
- .object = payload.?,
+ const payload_ty = union_ty.field_types.get(ip)[active_field].toType();
+ if (payload_ty.hasRuntimeBitsIgnoreComptime(mod)) {
+ const pl_ptr_ty_ref = try self.ptrType(layout.payload_ty, .Function);
+ const pl_ptr_id = try self.accessChain(pl_ptr_ty_ref, tmp_id, &.{layout.payload_index});
+ const active_pl_ptr_ty_ref = try self.ptrType(payload_ty, .Function);
+ const active_pl_ptr_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
+ .id_result_type = self.typeId(active_pl_ptr_ty_ref),
+ .id_result = active_pl_ptr_id,
+ .operand = pl_ptr_id,
});
+
+ try self.store(payload_ty, active_pl_ptr_id, payload.?, .{});
} else {
assert(payload == null);
}
@@ -3443,34 +3437,21 @@ const DeclGen = struct {
// Just leave the padding fields uninitialized...
// TODO: Or should we initialize them with undef explicitly?
- // Now cast the pointer and load it as the 'generic' union type.
-
- const casted_var_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
- .id_result_type = self.typeId(un_general_ptr_ty_ref),
- .id_result = casted_var_id,
- .operand = tmp_id,
- });
-
- const result_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, .OpLoad, .{
- .id_result_type = self.typeId(un_general_ty_ref),
- .id_result = result_id,
- .pointer = casted_var_id,
- });
-
- return result_id;
+ return try self.load(ty, tmp_id, .{});
}
fn airUnionInit(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
if (self.liveness.isUnused(inst)) return null;
+ const mod = self.module;
+ const ip = &mod.intern_pool;
const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
const extra = self.air.extraData(Air.UnionInit, ty_pl.payload).data;
const ty = self.typeOfIndex(inst);
- const layout = self.unionLayout(ty, extra.field_index);
- const payload = if (layout.active_field_size != 0)
+ const union_obj = mod.typeToUnion(ty).?;
+ const field_ty = union_obj.field_types.get(ip)[extra.field_index].toType();
+ const payload = if (field_ty.hasRuntimeBitsIgnoreComptime(mod))
try self.resolve(extra.init)
else
null;
@@ -3499,30 +3480,24 @@ const DeclGen = struct {
.Union => switch (object_ty.containerLayout(mod)) {
.Packed => unreachable, // TODO
else => {
- // Store, pointer-cast, load
- const un_general_ty_ref = try self.resolveType(object_ty, .indirect);
- const un_general_ptr_ty_ref = try self.spv.ptrType(un_general_ty_ref, .Function);
- const un_active_ty_ref = try self.resolveUnionType(object_ty, field_index);
- const un_active_ptr_ty_ref = try self.spv.ptrType(un_active_ty_ref, .Function);
- const field_ty_ref = try self.resolveType(field_ty, .indirect);
- const field_ptr_ty_ref = try self.spv.ptrType(field_ty_ref, .Function);
-
- const tmp_id = self.spv.allocId();
- try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
- .id_result_type = self.typeId(un_general_ptr_ty_ref),
- .id_result = tmp_id,
- .storage_class = .Function,
- });
- try self.store(object_ty, tmp_id, object_id, false);
- const casted_tmp_id = self.spv.allocId();
+ // Store, ptr-elem-ptr, pointer-cast, load
+ const layout = self.unionLayout(object_ty);
+ assert(layout.has_payload);
+
+ const tmp_id = try self.alloc(object_ty, .{ .storage_class = .Function });
+ try self.store(object_ty, tmp_id, object_id, .{});
+
+ const pl_ptr_ty_ref = try self.ptrType(layout.payload_ty, .Function);
+ const pl_ptr_id = try self.accessChain(pl_ptr_ty_ref, tmp_id, &.{layout.payload_index});
+
+ const active_pl_ptr_ty_ref = try self.ptrType(field_ty, .Function);
+ const active_pl_ptr_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
- .id_result_type = self.typeId(un_active_ptr_ty_ref),
- .id_result = casted_tmp_id,
- .operand = tmp_id,
+ .id_result_type = self.typeId(active_pl_ptr_ty_ref),
+ .id_result = active_pl_ptr_id,
+ .operand = pl_ptr_id,
});
- const layout = self.unionLayout(object_ty, field_index);
- const field_ptr_id = try self.accessChain(field_ptr_ty_ref, casted_tmp_id, &.{layout.active_field_index});
- return try self.load(field_ty, field_ptr_id, false);
+ return try self.load(field_ty, active_pl_ptr_id, .{});
},
},
else => unreachable,
@@ -3581,18 +3556,24 @@ const DeclGen = struct {
.Union => switch (object_ty.containerLayout(mod)) {
.Packed => unreachable, // TODO
else => {
+ const layout = self.unionLayout(object_ty);
+ if (!layout.has_payload) {
+ // Asked to get a pointer to a zero-sized field. Just lower this
+ // to undefined, there is no reason to make it be a valid pointer.
+ return try self.spv.constUndef(result_ty_ref);
+ }
+
const storage_class = spvStorageClass(object_ptr_ty.ptrAddressSpace(mod));
- const un_active_ty_ref = try self.resolveUnionType(object_ty, field_index);
- const un_active_ptr_ty_ref = try self.spv.ptrType(un_active_ty_ref, storage_class);
+ const pl_ptr_ty_ref = try self.ptrType(layout.payload_ty, storage_class);
+ const pl_ptr_id = try self.accessChain(pl_ptr_ty_ref, object_ptr, &.{layout.payload_index});
- const casted_id = self.spv.allocId();
+ const active_pl_ptr_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
- .id_result_type = self.typeId(un_active_ptr_ty_ref),
- .id_result = casted_id,
- .operand = object_ptr,
+ .id_result_type = self.typeId(result_ty_ref),
+ .id_result = active_pl_ptr_id,
+ .operand = pl_ptr_id,
});
- const layout = self.unionLayout(object_ty, field_index);
- return try self.accessChain(result_ty_ref, casted_id, &.{layout.active_field_index});
+ return active_pl_ptr_id;
},
},
else => unreachable,
@@ -3608,23 +3589,13 @@ const DeclGen = struct {
return try self.structFieldPtr(result_ptr_ty, struct_ptr_ty, struct_ptr, field_index);
}
- /// We cannot use an OpVariable directly in an OpSpecConstantOp, but we can
- /// after we insert a dummy AccessChain...
- /// TODO: Get rid of this
- fn makePointerConstant(
- self: *DeclGen,
- section: *SpvSection,
- ptr_ty_ref: CacheRef,
- ptr_id: IdRef,
- ) !IdRef {
- const result_id = self.spv.allocId();
- try section.emitSpecConstantOp(self.spv.gpa, .OpInBoundsAccessChain, .{
- .id_result_type = self.typeId(ptr_ty_ref),
- .id_result = result_id,
- .base = ptr_id,
- });
- return result_id;
- }
+ const AllocOptions = struct {
+ initializer: ?IdRef = null,
+ /// The final storage class of the pointer. This may be either `.Generic` or `.Function`.
+ /// In either case, the local is allocated in the `.Function` storage class, and optionally
+ /// cast back to `.Generic`.
+ storage_class: StorageClass = .Generic,
+ };
// Allocate a function-local variable, with possible initializer.
// This function returns a pointer to a variable of type `ty_ref`,
@@ -3632,30 +3603,36 @@ const DeclGen = struct {
// placed in the Function address space.
fn alloc(
self: *DeclGen,
- ty_ref: CacheRef,
- initializer: ?IdRef,
+ ty: Type,
+ options: AllocOptions,
) !IdRef {
- const fn_ptr_ty_ref = try self.spv.ptrType(ty_ref, .Function);
- const general_ptr_ty_ref = try self.spv.ptrType(ty_ref, .Generic);
+ const ptr_fn_ty_ref = try self.ptrType(ty, .Function);
// SPIR-V requires that OpVariable declarations for locals go into the first block, so we are just going to
// directly generate them into func.prologue instead of the body.
const var_id = self.spv.allocId();
try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
- .id_result_type = self.typeId(fn_ptr_ty_ref),
+ .id_result_type = self.typeId(ptr_fn_ty_ref),
.id_result = var_id,
.storage_class = .Function,
- .initializer = initializer,
+ .initializer = options.initializer,
});
- // Convert to a generic pointer
- const result_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, .OpPtrCastToGeneric, .{
- .id_result_type = self.typeId(general_ptr_ty_ref),
- .id_result = result_id,
- .pointer = var_id,
- });
- return result_id;
+ switch (options.storage_class) {
+ .Generic => {
+ const ptr_gn_ty_ref = try self.ptrType(ty, .Generic);
+ // Convert to a generic pointer
+ const result_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpPtrCastToGeneric, .{
+ .id_result_type = self.typeId(ptr_gn_ty_ref),
+ .id_result = result_id,
+ .pointer = var_id,
+ });
+ return result_id;
+ },
+ .Function => return var_id,
+ else => unreachable,
+ }
}
fn airAlloc(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -3664,8 +3641,7 @@ const DeclGen = struct {
const ptr_ty = self.typeOfIndex(inst);
assert(ptr_ty.ptrAddressSpace(mod) == .generic);
const child_ty = ptr_ty.childType(mod);
- const child_ty_ref = try self.resolveType(child_ty, .indirect);
- return try self.alloc(child_ty_ref, null);
+ return try self.alloc(child_ty, .{});
}
fn airArg(self: *DeclGen) IdRef {
@@ -3780,7 +3756,7 @@ const DeclGen = struct {
const operand = try self.resolve(ty_op.operand);
if (!ptr_ty.isVolatilePtr(mod) and self.liveness.isUnused(inst)) return null;
- return try self.load(elem_ty, operand, ptr_ty.isVolatilePtr(mod));
+ return try self.load(elem_ty, operand, .{ .is_volatile = ptr_ty.isVolatilePtr(mod) });
}
fn airStore(self: *DeclGen, inst: Air.Inst.Index) !void {
@@ -3790,7 +3766,7 @@ const DeclGen = struct {
const ptr = try self.resolve(bin_op.lhs);
const value = try self.resolve(bin_op.rhs);
- try self.store(elem_ty, ptr, value, ptr_ty.isVolatilePtr(self.module));
+ try self.store(elem_ty, ptr, value, .{ .is_volatile = ptr_ty.isVolatilePtr(self.module) });
}
fn airLoop(self: *DeclGen, inst: Air.Inst.Index) !void {
@@ -3854,7 +3830,7 @@ const DeclGen = struct {
}
const ptr = try self.resolve(un_op);
- const value = try self.load(ret_ty, ptr, ptr_ty.isVolatilePtr(mod));
+ const value = try self.load(ret_ty, ptr, .{ .is_volatile = ptr_ty.isVolatilePtr(mod) });
try self.func.body.emit(self.spv.gpa, .OpReturnValue, .{
.value = value,
});
@@ -3980,8 +3956,11 @@ const DeclGen = struct {
members[eu_layout.errorFieldIndex()] = operand_id;
members[eu_layout.payloadFieldIndex()] = try self.spv.constUndef(payload_ty_ref);
- const err_union_ty_ref = try self.resolveType(err_union_ty, .direct);
- return try self.constructStruct(err_union_ty_ref, &members);
+ var types: [2]Type = undefined;
+ types[eu_layout.errorFieldIndex()] = Type.anyerror;
+ types[eu_layout.payloadFieldIndex()] = payload_ty;
+
+ return try self.constructStruct(err_union_ty, &types, &members);
}
fn airWrapErrUnionPayload(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -4002,8 +3981,11 @@ const DeclGen = struct {
members[eu_layout.errorFieldIndex()] = try self.constInt(err_ty_ref, 0);
members[eu_layout.payloadFieldIndex()] = try self.convertToIndirect(payload_ty, operand_id);
- const err_union_ty_ref = try self.resolveType(err_union_ty, .direct);
- return try self.constructStruct(err_union_ty_ref, &members);
+ var types: [2]Type = undefined;
+ types[eu_layout.errorFieldIndex()] = Type.anyerror;
+ types[eu_layout.payloadFieldIndex()] = payload_ty;
+
+ return try self.constructStruct(err_union_ty, &types, &members);
}
fn airIsNull(self: *DeclGen, inst: Air.Inst.Index, pred: enum { is_null, is_non_null }) !?IdRef {
@@ -4037,7 +4019,7 @@ const DeclGen = struct {
.is_null => .eq,
.is_non_null => .neq,
};
- return try self.cmp(op, ptr_ty, ptr_id, null_id);
+ return try self.cmp(op, Type.bool, ptr_ty, ptr_id, null_id);
}
const is_non_null_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod))
@@ -4135,10 +4117,10 @@ const DeclGen = struct {
return operand_id;
}
- const optional_ty_ref = try self.resolveType(optional_ty, .direct);
const payload_id = try self.convertToIndirect(payload_ty, operand_id);
const members = [_]IdRef{ payload_id, try self.constBool(true, .indirect) };
- return try self.constructStruct(optional_ty_ref, &members);
+ const types = [_]Type{ payload_ty, Type.bool };
+ return try self.constructStruct(optional_ty, &types, &members);
}
fn airSwitchBr(self: *DeclGen, inst: Air.Inst.Index) !void {
@@ -4420,6 +4402,7 @@ const DeclGen = struct {
}
// TODO: Multiple results
+ // TODO: Check that the output type from assembly is the same as the type actually expected by Zig.
}
return null;