diff options
| author | dweiller <4678790+dweiller@users.noreply.github.com> | 2023-12-15 13:35:49 +1100 |
|---|---|---|
| committer | dweiller <4678790+dweiller@users.noreply.github.com> | 2024-01-09 14:42:12 +1100 |
| commit | 6a18cee3af8021bcebbca40413056b18f33af8c7 (patch) | |
| tree | e7675a562337ed5f1fed4d86c62872c7b5893413 /src | |
| parent | b7eb59fc140f3263b608a80fbe4e1ab56e04b318 (diff) | |
| download | zig-6a18cee3af8021bcebbca40413056b18f33af8c7.tar.gz zig-6a18cee3af8021bcebbca40413056b18f33af8c7.zip | |
astgen/sema: use switch_block_err_union for if-else-switch
Diffstat (limited to 'src')
| -rw-r--r-- | src/AstGen.zig | 165 | ||||
| -rw-r--r-- | src/Sema.zig | 48 | ||||
| -rw-r--r-- | src/Zir.zig | 3 |
3 files changed, 186 insertions, 30 deletions
diff --git a/src/AstGen.zig b/src/AstGen.zig index f6397fe9b1..698d3ce950 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -839,7 +839,18 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE .if_simple, .@"if", - => return ifExpr(gz, scope, ri.br(), node, tree.fullIf(node).?), + => { + const if_full = tree.fullIf(node).?; + if (if_full.error_token) |error_token| { + const tag = node_tags[if_full.ast.else_expr]; + if ((tag == .@"switch" or tag == .switch_comma) and + std.mem.eql(u8, tree.tokenSlice(error_token), tree.tokenSlice(error_token + 4))) + { + return switchExprErrUnion(gz, scope, ri.br(), node, .@"if"); + } + } + return ifExpr(gz, scope, ri.br(), node, if_full); + }, .while_simple, .while_cont, @@ -1020,7 +1031,7 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE token_tags[catch_token + 4] == .keyword_switch) { if (std.mem.eql(u8, tree.tokenSlice(catch_token + 2), tree.tokenSlice(catch_token + 6))) { - return switchExprErrUnion(gz, scope, ri.br(), node); + return switchExprErrUnion(gz, scope, ri.br(), node, .@"catch"); } } break :blk catch_token + 2; @@ -6869,7 +6880,8 @@ fn switchExprErrUnion( parent_gz: *GenZir, scope: *Scope, ri: ResultInfo, - catch_node: Ast.Node.Index, + catch_or_if_node: Ast.Node.Index, + node_ty: enum { @"catch", @"if" }, ) InnerError!Zir.Inst.Ref { const astgen = parent_gz.astgen; const gpa = astgen.gpa; @@ -6878,21 +6890,42 @@ fn switchExprErrUnion( const node_tags = tree.nodes.items(.tag); const main_tokens = tree.nodes.items(.main_token); const token_tags = tree.tokens.items(.tag); - const operand_node = node_datas[catch_node].lhs; - const switch_node = node_datas[catch_node].rhs; + + const if_full = switch (node_ty) { + .@"catch" => undefined, + .@"if" => tree.fullIf(catch_or_if_node).?, + }; + + const switch_node, const operand_node, const error_payload = switch (node_ty) { + .@"catch" => .{ + node_datas[catch_or_if_node].rhs, + node_datas[catch_or_if_node].lhs, + main_tokens[catch_or_if_node] + 2, + }, + .@"if" => .{ + if_full.ast.else_expr, + if_full.ast.cond_expr, + if_full.error_token.?, + }, + }; + assert(node_tags[switch_node] == .@"switch" or node_tags[switch_node] == .switch_comma); + const extra = tree.extraData(node_datas[switch_node].rhs, Ast.Node.SubRange); const case_nodes = tree.extra_data[extra.start..extra.end]; - const need_rl = astgen.nodes_need_rl.contains(catch_node); + const need_rl = astgen.nodes_need_rl.contains(catch_or_if_node); const block_ri: ResultInfo = if (need_rl) ri else .{ .rl = switch (ri.rl) { - .ptr => .{ .ty = (try ri.rl.resultType(parent_gz, catch_node)).? }, + .ptr => .{ .ty = (try ri.rl.resultType(parent_gz, catch_or_if_node)).? }, .inferred_ptr => .none, else => ri.rl, }, .ctx = ri.ctx, }; + const payload_is_ref = node_ty == .@"if" and + if_full.payload_token != null and token_tags[if_full.payload_token.?] == .asterisk; + // We need to call `rvalue` to write through to the pointer only if we had a // result pointer and aren't forwarding it. const LocTag = @typeInfo(ResultInfo.Loc).Union.tag_type.?; @@ -6960,12 +6993,15 @@ fn switchExprErrUnion( } } - const operand_ri: ResultInfo = .{ .rl = .none, .ctx = .error_handling_expr }; + const operand_ri: ResultInfo = .{ + .rl = if (payload_is_ref) .ref else .none, + .ctx = .error_handling_expr, + }; astgen.advanceSourceCursorToNode(operand_node); const operand_lc = LineColumn{ astgen.source_line - parent_gz.decl_line, astgen.source_column }; - const raw_operand = try reachableExpr(parent_gz, scope, operand_ri, operand_node, node_datas[catch_node].rhs); + const raw_operand = try reachableExpr(parent_gz, scope, operand_ri, operand_node, switch_node); const item_ri: ResultInfo = .{ .rl = .none }; // This contains the data that goes into the `extra` array for the SwitchBlockErrUnion, except @@ -7006,13 +7042,93 @@ fn switchExprErrUnion( try case_scope.addDbgBlockBegin(); - const unwrapped_payload = try case_scope.addUnNode(.err_union_payload_unsafe, raw_operand, catch_node); - const case_result = switch (ri.rl) { - .ref, .ref_coerced_ty => unwrapped_payload, - else => try rvalue(&case_scope, block_scope.break_result_info, unwrapped_payload, catch_node), - }; - try case_scope.addDbgBlockEnd(); - _ = try case_scope.addBreakWithSrcNode(.@"break", switch_block, case_result, catch_node); + const unwrap_payload_tag: Zir.Inst.Tag = if (payload_is_ref) + .err_union_payload_unsafe_ptr + else + .err_union_payload_unsafe; + + const unwrapped_payload = try case_scope.addUnNode( + unwrap_payload_tag, + raw_operand, + catch_or_if_node, + ); + + switch (node_ty) { + .@"catch" => { + const case_result = switch (ri.rl) { + .ref, .ref_coerced_ty => unwrapped_payload, + else => try rvalue( + &case_scope, + block_scope.break_result_info, + unwrapped_payload, + catch_or_if_node, + ), + }; + try case_scope.addDbgBlockEnd(); + _ = try case_scope.addBreakWithSrcNode( + .@"break", + switch_block, + case_result, + catch_or_if_node, + ); + }, + .@"if" => { + var payload_val_scope: Scope.LocalVal = undefined; + + try case_scope.addDbgBlockBegin(); + const then_node = if_full.ast.then_expr; + const then_sub_scope = s: { + assert(if_full.error_token != null); + if (if_full.payload_token) |payload_token| { + const token_name_index = payload_token + @intFromBool(payload_is_ref); + const ident_name = try astgen.identAsString(token_name_index); + const token_name_str = tree.tokenSlice(token_name_index); + if (mem.eql(u8, "_", token_name_str)) + break :s &case_scope.base; + try astgen.detectLocalShadowing( + &case_scope.base, + ident_name, + token_name_index, + token_name_str, + .capture, + ); + payload_val_scope = .{ + .parent = &case_scope.base, + .gen_zir = &case_scope, + .name = ident_name, + .inst = unwrapped_payload, + .token_src = payload_token, + .id_cat = .capture, + }; + try case_scope.addDbgVar(.dbg_var_val, ident_name, unwrapped_payload); + break :s &payload_val_scope.base; + } else { + _ = try case_scope.addUnNode( + .ensure_err_union_payload_void, + raw_operand, + catch_or_if_node, + ); + break :s &case_scope.base; + } + }; + const then_result = try expr( + &case_scope, + then_sub_scope, + block_scope.break_result_info, + then_node, + ); + try checkUsed(parent_gz, &case_scope.base, then_sub_scope); + if (!case_scope.endsWithNoReturn()) { + try case_scope.addDbgBlockEnd(); + _ = try case_scope.addBreakWithSrcNode( + .@"break", + switch_block, + then_result, + then_node, + ); + } + }, + } const case_slice = case_scope.instructionsSlice(); // Since we use the switch_block_err_union instruction itself to refer @@ -7029,9 +7145,18 @@ fn switchExprErrUnion( }; const body_len = refs_len + astgen.countBodyLenAfterFixups(case_slice); try payloads.ensureUnusedCapacity(gpa, body_len); + const capture: Zir.Inst.SwitchBlock.ProngInfo.Capture = switch (node_ty) { + .@"catch" => .none, + .@"if" => if (if_full.payload_token == null) + .none + else if (payload_is_ref) + .by_ref + else + .by_val, + }; payloads.items[body_len_index] = @bitCast(Zir.Inst.SwitchBlock.ProngInfo{ .body_len = @intCast(body_len), - .capture = .none, + .capture = capture, .is_inline = false, .has_tag_capture = false, }); @@ -7041,8 +7166,7 @@ fn switchExprErrUnion( appendBodyWithFixupsArrayList(astgen, payloads, case_slice); } - const err_name, const error_payload = blk: { - const error_payload = main_tokens[catch_node] + 2; + const err_name = blk: { const err_str = tree.tokenSlice(error_payload); if (mem.eql(u8, err_str, "_")) { return astgen.failTok(error_payload, "discard of error capture; omit it instead", .{}); @@ -7050,7 +7174,7 @@ fn switchExprErrUnion( const err_name = try astgen.identAsString(error_payload); try astgen.detectLocalShadowing(scope, err_name, error_payload, err_str, .capture); - break :blk .{ err_name, error_payload }; + break :blk err_name; }; // allocate a shared dummy instruction for the error capture @@ -7241,6 +7365,7 @@ fn switchExprErrUnion( .has_else = has_else, .scalar_cases_len = @intCast(scalar_cases_len), .any_uses_err_capture = any_uses_err_capture, + .payload_is_ref = payload_is_ref, }, }); diff --git a/src/Sema.zig b/src/Sema.zig index 84ea83f56b..fc8e9d801a 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -8939,10 +8939,14 @@ fn zirErrUnionCodePtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileE const tracy = trace(@src()); defer tracy.end(); - const mod = sema.mod; const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node; const src = inst_data.src(); const operand = try sema.resolveInst(inst_data.operand); + return sema.analyzeErrUnionCodePtr(block, src, operand); +} + +fn analyzeErrUnionCodePtr(sema: *Sema, block: *Block, src: LazySrcLoc, operand: Air.Inst.Ref) CompileError!Air.Inst.Ref { + const mod = sema.mod; const operand_ty = sema.typeOf(operand); assert(operand_ty.zigTypeTag(mod) == .Pointer); @@ -8957,7 +8961,10 @@ fn zirErrUnionCodePtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileE if (try sema.resolveDefinedValue(block, src, operand)) |pointer_val| { if (try sema.pointerDeref(block, src, pointer_val, operand_ty)) |val| { assert(val.getErrorName(mod) != .none); - return Air.internedToRef(val.toIntern()); + return Air.internedToRef((try mod.intern(.{ .err = .{ + .ty = result_ty.toIntern(), + .name = mod.intern_pool.indexToKey(val.toIntern()).error_union.val.err_name, + } }))); } } @@ -11174,7 +11181,6 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp const extra = sema.code.extraData(Zir.Inst.SwitchBlockErrUnion, inst_data.payload_index); const raw_operand_val = try sema.resolveInst(extra.data.operand); - assert(sema.typeOf(raw_operand_val).zigTypeTag(mod) == .ErrorUnion); // AstGen guarantees that the instruction immediately preceding // switch_block_err_union is a dbg_stmt @@ -11205,6 +11211,7 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp const NonError = struct { body: []const Zir.Inst.Index, end: usize, + capture: Zir.Inst.SwitchBlock.ProngInfo.Capture, }; const non_error_case: NonError = non_error: { @@ -11213,6 +11220,7 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp break :non_error .{ .body = sema.code.bodySlice(extra_body_start, info.body_len), .end = extra_body_start + info.body_len, + .capture = info.capture, }; }; @@ -11237,7 +11245,7 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp .body = sema.code.bodySlice(extra_body_start, info.body_len), .end = extra_body_start + info.body_len, .is_inline = info.is_inline, - .has_capture = info.capture == .by_val, + .has_capture = info.capture != .none, }; }; @@ -11245,7 +11253,10 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp defer seen_errors.deinit(); const operand_ty = sema.typeOf(raw_operand_val); - const operand_err_set_ty = operand_ty.errorUnionSet(mod); + const operand_err_set_ty = if (extra.data.bits.payload_is_ref) + operand_ty.childType(mod).errorUnionSet(mod) + else + operand_ty.errorUnionSet(mod); const block_inst: Air.Inst.Index = @enumFromInt(sema.air_instructions.len); try sema.air_instructions.append(gpa, .{ @@ -11313,7 +11324,12 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp .tag_capture_inst = undefined, }; - if (try sema.resolveDefinedValue(&child_block, src, raw_operand_val)) |operand_val| { + if (try sema.resolveDefinedValue(&child_block, src, raw_operand_val)) |ov| { + const operand_val = if (extra.data.bits.payload_is_ref) + (try sema.pointerDeref(&child_block, src, ov, operand_ty)).? + else + ov; + if (operand_val.errorUnionIsPayload(mod)) { return sema.resolveBlockBody(block, operand_src, &child_block, non_error_case.body, inst, merges); } else { @@ -11323,7 +11339,10 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp .name = operand_val.getErrorName(mod).unwrap().?, }, })); - spa.operand = try sema.analyzeErrUnionCode(block, operand_src, raw_operand_val); + spa.operand = if (extra.data.bits.payload_is_ref) + try sema.analyzeErrUnionCodePtr(block, operand_src, raw_operand_val) + else + try sema.analyzeErrUnionCode(block, operand_src, raw_operand_val); if (extra.data.bits.any_uses_err_capture) { sema.inst_map.putAssumeCapacity(err_capture_inst, spa.operand); @@ -11367,7 +11386,14 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp unreachable; } - const cond = try sema.analyzeIsNonErr(block, src, raw_operand_val); + const cond = if (extra.data.bits.payload_is_ref) blk: { + try sema.checkErrorType(block, src, sema.typeOf(raw_operand_val).elemType2(mod)); + const loaded = try sema.analyzeLoad(block, src, raw_operand_val, src); + break :blk try sema.analyzeIsNonErr(block, src, loaded); + } else blk: { + try sema.checkErrorType(block, src, sema.typeOf(raw_operand_val)); + break :blk try sema.analyzeIsNonErr(block, src, raw_operand_val); + }; var sub_block = child_block.makeSubBlock(); sub_block.runtime_loop = null; @@ -11379,7 +11405,11 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp const true_instructions = try sub_block.instructions.toOwnedSlice(gpa); defer gpa.free(true_instructions); - spa.operand = try sema.analyzeErrUnionCode(&sub_block, operand_src, raw_operand_val); + spa.operand = if (extra.data.bits.payload_is_ref) + try sema.analyzeErrUnionCodePtr(&sub_block, operand_src, raw_operand_val) + else + try sema.analyzeErrUnionCode(&sub_block, operand_src, raw_operand_val); + if (extra.data.bits.any_uses_err_capture) { sema.inst_map.putAssumeCapacity(err_capture_inst, spa.operand); } diff --git a/src/Zir.zig b/src/Zir.zig index 7eb9bd1209..94eb3cd4c4 100644 --- a/src/Zir.zig +++ b/src/Zir.zig @@ -2793,9 +2793,10 @@ pub const Inst = struct { /// If true, there is an else prong. This is mutually exclusive with `has_under`. has_else: bool, any_uses_err_capture: bool, + payload_is_ref: bool, scalar_cases_len: ScalarCasesLen, - pub const ScalarCasesLen = u29; + pub const ScalarCasesLen = u28; }; pub const MultiProng = struct { |
