aboutsummaryrefslogtreecommitdiff
path: root/src/codegen
diff options
context:
space:
mode:
authorRobin Voetter <robin@voetter.nl>2024-06-04 22:09:15 +0200
committerRobin Voetter <robin@voetter.nl>2024-06-10 20:32:50 +0200
commita567f3871ec06f3e6a8c0e6424aba556f1069ccc (patch)
treee319b045727eeaa9b701c06bc7fcde198525365f /src/codegen
parenta3b1ba82f57d5d8981a471850cbbb0db29c3a479 (diff)
downloadzig-a567f3871ec06f3e6a8c0e6424aba556f1069ccc.tar.gz
zig-a567f3871ec06f3e6a8c0e6424aba556f1069ccc.zip
spirv: improve shuffle codegen
Diffstat (limited to 'src/codegen')
-rw-r--r--src/codegen/spirv.zig63
1 files changed, 55 insertions, 8 deletions
diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig
index 215a9421f1..09185211ef 100644
--- a/src/codegen/spirv.zig
+++ b/src/codegen/spirv.zig
@@ -4082,25 +4082,72 @@ const DeclGen = struct {
const b = try self.resolve(extra.b);
const mask = Value.fromInterned(extra.mask);
- const ty = self.typeOfIndex(inst);
+ // Note: number of components in the result, a, and b may differ.
+ const result_ty = self.typeOfIndex(inst);
+ const a_ty = self.typeOf(extra.a);
+ const b_ty = self.typeOf(extra.b);
+
+ const scalar_ty = result_ty.scalarType(mod);
+ const scalar_ty_id = try self.resolveType(scalar_ty, .direct);
+
+ // If all of the types are SPIR-V vectors, we can use OpVectorShuffle.
+ if (self.isSpvVector(result_ty) and self.isSpvVector(a_ty) and self.isSpvVector(b_ty)) {
+ // The SPIR-V shuffle instruction is similar to the Air instruction, except that the elements are
+ // numbered consecutively instead of using negatives.
+
+ const components = try self.gpa.alloc(Word, result_ty.vectorLen(mod));
+ defer self.gpa.free(components);
+
+ const a_len = a_ty.vectorLen(mod);
+
+ for (components, 0..) |*component, i| {
+ const elem = try mask.elemValue(mod, i);
+ if (elem.isUndef(mod)) {
+ // This is explicitly valid for OpVectorShuffle, it indicates undefined.
+ component.* = 0xFFFF_FFFF;
+ continue;
+ }
+
+ const index = elem.toSignedInt(mod);
+ if (index >= 0) {
+ component.* = @intCast(index);
+ } else {
+ component.* = @intCast(~index + a_len);
+ }
+ }
- var wip = try self.elementWise(ty, true);
- defer wip.deinit();
- for (wip.results, 0..) |*result_id, i| {
+ const result_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpVectorShuffle, .{
+ .id_result_type = try self.resolveType(result_ty, .direct),
+ .id_result = result_id,
+ .vector_1 = a,
+ .vector_2 = b,
+ .components = components,
+ });
+ return result_id;
+ }
+
+ // Fall back to manually extracting and inserting components.
+
+ const components = try self.gpa.alloc(IdRef, result_ty.vectorLen(mod));
+ defer self.gpa.free(components);
+
+ for (components, 0..) |*id, i| {
const elem = try mask.elemValue(mod, i);
if (elem.isUndef(mod)) {
- result_id.* = try self.spv.constUndef(wip.ty_id);
+ id.* = try self.spv.constUndef(scalar_ty_id);
continue;
}
const index = elem.toSignedInt(mod);
if (index >= 0) {
- result_id.* = try self.extractVectorComponent(wip.ty, a, @intCast(index));
+ id.* = try self.extractVectorComponent(scalar_ty, a, @intCast(index));
} else {
- result_id.* = try self.extractVectorComponent(wip.ty, b, @intCast(~index));
+ id.* = try self.extractVectorComponent(scalar_ty, b, @intCast(~index));
}
}
- return try wip.finalize();
+
+ return try self.constructVector(result_ty, components);
}
fn indicesToIds(self: *DeclGen, indices: []const u32) ![]IdRef {