diff options
Diffstat (limited to 'src/codegen/spirv.zig')
| -rw-r--r-- | src/codegen/spirv.zig | 197 |
1 files changed, 76 insertions, 121 deletions
diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index ded73d6afd..451d348c48 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -358,8 +358,7 @@ const DeclGen = struct { const mod = self.module; const ty = mod.intern_pool.typeOf(val).toType(); - const ty_ref = try self.resolveType(ty, .indirect); - const ptr_ty_ref = try self.spv.ptrType(ty_ref, storage_class); + const ptr_ty_ref = try self.ptrType(ty, storage_class); const var_id = self.spv.declPtr(spv_decl_index).result_id; @@ -623,25 +622,15 @@ const DeclGen = struct { /// result_ty_ref must be an array type. /// Constituents should be in `indirect` representation (as the elements of an array should be). /// Result is in `direct` representation. - fn constructArray(self: *DeclGen, result_ty_ref: CacheRef, constituents: []const IdRef) !IdRef { + fn constructArray(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef { // The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which' // operands are not constant. // See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349 // For now, just initialize the struct by setting the fields manually... // TODO: Make this OpCompositeConstruct when we can - // TODO: Make this Function storage type - const ptr_ty_ref = try self.spv.ptrType(result_ty_ref, .Function); - const ptr_composite_id = self.spv.allocId(); - try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{ - .id_result_type = self.typeId(ptr_ty_ref), - .id_result = ptr_composite_id, - .storage_class = .Function, - }); - - const spv_composite_ty = self.spv.cache.lookup(result_ty_ref).array_type; - const elem_ty_ref = spv_composite_ty.element_type; - const ptr_elem_ty_ref = try self.spv.ptrType(elem_ty_ref, .Function); - + const mod = self.module; + const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function }); + const ptr_elem_ty_ref = try self.ptrType(ty.elemType2(mod), .Function); for (constituents, 0..) |constitent_id, index| { const ptr_id = try self.accessChain(ptr_elem_ty_ref, ptr_composite_id, &.{@as(u32, @intCast(index))}); try self.func.body.emit(self.spv.gpa, .OpStore, .{ @@ -649,13 +638,8 @@ const DeclGen = struct { .object = constitent_id, }); } - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLoad, .{ - .id_result_type = self.typeId(result_ty_ref), - .id_result = result_id, - .pointer = ptr_composite_id, - }); - return result_id; + + return try self.load(ty, ptr_composite_id, .{}); } /// This function generates a load for a constant in direct (ie, non-memory) representation. @@ -857,7 +841,7 @@ const DeclGen = struct { else => {}, } - return try self.constructArray(result_ty_ref, constituents); + return try self.constructArray(ty, constituents); }, .struct_type => { const struct_type = mod.typeToStruct(ty).?; @@ -892,7 +876,7 @@ const DeclGen = struct { const active_field = ty.unionTagFieldIndex(un.tag.toValue(), mod).?; const layout = self.unionLayout(ty, active_field); const payload = if (layout.active_field_size != 0) - try self.constant(layout.active_field_ty, un.val.toValue(), .indirect) + try self.constant(layout.active_field_ty, un.val.toValue(), .direct) else null; @@ -934,8 +918,7 @@ const DeclGen = struct { // TODO: Can we consolidate this in ptrElemPtr? const elem_ty = parent_ptr_ty.elemType2(mod); // use elemType() so that we get T for *[N]T. - const elem_ty_ref = try self.resolveType(elem_ty, .direct); - const elem_ptr_ty_ref = try self.spv.ptrType(elem_ty_ref, spvStorageClass(parent_ptr_ty.ptrAddressSpace(mod))); + const elem_ptr_ty_ref = try self.ptrType(elem_ty, spvStorageClass(parent_ptr_ty.ptrAddressSpace(mod))); if (elem_ptr_ty_ref == result_ty_ref) { return elem_ptr_id; @@ -992,8 +975,7 @@ const DeclGen = struct { }; const decl_id = try self.resolveAnonDecl(decl_val, actual_storage_class); - const decl_ty_ref = try self.resolveType(decl_ty, .indirect); - const decl_ptr_ty_ref = try self.spv.ptrType(decl_ty_ref, final_storage_class); + const decl_ptr_ty_ref = try self.ptrType(decl_ty, final_storage_class); const ptr_id = switch (final_storage_class) { .Generic => blk: { @@ -1049,8 +1031,7 @@ const DeclGen = struct { const final_storage_class = spvStorageClass(decl.@"addrspace"); - const decl_ty_ref = try self.resolveType(decl.ty, .indirect); - const decl_ptr_ty_ref = try self.spv.ptrType(decl_ty_ref, final_storage_class); + const decl_ptr_ty_ref = try self.ptrType(decl.ty, final_storage_class); const ptr_id = switch (final_storage_class) { .Generic => blk: { @@ -1118,6 +1099,12 @@ const DeclGen = struct { return try self.intType(.unsigned, self.getTarget().ptrBitWidth()); } + fn ptrType(self: *DeclGen, child_ty: Type, storage_class: StorageClass) !CacheRef { + // TODO: This function will be rewritten so that forward declarations work properly + const child_ty_ref = try self.resolveType(child_ty, .indirect); + return try self.spv.ptrType(child_ty_ref, storage_class); + } + /// Generate a union type, optionally with a known field. If the tag alignment is greater /// than that of the payload, a regular union (non-packed, with both tag and payload), will /// be generated as follows: @@ -1678,7 +1665,7 @@ const DeclGen = struct { /// the name of an error in the text executor. fn generateTestEntryPoint(self: *DeclGen, name: []const u8, spv_test_decl_index: SpvModule.Decl.Index) !void { const anyerror_ty_ref = try self.resolveType(Type.anyerror, .direct); - const ptr_anyerror_ty_ref = try self.spv.ptrType(anyerror_ty_ref, .CrossWorkgroup); + const ptr_anyerror_ty_ref = try self.ptrType(Type.anyerror, .CrossWorkgroup); const void_ty_ref = try self.resolveType(Type.void, .direct); const kernel_proto_ty_ref = try self.spv.resolve(.{ .function_type = .{ @@ -1713,6 +1700,7 @@ const DeclGen = struct { .id_result = error_id, .function = test_id, }); + // Note: Convert to direct not required. try section.emit(self.spv.gpa, .OpStore, .{ .pointer = p_error_id, .object = error_id, @@ -1817,8 +1805,7 @@ const DeclGen = struct { else => final_storage_class, }; - const ty_ref = try self.resolveType(decl.ty, .indirect); - const ptr_ty_ref = try self.spv.ptrType(ty_ref, actual_storage_class); + const ptr_ty_ref = try self.ptrType(decl.ty, actual_storage_class); const begin = self.spv.beginGlobal(); try self.spv.globals.section.emit(self.spv.gpa, .OpVariable, .{ @@ -2113,9 +2100,7 @@ const DeclGen = struct { constituent.* = try self.convertToIndirect(child_ty, result_id); } - const result_ty = try self.resolveType(child_ty, .indirect); - const result_ty_ref = try self.spv.arrayType(vector_len, result_ty); - return try self.constructArray(result_ty_ref, constituents); + return try self.constructArray(ty, constituents); } const result_id = self.spv.allocId(); @@ -2176,7 +2161,7 @@ const DeclGen = struct { const info = try self.arithmeticTypeInfo(result_ty); // TODO: Use fmin for OpenCL - const cmp_id = try self.cmp(op, result_ty, lhs_id, rhs_id); + const cmp_id = try self.cmp(op, Type.bool, result_ty, lhs_id, rhs_id); const selection_id = switch (info.class) { .float => blk: { // cmp uses OpFOrd. When we have 0 [<>] nan this returns false, @@ -2311,7 +2296,7 @@ const DeclGen = struct { constituent.* = try self.arithOp(child_ty, lhs_index_id, rhs_index_id, fop, sop, uop, modular); } - return self.constructArray(result_ty_ref, constituents); + return self.constructArray(ty, constituents); } // Binary operations are generally applicable to both scalar and vector operations @@ -2629,6 +2614,7 @@ const DeclGen = struct { fn cmp( self: *DeclGen, op: std.math.CompareOperator, + result_ty: Type, ty: Type, lhs_id: IdRef, rhs_id: IdRef, @@ -2669,7 +2655,7 @@ const DeclGen = struct { if (ty.optionalReprIsPayload(mod)) { assert(payload_ty.hasRuntimeBitsIgnoreComptime(mod)); assert(!payload_ty.isSlice(mod)); - return self.cmp(op, payload_ty, lhs_id, rhs_id); + return self.cmp(op, Type.bool, payload_ty, lhs_id, rhs_id); } const lhs_valid_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod)) @@ -2682,7 +2668,7 @@ const DeclGen = struct { else try self.convertToDirect(Type.bool, rhs_id); - const valid_cmp_id = try self.cmp(op, Type.bool, lhs_valid_id, rhs_valid_id); + const valid_cmp_id = try self.cmp(op, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id); if (!payload_ty.hasRuntimeBitsIgnoreComptime(mod)) { return valid_cmp_id; } @@ -2693,7 +2679,7 @@ const DeclGen = struct { const lhs_pl_id = try self.extractField(payload_ty, lhs_id, 0); const rhs_pl_id = try self.extractField(payload_ty, rhs_id, 0); - const pl_cmp_id = try self.cmp(op, payload_ty, lhs_pl_id, rhs_pl_id); + const pl_cmp_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id); // op == .eq => lhs_valid == rhs_valid && lhs_pl == rhs_pl // op == .neq => lhs_valid != rhs_valid || lhs_pl != rhs_pl @@ -2715,7 +2701,6 @@ const DeclGen = struct { .Vector => { const child_ty = ty.childType(mod); const vector_len = ty.vectorLen(mod); - const bool_ty_ref_indirect = try self.resolveType(Type.bool, .indirect); var constituents = try self.gpa.alloc(IdRef, vector_len); defer self.gpa.free(constituents); @@ -2723,12 +2708,11 @@ const DeclGen = struct { for (constituents, 0..) |*constituent, i| { const lhs_index_id = try self.extractField(child_ty, cmp_lhs_id, @intCast(i)); const rhs_index_id = try self.extractField(child_ty, cmp_rhs_id, @intCast(i)); - const result_id = try self.cmp(op, child_ty, lhs_index_id, rhs_index_id); + const result_id = try self.cmp(op, Type.bool, child_ty, lhs_index_id, rhs_index_id); constituent.* = try self.convertToIndirect(Type.bool, result_id); } - const result_ty_ref = try self.spv.arrayType(vector_len, bool_ty_ref_indirect); - return try self.constructArray(result_ty_ref, constituents); + return try self.constructArray(result_ty, constituents); }, else => unreachable, }; @@ -2801,8 +2785,9 @@ const DeclGen = struct { const lhs_id = try self.resolve(bin_op.lhs); const rhs_id = try self.resolve(bin_op.rhs); const ty = self.typeOf(bin_op.lhs); + const result_ty = self.typeOfIndex(inst); - return try self.cmp(op, ty, lhs_id, rhs_id); + return try self.cmp(op, result_ty, ty, lhs_id, rhs_id); } fn airVectorCmp(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { @@ -2814,8 +2799,9 @@ const DeclGen = struct { const rhs_id = try self.resolve(vec_cmp.rhs); const op = vec_cmp.compareOperator(); const ty = self.typeOf(vec_cmp.lhs); + const result_ty = self.typeOfIndex(inst); - return try self.cmp(op, ty, lhs_id, rhs_id); + return try self.cmp(op, result_ty, ty, lhs_id, rhs_id); } fn bitCast( @@ -2860,15 +2846,9 @@ const DeclGen = struct { return result_id; } - const src_ptr_ty_ref = try self.spv.ptrType(src_ty_ref, .Function); - const dst_ptr_ty_ref = try self.spv.ptrType(dst_ty_ref, .Function); + const dst_ptr_ty_ref = try self.ptrType(dst_ty, .Function); - const tmp_id = self.spv.allocId(); - try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{ - .id_result_type = self.typeId(src_ptr_ty_ref), - .id_result = tmp_id, - .storage_class = .Function, - }); + const tmp_id = try self.alloc(src_ty, .{ .storage_class = .Function }); try self.store(src_ty, tmp_id, src_id, false); const casted_ptr_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpBitcast, .{ @@ -3154,7 +3134,7 @@ const DeclGen = struct { elem_ids[n_elems - 1] = try self.constant(array_info.elem_type, sentinel_val, .indirect); } - return try self.constructArray(result_ty_ref, elem_ids); + return try self.constructArray(result_ty, elem_ids); }, else => unreachable, } @@ -3246,8 +3226,7 @@ const DeclGen = struct { const mod = self.module; // Construct new pointer type for the resulting pointer const elem_ty = ptr_ty.elemType2(mod); // use elemType() so that we get T for *[N]T. - const elem_ty_ref = try self.resolveType(elem_ty, .direct); - const elem_ptr_ty_ref = try self.spv.ptrType(elem_ty_ref, spvStorageClass(ptr_ty.ptrAddressSpace(mod))); + const elem_ptr_ty_ref = try self.ptrType(elem_ty, spvStorageClass(ptr_ty.ptrAddressSpace(mod))); if (ptr_ty.isSinglePointer(mod)) { // Pointer-to-array. In this case, the resulting pointer is not of the same type // as the ptr_ty (we want a *T, not a *[N]T), and hence we need to use accessChain. @@ -3283,9 +3262,7 @@ const DeclGen = struct { const mod = self.module; const bin_op = self.air.instructions.items(.data)[inst].bin_op; const array_ty = self.typeOf(bin_op.lhs); - const array_ty_ref = try self.resolveType(array_ty, .direct); const elem_ty = array_ty.childType(mod); - const elem_ty_ref = try self.resolveType(elem_ty, .indirect); const array_id = try self.resolve(bin_op.lhs); const index_id = try self.resolve(bin_op.rhs); @@ -3293,20 +3270,10 @@ const DeclGen = struct { // For now, just generate a temporary and use that. // TODO: This backend probably also should use isByRef from llvm... - const array_ptr_ty_ref = try self.spv.ptrType(array_ty_ref, .Function); - const elem_ptr_ty_ref = try self.spv.ptrType(elem_ty_ref, .Function); - - const tmp_id = self.spv.allocId(); - try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{ - .id_result_type = self.typeId(array_ptr_ty_ref), - .id_result = tmp_id, - .storage_class = .Function, - }); - try self.func.body.emit(self.spv.gpa, .OpStore, .{ - .pointer = tmp_id, - .object = array_id, - }); + const elem_ptr_ty_ref = try self.ptrType(elem_ty, .Function); + const tmp_id = try self.alloc(array_ty, .{ .storage_class = .Function }); + try self.store(array_ty, tmp_id, array_id, false); const elem_ptr_id = try self.accessChainId(elem_ptr_ty_ref, tmp_id, &.{index_id}); return try self.load(elem_ty, elem_ptr_id, false); } @@ -3334,8 +3301,7 @@ const DeclGen = struct { if (layout.tag_size == 0) return; const tag_ty = un_ty.unionTagTypeSafety(mod).?; - const tag_ty_ref = try self.resolveType(tag_ty, .indirect); - const tag_ptr_ty_ref = try self.spv.ptrType(tag_ty_ref, spvStorageClass(un_ptr_ty.ptrAddressSpace(mod))); + const tag_ptr_ty_ref = try self.ptrType(tag_ty, spvStorageClass(un_ptr_ty.ptrAddressSpace(mod))); const union_ptr_id = try self.resolve(bin_op.lhs); const new_tag_id = try self.resolve(bin_op.rhs); @@ -3400,6 +3366,7 @@ const DeclGen = struct { return try self.constInt(tag_ty_ref, tag_int); } + // TODO: Make this use self.ptrType const un_active_ty_ref = try self.resolveUnionType(ty, active_field); const un_active_ptr_ty_ref = try self.spv.ptrType(un_active_ty_ref, .Function); const un_general_ty_ref = try self.resolveType(ty, .direct); @@ -3414,23 +3381,16 @@ const DeclGen = struct { if (layout.tag_size != 0) { const tag_ty_ref = try self.resolveType(maybe_tag_ty.?, .direct); - const tag_ptr_ty_ref = try self.spv.ptrType(tag_ty_ref, .Function); + const tag_ptr_ty_ref = try self.ptrType(maybe_tag_ty.?, .Function); const ptr_id = try self.accessChain(tag_ptr_ty_ref, tmp_id, &.{@as(u32, @intCast(layout.tag_index))}); const tag_id = try self.constInt(tag_ty_ref, tag_int); - try self.func.body.emit(self.spv.gpa, .OpStore, .{ - .pointer = ptr_id, - .object = tag_id, - }); + try self.store(maybe_tag_ty.?, ptr_id, tag_id, false); } if (layout.active_field_size != 0) { - const active_field_ty_ref = try self.resolveType(layout.active_field_ty, .indirect); - const active_field_ptr_ty_ref = try self.spv.ptrType(active_field_ty_ref, .Function); + const active_field_ptr_ty_ref = try self.ptrType(layout.active_field_ty, .Function); const ptr_id = try self.accessChain(active_field_ptr_ty_ref, tmp_id, &.{@as(u32, @intCast(layout.active_field_index))}); - try self.func.body.emit(self.spv.gpa, .OpStore, .{ - .pointer = ptr_id, - .object = payload.?, - }); + try self.store(layout.active_field_ty, ptr_id, payload.?, false); } else { assert(payload == null); } @@ -3603,23 +3563,13 @@ const DeclGen = struct { return try self.structFieldPtr(result_ptr_ty, struct_ptr_ty, struct_ptr, field_index); } - /// We cannot use an OpVariable directly in an OpSpecConstantOp, but we can - /// after we insert a dummy AccessChain... - /// TODO: Get rid of this - fn makePointerConstant( - self: *DeclGen, - section: *SpvSection, - ptr_ty_ref: CacheRef, - ptr_id: IdRef, - ) !IdRef { - const result_id = self.spv.allocId(); - try section.emitSpecConstantOp(self.spv.gpa, .OpInBoundsAccessChain, .{ - .id_result_type = self.typeId(ptr_ty_ref), - .id_result = result_id, - .base = ptr_id, - }); - return result_id; - } + const AllocOptions = struct { + initializer: ?IdRef = null, + /// The final storage class of the pointer. This may be either `.Generic` or `.Function`. + /// In either case, the local is allocated in the `.Function` storage class, and optionally + /// cast back to `.Generic`. + storage_class: StorageClass = .Generic, + }; // Allocate a function-local variable, with possible initializer. // This function returns a pointer to a variable of type `ty_ref`, @@ -3627,30 +3577,36 @@ const DeclGen = struct { // placed in the Function address space. fn alloc( self: *DeclGen, - ty_ref: CacheRef, - initializer: ?IdRef, + ty: Type, + options: AllocOptions, ) !IdRef { - const fn_ptr_ty_ref = try self.spv.ptrType(ty_ref, .Function); - const general_ptr_ty_ref = try self.spv.ptrType(ty_ref, .Generic); + const ptr_fn_ty_ref = try self.ptrType(ty, .Function); // SPIR-V requires that OpVariable declarations for locals go into the first block, so we are just going to // directly generate them into func.prologue instead of the body. const var_id = self.spv.allocId(); try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{ - .id_result_type = self.typeId(fn_ptr_ty_ref), + .id_result_type = self.typeId(ptr_fn_ty_ref), .id_result = var_id, .storage_class = .Function, - .initializer = initializer, + .initializer = options.initializer, }); - // Convert to a generic pointer - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpPtrCastToGeneric, .{ - .id_result_type = self.typeId(general_ptr_ty_ref), - .id_result = result_id, - .pointer = var_id, - }); - return result_id; + switch (options.storage_class) { + .Generic => { + const ptr_gn_ty_ref = try self.ptrType(ty, .Generic); + // Convert to a generic pointer + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpPtrCastToGeneric, .{ + .id_result_type = self.typeId(ptr_gn_ty_ref), + .id_result = result_id, + .pointer = var_id, + }); + return result_id; + }, + .Function => return var_id, + else => unreachable, + } } fn airAlloc(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { @@ -3659,8 +3615,7 @@ const DeclGen = struct { const ptr_ty = self.typeOfIndex(inst); assert(ptr_ty.ptrAddressSpace(mod) == .generic); const child_ty = ptr_ty.childType(mod); - const child_ty_ref = try self.resolveType(child_ty, .indirect); - return try self.alloc(child_ty_ref, null); + return try self.alloc(child_ty, .{}); } fn airArg(self: *DeclGen) IdRef { @@ -4032,7 +3987,7 @@ const DeclGen = struct { .is_null => .eq, .is_non_null => .neq, }; - return try self.cmp(op, ptr_ty, ptr_id, null_id); + return try self.cmp(op, Type.bool, ptr_ty, ptr_id, null_id); } const is_non_null_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod)) |
