diff options
| author | Robin Voetter <robin@voetter.nl> | 2024-06-04 22:09:15 +0200 |
|---|---|---|
| committer | Robin Voetter <robin@voetter.nl> | 2024-06-10 20:32:50 +0200 |
| commit | a567f3871ec06f3e6a8c0e6424aba556f1069ccc (patch) | |
| tree | e319b045727eeaa9b701c06bc7fcde198525365f /src/codegen | |
| parent | a3b1ba82f57d5d8981a471850cbbb0db29c3a479 (diff) | |
| download | zig-a567f3871ec06f3e6a8c0e6424aba556f1069ccc.tar.gz zig-a567f3871ec06f3e6a8c0e6424aba556f1069ccc.zip | |
spirv: improve shuffle codegen
Diffstat (limited to 'src/codegen')
| -rw-r--r-- | src/codegen/spirv.zig | 63 |
1 files changed, 55 insertions, 8 deletions
diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 215a9421f1..09185211ef 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -4082,25 +4082,72 @@ const DeclGen = struct { const b = try self.resolve(extra.b); const mask = Value.fromInterned(extra.mask); - const ty = self.typeOfIndex(inst); + // Note: number of components in the result, a, and b may differ. + const result_ty = self.typeOfIndex(inst); + const a_ty = self.typeOf(extra.a); + const b_ty = self.typeOf(extra.b); + + const scalar_ty = result_ty.scalarType(mod); + const scalar_ty_id = try self.resolveType(scalar_ty, .direct); + + // If all of the types are SPIR-V vectors, we can use OpVectorShuffle. + if (self.isSpvVector(result_ty) and self.isSpvVector(a_ty) and self.isSpvVector(b_ty)) { + // The SPIR-V shuffle instruction is similar to the Air instruction, except that the elements are + // numbered consecutively instead of using negatives. + + const components = try self.gpa.alloc(Word, result_ty.vectorLen(mod)); + defer self.gpa.free(components); + + const a_len = a_ty.vectorLen(mod); + + for (components, 0..) |*component, i| { + const elem = try mask.elemValue(mod, i); + if (elem.isUndef(mod)) { + // This is explicitly valid for OpVectorShuffle, it indicates undefined. + component.* = 0xFFFF_FFFF; + continue; + } + + const index = elem.toSignedInt(mod); + if (index >= 0) { + component.* = @intCast(index); + } else { + component.* = @intCast(~index + a_len); + } + } - var wip = try self.elementWise(ty, true); - defer wip.deinit(); - for (wip.results, 0..) |*result_id, i| { + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpVectorShuffle, .{ + .id_result_type = try self.resolveType(result_ty, .direct), + .id_result = result_id, + .vector_1 = a, + .vector_2 = b, + .components = components, + }); + return result_id; + } + + // Fall back to manually extracting and inserting components. + + const components = try self.gpa.alloc(IdRef, result_ty.vectorLen(mod)); + defer self.gpa.free(components); + + for (components, 0..) |*id, i| { const elem = try mask.elemValue(mod, i); if (elem.isUndef(mod)) { - result_id.* = try self.spv.constUndef(wip.ty_id); + id.* = try self.spv.constUndef(scalar_ty_id); continue; } const index = elem.toSignedInt(mod); if (index >= 0) { - result_id.* = try self.extractVectorComponent(wip.ty, a, @intCast(index)); + id.* = try self.extractVectorComponent(scalar_ty, a, @intCast(index)); } else { - result_id.* = try self.extractVectorComponent(wip.ty, b, @intCast(~index)); + id.* = try self.extractVectorComponent(scalar_ty, b, @intCast(~index)); } } - return try wip.finalize(); + + return try self.constructVector(result_ty, components); } fn indicesToIds(self: *DeclGen, indices: []const u32) ![]IdRef { |
