aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormlugg <mlugg@mlugg.co.uk>2023-11-07 07:13:25 +0000
committerAndrew Kelley <andrew@ziglang.org>2023-11-08 23:55:53 -0700
commitd99bed1b10f85f2becca2a1c2587e1c2cb1968e6 (patch)
treee23b0ea094d2c4541c8dfa17d2a8681e892349ef
parenta1d688b86aeba7ab72ae12679ff04a7730931af3 (diff)
downloadzig-d99bed1b10f85f2becca2a1c2587e1c2cb1968e6.tar.gz
zig-d99bed1b10f85f2becca2a1c2587e1c2cb1968e6.zip
Sema: optimize runtime array_mul
There are two optimizations here, which work together to avoid a pathological case. The first optimization is that AstGen now records the result type of an array multiplication expression where possible. This type is not used according to the language specification, but instead as an optimization. In the expression '.{x} ** 1000', if we know that the result must be an array, then it is much more efficient to coerce the LHS to an array with length 1 before doing the multiplication. Otherwise, we end up with a 1000-element tuple which we must coerce to an array by individually extracting each field. Secondly, the previous logic would repeatedly extract element/field values from the LHS when initializing the result. This is unnecessary: each element must only be extracted once, and the result reused. These changes together give huge improvements to compiler performance on a pathological case: AIR instructions go from 65551 to 15, and total AIR bytes go from 1.86MiB to 264.57KiB. Codegen time spent on this function (in a debug compiler build) goes from minutes to essentially zero. Resolves: #17586
-rw-r--r--src/AstGen.zig6
-rw-r--r--src/Autodoc.zig1
-rw-r--r--src/Sema.zig70
-rw-r--r--src/Zir.zig11
-rw-r--r--src/print_zir.zig15
5 files changed, 80 insertions, 23 deletions
diff --git a/src/AstGen.zig b/src/AstGen.zig
index 245ec45ea0..5b957e48c5 100644
--- a/src/AstGen.zig
+++ b/src/AstGen.zig
@@ -758,7 +758,11 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE
.array_cat => return simpleBinOp(gz, scope, ri, node, .array_cat),
.array_mult => {
- const result = try gz.addPlNode(.array_mul, node, Zir.Inst.Bin{
+ // This syntax form does not currently use the result type in the language specification.
+ // However, the result type can be used to emit more optimal code for large multiplications by
+ // having Sema perform a coercion before the multiplication operation.
+ const result = try gz.addPlNode(.array_mul, node, Zir.Inst.ArrayMul{
+ .res_ty = if (try ri.rl.resultType(gz, node)) |t| t else .none,
.lhs = try expr(gz, scope, .{ .rl = .none }, node_datas[node].lhs),
.rhs = try comptimeExpr(gz, scope, .{ .rl = .{ .coerced_ty = .usize_type } }, node_datas[node].rhs),
});
diff --git a/src/Autodoc.zig b/src/Autodoc.zig
index 500f42dfd3..cd64b5e2cf 100644
--- a/src/Autodoc.zig
+++ b/src/Autodoc.zig
@@ -1567,7 +1567,6 @@ fn walkInstruction(
.bit_and,
.xor,
.array_cat,
- .array_mul,
=> {
const pl_node = data[@intFromEnum(inst)].pl_node;
const extra = file.zir.extraData(Zir.Inst.Bin, pl_node.payload_index);
diff --git a/src/Sema.zig b/src/Sema.zig
index 9801ce0040..f79f29dc0c 100644
--- a/src/Sema.zig
+++ b/src/Sema.zig
@@ -13998,14 +13998,49 @@ fn zirArrayMul(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
const mod = sema.mod;
const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].pl_node;
- const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
- const lhs = try sema.resolveInst(extra.lhs);
- const lhs_ty = sema.typeOf(lhs);
+ const extra = sema.code.extraData(Zir.Inst.ArrayMul, inst_data.payload_index).data;
+ const uncoerced_lhs = try sema.resolveInst(extra.lhs);
+ const uncoerced_lhs_ty = sema.typeOf(uncoerced_lhs);
const src: LazySrcLoc = inst_data.src();
const lhs_src: LazySrcLoc = .{ .node_offset_bin_lhs = inst_data.src_node };
const operator_src: LazySrcLoc = .{ .node_offset_main_token = inst_data.src_node };
const rhs_src: LazySrcLoc = .{ .node_offset_bin_rhs = inst_data.src_node };
+ const lhs, const lhs_ty = coerced_lhs: {
+ // If we have a result type, we might be able to do this more efficiently
+ // by coercing the LHS first. Specifically, if we want an array or vector
+ // and have a tuple, coerce the tuple immediately.
+ no_coerce: {
+ if (extra.res_ty == .none) break :no_coerce;
+ const res_ty_inst = try sema.resolveInst(extra.res_ty);
+ const res_ty = try sema.analyzeAsType(block, src, res_ty_inst);
+ if (res_ty.isGenericPoison()) break :no_coerce;
+ if (!uncoerced_lhs_ty.isTuple(mod)) break :no_coerce;
+ const lhs_len = uncoerced_lhs_ty.structFieldCount(mod);
+ const lhs_dest_ty = switch (res_ty.zigTypeTag(mod)) {
+ else => break :no_coerce,
+ .Array => try mod.arrayType(.{
+ .child = res_ty.childType(mod).toIntern(),
+ .len = lhs_len,
+ .sentinel = if (res_ty.sentinel(mod)) |s| s.toIntern() else .none,
+ }),
+ .Vector => try mod.vectorType(.{
+ .child = res_ty.childType(mod).toIntern(),
+ .len = lhs_len,
+ }),
+ };
+ // Attempt to coerce to this type, but don't emit an error if it fails. Instead,
+ // just exit out of this path and let the usual error happen later, so that error
+ // messages are consistent.
+ const coerced = sema.coerceExtra(block, lhs_dest_ty, uncoerced_lhs, lhs_src, .{ .report_err = false }) catch |err| switch (err) {
+ error.NotCoercible => break :no_coerce,
+ else => |e| return e,
+ };
+ break :coerced_lhs .{ coerced, lhs_dest_ty };
+ }
+ break :coerced_lhs .{ uncoerced_lhs, uncoerced_lhs_ty };
+ };
+
if (lhs_ty.isTuple(mod)) {
// In `**` rhs must be comptime-known, but lhs can be runtime-known
const factor = try sema.resolveInt(block, rhs_src, extra.rhs, Type.usize, .{
@@ -14086,6 +14121,14 @@ fn zirArrayMul(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
try sema.requireRuntimeBlock(block, src, lhs_src);
+ // Grab all the LHS values ahead of time, rather than repeatedly emitting instructions
+ // to get the same elem values.
+ const lhs_vals = try sema.arena.alloc(Air.Inst.Ref, lhs_len);
+ for (lhs_vals, 0..) |*lhs_val, idx| {
+ const idx_ref = try mod.intRef(Type.usize, idx);
+ lhs_val.* = try sema.elemVal(block, lhs_src, lhs, idx_ref, src, false);
+ }
+
if (ptr_addrspace) |ptr_as| {
const alloc_ty = try sema.ptrType(.{
.child = result_ty.toIntern(),
@@ -14099,14 +14142,11 @@ fn zirArrayMul(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
var elem_i: usize = 0;
while (elem_i < result_len) {
- var lhs_i: usize = 0;
- while (lhs_i < lhs_len) : (lhs_i += 1) {
+ for (lhs_vals) |lhs_val| {
const elem_index = try mod.intRef(Type.usize, elem_i);
- elem_i += 1;
- const lhs_index = try mod.intRef(Type.usize, lhs_i);
const elem_ptr = try block.addPtrElemPtr(alloc, elem_index, elem_ptr_ty);
- const init = try sema.elemVal(block, lhs_src, lhs, lhs_index, src, true);
- try sema.storePtr2(block, src, elem_ptr, src, init, lhs_src, .store);
+ try sema.storePtr2(block, src, elem_ptr, src, lhs_val, lhs_src, .store);
+ elem_i += 1;
}
}
if (lhs_info.sentinel) |sent_val| {
@@ -14120,17 +14160,9 @@ fn zirArrayMul(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
}
const element_refs = try sema.arena.alloc(Air.Inst.Ref, result_len);
- var elem_i: usize = 0;
- while (elem_i < result_len) {
- var lhs_i: usize = 0;
- while (lhs_i < lhs_len) : (lhs_i += 1) {
- const lhs_index = try mod.intRef(Type.usize, lhs_i);
- const init = try sema.elemVal(block, lhs_src, lhs, lhs_index, src, true);
- element_refs[elem_i] = init;
- elem_i += 1;
- }
+ for (0..try sema.usizeCast(block, rhs_src, factor)) |i| {
+ @memcpy(element_refs[i * lhs_len ..][0..lhs_len], lhs_vals);
}
-
return block.addAggregateInit(result_ty, element_refs);
}
diff --git a/src/Zir.zig b/src/Zir.zig
index 2aa6c4514c..4fecfd3c50 100644
--- a/src/Zir.zig
+++ b/src/Zir.zig
@@ -250,7 +250,7 @@ pub const Inst = struct {
/// Uses the `pl_node` union field. Payload is `Bin`.
array_cat,
/// Array multiplication `a ** b`
- /// Uses the `pl_node` union field. Payload is `Bin`.
+ /// Uses the `pl_node` union field. Payload is `ArrayMul`.
array_mul,
/// `[N]T` syntax. No source location provided.
/// Uses the `pl_node` union field. Payload is `Bin`. lhs is length, rhs is element type.
@@ -3373,6 +3373,15 @@ pub const Inst = struct {
/// The expected field count.
expect_len: u32,
};
+
+ pub const ArrayMul = struct {
+ /// The result type of the array multiplication operation, or `.none` if none was available.
+ res_ty: Ref,
+ /// The LHS of the array multiplication.
+ lhs: Ref,
+ /// The RHS of the array multiplication.
+ rhs: Ref,
+ };
};
pub const SpecialProng = enum { none, @"else", under };
diff --git a/src/print_zir.zig b/src/print_zir.zig
index 82eca87e15..3f2334e18d 100644
--- a/src/print_zir.zig
+++ b/src/print_zir.zig
@@ -370,7 +370,6 @@ const Writer = struct {
.add_sat,
.add_unsafe,
.array_cat,
- .array_mul,
.mul,
.mulwrap,
.mul_sat,
@@ -431,6 +430,8 @@ const Writer = struct {
.for_len => try self.writePlNodeMultiOp(stream, inst),
+ .array_mul => try self.writeArrayMul(stream, inst),
+
.elem_val_imm => try self.writeElemValImm(stream, inst),
.@"export" => try self.writePlNodeExport(stream, inst),
@@ -977,6 +978,18 @@ const Writer = struct {
try self.writeSrc(stream, inst_data.src());
}
+ fn writeArrayMul(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
+ const inst_data = self.code.instructions.items(.data)[@intFromEnum(inst)].pl_node;
+ const extra = self.code.extraData(Zir.Inst.ArrayMul, inst_data.payload_index).data;
+ try self.writeInstRef(stream, extra.res_ty);
+ try stream.writeAll(", ");
+ try self.writeInstRef(stream, extra.lhs);
+ try stream.writeAll(", ");
+ try self.writeInstRef(stream, extra.rhs);
+ try stream.writeAll(") ");
+ try self.writeSrc(stream, inst_data.src());
+ }
+
fn writeElemValImm(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
const inst_data = self.code.instructions.items(.data)[@intFromEnum(inst)].elem_val_imm;
try self.writeInstRef(stream, inst_data.operand);