diff options
| author | Veikka Tuominen <git@vexu.eu> | 2022-03-11 14:25:15 +0200 |
|---|---|---|
| committer | Veikka Tuominen <git@vexu.eu> | 2022-03-11 14:25:15 +0200 |
| commit | 01cd4119b032f13899e9b7b30c7e093620058dfd (patch) | |
| tree | bfafdc11324d3b77ff5af8b59b5ff324db66335b /src/Sema.zig | |
| parent | cba68090a60c3de8eadbf8eb53e37620a1d66683 (diff) | |
| download | zig-01cd4119b032f13899e9b7b30c7e093620058dfd.tar.gz zig-01cd4119b032f13899e9b7b30c7e093620058dfd.zip | |
Sema: implement `@shuffle` at comptime and for differing lengths
Diffstat (limited to 'src/Sema.zig')
| -rw-r--r-- | src/Sema.zig | 83 |
1 files changed, 73 insertions, 10 deletions
diff --git a/src/Sema.zig b/src/Sema.zig index f631bf36b5..fb45c9e9d3 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -13454,8 +13454,6 @@ fn zirShuffle(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air const inst_data = sema.code.instructions.items(.data)[inst].pl_node; const extra = sema.code.extraData(Zir.Inst.Shuffle, inst_data.payload_index).data; const elem_ty_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node }; - const a_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node }; - const b_src: LazySrcLoc = .{ .node_offset_builtin_call_arg2 = inst_data.src_node }; const mask_src: LazySrcLoc = .{ .node_offset_builtin_call_arg3 = inst_data.src_node }; const elem_ty = try sema.resolveType(block, elem_ty_src, extra.elem_type); @@ -13474,6 +13472,25 @@ fn zirShuffle(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air .elem_type = Type.@"i32", }); mask = try sema.coerce(block, mask_ty, mask, mask_src); + const mask_val = try sema.resolveConstMaybeUndefVal(block, mask_src, mask); + return sema.analyzeShuffle(block, inst_data.src_node, elem_ty, a, b, mask_val, @intCast(u32, mask_len)); +} + +fn analyzeShuffle( + sema: *Sema, + block: *Block, + src_node: i32, + elem_ty: Type, + a_arg: Air.Inst.Ref, + b_arg: Air.Inst.Ref, + mask: Value, + mask_len: u32, +) CompileError!Air.Inst.Ref { + const a_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = src_node }; + const b_src: LazySrcLoc = .{ .node_offset_builtin_call_arg2 = src_node }; + const mask_src: LazySrcLoc = .{ .node_offset_builtin_call_arg3 = src_node }; + var a = a_arg; + var b = b_arg; const res_ty = try Type.Tag.vector.create(sema.arena, .{ .len = mask_len, @@ -13485,7 +13502,7 @@ fn zirShuffle(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air .Undefined => null, else => return sema.fail(block, a_src, "expected vector or array with element type {}, found {}", .{ elem_ty, - sema.typeOf(mask), + sema.typeOf(a), }), }; var maybe_b_len = switch (sema.typeOf(b).zigTypeTag()) { @@ -13493,7 +13510,7 @@ fn zirShuffle(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air .Undefined => null, else => return sema.fail(block, b_src, "expected vector or array with element type {}, found {}", .{ elem_ty, - sema.typeOf(mask), + sema.typeOf(b), }), }; if (maybe_a_len == null and maybe_b_len == null) { @@ -13519,11 +13536,10 @@ fn zirShuffle(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air .{ b_len, b_src, b_ty }, }; - const mask_val = try sema.resolveConstMaybeUndefVal(block, mask_src, mask); var i: usize = 0; while (i < mask_len) : (i += 1) { var buf: Value.ElemValueBuffer = undefined; - const elem = mask_val.elemValueBuffer(i, &buf); + const elem = mask.elemValueBuffer(i, &buf); if (elem.isUndef()) continue; const int = elem.toSignedInt(); var unsigned: u32 = undefined; @@ -13555,14 +13571,61 @@ fn zirShuffle(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air } } - // TODO at comptime + if (try sema.resolveMaybeUndefVal(block, a_src, a)) |a_val| { + if (try sema.resolveMaybeUndefVal(block, b_src, b)) |b_val| { + const values = try sema.arena.alloc(Value, mask_len); + + i = 0; + while (i < mask_len) : (i += 1) { + var buf: Value.ElemValueBuffer = undefined; + const mask_elem_val = mask.elemValueBuffer(i, &buf); + if (mask_elem_val.isUndef()) { + values[i] = Value.undef; + continue; + } + const int = mask_elem_val.toSignedInt(); + const unsigned = if (int >= 0) @intCast(u32, int) else @intCast(u32, ~int); + if (int >= 0) { + values[i] = try a_val.elemValue(sema.arena, unsigned); + } else { + values[i] = try b_val.elemValue(sema.arena, unsigned); + } + } + const res_val = try Value.Tag.array.create(sema.arena, values); + return sema.addConstant(res_ty, res_val); + } + } + // All static analysis passed, and not comptime. + // For runtime codegen, vectors a and b must be the same length. Here we + // recursively @shuffle the smaller vector to append undefined elements + // to it up to the length of the longer vector. This recursion terminates + // in 1 call because these calls to analyzeShuffle guarantee a_len == b_len. if (a_len != b_len) { - return sema.fail(block, mask_src, "TODO handle shuffle a_len != b_len", .{}); + const min_len = std.math.min(a_len, b_len); + const max_len = std.math.max(a_len, b_len); + + const expand_mask_values = try sema.arena.alloc(Value, max_len); + i = 0; + while (i < min_len) : (i += 1) { + expand_mask_values[i] = try Value.Tag.int_u64.create(sema.arena, i); + } + while (i < max_len) : (i += 1) { + expand_mask_values[i] = Value.negative_one; + } + const expand_mask = try Value.Tag.array.create(sema.arena, expand_mask_values); + + if (a_len < b_len) { + const undef = try sema.addConstUndef(a_ty); + a = try sema.analyzeShuffle(block, src_node, elem_ty, a, undef, expand_mask, @intCast(u32, max_len)); + } else { + const undef = try sema.addConstUndef(b_ty); + b = try sema.analyzeShuffle(block, src_node, elem_ty, b, undef, expand_mask, @intCast(u32, max_len)); + } } const mask_index = @intCast(u32, sema.air_values.items.len); - try sema.air_values.append(sema.gpa, mask_val); + try sema.air_values.append(sema.gpa, mask); return block.addInst(.{ .tag = .shuffle, .data = .{ .ty_pl = .{ @@ -13571,7 +13634,7 @@ fn zirShuffle(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air .a = a, .b = b, .mask = mask_index, - .mask_len = @intCast(u32, mask_len), + .mask_len = mask_len, }), } }, }); |
