aboutsummaryrefslogtreecommitdiff
path: root/src/codegen
diff options
context:
space:
mode:
authormlugg <mlugg@mlugg.co.uk>2024-04-28 21:44:57 +0100
committermlugg <mlugg@mlugg.co.uk>2024-09-01 18:30:31 +0100
commit5e12ca9fe3c77ce1d2a3ea1c22c4bcb6d9b2bb0c (patch)
treea4badc5eab3da4901e1c0c3f3239b07628fc339f /src/codegen
parent5fb4a7df38deb705f77088d7788f0acc09da613d (diff)
downloadzig-5e12ca9fe3c77ce1d2a3ea1c22c4bcb6d9b2bb0c.tar.gz
zig-5e12ca9fe3c77ce1d2a3ea1c22c4bcb6d9b2bb0c.zip
compiler: implement labeled switch/continue
Diffstat (limited to 'src/codegen')
-rw-r--r--src/codegen/c.zig127
-rw-r--r--src/codegen/llvm.zig488
2 files changed, 509 insertions, 106 deletions
diff --git a/src/codegen/c.zig b/src/codegen/c.zig
index 91dec12279..c761aa7225 100644
--- a/src/codegen/c.zig
+++ b/src/codegen/c.zig
@@ -321,6 +321,9 @@ pub const Function = struct {
/// by type alignment.
/// The value is whether the alloc needs to be emitted in the header.
allocs: std.AutoArrayHashMapUnmanaged(LocalIndex, bool) = .{},
+ /// Maps from `loop_switch_br` instructions to the allocated local used
+ /// for the switch cond. Dispatches should set this local to the new cond.
+ loop_switch_conds: std.AutoHashMapUnmanaged(Air.Inst.Index, LocalIndex) = .{},
fn resolveInst(f: *Function, ref: Air.Inst.Ref) !CValue {
const gop = try f.value_map.getOrPut(ref);
@@ -531,6 +534,7 @@ pub const Function = struct {
f.blocks.deinit(gpa);
f.value_map.deinit();
f.lazy_fns.deinit(gpa);
+ f.loop_switch_conds.deinit(gpa);
}
fn typeOf(f: *Function, inst: Air.Inst.Ref) Type {
@@ -3376,16 +3380,18 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail,
=> unreachable,
// Instructions that are known to always be `noreturn` based on their tag.
- .br => return airBr(f, inst),
- .repeat => return airRepeat(f, inst),
- .cond_br => return airCondBr(f, inst),
- .switch_br => return airSwitchBr(f, inst),
- .loop => return airLoop(f, inst),
- .ret => return airRet(f, inst, false),
- .ret_safe => return airRet(f, inst, false), // TODO
- .ret_load => return airRet(f, inst, true),
- .trap => return airTrap(f, f.object.writer()),
- .unreach => return airUnreach(f),
+ .br => return airBr(f, inst),
+ .repeat => return airRepeat(f, inst),
+ .switch_dispatch => return airSwitchDispatch(f, inst),
+ .cond_br => return airCondBr(f, inst),
+ .switch_br => return airSwitchBr(f, inst, false),
+ .loop_switch_br => return airSwitchBr(f, inst, true),
+ .loop => return airLoop(f, inst),
+ .ret => return airRet(f, inst, false),
+ .ret_safe => return airRet(f, inst, false), // TODO
+ .ret_load => return airRet(f, inst, true),
+ .trap => return airTrap(f, f.object.writer()),
+ .unreach => return airUnreach(f),
// Instructions which may be `noreturn`.
.block => res: {
@@ -4786,6 +4792,46 @@ fn airRepeat(f: *Function, inst: Air.Inst.Index) !void {
try writer.print("goto zig_loop_{d};\n", .{@intFromEnum(repeat.loop_inst)});
}
+fn airSwitchDispatch(f: *Function, inst: Air.Inst.Index) !void {
+ const pt = f.object.dg.pt;
+ const zcu = pt.zcu;
+ const br = f.air.instructions.items(.data)[@intFromEnum(inst)].br;
+ const writer = f.object.writer();
+
+ if (try f.air.value(br.operand, pt)) |cond_val| {
+ // Comptime-known dispatch. Iterate the cases to find the correct
+ // one, and branch directly to the corresponding case.
+ const switch_br = f.air.unwrapSwitch(br.block_inst);
+ var it = switch_br.iterateCases();
+ const target_case_idx: u32 = target: while (it.next()) |case| {
+ for (case.items) |item| {
+ const val = Value.fromInterned(item.toInterned().?);
+ if (cond_val.compareHetero(.eq, val, zcu)) break :target case.idx;
+ }
+ for (case.ranges) |range| {
+ const low = Value.fromInterned(range[0].toInterned().?);
+ const high = Value.fromInterned(range[1].toInterned().?);
+ if (cond_val.compareHetero(.gte, low, zcu) and
+ cond_val.compareHetero(.lte, high, zcu))
+ {
+ break :target case.idx;
+ }
+ }
+ } else switch_br.cases_len;
+ try writer.print("goto zig_switch_{d}_dispatch_{d};\n", .{ @intFromEnum(br.block_inst), target_case_idx });
+ return;
+ }
+
+ // Runtime-known dispatch. Set the switch condition, and branch back.
+ const cond = try f.resolveInst(br.operand);
+ const cond_local = f.loop_switch_conds.get(br.block_inst).?;
+ try f.writeCValue(writer, .{ .local = cond_local }, .Other);
+ try writer.writeAll(" = ");
+ try f.writeCValue(writer, cond, .Initializer);
+ try writer.writeAll(";\n");
+ try writer.print("goto zig_switch_{d}_loop;", .{@intFromEnum(br.block_inst)});
+}
+
fn airBitcast(f: *Function, inst: Air.Inst.Index) !CValue {
const ty_op = f.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
const inst_ty = f.typeOfIndex(inst);
@@ -5004,15 +5050,34 @@ fn airCondBr(f: *Function, inst: Air.Inst.Index) !void {
try genBodyInner(f, else_body);
}
-fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !void {
+fn airSwitchBr(f: *Function, inst: Air.Inst.Index, is_dispatch_loop: bool) !void {
const pt = f.object.dg.pt;
const zcu = pt.zcu;
+ const gpa = f.object.dg.gpa;
const switch_br = f.air.unwrapSwitch(inst);
- const condition = try f.resolveInst(switch_br.operand);
+ const init_condition = try f.resolveInst(switch_br.operand);
try reap(f, inst, &.{switch_br.operand});
const condition_ty = f.typeOf(switch_br.operand);
const writer = f.object.writer();
+ // For dispatches, we will create a local alloc to contain the condition value.
+ // This may not result in optimal codegen for switch loops, but it minimizes the
+ // amount of C code we generate, which is probably more desirable here (and is simpler).
+ const condition = if (is_dispatch_loop) cond: {
+ const new_local = try f.allocLocal(inst, condition_ty);
+ try f.writeCValue(writer, new_local, .Other);
+ try writer.writeAll(" = ");
+ try f.writeCValue(writer, init_condition, .Initializer);
+ try writer.writeAll(";\n");
+ try writer.print("zig_switch_{d}_loop:", .{@intFromEnum(inst)});
+ try f.loop_switch_conds.put(gpa, inst, new_local.new_local);
+ break :cond new_local;
+ } else init_condition;
+
+ defer if (is_dispatch_loop) {
+ assert(f.loop_switch_conds.remove(inst));
+ };
+
try writer.writeAll("switch (");
const lowered_condition_ty = if (condition_ty.toIntern() == .bool_type)
@@ -5030,7 +5095,6 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !void {
try writer.writeAll(") {");
f.object.indent_writer.pushIndent();
- const gpa = f.object.dg.gpa;
const liveness = try f.liveness.getSwitchBr(gpa, inst, switch_br.cases_len + 1);
defer gpa.free(liveness.deaths);
@@ -5045,9 +5109,15 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !void {
try f.object.indent_writer.insertNewline();
try writer.writeAll("case ");
const item_value = try f.air.value(item, pt);
- if (item_value.?.getUnsignedInt(zcu)) |item_int| try writer.print("{}\n", .{
- try f.fmtIntLiteral(try pt.intValue(lowered_condition_ty, item_int)),
- }) else {
+ // If `item_value` is a pointer with a known integer address, print the address
+ // with no cast to avoid a warning.
+ write_val: {
+ if (condition_ty.isPtrAtRuntime(zcu)) {
+ if (item_value.?.getUnsignedInt(zcu)) |item_int| {
+ try writer.print("{}", .{try f.fmtIntLiteral(try pt.intValue(lowered_condition_ty, item_int))});
+ break :write_val;
+ }
+ }
if (condition_ty.isPtrAtRuntime(zcu)) {
try writer.writeByte('(');
try f.renderType(writer, Type.usize);
@@ -5057,9 +5127,14 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !void {
}
try writer.writeByte(':');
}
- try writer.writeByte(' ');
-
- try genBodyResolveState(f, inst, liveness.deaths[case.idx], case.body, false);
+ try writer.writeAll(" {\n");
+ f.object.indent_writer.pushIndent();
+ if (is_dispatch_loop) {
+ try writer.print("zig_switch_{d}_dispatch_{d}: ", .{ @intFromEnum(inst), case.idx });
+ }
+ try genBodyResolveState(f, inst, liveness.deaths[case.idx], case.body, true);
+ f.object.indent_writer.popIndent();
+ try writer.writeByte('}');
// The case body must be noreturn so we don't need to insert a break.
}
@@ -5095,11 +5170,19 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !void {
try f.object.dg.renderValue(writer, (try f.air.value(range[1], pt)).?, .Other);
try writer.writeByte(')');
}
- try writer.writeAll(") ");
- try genBodyResolveState(f, inst, liveness.deaths[case.idx], case.body, false);
+ try writer.writeAll(") {\n");
+ f.object.indent_writer.pushIndent();
+ if (is_dispatch_loop) {
+ try writer.print("zig_switch_{d}_dispatch_{d}: ", .{ @intFromEnum(inst), case.idx });
+ }
+ try genBodyResolveState(f, inst, liveness.deaths[case.idx], case.body, true);
+ f.object.indent_writer.popIndent();
+ try writer.writeByte('}');
}
}
-
+ if (is_dispatch_loop) {
+ try writer.print("zig_switch_{d}_dispatch_{d}: ", .{ @intFromEnum(inst), switch_br.cases_len });
+ }
if (else_body.len > 0) {
// Note that this must be the last case, so we do not need to use `genBodyResolveState` since
// the parent block will do it (because the case body is noreturn).
diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig
index d1ec8eca9f..e5e43bbbf8 100644
--- a/src/codegen/llvm.zig
+++ b/src/codegen/llvm.zig
@@ -1721,6 +1721,7 @@ pub const Object = struct {
.func_inst_table = .{},
.blocks = .{},
.loops = .{},
+ .switch_dispatch_info = .{},
.sync_scope = if (owner_mod.single_threaded) .singlethread else .system,
.file = file,
.scope = subprogram,
@@ -4845,6 +4846,10 @@ pub const FuncGen = struct {
/// Maps `loop` instructions to the bb to branch to to repeat the loop.
loops: std.AutoHashMapUnmanaged(Air.Inst.Index, Builder.Function.Block.Index),
+ /// Maps `loop_switch_br` instructions to the information required to lower
+ /// dispatches (`switch_dispatch` instructions).
+ switch_dispatch_info: std.AutoHashMapUnmanaged(Air.Inst.Index, SwitchDispatchInfo),
+
sync_scope: Builder.SyncScope,
const Fuzz = struct {
@@ -4857,6 +4862,33 @@ pub const FuncGen = struct {
}
};
+ const SwitchDispatchInfo = struct {
+ /// These are the blocks corresponding to each switch case.
+ /// The final element corresponds to the `else` case.
+ /// Slices allocated into `gpa`.
+ case_blocks: []Builder.Function.Block.Index,
+ /// This is `.none` if `jmp_table` is set, since we won't use a `switch` instruction to dispatch.
+ switch_weights: Builder.Function.Instruction.BrCond.Weights,
+ /// If not `null`, we have manually constructed a jump table to reach the desired block.
+ /// `table` can be used if the value is between `min` and `max` inclusive.
+ /// We perform this lowering manually to avoid some questionable behavior from LLVM.
+ /// See `airSwitchBr` for details.
+ jmp_table: ?JmpTable,
+
+ const JmpTable = struct {
+ min: Builder.Constant,
+ max: Builder.Constant,
+ in_bounds_hint: enum { none, unpredictable, likely, unlikely },
+ /// Pointer to the jump table itself, to be used with `indirectbr`.
+ /// The index into the jump table is the dispatch condition minus `min`.
+ /// The table values are `blockaddress` constants corresponding to blocks in `case_blocks`.
+ table: Builder.Constant,
+ /// `true` if `table` conatins a reference to the `else` block.
+ /// In this case, the `indirectbr` must include the `else` block in its target list.
+ table_includes_else: bool,
+ };
+ };
+
const BreakList = union {
list: std.MultiArrayList(struct {
bb: Builder.Function.Block.Index,
@@ -4872,6 +4904,11 @@ pub const FuncGen = struct {
self.func_inst_table.deinit(gpa);
self.blocks.deinit(gpa);
self.loops.deinit(gpa);
+ var it = self.switch_dispatch_info.valueIterator();
+ while (it.next()) |info| {
+ self.gpa.free(info.case_blocks);
+ }
+ self.switch_dispatch_info.deinit(gpa);
}
fn todo(self: *FuncGen, comptime format: []const u8, args: anytype) Error {
@@ -5182,16 +5219,18 @@ pub const FuncGen = struct {
.work_group_id => try self.airWorkGroupId(inst),
// Instructions that are known to always be `noreturn` based on their tag.
- .br => return self.airBr(inst),
- .repeat => return self.airRepeat(inst),
- .cond_br => return self.airCondBr(inst),
- .switch_br => return self.airSwitchBr(inst),
- .loop => return self.airLoop(inst),
- .ret => return self.airRet(inst, false),
- .ret_safe => return self.airRet(inst, true),
- .ret_load => return self.airRetLoad(inst),
- .trap => return self.airTrap(inst),
- .unreach => return self.airUnreach(inst),
+ .br => return self.airBr(inst),
+ .repeat => return self.airRepeat(inst),
+ .switch_dispatch => return self.airSwitchDispatch(inst),
+ .cond_br => return self.airCondBr(inst),
+ .switch_br => return self.airSwitchBr(inst, false),
+ .loop_switch_br => return self.airSwitchBr(inst, true),
+ .loop => return self.airLoop(inst),
+ .ret => return self.airRet(inst, false),
+ .ret_safe => return self.airRet(inst, true),
+ .ret_load => return self.airRetLoad(inst),
+ .trap => return self.airTrap(inst),
+ .unreach => return self.airUnreach(inst),
// Instructions which may be `noreturn`.
.block => res: {
@@ -6093,6 +6132,202 @@ pub const FuncGen = struct {
_ = try self.wip.br(loop_bb);
}
+ fn lowerSwitchDispatch(
+ self: *FuncGen,
+ switch_inst: Air.Inst.Index,
+ cond_ref: Air.Inst.Ref,
+ dispatch_info: SwitchDispatchInfo,
+ ) !void {
+ const o = self.ng.object;
+ const pt = o.pt;
+ const zcu = pt.zcu;
+ const cond_ty = self.typeOf(cond_ref);
+ const switch_br = self.air.unwrapSwitch(switch_inst);
+
+ if (try self.air.value(cond_ref, pt)) |cond_val| {
+ // Comptime-known dispatch. Iterate the cases to find the correct
+ // one, and branch to the corresponding element of `case_blocks`.
+ var it = switch_br.iterateCases();
+ const target_case_idx = target: while (it.next()) |case| {
+ for (case.items) |item| {
+ const val = Value.fromInterned(item.toInterned().?);
+ if (cond_val.compareHetero(.eq, val, zcu)) break :target case.idx;
+ }
+ for (case.ranges) |range| {
+ const low = Value.fromInterned(range[0].toInterned().?);
+ const high = Value.fromInterned(range[1].toInterned().?);
+ if (cond_val.compareHetero(.gte, low, zcu) and
+ cond_val.compareHetero(.lte, high, zcu))
+ {
+ break :target case.idx;
+ }
+ }
+ } else dispatch_info.case_blocks.len - 1;
+ const target_block = dispatch_info.case_blocks[target_case_idx];
+ target_block.ptr(&self.wip).incoming += 1;
+ _ = try self.wip.br(target_block);
+ return;
+ }
+
+ // Runtime-known dispatch.
+ const cond = try self.resolveInst(cond_ref);
+
+ if (dispatch_info.jmp_table) |jmp_table| {
+ // We should use the constructed jump table.
+ // First, check the bounds to branch to the `else` case if needed.
+ const inbounds = try self.wip.bin(
+ .@"and",
+ try self.cmp(.normal, .gte, cond_ty, cond, jmp_table.min.toValue()),
+ try self.cmp(.normal, .lte, cond_ty, cond, jmp_table.max.toValue()),
+ "",
+ );
+ const jmp_table_block = try self.wip.block(1, "Then");
+ const else_block = dispatch_info.case_blocks[dispatch_info.case_blocks.len - 1];
+ else_block.ptr(&self.wip).incoming += 1;
+ _ = try self.wip.brCond(inbounds, jmp_table_block, else_block, switch (jmp_table.in_bounds_hint) {
+ .none => .none,
+ .unpredictable => .unpredictable,
+ .likely => .then_likely,
+ .unlikely => .else_likely,
+ });
+
+ self.wip.cursor = .{ .block = jmp_table_block };
+
+ // Figure out the list of blocks we might branch to.
+ // This includes all case blocks, but it might not include the `else` block if
+ // the table is dense.
+ const target_blocks_len = dispatch_info.case_blocks.len - @intFromBool(!jmp_table.table_includes_else);
+ const target_blocks = dispatch_info.case_blocks[0..target_blocks_len];
+
+ // Make sure to cast the index to a usize so it's not treated as negative!
+ const table_index = try self.wip.cast(
+ .zext,
+ try self.wip.bin(.@"sub nuw", cond, jmp_table.min.toValue(), ""),
+ try o.lowerType(Type.usize),
+ "",
+ );
+ const target_ptr_ptr = try self.wip.gep(
+ .inbounds,
+ .ptr,
+ jmp_table.table.toValue(),
+ &.{table_index},
+ "",
+ );
+ const target_ptr = try self.wip.load(.normal, .ptr, target_ptr_ptr, .default, "");
+
+ // Do the branch!
+ _ = try self.wip.indirectbr(target_ptr, target_blocks);
+
+ // Mark all target blocks as having one more incoming branch.
+ for (target_blocks) |case_block| {
+ case_block.ptr(&self.wip).incoming += 1;
+ }
+
+ return;
+ }
+
+ // We must lower to an actual LLVM `switch` instruction.
+ // The switch prongs will correspond to our scalar cases. Ranges will
+ // be handled by conditional branches in the `else` prong.
+
+ const llvm_usize = try o.lowerType(Type.usize);
+ const cond_int = if (cond.typeOfWip(&self.wip).isPointer(&o.builder))
+ try self.wip.cast(.ptrtoint, cond, llvm_usize, "")
+ else
+ cond;
+
+ const llvm_cases_len, const last_range_case = info: {
+ var llvm_cases_len: u32 = 0;
+ var last_range_case: ?u32 = null;
+ var it = switch_br.iterateCases();
+ while (it.next()) |case| {
+ if (case.ranges.len > 0) last_range_case = case.idx;
+ llvm_cases_len += @intCast(case.items.len);
+ }
+ break :info .{ llvm_cases_len, last_range_case };
+ };
+
+ // The `else` of the LLVM `switch` is the actual `else` prong only
+ // if there are no ranges. Otherwise, the `else` will have a
+ // conditional chain before the "true" `else` prong.
+ const llvm_else_block = if (last_range_case == null)
+ dispatch_info.case_blocks[dispatch_info.case_blocks.len - 1]
+ else
+ try self.wip.block(0, "RangeTest");
+
+ llvm_else_block.ptr(&self.wip).incoming += 1;
+
+ var wip_switch = try self.wip.@"switch"(cond_int, llvm_else_block, llvm_cases_len, dispatch_info.switch_weights);
+ defer wip_switch.finish(&self.wip);
+
+ // Construct the actual cases. Set the cursor to the `else` block so
+ // we can construct ranges at the same time as scalar cases.
+ self.wip.cursor = .{ .block = llvm_else_block };
+
+ var it = switch_br.iterateCases();
+ while (it.next()) |case| {
+ const case_block = dispatch_info.case_blocks[case.idx];
+
+ for (case.items) |item| {
+ const llvm_item = (try self.resolveInst(item)).toConst().?;
+ const llvm_int_item = if (llvm_item.typeOf(&o.builder).isPointer(&o.builder))
+ try o.builder.castConst(.ptrtoint, llvm_item, llvm_usize)
+ else
+ llvm_item;
+ try wip_switch.addCase(llvm_int_item, case_block, &self.wip);
+ }
+ case_block.ptr(&self.wip).incoming += @intCast(case.items.len);
+
+ if (case.ranges.len == 0) continue;
+
+ // Add a conditional for the ranges, directing to the relevant bb.
+ // We don't need to consider `cold` branch hints since that information is stored
+ // in the target bb body, but we do care about likely/unlikely/unpredictable.
+
+ const hint = switch_br.getHint(case.idx);
+
+ var range_cond: ?Builder.Value = null;
+ for (case.ranges) |range| {
+ const llvm_min = try self.resolveInst(range[0]);
+ const llvm_max = try self.resolveInst(range[1]);
+ const cond_part = try self.wip.bin(
+ .@"and",
+ try self.cmp(.normal, .gte, cond_ty, cond, llvm_min),
+ try self.cmp(.normal, .lte, cond_ty, cond, llvm_max),
+ "",
+ );
+ if (range_cond) |prev| {
+ range_cond = try self.wip.bin(.@"or", prev, cond_part, "");
+ } else range_cond = cond_part;
+ }
+
+ // If the check fails, we either branch to the "true" `else` case,
+ // or to the next range condition.
+ const range_else_block = if (case.idx == last_range_case.?)
+ dispatch_info.case_blocks[dispatch_info.case_blocks.len - 1]
+ else
+ try self.wip.block(0, "RangeTest");
+
+ _ = try self.wip.brCond(range_cond.?, case_block, range_else_block, switch (hint) {
+ .none, .cold => .none,
+ .unpredictable => .unpredictable,
+ .likely => .then_likely,
+ .unlikely => .else_likely,
+ });
+ case_block.ptr(&self.wip).incoming += 1;
+ range_else_block.ptr(&self.wip).incoming += 1;
+
+ // Construct the next range conditional (if any) in the false branch.
+ self.wip.cursor = .{ .block = range_else_block };
+ }
+ }
+
+ fn airSwitchDispatch(self: *FuncGen, inst: Air.Inst.Index) !void {
+ const br = self.air.instructions.items(.data)[@intFromEnum(inst)].br;
+ const dispatch_info = self.switch_dispatch_info.get(br.block_inst).?;
+ return self.lowerSwitchDispatch(br.block_inst, br.operand, dispatch_info);
+ }
+
fn airCondBr(self: *FuncGen, inst: Air.Inst.Index) !void {
const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op;
const cond = try self.resolveInst(pl_op.operand);
@@ -6257,36 +6492,123 @@ pub const FuncGen = struct {
return fg.wip.extractValue(err_union, &.{offset}, "");
}
- fn airSwitchBr(self: *FuncGen, inst: Air.Inst.Index) !void {
+ fn airSwitchBr(self: *FuncGen, inst: Air.Inst.Index, is_dispatch_loop: bool) !void {
const o = self.ng.object;
+ const zcu = o.pt.zcu;
const switch_br = self.air.unwrapSwitch(inst);
- const cond = try self.resolveInst(switch_br.operand);
+ // For `loop_switch_br`, we need these BBs prepared ahead of time to generate dispatches.
+ // For `switch_br`, they allow us to sometimes generate better IR by sharing a BB between
+ // scalar and range cases in the same prong.
+ // +1 for `else` case. This is not the same as the LLVM `else` prong, as that may first contain
+ // conditionals to handle ranges.
+ const case_blocks = try self.gpa.alloc(Builder.Function.Block.Index, switch_br.cases_len + 1);
+ defer self.gpa.free(case_blocks);
+ // We set incoming as 0 for now, and increment it as we construct dispatches.
+ for (case_blocks[0 .. case_blocks.len - 1]) |*b| b.* = try self.wip.block(0, "Case");
+ case_blocks[case_blocks.len - 1] = try self.wip.block(0, "Default");
+
+ // There's a special case here to manually generate a jump table in some cases.
+ //
+ // Labeled switch in Zig is intended to follow the "direct threading" pattern. We would ideally use a jump
+ // table, and each `continue` has its own indirect `jmp`, to allow the branch predictor to more accurately
+ // use data patterns to predict future dispatches. The problem, however, is that LLVM emits fascinatingly
+ // bad asm for this. Not only does it not share the jump table -- which we really need it to do to prevent
+ // destroying the cache -- but it also actually generates slightly different jump tables for each case,
+ // and *a separate conditional branch beforehand* to handle dispatching back to the case we're currently
+ // within(!!).
+ //
+ // This asm is really, really, not what we want. As such, we will construct the jump table manually where
+ // appropriate (the values are dense and relatively few), and use it when lowering dispatches.
+
+ const jmp_table: ?SwitchDispatchInfo.JmpTable = jmp_table: {
+ if (!is_dispatch_loop) break :jmp_table null;
+ // On a 64-bit target, 1024 pointers in our jump table is about 8K of pointers. This seems just
+ // about acceptable - it won't fill L1d cache on most CPUs.
+ const max_table_len = 1024;
- // This is not necessarily the actual `else` prong; it first contains conditionals
- // for any range cases. It's just the `else` of the LLVM switch.
- const llvm_else_block = try self.wip.block(1, "Default");
+ const cond_ty = self.typeOf(switch_br.operand);
+ switch (cond_ty.zigTypeTag(zcu)) {
+ .bool, .pointer => break :jmp_table null,
+ .@"enum", .int, .error_set => {},
+ else => unreachable,
+ }
- const case_blocks = try self.gpa.alloc(Builder.Function.Block.Index, switch_br.cases_len);
- defer self.gpa.free(case_blocks);
- // We set incoming as 0 for now, and increment it as we construct the switch.
- for (case_blocks) |*b| b.* = try self.wip.block(0, "Case");
+ if (cond_ty.intInfo(zcu).signedness == .signed) break :jmp_table null;
- const llvm_usize = try o.lowerType(Type.usize);
- const cond_int = if (cond.typeOfWip(&self.wip).isPointer(&o.builder))
- try self.wip.cast(.ptrtoint, cond, llvm_usize, "")
- else
- cond;
+ // Don't worry about the size of the type -- it's irrelevant, because the prong values could be fairly dense.
+ // If they are, then we will construct a jump table.
+ const min, const max = self.switchCaseItemRange(switch_br);
+ const min_int = min.getUnsignedInt(zcu) orelse break :jmp_table null;
+ const max_int = max.getUnsignedInt(zcu) orelse break :jmp_table null;
+ const table_len = max_int - min_int + 1;
+ if (table_len > max_table_len) break :jmp_table null;
+
+ const table_elems = try self.gpa.alloc(Builder.Constant, @intCast(table_len));
+ defer self.gpa.free(table_elems);
- const llvm_cases_len = llvm_cases_len: {
- var len: u32 = 0;
+ // Set them all to the `else` branch, then iterate over the AIR switch
+ // and replace all values which correspond to other prongs.
+ @memset(table_elems, try o.builder.blockAddrConst(
+ self.wip.function,
+ case_blocks[case_blocks.len - 1],
+ ));
+ var item_count: u32 = 0;
var it = switch_br.iterateCases();
- while (it.next()) |case| len += @intCast(case.items.len);
- break :llvm_cases_len len;
+ while (it.next()) |case| {
+ const case_block = case_blocks[case.idx];
+ const case_block_addr = try o.builder.blockAddrConst(
+ self.wip.function,
+ case_block,
+ );
+ for (case.items) |item| {
+ const val = Value.fromInterned(item.toInterned().?);
+ const table_idx = val.toUnsignedInt(zcu) - min_int;
+ table_elems[@intCast(table_idx)] = case_block_addr;
+ item_count += 1;
+ }
+ for (case.ranges) |range| {
+ const low = Value.fromInterned(range[0].toInterned().?);
+ const high = Value.fromInterned(range[1].toInterned().?);
+ const low_idx = low.toUnsignedInt(zcu) - min_int;
+ const high_idx = high.toUnsignedInt(zcu) - min_int;
+ @memset(table_elems[@intCast(low_idx)..@intCast(high_idx + 1)], case_block_addr);
+ item_count += @intCast(high_idx + 1 - low_idx);
+ }
+ }
+
+ const table_llvm_ty = try o.builder.arrayType(table_elems.len, .ptr);
+ const table_val = try o.builder.arrayConst(table_llvm_ty, table_elems);
+
+ const table_variable = try o.builder.addVariable(
+ try o.builder.strtabStringFmt("__jmptab_{d}", .{@intFromEnum(inst)}),
+ table_llvm_ty,
+ .default,
+ );
+ try table_variable.setInitializer(table_val, &o.builder);
+ table_variable.setLinkage(.internal, &o.builder);
+ table_variable.setUnnamedAddr(.unnamed_addr, &o.builder);
+
+ const table_includes_else = item_count != table_len;
+
+ break :jmp_table .{
+ .min = try o.lowerValue(min.toIntern()),
+ .max = try o.lowerValue(max.toIntern()),
+ .in_bounds_hint = if (table_includes_else) .none else switch (switch_br.getElseHint()) {
+ .none, .cold => .none,
+ .unpredictable => .unpredictable,
+ .likely => .likely,
+ .unlikely => .unlikely,
+ },
+ .table = table_variable.toConst(&o.builder),
+ .table_includes_else = table_includes_else,
+ };
};
const weights: Builder.Function.Instruction.BrCond.Weights = weights: {
+ if (jmp_table != null) break :weights .none; // not used
+
// First pass. If any weights are `.unpredictable`, unpredictable.
// If all are `.none` or `.cold`, none.
var any_likely = false;
@@ -6304,6 +6626,13 @@ pub const FuncGen = struct {
}
if (!any_likely) break :weights .none;
+ const llvm_cases_len = llvm_cases_len: {
+ var len: u32 = 0;
+ var it = switch_br.iterateCases();
+ while (it.next()) |case| len += @intCast(case.items.len);
+ break :llvm_cases_len len;
+ };
+
var weights = try self.gpa.alloc(Builder.Metadata, llvm_cases_len + 1);
defer self.gpa.free(weights);
@@ -6336,75 +6665,66 @@ pub const FuncGen = struct {
break :weights @enumFromInt(@intFromEnum(tuple));
};
- var wip_switch = try self.wip.@"switch"(cond_int, llvm_else_block, llvm_cases_len, weights);
- defer wip_switch.finish(&self.wip);
+ const dispatch_info: SwitchDispatchInfo = .{
+ .case_blocks = case_blocks,
+ .switch_weights = weights,
+ .jmp_table = jmp_table,
+ };
+
+ if (is_dispatch_loop) {
+ try self.switch_dispatch_info.putNoClobber(self.gpa, inst, dispatch_info);
+ }
+ defer if (is_dispatch_loop) {
+ assert(self.switch_dispatch_info.remove(inst));
+ };
+
+ // Generate the initial dispatch.
+ // If this is a simple `switch_br`, this is the only dispatch.
+ try self.lowerSwitchDispatch(inst, switch_br.operand, dispatch_info);
+ // Iterate the cases and generate their bodies.
var it = switch_br.iterateCases();
- var any_ranges = false;
while (it.next()) |case| {
- if (case.ranges.len > 0) any_ranges = true;
const case_block = case_blocks[case.idx];
- case_block.ptr(&self.wip).incoming += @intCast(case.items.len);
- // Handle scalar items, and generate the block.
- // We'll generate conditionals for the ranges later on.
- for (case.items) |item| {
- const llvm_item = (try self.resolveInst(item)).toConst().?;
- const llvm_int_item = if (llvm_item.typeOf(&o.builder).isPointer(&o.builder))
- try o.builder.castConst(.ptrtoint, llvm_item, llvm_usize)
- else
- llvm_item;
- try wip_switch.addCase(llvm_int_item, case_block, &self.wip);
- }
self.wip.cursor = .{ .block = case_block };
if (switch_br.getHint(case.idx) == .cold) _ = try self.wip.callIntrinsicAssumeCold();
- try self.genBodyDebugScope(null, case.body, .poi);
+ try self.genBodyDebugScope(null, case.body, .none);
}
-
+ self.wip.cursor = .{ .block = case_blocks[case_blocks.len - 1] };
const else_body = it.elseBody();
- self.wip.cursor = .{ .block = llvm_else_block };
- if (any_ranges) {
- const cond_ty = self.typeOf(switch_br.operand);
- // Add conditionals for the ranges, directing to the relevant bb.
- // We don't need to consider `cold` branch hints since that information is stored
- // in the target bb body, but we do care about likely/unlikely/unpredictable.
- it = switch_br.iterateCases();
- while (it.next()) |case| {
- if (case.ranges.len == 0) continue;
- const case_block = case_blocks[case.idx];
- const hint = switch_br.getHint(case.idx);
- case_block.ptr(&self.wip).incoming += 1;
- const next_else_block = try self.wip.block(1, "Default");
- var range_cond: ?Builder.Value = null;
- for (case.ranges) |range| {
- const llvm_min = try self.resolveInst(range[0]);
- const llvm_max = try self.resolveInst(range[1]);
- const cond_part = try self.wip.bin(
- .@"and",
- try self.cmp(.normal, .gte, cond_ty, cond, llvm_min),
- try self.cmp(.normal, .lte, cond_ty, cond, llvm_max),
- "",
- );
- if (range_cond) |prev| {
- range_cond = try self.wip.bin(.@"or", prev, cond_part, "");
- } else range_cond = cond_part;
- }
- _ = try self.wip.brCond(range_cond.?, case_block, next_else_block, switch (hint) {
- .none, .cold => .none,
- .unpredictable => .unpredictable,
- .likely => .then_likely,
- .unlikely => .else_likely,
- });
- self.wip.cursor = .{ .block = next_else_block };
- }
- }
if (switch_br.getElseHint() == .cold) _ = try self.wip.callIntrinsicAssumeCold();
- if (else_body.len != 0) {
- try self.genBodyDebugScope(null, else_body, .poi);
+ if (else_body.len > 0) {
+ try self.genBodyDebugScope(null, it.elseBody(), .none);
} else {
_ = try self.wip.@"unreachable"();
}
+ }
- // No need to reset the insert cursor since this instruction is noreturn.
+ fn switchCaseItemRange(self: *FuncGen, switch_br: Air.UnwrappedSwitch) [2]Value {
+ const zcu = self.ng.object.pt.zcu;
+ var it = switch_br.iterateCases();
+ var min: ?Value = null;
+ var max: ?Value = null;
+ while (it.next()) |case| {
+ for (case.items) |item| {
+ const val = Value.fromInterned(item.toInterned().?);
+ const low = if (min) |m| val.compareHetero(.lt, m, zcu) else true;
+ const high = if (max) |m| val.compareHetero(.gt, m, zcu) else true;
+ if (low) min = val;
+ if (high) max = val;
+ }
+ for (case.ranges) |range| {
+ const vals: [2]Value = .{
+ Value.fromInterned(range[0].toInterned().?),
+ Value.fromInterned(range[1].toInterned().?),
+ };
+ const low = if (min) |m| vals[0].compareHetero(.lt, m, zcu) else true;
+ const high = if (max) |m| vals[1].compareHetero(.gt, m, zcu) else true;
+ if (low) min = vals[0];
+ if (high) max = vals[1];
+ }
+ }
+ return .{ min.?, max.? };
}
fn airLoop(self: *FuncGen, inst: Air.Inst.Index) !void {