From cb9e20da00a2c33706e2c7bf2008887c6c72a896 Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Mon, 15 Jan 2024 23:06:54 +0100 Subject: spirv: element-wise operation helper --- src/codegen/spirv.zig | 123 +++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 96 insertions(+), 27 deletions(-) (limited to 'src/codegen/spirv.zig') diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 580c3d959a..a8bc385f7a 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -1760,6 +1760,92 @@ const DeclGen = struct { return union_layout; } + /// This structure is used as helper for element-wise operations. It is intended + /// to be used with both vectors and single elements. + const WipElementWise = struct { + dg: *DeclGen, + result_ty: Type, + /// Always in direct representation. + result_ty_ref: CacheRef, + scalar_ty: Type, + /// Always in direct representation. + scalar_ty_ref: CacheRef, + scalar_ty_id: IdRef, + /// True if the input is actually a vector type. + is_vector: bool, + /// The element-wise operation should fill these results before calling finalize(). + /// These should all be in **direct** representation! `finalize()` will convert + /// them to indirect if required. + results: []IdRef, + + fn deinit(wip: *WipElementWise) void { + wip.dg.gpa.free(wip.results); + } + + /// Utility function to extract the element at a particular index in an + /// input vector. This type is expected to be a vector if `wip.is_vector`, and + /// a scalar otherwise. + fn elementAt(wip: WipElementWise, ty: Type, value: IdRef, index: usize) !IdRef { + const mod = wip.dg.module; + if (wip.is_vector) { + assert(ty.isVector(mod)); + return try wip.dg.extractField(ty, value, @intCast(index)); + } else { + assert(!ty.isVector(mod)); + assert(index == 0); + return value; + } + } + + /// Turns the results of this WipElementWise into a result. This can either + /// be a vector or single element, depending on `result_ty`. + /// After calling this function, this WIP is no longer usable. + /// Results is in `direct` representation. + fn finalize(wip: *WipElementWise) !IdRef { + if (wip.is_vector) { + // Convert all the constituents to indirect, as required for the array. + for (wip.results) |*result| { + result.* = try wip.dg.convertToIndirect(wip.scalar_ty, result.*); + } + return try wip.dg.constructArray(wip.result_ty, wip.results); + } else { + return wip.results[0]; + } + } + + /// Allocate a result id at a particular index, and return it. + fn allocId(wip: *WipElementWise, index: usize) IdRef { + assert(wip.is_vector or index == 0); + wip.results[index] = wip.dg.spv.allocId(); + return wip.results[index]; + } + }; + + /// Create a new element-wise operation. + fn elementWise(self: *DeclGen, result_ty: Type) !WipElementWise { + const mod = self.module; + // For now, this operation also reasons in terms of `.direct` representation. + const result_ty_ref = try self.resolveType(result_ty, .direct); + const is_vector = result_ty.isVector(mod); + const num_results = if (is_vector) result_ty.vectorLen(mod) else 1; + const results = try self.gpa.alloc(IdRef, num_results); + for (results) |*result| result.* = undefined; + + const scalar_ty = if (is_vector) result_ty.childType(mod) else result_ty; + const scalar_ty_ref = try self.resolveType(scalar_ty, .direct); + + return .{ + .dg = self, + .result_ty = result_ty, + .result_ty_ref = result_ty_ref, + .scalar_ty = scalar_ty, + .scalar_ty_ref = scalar_ty_ref, + .scalar_ty_id = self.typeId(scalar_ty_ref), + .is_vector = is_vector, + .results = results, + }; + } + /// The SPIR-V backend is not yet advanced enough to support the std testing infrastructure. /// In order to be able to run tests, we "temporarily" lower test kernels into separate entry- /// points. The test executor will then be able to invoke these to run the tests. @@ -2214,34 +2300,17 @@ const DeclGen = struct { } fn binOpSimple(self: *DeclGen, ty: Type, lhs_id: IdRef, rhs_id: IdRef, comptime opcode: Opcode) !IdRef { - const mod = self.module; - - if (ty.isVector(mod)) { - const child_ty = ty.childType(mod); - const vector_len = ty.vectorLen(mod); - - const constituents = try self.gpa.alloc(IdRef, vector_len); - defer self.gpa.free(constituents); - - for (constituents, 0..) |*constituent, i| { - const lhs_index_id = try self.extractField(child_ty, lhs_id, @intCast(i)); - const rhs_index_id = try self.extractField(child_ty, rhs_id, @intCast(i)); - const result_id = try self.binOpSimple(child_ty, lhs_index_id, rhs_index_id, opcode); - constituent.* = try self.convertToIndirect(child_ty, result_id); - } - - return try self.constructArray(ty, constituents); + var wip = try self.elementWise(ty); + defer wip.deinit(); + for (0..wip.results.len) |i| { + try self.func.body.emit(self.spv.gpa, opcode, .{ + .id_result_type = wip.scalar_ty_id, + .id_result = wip.allocId(i), + .operand_1 = try wip.elementAt(ty, lhs_id, i), + .operand_2 = try wip.elementAt(ty, rhs_id, i), + }); } - - const result_id = self.spv.allocId(); - const result_type_id = try self.resolveTypeId(ty); - try self.func.body.emit(self.spv.gpa, opcode, .{ - .id_result_type = result_type_id, - .id_result = result_id, - .operand_1 = lhs_id, - .operand_2 = rhs_id, - }); - return result_id; + return try wip.finalize(); } fn airBinOpSimple(self: *DeclGen, inst: Air.Inst.Index, comptime opcode: Opcode) !?IdRef { -- cgit v1.2.3