diff options
| author | mlugg <mlugg@mlugg.co.uk> | 2025-05-26 05:07:13 +0100 |
|---|---|---|
| committer | mlugg <mlugg@mlugg.co.uk> | 2025-06-01 08:24:01 +0100 |
| commit | add2976a9ba76ec661ae5668eb2a8dca2ccfad42 (patch) | |
| tree | 54ea8377660395007c1ac5188863f5cbe3a3310e /src/codegen | |
| parent | b48d6ff619de424e664cf11b43e2a03fecbca6ce (diff) | |
| download | zig-add2976a9ba76ec661ae5668eb2a8dca2ccfad42.tar.gz zig-add2976a9ba76ec661ae5668eb2a8dca2ccfad42.zip | |
compiler: implement better shuffle AIR
Runtime `@shuffle` has two cases which backends generally want to handle
differently for efficiency:
* One runtime vector operand; some result elements may be comptime-known
* Two runtime vector operands; some result elements may be undefined
The latter case happens if both vectors given to `@shuffle` are
runtime-known and they are both used (i.e. the mask refers to them).
Otherwise, if the result is not entirely comptime-known, we are in the
former case. `Sema` now diffentiates these two cases in the AIR so that
backends can easily handle them however they want to. Note that this
*doesn't* really involve Sema doing any more work than it would
otherwise need to, so there's not really a negative here!
Most existing backends have their lowerings for `@shuffle` migrated in
this commit. The LLVM backend uses new lowerings suggested by Jacob as
ones which it will handle effectively. The x86_64 backend has not yet
been migrated; for now there's a panic in there. Jacob will implement
that before this is merged anywhere.
Diffstat (limited to 'src/codegen')
| -rw-r--r-- | src/codegen/c.zig | 74 | ||||
| -rw-r--r-- | src/codegen/llvm.zig | 215 | ||||
| -rw-r--r-- | src/codegen/spirv.zig | 74 |
3 files changed, 286 insertions, 77 deletions
diff --git a/src/codegen/c.zig b/src/codegen/c.zig index 8d947ce56a..c68abc06ce 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -3374,7 +3374,8 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, .error_name => try airErrorName(f, inst), .splat => try airSplat(f, inst), .select => try airSelect(f, inst), - .shuffle => try airShuffle(f, inst), + .shuffle_one => try airShuffleOne(f, inst), + .shuffle_two => try airShuffleTwo(f, inst), .reduce => try airReduce(f, inst), .aggregate_init => try airAggregateInit(f, inst), .union_init => try airUnionInit(f, inst), @@ -7163,34 +7164,73 @@ fn airSelect(f: *Function, inst: Air.Inst.Index) !CValue { return local; } -fn airShuffle(f: *Function, inst: Air.Inst.Index) !CValue { +fn airShuffleOne(f: *Function, inst: Air.Inst.Index) !CValue { const pt = f.object.dg.pt; const zcu = pt.zcu; - const ty_pl = f.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; - const extra = f.air.extraData(Air.Shuffle, ty_pl.payload).data; - - const mask = Value.fromInterned(extra.mask); - const lhs = try f.resolveInst(extra.a); - const rhs = try f.resolveInst(extra.b); - const inst_ty = f.typeOfIndex(inst); + const unwrapped = f.air.unwrapShuffleOne(zcu, inst); + const mask = unwrapped.mask; + const operand = try f.resolveInst(unwrapped.operand); + const inst_ty = unwrapped.result_ty; const writer = f.object.writer(); const local = try f.allocLocal(inst, inst_ty); - try reap(f, inst, &.{ extra.a, extra.b }); // local cannot alias operands - for (0..extra.mask_len) |index| { + try reap(f, inst, &.{unwrapped.operand}); // local cannot alias operand + for (mask, 0..) |mask_elem, out_idx| { try f.writeCValue(writer, local, .Other); try writer.writeByte('['); - try f.object.dg.renderValue(writer, try pt.intValue(.usize, index), .Other); + try f.object.dg.renderValue(writer, try pt.intValue(.usize, out_idx), .Other); try writer.writeAll("] = "); + switch (mask_elem.unwrap()) { + .elem => |src_idx| { + try f.writeCValue(writer, operand, .Other); + try writer.writeByte('['); + try f.object.dg.renderValue(writer, try pt.intValue(.usize, src_idx), .Other); + try writer.writeByte(']'); + }, + .value => |val| try f.object.dg.renderValue(writer, .fromInterned(val), .Other), + } + try writer.writeAll(";\n"); + } - const mask_elem = (try mask.elemValue(pt, index)).toSignedInt(zcu); - const src_val = try pt.intValue(.usize, @as(u64, @intCast(mask_elem ^ mask_elem >> 63))); + return local; +} - try f.writeCValue(writer, if (mask_elem >= 0) lhs else rhs, .Other); +fn airShuffleTwo(f: *Function, inst: Air.Inst.Index) !CValue { + const pt = f.object.dg.pt; + const zcu = pt.zcu; + + const unwrapped = f.air.unwrapShuffleTwo(zcu, inst); + const mask = unwrapped.mask; + const operand_a = try f.resolveInst(unwrapped.operand_a); + const operand_b = try f.resolveInst(unwrapped.operand_b); + const inst_ty = unwrapped.result_ty; + const elem_ty = inst_ty.childType(zcu); + + const writer = f.object.writer(); + const local = try f.allocLocal(inst, inst_ty); + try reap(f, inst, &.{ unwrapped.operand_a, unwrapped.operand_b }); // local cannot alias operands + for (mask, 0..) |mask_elem, out_idx| { + try f.writeCValue(writer, local, .Other); try writer.writeByte('['); - try f.object.dg.renderValue(writer, src_val, .Other); - try writer.writeAll("];\n"); + try f.object.dg.renderValue(writer, try pt.intValue(.usize, out_idx), .Other); + try writer.writeAll("] = "); + switch (mask_elem.unwrap()) { + .a_elem => |src_idx| { + try f.writeCValue(writer, operand_a, .Other); + try writer.writeByte('['); + try f.object.dg.renderValue(writer, try pt.intValue(.usize, src_idx), .Other); + try writer.writeByte(']'); + }, + .b_elem => |src_idx| { + try f.writeCValue(writer, operand_b, .Other); + try writer.writeByte('['); + try f.object.dg.renderValue(writer, try pt.intValue(.usize, src_idx), .Other); + try writer.writeByte(']'); + }, + .undef => try f.object.dg.renderUndefValue(writer, elem_ty, .Other), + } + try writer.writeAll(";\n"); } return local; diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 77d8f3ff47..960d5f819b 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -4969,7 +4969,8 @@ pub const FuncGen = struct { .error_name => try self.airErrorName(inst), .splat => try self.airSplat(inst), .select => try self.airSelect(inst), - .shuffle => try self.airShuffle(inst), + .shuffle_one => try self.airShuffleOne(inst), + .shuffle_two => try self.airShuffleTwo(inst), .aggregate_init => try self.airAggregateInit(inst), .union_init => try self.airUnionInit(inst), .prefetch => try self.airPrefetch(inst), @@ -9666,7 +9667,7 @@ pub const FuncGen = struct { const zcu = o.pt.zcu; const ip = &zcu.intern_pool; for (body_tail[1..]) |body_inst| { - switch (fg.liveness.categorizeOperand(fg.air, body_inst, body_tail[0], ip)) { + switch (fg.liveness.categorizeOperand(fg.air, zcu, body_inst, body_tail[0], ip)) { .none => continue, .write, .noret, .complex => return false, .tomb => return true, @@ -10421,42 +10422,192 @@ pub const FuncGen = struct { return self.wip.select(.normal, pred, a, b, ""); } - fn airShuffle(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value { - const o = self.ng.object; + fn airShuffleOne(fg: *FuncGen, inst: Air.Inst.Index) !Builder.Value { + const o = fg.ng.object; const pt = o.pt; const zcu = pt.zcu; - const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; - const extra = self.air.extraData(Air.Shuffle, ty_pl.payload).data; - const a = try self.resolveInst(extra.a); - const b = try self.resolveInst(extra.b); - const mask = Value.fromInterned(extra.mask); - const mask_len = extra.mask_len; - const a_len = self.typeOf(extra.a).vectorLen(zcu); - - // LLVM uses integers larger than the length of the first array to - // index into the second array. This was deemed unnecessarily fragile - // when changing code, so Zig uses negative numbers to index the - // second vector. These start at -1 and go down, and are easiest to use - // with the ~ operator. Here we convert between the two formats. - const values = try self.gpa.alloc(Builder.Constant, mask_len); - defer self.gpa.free(values); - - for (values, 0..) |*val, i| { - const elem = try mask.elemValue(pt, i); - if (elem.isUndef(zcu)) { - val.* = try o.builder.undefConst(.i32); - } else { - const int = elem.toSignedInt(zcu); - const unsigned: u32 = @intCast(if (int >= 0) int else ~int + a_len); - val.* = try o.builder.intConst(.i32, unsigned); + const gpa = zcu.gpa; + + const unwrapped = fg.air.unwrapShuffleOne(zcu, inst); + + const operand = try fg.resolveInst(unwrapped.operand); + const mask = unwrapped.mask; + const operand_ty = fg.typeOf(unwrapped.operand); + const llvm_operand_ty = try o.lowerType(operand_ty); + const llvm_result_ty = try o.lowerType(unwrapped.result_ty); + const llvm_elem_ty = try o.lowerType(unwrapped.result_ty.childType(zcu)); + const llvm_poison_elem = try o.builder.poisonConst(llvm_elem_ty); + const llvm_poison_mask_elem = try o.builder.poisonConst(.i32); + const llvm_mask_ty = try o.builder.vectorType(.normal, @intCast(mask.len), .i32); + + // LLVM requires that the two input vectors have the same length, so lowering isn't trivial. + // And, in the words of jacobly0: "llvm sucks at shuffles so we do have to hold its hand at + // least a bit". So, there are two cases here. + // + // If the operand length equals the mask length, we do just the one `shufflevector`, where + // the second operand is a constant vector with comptime-known elements at the right indices + // and poison values elsewhere (in the indices which won't be selected). + // + // Otherwise, we lower to *two* `shufflevector` instructions. The first shuffles the runtime + // operand with an all-poison vector to extract and correctly position all of the runtime + // elements. We also make a constant vector with all of the comptime elements correctly + // positioned. Then, our second instruction selects elements from those "runtime-or-poison" + // and "comptime-or-poison" vectors to compute the result. + + // This buffer is used primarily for the mask constants. + const llvm_elem_buf = try gpa.alloc(Builder.Constant, mask.len); + defer gpa.free(llvm_elem_buf); + + // ...but first, we'll collect all of the comptime-known values. + var any_defined_comptime_value = false; + for (mask, llvm_elem_buf) |mask_elem, *llvm_elem| { + llvm_elem.* = switch (mask_elem.unwrap()) { + .elem => llvm_poison_elem, + .value => |val| if (!Value.fromInterned(val).isUndef(zcu)) elem: { + any_defined_comptime_value = true; + break :elem try o.lowerValue(val); + } else llvm_poison_elem, + }; + } + // This vector is like the result, but runtime elements are replaced with poison. + const comptime_and_poison: Builder.Value = if (any_defined_comptime_value) vec: { + break :vec try o.builder.vectorValue(llvm_result_ty, llvm_elem_buf); + } else try o.builder.poisonValue(llvm_result_ty); + + if (operand_ty.vectorLen(zcu) == mask.len) { + // input length equals mask/output length, so we lower to one instruction + for (mask, llvm_elem_buf, 0..) |mask_elem, *llvm_elem, elem_idx| { + llvm_elem.* = switch (mask_elem.unwrap()) { + .elem => |idx| try o.builder.intConst(.i32, idx), + .value => |val| if (!Value.fromInterned(val).isUndef(zcu)) mask_val: { + break :mask_val try o.builder.intConst(.i32, mask.len + elem_idx); + } else llvm_poison_mask_elem, + }; } + return fg.wip.shuffleVector( + operand, + comptime_and_poison, + try o.builder.vectorValue(llvm_mask_ty, llvm_elem_buf), + "", + ); + } + + for (mask, llvm_elem_buf) |mask_elem, *llvm_elem| { + llvm_elem.* = switch (mask_elem.unwrap()) { + .elem => |idx| try o.builder.intConst(.i32, idx), + .value => llvm_poison_mask_elem, + }; + } + // This vector is like our result, but all comptime-known elements are poison. + const runtime_and_poison = try fg.wip.shuffleVector( + operand, + try o.builder.poisonValue(llvm_operand_ty), + try o.builder.vectorValue(llvm_mask_ty, llvm_elem_buf), + "", + ); + + if (!any_defined_comptime_value) { + // `comptime_and_poison` is just poison; a second shuffle would be a nop. + return runtime_and_poison; + } + + // In this second shuffle, the inputs, the mask, and the output all have the same length. + for (mask, llvm_elem_buf, 0..) |mask_elem, *llvm_elem, elem_idx| { + llvm_elem.* = switch (mask_elem.unwrap()) { + .elem => try o.builder.intConst(.i32, elem_idx), + .value => |val| if (!Value.fromInterned(val).isUndef(zcu)) mask_val: { + break :mask_val try o.builder.intConst(.i32, mask.len + elem_idx); + } else llvm_poison_mask_elem, + }; } + // Merge the runtime and comptime elements with the mask we just built. + return fg.wip.shuffleVector( + runtime_and_poison, + comptime_and_poison, + try o.builder.vectorValue(llvm_mask_ty, llvm_elem_buf), + "", + ); + } + + fn airShuffleTwo(fg: *FuncGen, inst: Air.Inst.Index) !Builder.Value { + const o = fg.ng.object; + const pt = o.pt; + const zcu = pt.zcu; + const gpa = zcu.gpa; + + const unwrapped = fg.air.unwrapShuffleTwo(zcu, inst); + + const mask = unwrapped.mask; + const llvm_elem_ty = try o.lowerType(unwrapped.result_ty.childType(zcu)); + const llvm_mask_ty = try o.builder.vectorType(.normal, @intCast(mask.len), .i32); + const llvm_poison_mask_elem = try o.builder.poisonConst(.i32); + + // This is kind of simpler than in `airShuffleOne`. We extend the shorter vector to the + // length of the longer one with an initial `shufflevector` if necessary, and then do the + // actual computation with a second `shufflevector`. + + const operand_a_len = fg.typeOf(unwrapped.operand_a).vectorLen(zcu); + const operand_b_len = fg.typeOf(unwrapped.operand_b).vectorLen(zcu); + const operand_len: u32 = @max(operand_a_len, operand_b_len); + + // If we need to extend an operand, this is the type that mask will have. + const llvm_operand_mask_ty = try o.builder.vectorType(.normal, operand_len, .i32); + + const llvm_elem_buf = try gpa.alloc(Builder.Constant, @max(mask.len, operand_len)); + defer gpa.free(llvm_elem_buf); - const llvm_mask_value = try o.builder.vectorValue( - try o.builder.vectorType(.normal, mask_len, .i32), - values, + const operand_a: Builder.Value = extend: { + const raw = try fg.resolveInst(unwrapped.operand_a); + if (operand_a_len == operand_len) break :extend raw; + // Extend with a `shufflevector`, with a mask `<0, 1, ..., n, poison, poison, ..., poison>` + const mask_elems = llvm_elem_buf[0..operand_len]; + for (mask_elems[0..operand_a_len], 0..) |*llvm_elem, elem_idx| { + llvm_elem.* = try o.builder.intConst(.i32, elem_idx); + } + @memset(mask_elems[operand_a_len..], llvm_poison_mask_elem); + const llvm_this_operand_ty = try o.builder.vectorType(.normal, operand_a_len, llvm_elem_ty); + break :extend try fg.wip.shuffleVector( + raw, + try o.builder.poisonValue(llvm_this_operand_ty), + try o.builder.vectorValue(llvm_operand_mask_ty, mask_elems), + "", + ); + }; + const operand_b: Builder.Value = extend: { + const raw = try fg.resolveInst(unwrapped.operand_b); + if (operand_b_len == operand_len) break :extend raw; + // Extend with a `shufflevector`, with a mask `<0, 1, ..., n, poison, poison, ..., poison>` + const mask_elems = llvm_elem_buf[0..operand_len]; + for (mask_elems[0..operand_b_len], 0..) |*llvm_elem, elem_idx| { + llvm_elem.* = try o.builder.intConst(.i32, elem_idx); + } + @memset(mask_elems[operand_b_len..], llvm_poison_mask_elem); + const llvm_this_operand_ty = try o.builder.vectorType(.normal, operand_b_len, llvm_elem_ty); + break :extend try fg.wip.shuffleVector( + raw, + try o.builder.poisonValue(llvm_this_operand_ty), + try o.builder.vectorValue(llvm_operand_mask_ty, mask_elems), + "", + ); + }; + + // `operand_a` and `operand_b` now have the same length (we've extended the shorter one with + // an initial shuffle if necessary). Now for the easy bit. + + const mask_elems = llvm_elem_buf[0..mask.len]; + for (mask, mask_elems) |mask_elem, *llvm_mask_elem| { + llvm_mask_elem.* = switch (mask_elem.unwrap()) { + .a_elem => |idx| try o.builder.intConst(.i32, idx), + .b_elem => |idx| try o.builder.intConst(.i32, operand_len + idx), + .undef => llvm_poison_mask_elem, + }; + } + return fg.wip.shuffleVector( + operand_a, + operand_b, + try o.builder.vectorValue(llvm_mask_ty, mask_elems), + "", ); - return self.wip.shuffleVector(a, b, llvm_mask_value, ""); } /// Reduce a vector by repeatedly applying `llvm_fn` to produce an accumulated result. diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 1381a79075..f83c6979ff 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -3252,7 +3252,8 @@ const NavGen = struct { .splat => try self.airSplat(inst), .reduce, .reduce_optimized => try self.airReduce(inst), - .shuffle => try self.airShuffle(inst), + .shuffle_one => try self.airShuffleOne(inst), + .shuffle_two => try self.airShuffleTwo(inst), .ptr_add => try self.airPtrAdd(inst), .ptr_sub => try self.airPtrSub(inst), @@ -4047,40 +4048,57 @@ const NavGen = struct { return result_id; } - fn airShuffle(self: *NavGen, inst: Air.Inst.Index) !?IdRef { - const pt = self.pt; + fn airShuffleOne(ng: *NavGen, inst: Air.Inst.Index) !?IdRef { + const pt = ng.pt; const zcu = pt.zcu; - const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; - const extra = self.air.extraData(Air.Shuffle, ty_pl.payload).data; - const a = try self.resolve(extra.a); - const b = try self.resolve(extra.b); - const mask = Value.fromInterned(extra.mask); + const gpa = zcu.gpa; - // Note: number of components in the result, a, and b may differ. - const result_ty = self.typeOfIndex(inst); - const scalar_ty = result_ty.scalarType(zcu); - const scalar_ty_id = try self.resolveType(scalar_ty, .direct); + const unwrapped = ng.air.unwrapShuffleOne(zcu, inst); + const mask = unwrapped.mask; + const result_ty = unwrapped.result_ty; + const elem_ty = result_ty.childType(zcu); + const operand = try ng.resolve(unwrapped.operand); - const constituents = try self.gpa.alloc(IdRef, result_ty.vectorLen(zcu)); - defer self.gpa.free(constituents); + const constituents = try gpa.alloc(IdRef, mask.len); + defer gpa.free(constituents); - for (constituents, 0..) |*id, i| { - const elem = try mask.elemValue(pt, i); - if (elem.isUndef(zcu)) { - id.* = try self.spv.constUndef(scalar_ty_id); - continue; - } + for (constituents, mask) |*id, mask_elem| { + id.* = switch (mask_elem.unwrap()) { + .elem => |idx| try ng.extractVectorComponent(elem_ty, operand, idx), + .value => |val| try ng.constant(elem_ty, .fromInterned(val), .direct), + }; + } - const index = elem.toSignedInt(zcu); - if (index >= 0) { - id.* = try self.extractVectorComponent(scalar_ty, a, @intCast(index)); - } else { - id.* = try self.extractVectorComponent(scalar_ty, b, @intCast(~index)); - } + const result_ty_id = try ng.resolveType(result_ty, .direct); + return try ng.constructComposite(result_ty_id, constituents); + } + + fn airShuffleTwo(ng: *NavGen, inst: Air.Inst.Index) !?IdRef { + const pt = ng.pt; + const zcu = pt.zcu; + const gpa = zcu.gpa; + + const unwrapped = ng.air.unwrapShuffleTwo(zcu, inst); + const mask = unwrapped.mask; + const result_ty = unwrapped.result_ty; + const elem_ty = result_ty.childType(zcu); + const elem_ty_id = try ng.resolveType(elem_ty, .direct); + const operand_a = try ng.resolve(unwrapped.operand_a); + const operand_b = try ng.resolve(unwrapped.operand_b); + + const constituents = try gpa.alloc(IdRef, mask.len); + defer gpa.free(constituents); + + for (constituents, mask) |*id, mask_elem| { + id.* = switch (mask_elem.unwrap()) { + .a_elem => |idx| try ng.extractVectorComponent(elem_ty, operand_a, idx), + .b_elem => |idx| try ng.extractVectorComponent(elem_ty, operand_b, idx), + .undef => try ng.spv.constUndef(elem_ty_id), + }; } - const result_ty_id = try self.resolveType(result_ty, .direct); - return try self.constructComposite(result_ty_id, constituents); + const result_ty_id = try ng.resolveType(result_ty, .direct); + return try ng.constructComposite(result_ty_id, constituents); } fn indicesToIds(self: *NavGen, indices: []const u32) ![]IdRef { |
