diff options
| author | mlugg <mlugg@mlugg.co.uk> | 2024-04-28 21:44:57 +0100 |
|---|---|---|
| committer | mlugg <mlugg@mlugg.co.uk> | 2024-09-01 18:30:31 +0100 |
| commit | 5e12ca9fe3c77ce1d2a3ea1c22c4bcb6d9b2bb0c (patch) | |
| tree | a4badc5eab3da4901e1c0c3f3239b07628fc339f /src/codegen | |
| parent | 5fb4a7df38deb705f77088d7788f0acc09da613d (diff) | |
| download | zig-5e12ca9fe3c77ce1d2a3ea1c22c4bcb6d9b2bb0c.tar.gz zig-5e12ca9fe3c77ce1d2a3ea1c22c4bcb6d9b2bb0c.zip | |
compiler: implement labeled switch/continue
Diffstat (limited to 'src/codegen')
| -rw-r--r-- | src/codegen/c.zig | 127 | ||||
| -rw-r--r-- | src/codegen/llvm.zig | 488 |
2 files changed, 509 insertions, 106 deletions
diff --git a/src/codegen/c.zig b/src/codegen/c.zig index 91dec12279..c761aa7225 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -321,6 +321,9 @@ pub const Function = struct { /// by type alignment. /// The value is whether the alloc needs to be emitted in the header. allocs: std.AutoArrayHashMapUnmanaged(LocalIndex, bool) = .{}, + /// Maps from `loop_switch_br` instructions to the allocated local used + /// for the switch cond. Dispatches should set this local to the new cond. + loop_switch_conds: std.AutoHashMapUnmanaged(Air.Inst.Index, LocalIndex) = .{}, fn resolveInst(f: *Function, ref: Air.Inst.Ref) !CValue { const gop = try f.value_map.getOrPut(ref); @@ -531,6 +534,7 @@ pub const Function = struct { f.blocks.deinit(gpa); f.value_map.deinit(); f.lazy_fns.deinit(gpa); + f.loop_switch_conds.deinit(gpa); } fn typeOf(f: *Function, inst: Air.Inst.Ref) Type { @@ -3376,16 +3380,18 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, => unreachable, // Instructions that are known to always be `noreturn` based on their tag. - .br => return airBr(f, inst), - .repeat => return airRepeat(f, inst), - .cond_br => return airCondBr(f, inst), - .switch_br => return airSwitchBr(f, inst), - .loop => return airLoop(f, inst), - .ret => return airRet(f, inst, false), - .ret_safe => return airRet(f, inst, false), // TODO - .ret_load => return airRet(f, inst, true), - .trap => return airTrap(f, f.object.writer()), - .unreach => return airUnreach(f), + .br => return airBr(f, inst), + .repeat => return airRepeat(f, inst), + .switch_dispatch => return airSwitchDispatch(f, inst), + .cond_br => return airCondBr(f, inst), + .switch_br => return airSwitchBr(f, inst, false), + .loop_switch_br => return airSwitchBr(f, inst, true), + .loop => return airLoop(f, inst), + .ret => return airRet(f, inst, false), + .ret_safe => return airRet(f, inst, false), // TODO + .ret_load => return airRet(f, inst, true), + .trap => return airTrap(f, f.object.writer()), + .unreach => return airUnreach(f), // Instructions which may be `noreturn`. .block => res: { @@ -4786,6 +4792,46 @@ fn airRepeat(f: *Function, inst: Air.Inst.Index) !void { try writer.print("goto zig_loop_{d};\n", .{@intFromEnum(repeat.loop_inst)}); } +fn airSwitchDispatch(f: *Function, inst: Air.Inst.Index) !void { + const pt = f.object.dg.pt; + const zcu = pt.zcu; + const br = f.air.instructions.items(.data)[@intFromEnum(inst)].br; + const writer = f.object.writer(); + + if (try f.air.value(br.operand, pt)) |cond_val| { + // Comptime-known dispatch. Iterate the cases to find the correct + // one, and branch directly to the corresponding case. + const switch_br = f.air.unwrapSwitch(br.block_inst); + var it = switch_br.iterateCases(); + const target_case_idx: u32 = target: while (it.next()) |case| { + for (case.items) |item| { + const val = Value.fromInterned(item.toInterned().?); + if (cond_val.compareHetero(.eq, val, zcu)) break :target case.idx; + } + for (case.ranges) |range| { + const low = Value.fromInterned(range[0].toInterned().?); + const high = Value.fromInterned(range[1].toInterned().?); + if (cond_val.compareHetero(.gte, low, zcu) and + cond_val.compareHetero(.lte, high, zcu)) + { + break :target case.idx; + } + } + } else switch_br.cases_len; + try writer.print("goto zig_switch_{d}_dispatch_{d};\n", .{ @intFromEnum(br.block_inst), target_case_idx }); + return; + } + + // Runtime-known dispatch. Set the switch condition, and branch back. + const cond = try f.resolveInst(br.operand); + const cond_local = f.loop_switch_conds.get(br.block_inst).?; + try f.writeCValue(writer, .{ .local = cond_local }, .Other); + try writer.writeAll(" = "); + try f.writeCValue(writer, cond, .Initializer); + try writer.writeAll(";\n"); + try writer.print("goto zig_switch_{d}_loop;", .{@intFromEnum(br.block_inst)}); +} + fn airBitcast(f: *Function, inst: Air.Inst.Index) !CValue { const ty_op = f.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; const inst_ty = f.typeOfIndex(inst); @@ -5004,15 +5050,34 @@ fn airCondBr(f: *Function, inst: Air.Inst.Index) !void { try genBodyInner(f, else_body); } -fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !void { +fn airSwitchBr(f: *Function, inst: Air.Inst.Index, is_dispatch_loop: bool) !void { const pt = f.object.dg.pt; const zcu = pt.zcu; + const gpa = f.object.dg.gpa; const switch_br = f.air.unwrapSwitch(inst); - const condition = try f.resolveInst(switch_br.operand); + const init_condition = try f.resolveInst(switch_br.operand); try reap(f, inst, &.{switch_br.operand}); const condition_ty = f.typeOf(switch_br.operand); const writer = f.object.writer(); + // For dispatches, we will create a local alloc to contain the condition value. + // This may not result in optimal codegen for switch loops, but it minimizes the + // amount of C code we generate, which is probably more desirable here (and is simpler). + const condition = if (is_dispatch_loop) cond: { + const new_local = try f.allocLocal(inst, condition_ty); + try f.writeCValue(writer, new_local, .Other); + try writer.writeAll(" = "); + try f.writeCValue(writer, init_condition, .Initializer); + try writer.writeAll(";\n"); + try writer.print("zig_switch_{d}_loop:", .{@intFromEnum(inst)}); + try f.loop_switch_conds.put(gpa, inst, new_local.new_local); + break :cond new_local; + } else init_condition; + + defer if (is_dispatch_loop) { + assert(f.loop_switch_conds.remove(inst)); + }; + try writer.writeAll("switch ("); const lowered_condition_ty = if (condition_ty.toIntern() == .bool_type) @@ -5030,7 +5095,6 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !void { try writer.writeAll(") {"); f.object.indent_writer.pushIndent(); - const gpa = f.object.dg.gpa; const liveness = try f.liveness.getSwitchBr(gpa, inst, switch_br.cases_len + 1); defer gpa.free(liveness.deaths); @@ -5045,9 +5109,15 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !void { try f.object.indent_writer.insertNewline(); try writer.writeAll("case "); const item_value = try f.air.value(item, pt); - if (item_value.?.getUnsignedInt(zcu)) |item_int| try writer.print("{}\n", .{ - try f.fmtIntLiteral(try pt.intValue(lowered_condition_ty, item_int)), - }) else { + // If `item_value` is a pointer with a known integer address, print the address + // with no cast to avoid a warning. + write_val: { + if (condition_ty.isPtrAtRuntime(zcu)) { + if (item_value.?.getUnsignedInt(zcu)) |item_int| { + try writer.print("{}", .{try f.fmtIntLiteral(try pt.intValue(lowered_condition_ty, item_int))}); + break :write_val; + } + } if (condition_ty.isPtrAtRuntime(zcu)) { try writer.writeByte('('); try f.renderType(writer, Type.usize); @@ -5057,9 +5127,14 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !void { } try writer.writeByte(':'); } - try writer.writeByte(' '); - - try genBodyResolveState(f, inst, liveness.deaths[case.idx], case.body, false); + try writer.writeAll(" {\n"); + f.object.indent_writer.pushIndent(); + if (is_dispatch_loop) { + try writer.print("zig_switch_{d}_dispatch_{d}: ", .{ @intFromEnum(inst), case.idx }); + } + try genBodyResolveState(f, inst, liveness.deaths[case.idx], case.body, true); + f.object.indent_writer.popIndent(); + try writer.writeByte('}'); // The case body must be noreturn so we don't need to insert a break. } @@ -5095,11 +5170,19 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !void { try f.object.dg.renderValue(writer, (try f.air.value(range[1], pt)).?, .Other); try writer.writeByte(')'); } - try writer.writeAll(") "); - try genBodyResolveState(f, inst, liveness.deaths[case.idx], case.body, false); + try writer.writeAll(") {\n"); + f.object.indent_writer.pushIndent(); + if (is_dispatch_loop) { + try writer.print("zig_switch_{d}_dispatch_{d}: ", .{ @intFromEnum(inst), case.idx }); + } + try genBodyResolveState(f, inst, liveness.deaths[case.idx], case.body, true); + f.object.indent_writer.popIndent(); + try writer.writeByte('}'); } } - + if (is_dispatch_loop) { + try writer.print("zig_switch_{d}_dispatch_{d}: ", .{ @intFromEnum(inst), switch_br.cases_len }); + } if (else_body.len > 0) { // Note that this must be the last case, so we do not need to use `genBodyResolveState` since // the parent block will do it (because the case body is noreturn). diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index d1ec8eca9f..e5e43bbbf8 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -1721,6 +1721,7 @@ pub const Object = struct { .func_inst_table = .{}, .blocks = .{}, .loops = .{}, + .switch_dispatch_info = .{}, .sync_scope = if (owner_mod.single_threaded) .singlethread else .system, .file = file, .scope = subprogram, @@ -4845,6 +4846,10 @@ pub const FuncGen = struct { /// Maps `loop` instructions to the bb to branch to to repeat the loop. loops: std.AutoHashMapUnmanaged(Air.Inst.Index, Builder.Function.Block.Index), + /// Maps `loop_switch_br` instructions to the information required to lower + /// dispatches (`switch_dispatch` instructions). + switch_dispatch_info: std.AutoHashMapUnmanaged(Air.Inst.Index, SwitchDispatchInfo), + sync_scope: Builder.SyncScope, const Fuzz = struct { @@ -4857,6 +4862,33 @@ pub const FuncGen = struct { } }; + const SwitchDispatchInfo = struct { + /// These are the blocks corresponding to each switch case. + /// The final element corresponds to the `else` case. + /// Slices allocated into `gpa`. + case_blocks: []Builder.Function.Block.Index, + /// This is `.none` if `jmp_table` is set, since we won't use a `switch` instruction to dispatch. + switch_weights: Builder.Function.Instruction.BrCond.Weights, + /// If not `null`, we have manually constructed a jump table to reach the desired block. + /// `table` can be used if the value is between `min` and `max` inclusive. + /// We perform this lowering manually to avoid some questionable behavior from LLVM. + /// See `airSwitchBr` for details. + jmp_table: ?JmpTable, + + const JmpTable = struct { + min: Builder.Constant, + max: Builder.Constant, + in_bounds_hint: enum { none, unpredictable, likely, unlikely }, + /// Pointer to the jump table itself, to be used with `indirectbr`. + /// The index into the jump table is the dispatch condition minus `min`. + /// The table values are `blockaddress` constants corresponding to blocks in `case_blocks`. + table: Builder.Constant, + /// `true` if `table` conatins a reference to the `else` block. + /// In this case, the `indirectbr` must include the `else` block in its target list. + table_includes_else: bool, + }; + }; + const BreakList = union { list: std.MultiArrayList(struct { bb: Builder.Function.Block.Index, @@ -4872,6 +4904,11 @@ pub const FuncGen = struct { self.func_inst_table.deinit(gpa); self.blocks.deinit(gpa); self.loops.deinit(gpa); + var it = self.switch_dispatch_info.valueIterator(); + while (it.next()) |info| { + self.gpa.free(info.case_blocks); + } + self.switch_dispatch_info.deinit(gpa); } fn todo(self: *FuncGen, comptime format: []const u8, args: anytype) Error { @@ -5182,16 +5219,18 @@ pub const FuncGen = struct { .work_group_id => try self.airWorkGroupId(inst), // Instructions that are known to always be `noreturn` based on their tag. - .br => return self.airBr(inst), - .repeat => return self.airRepeat(inst), - .cond_br => return self.airCondBr(inst), - .switch_br => return self.airSwitchBr(inst), - .loop => return self.airLoop(inst), - .ret => return self.airRet(inst, false), - .ret_safe => return self.airRet(inst, true), - .ret_load => return self.airRetLoad(inst), - .trap => return self.airTrap(inst), - .unreach => return self.airUnreach(inst), + .br => return self.airBr(inst), + .repeat => return self.airRepeat(inst), + .switch_dispatch => return self.airSwitchDispatch(inst), + .cond_br => return self.airCondBr(inst), + .switch_br => return self.airSwitchBr(inst, false), + .loop_switch_br => return self.airSwitchBr(inst, true), + .loop => return self.airLoop(inst), + .ret => return self.airRet(inst, false), + .ret_safe => return self.airRet(inst, true), + .ret_load => return self.airRetLoad(inst), + .trap => return self.airTrap(inst), + .unreach => return self.airUnreach(inst), // Instructions which may be `noreturn`. .block => res: { @@ -6093,6 +6132,202 @@ pub const FuncGen = struct { _ = try self.wip.br(loop_bb); } + fn lowerSwitchDispatch( + self: *FuncGen, + switch_inst: Air.Inst.Index, + cond_ref: Air.Inst.Ref, + dispatch_info: SwitchDispatchInfo, + ) !void { + const o = self.ng.object; + const pt = o.pt; + const zcu = pt.zcu; + const cond_ty = self.typeOf(cond_ref); + const switch_br = self.air.unwrapSwitch(switch_inst); + + if (try self.air.value(cond_ref, pt)) |cond_val| { + // Comptime-known dispatch. Iterate the cases to find the correct + // one, and branch to the corresponding element of `case_blocks`. + var it = switch_br.iterateCases(); + const target_case_idx = target: while (it.next()) |case| { + for (case.items) |item| { + const val = Value.fromInterned(item.toInterned().?); + if (cond_val.compareHetero(.eq, val, zcu)) break :target case.idx; + } + for (case.ranges) |range| { + const low = Value.fromInterned(range[0].toInterned().?); + const high = Value.fromInterned(range[1].toInterned().?); + if (cond_val.compareHetero(.gte, low, zcu) and + cond_val.compareHetero(.lte, high, zcu)) + { + break :target case.idx; + } + } + } else dispatch_info.case_blocks.len - 1; + const target_block = dispatch_info.case_blocks[target_case_idx]; + target_block.ptr(&self.wip).incoming += 1; + _ = try self.wip.br(target_block); + return; + } + + // Runtime-known dispatch. + const cond = try self.resolveInst(cond_ref); + + if (dispatch_info.jmp_table) |jmp_table| { + // We should use the constructed jump table. + // First, check the bounds to branch to the `else` case if needed. + const inbounds = try self.wip.bin( + .@"and", + try self.cmp(.normal, .gte, cond_ty, cond, jmp_table.min.toValue()), + try self.cmp(.normal, .lte, cond_ty, cond, jmp_table.max.toValue()), + "", + ); + const jmp_table_block = try self.wip.block(1, "Then"); + const else_block = dispatch_info.case_blocks[dispatch_info.case_blocks.len - 1]; + else_block.ptr(&self.wip).incoming += 1; + _ = try self.wip.brCond(inbounds, jmp_table_block, else_block, switch (jmp_table.in_bounds_hint) { + .none => .none, + .unpredictable => .unpredictable, + .likely => .then_likely, + .unlikely => .else_likely, + }); + + self.wip.cursor = .{ .block = jmp_table_block }; + + // Figure out the list of blocks we might branch to. + // This includes all case blocks, but it might not include the `else` block if + // the table is dense. + const target_blocks_len = dispatch_info.case_blocks.len - @intFromBool(!jmp_table.table_includes_else); + const target_blocks = dispatch_info.case_blocks[0..target_blocks_len]; + + // Make sure to cast the index to a usize so it's not treated as negative! + const table_index = try self.wip.cast( + .zext, + try self.wip.bin(.@"sub nuw", cond, jmp_table.min.toValue(), ""), + try o.lowerType(Type.usize), + "", + ); + const target_ptr_ptr = try self.wip.gep( + .inbounds, + .ptr, + jmp_table.table.toValue(), + &.{table_index}, + "", + ); + const target_ptr = try self.wip.load(.normal, .ptr, target_ptr_ptr, .default, ""); + + // Do the branch! + _ = try self.wip.indirectbr(target_ptr, target_blocks); + + // Mark all target blocks as having one more incoming branch. + for (target_blocks) |case_block| { + case_block.ptr(&self.wip).incoming += 1; + } + + return; + } + + // We must lower to an actual LLVM `switch` instruction. + // The switch prongs will correspond to our scalar cases. Ranges will + // be handled by conditional branches in the `else` prong. + + const llvm_usize = try o.lowerType(Type.usize); + const cond_int = if (cond.typeOfWip(&self.wip).isPointer(&o.builder)) + try self.wip.cast(.ptrtoint, cond, llvm_usize, "") + else + cond; + + const llvm_cases_len, const last_range_case = info: { + var llvm_cases_len: u32 = 0; + var last_range_case: ?u32 = null; + var it = switch_br.iterateCases(); + while (it.next()) |case| { + if (case.ranges.len > 0) last_range_case = case.idx; + llvm_cases_len += @intCast(case.items.len); + } + break :info .{ llvm_cases_len, last_range_case }; + }; + + // The `else` of the LLVM `switch` is the actual `else` prong only + // if there are no ranges. Otherwise, the `else` will have a + // conditional chain before the "true" `else` prong. + const llvm_else_block = if (last_range_case == null) + dispatch_info.case_blocks[dispatch_info.case_blocks.len - 1] + else + try self.wip.block(0, "RangeTest"); + + llvm_else_block.ptr(&self.wip).incoming += 1; + + var wip_switch = try self.wip.@"switch"(cond_int, llvm_else_block, llvm_cases_len, dispatch_info.switch_weights); + defer wip_switch.finish(&self.wip); + + // Construct the actual cases. Set the cursor to the `else` block so + // we can construct ranges at the same time as scalar cases. + self.wip.cursor = .{ .block = llvm_else_block }; + + var it = switch_br.iterateCases(); + while (it.next()) |case| { + const case_block = dispatch_info.case_blocks[case.idx]; + + for (case.items) |item| { + const llvm_item = (try self.resolveInst(item)).toConst().?; + const llvm_int_item = if (llvm_item.typeOf(&o.builder).isPointer(&o.builder)) + try o.builder.castConst(.ptrtoint, llvm_item, llvm_usize) + else + llvm_item; + try wip_switch.addCase(llvm_int_item, case_block, &self.wip); + } + case_block.ptr(&self.wip).incoming += @intCast(case.items.len); + + if (case.ranges.len == 0) continue; + + // Add a conditional for the ranges, directing to the relevant bb. + // We don't need to consider `cold` branch hints since that information is stored + // in the target bb body, but we do care about likely/unlikely/unpredictable. + + const hint = switch_br.getHint(case.idx); + + var range_cond: ?Builder.Value = null; + for (case.ranges) |range| { + const llvm_min = try self.resolveInst(range[0]); + const llvm_max = try self.resolveInst(range[1]); + const cond_part = try self.wip.bin( + .@"and", + try self.cmp(.normal, .gte, cond_ty, cond, llvm_min), + try self.cmp(.normal, .lte, cond_ty, cond, llvm_max), + "", + ); + if (range_cond) |prev| { + range_cond = try self.wip.bin(.@"or", prev, cond_part, ""); + } else range_cond = cond_part; + } + + // If the check fails, we either branch to the "true" `else` case, + // or to the next range condition. + const range_else_block = if (case.idx == last_range_case.?) + dispatch_info.case_blocks[dispatch_info.case_blocks.len - 1] + else + try self.wip.block(0, "RangeTest"); + + _ = try self.wip.brCond(range_cond.?, case_block, range_else_block, switch (hint) { + .none, .cold => .none, + .unpredictable => .unpredictable, + .likely => .then_likely, + .unlikely => .else_likely, + }); + case_block.ptr(&self.wip).incoming += 1; + range_else_block.ptr(&self.wip).incoming += 1; + + // Construct the next range conditional (if any) in the false branch. + self.wip.cursor = .{ .block = range_else_block }; + } + } + + fn airSwitchDispatch(self: *FuncGen, inst: Air.Inst.Index) !void { + const br = self.air.instructions.items(.data)[@intFromEnum(inst)].br; + const dispatch_info = self.switch_dispatch_info.get(br.block_inst).?; + return self.lowerSwitchDispatch(br.block_inst, br.operand, dispatch_info); + } + fn airCondBr(self: *FuncGen, inst: Air.Inst.Index) !void { const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op; const cond = try self.resolveInst(pl_op.operand); @@ -6257,36 +6492,123 @@ pub const FuncGen = struct { return fg.wip.extractValue(err_union, &.{offset}, ""); } - fn airSwitchBr(self: *FuncGen, inst: Air.Inst.Index) !void { + fn airSwitchBr(self: *FuncGen, inst: Air.Inst.Index, is_dispatch_loop: bool) !void { const o = self.ng.object; + const zcu = o.pt.zcu; const switch_br = self.air.unwrapSwitch(inst); - const cond = try self.resolveInst(switch_br.operand); + // For `loop_switch_br`, we need these BBs prepared ahead of time to generate dispatches. + // For `switch_br`, they allow us to sometimes generate better IR by sharing a BB between + // scalar and range cases in the same prong. + // +1 for `else` case. This is not the same as the LLVM `else` prong, as that may first contain + // conditionals to handle ranges. + const case_blocks = try self.gpa.alloc(Builder.Function.Block.Index, switch_br.cases_len + 1); + defer self.gpa.free(case_blocks); + // We set incoming as 0 for now, and increment it as we construct dispatches. + for (case_blocks[0 .. case_blocks.len - 1]) |*b| b.* = try self.wip.block(0, "Case"); + case_blocks[case_blocks.len - 1] = try self.wip.block(0, "Default"); + + // There's a special case here to manually generate a jump table in some cases. + // + // Labeled switch in Zig is intended to follow the "direct threading" pattern. We would ideally use a jump + // table, and each `continue` has its own indirect `jmp`, to allow the branch predictor to more accurately + // use data patterns to predict future dispatches. The problem, however, is that LLVM emits fascinatingly + // bad asm for this. Not only does it not share the jump table -- which we really need it to do to prevent + // destroying the cache -- but it also actually generates slightly different jump tables for each case, + // and *a separate conditional branch beforehand* to handle dispatching back to the case we're currently + // within(!!). + // + // This asm is really, really, not what we want. As such, we will construct the jump table manually where + // appropriate (the values are dense and relatively few), and use it when lowering dispatches. + + const jmp_table: ?SwitchDispatchInfo.JmpTable = jmp_table: { + if (!is_dispatch_loop) break :jmp_table null; + // On a 64-bit target, 1024 pointers in our jump table is about 8K of pointers. This seems just + // about acceptable - it won't fill L1d cache on most CPUs. + const max_table_len = 1024; - // This is not necessarily the actual `else` prong; it first contains conditionals - // for any range cases. It's just the `else` of the LLVM switch. - const llvm_else_block = try self.wip.block(1, "Default"); + const cond_ty = self.typeOf(switch_br.operand); + switch (cond_ty.zigTypeTag(zcu)) { + .bool, .pointer => break :jmp_table null, + .@"enum", .int, .error_set => {}, + else => unreachable, + } - const case_blocks = try self.gpa.alloc(Builder.Function.Block.Index, switch_br.cases_len); - defer self.gpa.free(case_blocks); - // We set incoming as 0 for now, and increment it as we construct the switch. - for (case_blocks) |*b| b.* = try self.wip.block(0, "Case"); + if (cond_ty.intInfo(zcu).signedness == .signed) break :jmp_table null; - const llvm_usize = try o.lowerType(Type.usize); - const cond_int = if (cond.typeOfWip(&self.wip).isPointer(&o.builder)) - try self.wip.cast(.ptrtoint, cond, llvm_usize, "") - else - cond; + // Don't worry about the size of the type -- it's irrelevant, because the prong values could be fairly dense. + // If they are, then we will construct a jump table. + const min, const max = self.switchCaseItemRange(switch_br); + const min_int = min.getUnsignedInt(zcu) orelse break :jmp_table null; + const max_int = max.getUnsignedInt(zcu) orelse break :jmp_table null; + const table_len = max_int - min_int + 1; + if (table_len > max_table_len) break :jmp_table null; + + const table_elems = try self.gpa.alloc(Builder.Constant, @intCast(table_len)); + defer self.gpa.free(table_elems); - const llvm_cases_len = llvm_cases_len: { - var len: u32 = 0; + // Set them all to the `else` branch, then iterate over the AIR switch + // and replace all values which correspond to other prongs. + @memset(table_elems, try o.builder.blockAddrConst( + self.wip.function, + case_blocks[case_blocks.len - 1], + )); + var item_count: u32 = 0; var it = switch_br.iterateCases(); - while (it.next()) |case| len += @intCast(case.items.len); - break :llvm_cases_len len; + while (it.next()) |case| { + const case_block = case_blocks[case.idx]; + const case_block_addr = try o.builder.blockAddrConst( + self.wip.function, + case_block, + ); + for (case.items) |item| { + const val = Value.fromInterned(item.toInterned().?); + const table_idx = val.toUnsignedInt(zcu) - min_int; + table_elems[@intCast(table_idx)] = case_block_addr; + item_count += 1; + } + for (case.ranges) |range| { + const low = Value.fromInterned(range[0].toInterned().?); + const high = Value.fromInterned(range[1].toInterned().?); + const low_idx = low.toUnsignedInt(zcu) - min_int; + const high_idx = high.toUnsignedInt(zcu) - min_int; + @memset(table_elems[@intCast(low_idx)..@intCast(high_idx + 1)], case_block_addr); + item_count += @intCast(high_idx + 1 - low_idx); + } + } + + const table_llvm_ty = try o.builder.arrayType(table_elems.len, .ptr); + const table_val = try o.builder.arrayConst(table_llvm_ty, table_elems); + + const table_variable = try o.builder.addVariable( + try o.builder.strtabStringFmt("__jmptab_{d}", .{@intFromEnum(inst)}), + table_llvm_ty, + .default, + ); + try table_variable.setInitializer(table_val, &o.builder); + table_variable.setLinkage(.internal, &o.builder); + table_variable.setUnnamedAddr(.unnamed_addr, &o.builder); + + const table_includes_else = item_count != table_len; + + break :jmp_table .{ + .min = try o.lowerValue(min.toIntern()), + .max = try o.lowerValue(max.toIntern()), + .in_bounds_hint = if (table_includes_else) .none else switch (switch_br.getElseHint()) { + .none, .cold => .none, + .unpredictable => .unpredictable, + .likely => .likely, + .unlikely => .unlikely, + }, + .table = table_variable.toConst(&o.builder), + .table_includes_else = table_includes_else, + }; }; const weights: Builder.Function.Instruction.BrCond.Weights = weights: { + if (jmp_table != null) break :weights .none; // not used + // First pass. If any weights are `.unpredictable`, unpredictable. // If all are `.none` or `.cold`, none. var any_likely = false; @@ -6304,6 +6626,13 @@ pub const FuncGen = struct { } if (!any_likely) break :weights .none; + const llvm_cases_len = llvm_cases_len: { + var len: u32 = 0; + var it = switch_br.iterateCases(); + while (it.next()) |case| len += @intCast(case.items.len); + break :llvm_cases_len len; + }; + var weights = try self.gpa.alloc(Builder.Metadata, llvm_cases_len + 1); defer self.gpa.free(weights); @@ -6336,75 +6665,66 @@ pub const FuncGen = struct { break :weights @enumFromInt(@intFromEnum(tuple)); }; - var wip_switch = try self.wip.@"switch"(cond_int, llvm_else_block, llvm_cases_len, weights); - defer wip_switch.finish(&self.wip); + const dispatch_info: SwitchDispatchInfo = .{ + .case_blocks = case_blocks, + .switch_weights = weights, + .jmp_table = jmp_table, + }; + + if (is_dispatch_loop) { + try self.switch_dispatch_info.putNoClobber(self.gpa, inst, dispatch_info); + } + defer if (is_dispatch_loop) { + assert(self.switch_dispatch_info.remove(inst)); + }; + + // Generate the initial dispatch. + // If this is a simple `switch_br`, this is the only dispatch. + try self.lowerSwitchDispatch(inst, switch_br.operand, dispatch_info); + // Iterate the cases and generate their bodies. var it = switch_br.iterateCases(); - var any_ranges = false; while (it.next()) |case| { - if (case.ranges.len > 0) any_ranges = true; const case_block = case_blocks[case.idx]; - case_block.ptr(&self.wip).incoming += @intCast(case.items.len); - // Handle scalar items, and generate the block. - // We'll generate conditionals for the ranges later on. - for (case.items) |item| { - const llvm_item = (try self.resolveInst(item)).toConst().?; - const llvm_int_item = if (llvm_item.typeOf(&o.builder).isPointer(&o.builder)) - try o.builder.castConst(.ptrtoint, llvm_item, llvm_usize) - else - llvm_item; - try wip_switch.addCase(llvm_int_item, case_block, &self.wip); - } self.wip.cursor = .{ .block = case_block }; if (switch_br.getHint(case.idx) == .cold) _ = try self.wip.callIntrinsicAssumeCold(); - try self.genBodyDebugScope(null, case.body, .poi); + try self.genBodyDebugScope(null, case.body, .none); } - + self.wip.cursor = .{ .block = case_blocks[case_blocks.len - 1] }; const else_body = it.elseBody(); - self.wip.cursor = .{ .block = llvm_else_block }; - if (any_ranges) { - const cond_ty = self.typeOf(switch_br.operand); - // Add conditionals for the ranges, directing to the relevant bb. - // We don't need to consider `cold` branch hints since that information is stored - // in the target bb body, but we do care about likely/unlikely/unpredictable. - it = switch_br.iterateCases(); - while (it.next()) |case| { - if (case.ranges.len == 0) continue; - const case_block = case_blocks[case.idx]; - const hint = switch_br.getHint(case.idx); - case_block.ptr(&self.wip).incoming += 1; - const next_else_block = try self.wip.block(1, "Default"); - var range_cond: ?Builder.Value = null; - for (case.ranges) |range| { - const llvm_min = try self.resolveInst(range[0]); - const llvm_max = try self.resolveInst(range[1]); - const cond_part = try self.wip.bin( - .@"and", - try self.cmp(.normal, .gte, cond_ty, cond, llvm_min), - try self.cmp(.normal, .lte, cond_ty, cond, llvm_max), - "", - ); - if (range_cond) |prev| { - range_cond = try self.wip.bin(.@"or", prev, cond_part, ""); - } else range_cond = cond_part; - } - _ = try self.wip.brCond(range_cond.?, case_block, next_else_block, switch (hint) { - .none, .cold => .none, - .unpredictable => .unpredictable, - .likely => .then_likely, - .unlikely => .else_likely, - }); - self.wip.cursor = .{ .block = next_else_block }; - } - } if (switch_br.getElseHint() == .cold) _ = try self.wip.callIntrinsicAssumeCold(); - if (else_body.len != 0) { - try self.genBodyDebugScope(null, else_body, .poi); + if (else_body.len > 0) { + try self.genBodyDebugScope(null, it.elseBody(), .none); } else { _ = try self.wip.@"unreachable"(); } + } - // No need to reset the insert cursor since this instruction is noreturn. + fn switchCaseItemRange(self: *FuncGen, switch_br: Air.UnwrappedSwitch) [2]Value { + const zcu = self.ng.object.pt.zcu; + var it = switch_br.iterateCases(); + var min: ?Value = null; + var max: ?Value = null; + while (it.next()) |case| { + for (case.items) |item| { + const val = Value.fromInterned(item.toInterned().?); + const low = if (min) |m| val.compareHetero(.lt, m, zcu) else true; + const high = if (max) |m| val.compareHetero(.gt, m, zcu) else true; + if (low) min = val; + if (high) max = val; + } + for (case.ranges) |range| { + const vals: [2]Value = .{ + Value.fromInterned(range[0].toInterned().?), + Value.fromInterned(range[1].toInterned().?), + }; + const low = if (min) |m| vals[0].compareHetero(.lt, m, zcu) else true; + const high = if (max) |m| vals[1].compareHetero(.gt, m, zcu) else true; + if (low) min = vals[0]; + if (high) max = vals[1]; + } + } + return .{ min.?, max.? }; } fn airLoop(self: *FuncGen, inst: Air.Inst.Index) !void { |
