From add2976a9ba76ec661ae5668eb2a8dca2ccfad42 Mon Sep 17 00:00:00 2001 From: mlugg Date: Mon, 26 May 2025 05:07:13 +0100 Subject: compiler: implement better shuffle AIR Runtime `@shuffle` has two cases which backends generally want to handle differently for efficiency: * One runtime vector operand; some result elements may be comptime-known * Two runtime vector operands; some result elements may be undefined The latter case happens if both vectors given to `@shuffle` are runtime-known and they are both used (i.e. the mask refers to them). Otherwise, if the result is not entirely comptime-known, we are in the former case. `Sema` now diffentiates these two cases in the AIR so that backends can easily handle them however they want to. Note that this *doesn't* really involve Sema doing any more work than it would otherwise need to, so there's not really a negative here! Most existing backends have their lowerings for `@shuffle` migrated in this commit. The LLVM backend uses new lowerings suggested by Jacob as ones which it will handle effectively. The x86_64 backend has not yet been migrated; for now there's a panic in there. Jacob will implement that before this is merged anywhere. --- src/codegen/spirv.zig | 74 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 28 deletions(-) (limited to 'src/codegen/spirv.zig') diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 1381a79075..f83c6979ff 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -3252,7 +3252,8 @@ const NavGen = struct { .splat => try self.airSplat(inst), .reduce, .reduce_optimized => try self.airReduce(inst), - .shuffle => try self.airShuffle(inst), + .shuffle_one => try self.airShuffleOne(inst), + .shuffle_two => try self.airShuffleTwo(inst), .ptr_add => try self.airPtrAdd(inst), .ptr_sub => try self.airPtrSub(inst), @@ -4047,40 +4048,57 @@ const NavGen = struct { return result_id; } - fn airShuffle(self: *NavGen, inst: Air.Inst.Index) !?IdRef { - const pt = self.pt; + fn airShuffleOne(ng: *NavGen, inst: Air.Inst.Index) !?IdRef { + const pt = ng.pt; const zcu = pt.zcu; - const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; - const extra = self.air.extraData(Air.Shuffle, ty_pl.payload).data; - const a = try self.resolve(extra.a); - const b = try self.resolve(extra.b); - const mask = Value.fromInterned(extra.mask); + const gpa = zcu.gpa; - // Note: number of components in the result, a, and b may differ. - const result_ty = self.typeOfIndex(inst); - const scalar_ty = result_ty.scalarType(zcu); - const scalar_ty_id = try self.resolveType(scalar_ty, .direct); + const unwrapped = ng.air.unwrapShuffleOne(zcu, inst); + const mask = unwrapped.mask; + const result_ty = unwrapped.result_ty; + const elem_ty = result_ty.childType(zcu); + const operand = try ng.resolve(unwrapped.operand); - const constituents = try self.gpa.alloc(IdRef, result_ty.vectorLen(zcu)); - defer self.gpa.free(constituents); + const constituents = try gpa.alloc(IdRef, mask.len); + defer gpa.free(constituents); - for (constituents, 0..) |*id, i| { - const elem = try mask.elemValue(pt, i); - if (elem.isUndef(zcu)) { - id.* = try self.spv.constUndef(scalar_ty_id); - continue; - } + for (constituents, mask) |*id, mask_elem| { + id.* = switch (mask_elem.unwrap()) { + .elem => |idx| try ng.extractVectorComponent(elem_ty, operand, idx), + .value => |val| try ng.constant(elem_ty, .fromInterned(val), .direct), + }; + } - const index = elem.toSignedInt(zcu); - if (index >= 0) { - id.* = try self.extractVectorComponent(scalar_ty, a, @intCast(index)); - } else { - id.* = try self.extractVectorComponent(scalar_ty, b, @intCast(~index)); - } + const result_ty_id = try ng.resolveType(result_ty, .direct); + return try ng.constructComposite(result_ty_id, constituents); + } + + fn airShuffleTwo(ng: *NavGen, inst: Air.Inst.Index) !?IdRef { + const pt = ng.pt; + const zcu = pt.zcu; + const gpa = zcu.gpa; + + const unwrapped = ng.air.unwrapShuffleTwo(zcu, inst); + const mask = unwrapped.mask; + const result_ty = unwrapped.result_ty; + const elem_ty = result_ty.childType(zcu); + const elem_ty_id = try ng.resolveType(elem_ty, .direct); + const operand_a = try ng.resolve(unwrapped.operand_a); + const operand_b = try ng.resolve(unwrapped.operand_b); + + const constituents = try gpa.alloc(IdRef, mask.len); + defer gpa.free(constituents); + + for (constituents, mask) |*id, mask_elem| { + id.* = switch (mask_elem.unwrap()) { + .a_elem => |idx| try ng.extractVectorComponent(elem_ty, operand_a, idx), + .b_elem => |idx| try ng.extractVectorComponent(elem_ty, operand_b, idx), + .undef => try ng.spv.constUndef(elem_ty_id), + }; } - const result_ty_id = try self.resolveType(result_ty, .direct); - return try self.constructComposite(result_ty_id, constituents); + const result_ty_id = try ng.resolveType(result_ty, .direct); + return try ng.constructComposite(result_ty_id, constituents); } fn indicesToIds(self: *NavGen, indices: []const u32) ![]IdRef { -- cgit v1.2.3