diff options
| author | Pavel Verigo <paul.verigo@gmail.com> | 2024-09-18 16:55:50 +0200 |
|---|---|---|
| committer | Andrew Kelley <andrew@ziglang.org> | 2025-02-22 18:34:00 -0500 |
| commit | b25d93e7d95418ea92b388ff8b58a04673c04539 (patch) | |
| tree | 81a486b469fdffd0078d6289b1fe5390e1ab614a /src/arch/wasm/CodeGen.zig | |
| parent | 61b69a418db2f525c4be4e19029487f61fb3234e (diff) | |
| download | zig-b25d93e7d95418ea92b388ff8b58a04673c04539.tar.gz zig-b25d93e7d95418ea92b388ff8b58a04673c04539.zip | |
stage2-wasm: implement switch_dispatch + handle > 32 bit integers in switches
Updated solution is future proof for arbitary size integer handling for both strategies .br_table lowering if switch case is dense, .br_if base jump table if values are too sparse.
Diffstat (limited to 'src/arch/wasm/CodeGen.zig')
| -rw-r--r-- | src/arch/wasm/CodeGen.zig | 388 |
1 files changed, 185 insertions, 203 deletions
diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index 92b860926d..641347bee1 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -1925,7 +1925,7 @@ fn genInst(cg: *CodeGen, inst: Air.Inst.Index) InnerError!void { .breakpoint => cg.airBreakpoint(inst), .br => cg.airBr(inst), .repeat => cg.airRepeat(inst), - .switch_dispatch => return cg.fail("TODO implement `switch_dispatch`", .{}), + .switch_dispatch => cg.airSwitchDispatch(inst), .cond_br => cg.airCondBr(inst), .intcast => cg.airIntcast(inst), .fptrunc => cg.airFptrunc(inst), @@ -2005,8 +2005,8 @@ fn genInst(cg: *CodeGen, inst: Air.Inst.Index) InnerError!void { .struct_field_val => cg.airStructFieldVal(inst), .field_parent_ptr => cg.airFieldParentPtr(inst), - .switch_br => cg.airSwitchBr(inst), - .loop_switch_br => return cg.fail("TODO implement `loop_switch_br`", .{}), + .switch_br => cg.airSwitchBr(inst, false), + .loop_switch_br => cg.airSwitchBr(inst, true), .trunc => cg.airTrunc(inst), .unreach => cg.airUnreachable(inst), @@ -3356,43 +3356,6 @@ fn emitUndefined(cg: *CodeGen, ty: Type) InnerError!WValue { } } -/// Returns a `Value` as a signed 32 bit value. -/// It's illegal to provide a value with a type that cannot be represented -/// as an integer value. -fn valueAsI32(cg: *const CodeGen, val: Value) i32 { - const zcu = cg.pt.zcu; - const ip = &zcu.intern_pool; - - switch (val.toIntern()) { - .bool_true => return 1, - .bool_false => return 0, - else => return switch (ip.indexToKey(val.ip_index)) { - .enum_tag => |enum_tag| intIndexAsI32(ip, enum_tag.int, zcu), - .int => |int| intStorageAsI32(int.storage, zcu), - .ptr => |ptr| { - assert(ptr.base_addr == .int); - return @intCast(ptr.byte_offset); - }, - .err => |err| @bitCast(ip.getErrorValueIfExists(err.name).?), - else => unreachable, - }, - } -} - -fn intIndexAsI32(ip: *const InternPool, int: InternPool.Index, zcu: *const Zcu) i32 { - return intStorageAsI32(ip.indexToKey(int).int.storage, zcu); -} - -fn intStorageAsI32(storage: InternPool.Key.Int.Storage, zcu: *const Zcu) i32 { - return switch (storage) { - .i64 => |x| @as(i32, @intCast(x)), - .u64 => |x| @as(i32, @bitCast(@as(u32, @intCast(x)))), - .big_int => unreachable, - .lazy_align => |ty| @as(i32, @bitCast(@as(u32, @intCast(Type.fromInterned(ty).abiAlignment(zcu).toByteUnits() orelse 0)))), - .lazy_size => |ty| @as(i32, @bitCast(@as(u32, @intCast(Type.fromInterned(ty).abiSize(zcu))))), - }; -} - fn airBlock(cg: *CodeGen, inst: Air.Inst.Index) InnerError!void { const ty_pl = cg.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; const extra = cg.air.extraData(Air.Block, ty_pl.payload); @@ -3401,13 +3364,11 @@ fn airBlock(cg: *CodeGen, inst: Air.Inst.Index) InnerError!void { fn lowerBlock(cg: *CodeGen, inst: Air.Inst.Index, block_ty: Type, body: []const Air.Inst.Index) InnerError!void { const zcu = cg.pt.zcu; - const wasm_block_ty = genBlockType(block_ty, zcu, cg.target); - // if wasm_block_ty is non-empty, we create a register to store the temporary value - const block_result: WValue = if (wasm_block_ty != .empty) blk: { - const ty: Type = if (isByRef(block_ty, zcu, cg.target)) Type.u32 else block_ty; - break :blk try cg.ensureAllocLocal(ty); // make sure it's a clean local as it may never get overwritten - } else .none; + const block_result: WValue = if (block_ty.hasRuntimeBitsIgnoreComptime(zcu)) + try cg.allocLocal(block_ty) + else + .none; try cg.startBlock(.block, .empty); // Here we set the current block idx, so breaks know the depth to jump @@ -3621,18 +3582,14 @@ fn airCmpLtErrorsLen(cg: *CodeGen, inst: Air.Inst.Index) InnerError!void { } fn airBr(cg: *CodeGen, inst: Air.Inst.Index) InnerError!void { - const zcu = cg.pt.zcu; const br = cg.air.instructions.items(.data)[@intFromEnum(inst)].br; const block = cg.blocks.get(br.block_inst).?; // if operand has codegen bits we should break with a value - if (cg.typeOf(br.operand).hasRuntimeBitsIgnoreComptime(zcu)) { + if (block.value != .none) { const operand = try cg.resolveInst(br.operand); try cg.lowerToStack(operand); - - if (block.value != .none) { - try cg.addLocal(.local_set, block.value.local.value); - } + try cg.addLocal(.local_set, block.value.local.value); } // We map every block to its block index. @@ -3957,189 +3914,192 @@ fn airStructFieldVal(cg: *CodeGen, inst: Air.Inst.Index) InnerError!void { return cg.finishAir(inst, result, &.{struct_field.struct_operand}); } -fn airSwitchBr(cg: *CodeGen, inst: Air.Inst.Index) InnerError!void { +fn airSwitchBr(cg: *CodeGen, inst: Air.Inst.Index, is_dispatch_loop: bool) InnerError!void { const pt = cg.pt; const zcu = pt.zcu; - // result type is always 'noreturn' - const blocktype: std.wasm.BlockType = .empty; + const switch_br = cg.air.unwrapSwitch(inst); - const target = try cg.resolveInst(switch_br.operand); const target_ty = cg.typeOf(switch_br.operand); + + assert(target_ty.hasRuntimeBitsIgnoreComptime(zcu)); + + // swap target value with placeholder local, for dispatching + const target = if (is_dispatch_loop) target: { + const initial_target = try cg.resolveInst(switch_br.operand); + const target: WValue = try cg.allocLocal(target_ty); + try cg.lowerToStack(initial_target); + try cg.addLocal(.local_set, target.local.value); + + try cg.startBlock(.loop, .empty); // dispatch loop start + try cg.blocks.putNoClobber(cg.gpa, inst, .{ + .label = cg.block_depth, + .value = target, + }); + + break :target target; + } else try cg.resolveInst(switch_br.operand); + const liveness = try cg.liveness.getSwitchBr(cg.gpa, inst, switch_br.cases_len + 1); defer cg.gpa.free(liveness.deaths); - // a list that maps each value with its value and body based on the order inside the list. - const CaseValue = union(enum) { - singular: struct { integer: i32, value: Value }, - range: struct { min: i32, min_value: Value, max: i32, max_value: Value }, - }; - var case_list = try std.ArrayList(struct { - values: []const CaseValue, - body: []const Air.Inst.Index, - }).initCapacity(cg.gpa, switch_br.cases_len); - defer for (case_list.items) |case| { - cg.gpa.free(case.values); - } else case_list.deinit(); - - var lowest_maybe: ?i32 = null; - var highest_maybe: ?i32 = null; - var it = switch_br.iterateCases(); - while (it.next()) |case| { - const values = try cg.gpa.alloc(CaseValue, case.items.len + case.ranges.len); - errdefer cg.gpa.free(values); - - for (case.items, 0..) |ref, i| { - const item_val = (try cg.air.value(ref, pt)).?; - const int_val = cg.valueAsI32(item_val); - if (lowest_maybe == null or int_val < lowest_maybe.?) { - lowest_maybe = int_val; - } - if (highest_maybe == null or int_val > highest_maybe.?) { - highest_maybe = int_val; - } - values[i] = .{ .singular = .{ .integer = int_val, .value = item_val } }; + const has_else_body = switch_br.else_body_len != 0; + const branch_count = switch_br.cases_len + 1; // if else branch is missing, we trap when failing all conditions + try cg.branches.ensureUnusedCapacity(cg.gpa, switch_br.cases_len + @intFromBool(has_else_body)); + + if (switch_br.cases_len == 0) { + assert(has_else_body); + + var it = switch_br.iterateCases(); + const else_body = it.elseBody(); + + cg.branches.appendAssumeCapacity(.{}); + const else_deaths = liveness.deaths.len - 1; + try cg.currentBranch().values.ensureUnusedCapacity(cg.gpa, liveness.deaths[else_deaths].len); + defer { + var else_branch = cg.branches.pop().?; + else_branch.deinit(cg.gpa); + } + try cg.genBody(else_body); + + if (is_dispatch_loop) { + try cg.endBlock(); // dispatch loop end } + return cg.finishAir(inst, .none, &.{}); + } - for (case.ranges, 0..) |range, i| { - const min_val = (try cg.air.value(range[0], pt)).?; - const int_min_val = cg.valueAsI32(min_val); + var min: ?Value = null; + var max: ?Value = null; + var branching_size: u32 = 0; // single item +1, range +2 - if (lowest_maybe == null or int_min_val < lowest_maybe.?) { - lowest_maybe = int_min_val; + { + var cases_it = switch_br.iterateCases(); + while (cases_it.next()) |case| { + for (case.items) |item| { + const val = Value.fromInterned(item.toInterned().?); + if (min == null or val.compareHetero(.lt, min.?, zcu)) min = val; + if (max == null or val.compareHetero(.gt, max.?, zcu)) max = val; + branching_size += 1; + } + for (case.ranges) |range| { + const low = Value.fromInterned(range[0].toInterned().?); + if (min == null or low.compareHetero(.lt, min.?, zcu)) min = low; + const high = Value.fromInterned(range[1].toInterned().?); + if (max == null or high.compareHetero(.gt, max.?, zcu)) max = high; + branching_size += 2; } + } + } - const max_val = (try cg.air.value(range[1], pt)).?; - const int_max_val = cg.valueAsI32(max_val); + var min_space: Value.BigIntSpace = undefined; + const min_bigint = min.?.toBigInt(&min_space, zcu); + var max_space: Value.BigIntSpace = undefined; + const max_bigint = max.?.toBigInt(&max_space, zcu); + const limbs = try cg.gpa.alloc( + std.math.big.Limb, + @max(min_bigint.limbs.len, max_bigint.limbs.len) + 1, + ); + defer cg.gpa.free(limbs); - if (highest_maybe == null or int_max_val > highest_maybe.?) { - highest_maybe = int_max_val; - } + const width_maybe: ?u32 = width: { + var width_bigint: std.math.big.int.Mutable = .{ .limbs = limbs, .positive = undefined, .len = undefined }; + width_bigint.sub(max_bigint, min_bigint); + width_bigint.addScalar(width_bigint.toConst(), 1); + break :width width_bigint.toConst().to(u32) catch null; + }; - values[i + case.items.len] = .{ .range = .{ - .min = int_min_val, - .min_value = min_val, - .max = int_max_val, - .max_value = max_val, - } }; - } + try cg.startBlock(.block, .empty); // whole switch block start - case_list.appendAssumeCapacity(.{ .values = values, .body = case.body }); - try cg.startBlock(.block, blocktype); + for (0..branch_count) |_| { + try cg.startBlock(.block, .empty); } - // When highest and lowest are null, we have no cases and can use a jump table - const lowest = lowest_maybe orelse 0; - const highest = highest_maybe orelse 0; - // When the highest and lowest values are seperated by '50', - // we define it as sparse and use an if/else-chain, rather than a jump table. - // When the target is an integer size larger than u32, we have no way to use the value - // as an index, therefore we also use an if/else-chain for those cases. - // TODO: Benchmark this to find a proper value, LLVM seems to draw the line at '40~45'. - const is_sparse = highest - lowest > 50 or target_ty.bitSize(zcu) > 32; + // Heuristic on deciding when to use .br_table instead of .br_if jump table + // 1. Differences between lowest and highest values should fit into u32 + // 2. .br_table should be applied for "dense" switch, we test it by checking .br_if jumps will need more instructions + // 3. Do not use .br_table for tiny switches + const use_br_table = cond: { + const width = width_maybe orelse break :cond false; + if (width > 2 * branching_size) break :cond false; + if (width < 2 or branch_count < 2) break :cond false; + break :cond true; + }; - const else_body = it.elseBody(); - const has_else_body = else_body.len != 0; - if (has_else_body) { - try cg.startBlock(.block, blocktype); - } - - if (!is_sparse) { - // Generate the jump table 'br_table' when the prongs are not sparse. - // The value 'target' represents the index into the table. - // Each index in the table represents a label to the branch - // to jump to. - try cg.startBlock(.block, blocktype); - try cg.emitWValue(target); - if (lowest < 0) { - // since br_table works using indexes, starting from '0', we must ensure all values - // we put inside, are atleast 0. - try cg.addImm32(@bitCast(lowest * -1)); - try cg.addTag(.i32_add); - } else if (lowest > 0) { - // make the index start from 0 by substracting the lowest value - try cg.addImm32(@bitCast(lowest)); - try cg.addTag(.i32_sub); - } + if (use_br_table) { + const width = width_maybe.?; + + const br_value_original = try cg.binOp(target, try cg.resolveInst(Air.internedToRef(min.?.toIntern())), target_ty, .sub); + _ = try cg.intcast(br_value_original, target_ty, Type.u32); - // Account for default branch so always add '1' - const depth = @as(u32, @intCast(highest - lowest + @intFromBool(has_else_body))) + 1; - const jump_table: Mir.JumpTable = .{ .length = depth }; + const jump_table: Mir.JumpTable = .{ .length = width + 1 }; const table_extra_index = try cg.addExtra(jump_table); try cg.addInst(.{ .tag = .br_table, .data = .{ .payload = table_extra_index } }); - try cg.mir_extra.ensureUnusedCapacity(cg.gpa, depth); - var value = lowest; - while (value <= highest) : (value += 1) { - // idx represents the branch we jump to - const idx = blk: { - for (case_list.items, 0..) |case, idx| { - for (case.values) |case_value| { - switch (case_value) { - .singular => |val| if (val.integer == value) break :blk @as(u32, @intCast(idx)), - .range => |range_val| if (value >= range_val.min and value <= range_val.max) { - break :blk @as(u32, @intCast(idx)); - }, - } - } - } - // error sets are almost always sparse so we use the default case - // for errors that are not present in any branch. This is fine as this default - // case will never be hit for those cases but we do save runtime cost and size - // by using a jump table for this instead of if-else chains. - break :blk if (has_else_body or target_ty.zigTypeTag(zcu) == .error_set) switch_br.cases_len else unreachable; - }; - cg.mir_extra.appendAssumeCapacity(idx); - } else if (has_else_body) { - cg.mir_extra.appendAssumeCapacity(switch_br.cases_len); // default branch - } - try cg.endBlock(); - } - try cg.branches.ensureUnusedCapacity(cg.gpa, case_list.items.len + @intFromBool(has_else_body)); - for (case_list.items, 0..) |case, index| { - // when sparse, we use if/else-chain, so emit conditional checks - if (is_sparse) { - // for single value prong we can emit a simple condition - if (case.values.len == 1 and case.values[0] == .singular) { - const val = try cg.lowerConstant(case.values[0].singular.value, target_ty); - // not equal, because we want to jump out of this block if it does not match the condition. - _ = try cg.cmp(target, val, target_ty, .neq); - try cg.addLabel(.br_if, 0); - } else { - // in multi-value prongs we must check if any prongs match the target value. - try cg.startBlock(.block, blocktype); - for (case.values) |value| { - switch (value) { - .singular => |single_val| { - const val = try cg.lowerConstant(single_val.value, target_ty); - _ = try cg.cmp(target, val, target_ty, .eq); - }, - .range => |range| { - const min_val = try cg.lowerConstant(range.min_value, target_ty); - const max_val = try cg.lowerConstant(range.max_value, target_ty); - - const gte = try cg.cmp(target, min_val, target_ty, .gte); - const lte = try cg.cmp(target, max_val, target_ty, .lte); - _ = try cg.binOp(gte, lte, Type.bool, .@"and"); - }, - } - try cg.addLabel(.br_if, 0); - } - // value did not match any of the prong values - try cg.addLabel(.br, 1); - try cg.endBlock(); + const branch_list = try cg.mir_extra.addManyAsSlice(cg.gpa, width + 1); + @memset(branch_list, branch_count - 1); + + var cases_it = switch_br.iterateCases(); + while (cases_it.next()) |case| { + for (case.items) |item| { + const val = Value.fromInterned(item.toInterned().?); + var val_space: Value.BigIntSpace = undefined; + const val_bigint = val.toBigInt(&val_space, zcu); + var index_bigint: std.math.big.int.Mutable = .{ .limbs = limbs, .positive = undefined, .len = undefined }; + index_bigint.sub(val_bigint, min_bigint); + branch_list[index_bigint.toConst().to(u32) catch unreachable] = case.idx; + } + for (case.ranges) |range| { + var low_space: Value.BigIntSpace = undefined; + const low_bigint = Value.fromInterned(range[0].toInterned().?).toBigInt(&low_space, zcu); + var high_space: Value.BigIntSpace = undefined; + const high_bigint = Value.fromInterned(range[1].toInterned().?).toBigInt(&high_space, zcu); + var index_bigint: std.math.big.int.Mutable = .{ .limbs = limbs, .positive = undefined, .len = undefined }; + index_bigint.sub(low_bigint, min_bigint); + const start = index_bigint.toConst().to(u32) catch unreachable; + index_bigint.sub(high_bigint, min_bigint); + const end = (index_bigint.toConst().to(u32) catch unreachable) + 1; + @memset(branch_list[start..end], case.idx); + } + } + } else { + var cases_it = switch_br.iterateCases(); + while (cases_it.next()) |case| { + for (case.items) |ref| { + const val = try cg.resolveInst(ref); + _ = try cg.cmp(target, val, target_ty, .eq); + try cg.addLabel(.br_if, case.idx); // item match found + } + for (case.ranges) |range| { + const low = try cg.resolveInst(range[0]); + const high = try cg.resolveInst(range[1]); + + const gte = try cg.cmp(target, low, target_ty, .gte); + const lte = try cg.cmp(target, high, target_ty, .lte); + _ = try cg.binOp(gte, lte, Type.bool, .@"and"); + try cg.addLabel(.br_if, case.idx); // range match found } } + try cg.addLabel(.br, branch_count - 1); + } + + var cases_it = switch_br.iterateCases(); + while (cases_it.next()) |case| { + try cg.endBlock(); + cg.branches.appendAssumeCapacity(.{}); - try cg.currentBranch().values.ensureUnusedCapacity(cg.gpa, liveness.deaths[index].len); + try cg.currentBranch().values.ensureUnusedCapacity(cg.gpa, liveness.deaths[case.idx].len); defer { var case_branch = cg.branches.pop().?; case_branch.deinit(cg.gpa); } try cg.genBody(case.body); - try cg.endBlock(); + + try cg.addLabel(.br, branch_count - case.idx - 1); // matching case found and executed => exit switch } + try cg.endBlock(); if (has_else_body) { + const else_body = cases_it.elseBody(); + cg.branches.appendAssumeCapacity(.{}); const else_deaths = liveness.deaths.len - 1; try cg.currentBranch().values.ensureUnusedCapacity(cg.gpa, liveness.deaths[else_deaths].len); @@ -4148,11 +4108,33 @@ fn airSwitchBr(cg: *CodeGen, inst: Air.Inst.Index) InnerError!void { else_branch.deinit(cg.gpa); } try cg.genBody(else_body); - try cg.endBlock(); + } else { + try cg.addTag(.@"unreachable"); } + + try cg.endBlock(); // whole switch block end + + if (is_dispatch_loop) { + try cg.endBlock(); // dispatch loop end + } + return cg.finishAir(inst, .none, &.{}); } +fn airSwitchDispatch(cg: *CodeGen, inst: Air.Inst.Index) InnerError!void { + const br = cg.air.instructions.items(.data)[@intFromEnum(inst)].br; + const switch_loop = cg.blocks.get(br.block_inst).?; + + const operand = try cg.resolveInst(br.operand); + try cg.lowerToStack(operand); + try cg.addLocal(.local_set, switch_loop.value.local.value); + + const idx: u32 = cg.block_depth - switch_loop.label; + try cg.addLabel(.br, idx); + + return cg.finishAir(inst, .none, &.{br.operand}); +} + fn airIsErr(cg: *CodeGen, inst: Air.Inst.Index, opcode: std.wasm.Opcode) InnerError!void { const zcu = cg.pt.zcu; const un_op = cg.air.instructions.items(.data)[@intFromEnum(inst)].un_op; |
