diff options
| author | Andrew Kelley <andrew@ziglang.org> | 2021-01-18 19:29:23 -0700 |
|---|---|---|
| committer | Andrew Kelley <andrew@ziglang.org> | 2021-01-18 19:29:53 -0700 |
| commit | 7e56028bc7c00b884c92e2948728cbc47e5a8a09 (patch) | |
| tree | 660b8f5e027066e728fbd88db5e58385940aa2d7 | |
| parent | 6c7e66613d57aec2f2949c065ea6431ff6c31f88 (diff) | |
| parent | ecc246efa2c133aaab73032a18fed5b2c15e08ce (diff) | |
| download | zig-7e56028bc7c00b884c92e2948728cbc47e5a8a09.tar.gz zig-7e56028bc7c00b884c92e2948728cbc47e5a8a09.zip | |
Merge branch 'stage2: rework ZIR/TZIR for optionals and error unions'
closes #7730
closes #7662
| -rw-r--r-- | src/Module.zig | 2 | ||||
| -rw-r--r-- | src/astgen.zig | 122 | ||||
| -rw-r--r-- | src/codegen.zig | 37 | ||||
| -rw-r--r-- | src/ir.zig | 32 | ||||
| -rw-r--r-- | src/test.zig | 5 | ||||
| -rw-r--r-- | src/zir.zig | 130 | ||||
| -rw-r--r-- | src/zir_sema.zig | 135 | ||||
| -rw-r--r-- | test/stage2/test.zig | 47 |
8 files changed, 404 insertions, 106 deletions
diff --git a/src/Module.zig b/src/Module.zig index d75c1d2a0d..747e60f970 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -2453,7 +2453,7 @@ pub fn analyzeIsNull( return self.constBool(scope, src, bool_value); } const b = try self.requireRuntimeBlock(scope, src); - const inst_tag: Inst.Tag = if (invert_logic) .isnonnull else .isnull; + const inst_tag: Inst.Tag = if (invert_logic) .is_non_null else .is_null; return self.addUnOp(b, src, Type.initTag(.bool), inst_tag, operand); } diff --git a/src/astgen.zig b/src/astgen.zig index 6c697897aa..f24d078b4b 100644 --- a/src/astgen.zig +++ b/src/astgen.zig @@ -118,7 +118,6 @@ fn lvalExpr(mod: *Module, scope: *Scope, node: *ast.Node) InnerError!*zir.Inst { .LabeledBlock, .Break, .PtrType, - .GroupedExpression, .ArrayType, .ArrayTypeSentinel, .EnumLiteral, @@ -129,7 +128,6 @@ fn lvalExpr(mod: *Module, scope: *Scope, node: *ast.Node) InnerError!*zir.Inst { .ErrorUnion, .MergeErrorSets, .Range, - .OrElse, .Await, .BitNot, .Negation, @@ -168,7 +166,14 @@ fn lvalExpr(mod: *Module, scope: *Scope, node: *ast.Node) InnerError!*zir.Inst { }, // can be assigned to - .UnwrapOptional, .Deref, .Period, .ArrayAccess, .Identifier => {}, + .UnwrapOptional, + .Deref, + .Period, + .ArrayAccess, + .Identifier, + .GroupedExpression, + .OrElse, + => {}, } return expr(mod, scope, .ref, node); } @@ -913,8 +918,12 @@ fn unwrapOptional(mod: *Module, scope: *Scope, rl: ResultLoc, node: *ast.Node.Si const tree = scope.tree(); const src = tree.token_locs[node.rtoken].start; - const operand = try expr(mod, scope, .ref, node.lhs); - return rlWrapPtr(mod, scope, rl, try addZIRUnOp(mod, scope, src, .unwrap_optional_safe, operand)); + const operand = try expr(mod, scope, rl, node.lhs); + const op: zir.Inst.Tag = switch (rl) { + .ref => .optional_payload_safe_ptr, + else => .optional_payload_safe, + }; + return addZIRUnOp(mod, scope, src, op, operand); } fn containerField( @@ -1110,6 +1119,7 @@ fn errorSetDecl(mod: *Module, scope: *Scope, rl: ResultLoc, node: *ast.Node.Erro } // analyzing the error set results in a decl ref, so we might need to dereference it + // TODO remove all callsites to rlWrapPtr return rlWrapPtr(mod, scope, rl, try addZIRInst(mod, scope, src, zir.Inst.ErrorSet, .{ .fields = fields }, .{})); } @@ -1123,11 +1133,61 @@ fn errorType(mod: *Module, scope: *Scope, node: *ast.Node.OneToken) InnerError!* } fn catchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, node: *ast.Node.Catch) InnerError!*zir.Inst { - return orelseCatchExpr(mod, scope, rl, node.lhs, node.op_token, .iserr, .unwrap_err_unsafe, node.rhs, node.payload); + switch (rl) { + .ref => return orelseCatchExpr( + mod, + scope, + rl, + node.lhs, + node.op_token, + .is_err_ptr, + .err_union_payload_unsafe_ptr, + .err_union_code_ptr, + node.rhs, + node.payload, + ), + else => return orelseCatchExpr( + mod, + scope, + rl, + node.lhs, + node.op_token, + .is_err, + .err_union_payload_unsafe, + .err_union_code, + node.rhs, + node.payload, + ), + } } fn orelseExpr(mod: *Module, scope: *Scope, rl: ResultLoc, node: *ast.Node.SimpleInfixOp) InnerError!*zir.Inst { - return orelseCatchExpr(mod, scope, rl, node.lhs, node.op_token, .isnull, .unwrap_optional_unsafe, node.rhs, null); + switch (rl) { + .ref => return orelseCatchExpr( + mod, + scope, + rl, + node.lhs, + node.op_token, + .is_null_ptr, + .optional_payload_unsafe_ptr, + undefined, + node.rhs, + null, + ), + else => return orelseCatchExpr( + mod, + scope, + rl, + node.lhs, + node.op_token, + .is_null, + .optional_payload_unsafe, + undefined, + node.rhs, + null, + ), + } } fn orelseCatchExpr( @@ -1138,17 +1198,13 @@ fn orelseCatchExpr( op_token: ast.TokenIndex, cond_op: zir.Inst.Tag, unwrap_op: zir.Inst.Tag, + unwrap_code_op: zir.Inst.Tag, rhs: *ast.Node, payload_node: ?*ast.Node, ) InnerError!*zir.Inst { const tree = scope.tree(); const src = tree.token_locs[op_token].start; - const operand_ptr = try expr(mod, scope, .ref, lhs); - // TODO we could avoid an unnecessary copy if .iserr, .isnull took a pointer - const err_union = try addZIRUnOp(mod, scope, src, .deref, operand_ptr); - const cond = try addZIRUnOp(mod, scope, src, cond_op, err_union); - var block_scope: Scope.GenZIR = .{ .parent = scope, .decl = scope.ownerDecl().?, @@ -1157,14 +1213,8 @@ fn orelseCatchExpr( }; defer block_scope.instructions.deinit(mod.gpa); - const condbr = try addZIRInstSpecial(mod, &block_scope.base, src, zir.Inst.CondBr, .{ - .condition = cond, - .then_body = undefined, // populated below - .else_body = undefined, // populated below - }, .{}); - const block = try addZIRInstBlock(mod, scope, src, .block, .{ - .instructions = try block_scope.arena.dupe(*zir.Inst, block_scope.instructions.items), + .instructions = undefined, // populated below }); // Most result location types can be forwarded directly; however @@ -1175,9 +1225,18 @@ fn orelseCatchExpr( .discard, .none, .ty, .ptr, .ref => rl, .inferred_ptr, .bitcasted_ptr, .block_ptr => .{ .block_ptr = block }, }; + // This could be a pointer or value depending on the `rl` parameter. + const operand = try expr(mod, &block_scope.base, branch_rl, lhs); + const cond = try addZIRUnOp(mod, &block_scope.base, src, cond_op, operand); + + const condbr = try addZIRInstSpecial(mod, &block_scope.base, src, zir.Inst.CondBr, .{ + .condition = cond, + .then_body = undefined, // populated below + .else_body = undefined, // populated below + }, .{}); var then_scope: Scope.GenZIR = .{ - .parent = scope, + .parent = &block_scope.base, .decl = block_scope.decl, .arena = block_scope.arena, .instructions = .{}, @@ -1193,12 +1252,11 @@ fn orelseCatchExpr( if (mem.eql(u8, err_name, "_")) break :blk &then_scope.base; - const unwrapped_err_ptr = try addZIRUnOp(mod, &then_scope.base, src, .unwrap_err_code, operand_ptr); err_val_scope = .{ .parent = &then_scope.base, .gen_zir = &then_scope, .name = err_name, - .inst = try addZIRUnOp(mod, &then_scope.base, src, .deref, unwrapped_err_ptr), + .inst = try addZIRUnOp(mod, &then_scope.base, src, unwrap_code_op, operand), }; break :blk &err_val_scope.base; }; @@ -1209,22 +1267,26 @@ fn orelseCatchExpr( }, .{}); var else_scope: Scope.GenZIR = .{ - .parent = scope, + .parent = &block_scope.base, .decl = block_scope.decl, .arena = block_scope.arena, .instructions = .{}, }; defer else_scope.instructions.deinit(mod.gpa); - const unwrapped_payload = try addZIRUnOp(mod, &else_scope.base, src, unwrap_op, operand_ptr); + // This could be a pointer or value depending on `unwrap_op`. + const unwrapped_payload = try addZIRUnOp(mod, &else_scope.base, src, unwrap_op, operand); _ = try addZIRInst(mod, &else_scope.base, src, zir.Inst.Break, .{ .block = block, .operand = unwrapped_payload, }, .{}); + // All branches have been generated, add the instructions to the block. + block.positionals.body.instructions = try block_scope.arena.dupe(*zir.Inst, block_scope.instructions.items); + condbr.positionals.then_body = .{ .instructions = try then_scope.arena.dupe(*zir.Inst, then_scope.instructions.items) }; condbr.positionals.else_body = .{ .instructions = try else_scope.arena.dupe(*zir.Inst, else_scope.instructions.items) }; - return rlWrapPtr(mod, scope, rl, &block.base); + return &block.base; } /// Return whether the identifier names of two tokens are equal. Resolves @"" @@ -1253,6 +1315,7 @@ fn field(mod: *Module, scope: *Scope, rl: ResultLoc, node: *ast.Node.SimpleInfix const lhs = try expr(mod, scope, .ref, node.lhs); const field_name = try identifierStringInst(mod, scope, node.rhs.castTag(.Identifier).?); + // TODO remove all callsites to rlWrapPtr return rlWrapPtr(mod, scope, rl, try addZIRInst(mod, scope, src, zir.Inst.FieldPtr, .{ .object_ptr = lhs, .field_name = field_name }, .{})); } @@ -1263,6 +1326,7 @@ fn arrayAccess(mod: *Module, scope: *Scope, rl: ResultLoc, node: *ast.Node.Array const array_ptr = try expr(mod, scope, .ref, node.lhs); const index = try expr(mod, scope, .none, node.index_expr); + // TODO remove all callsites to rlWrapPtr return rlWrapPtr(mod, scope, rl, try addZIRInst(mod, scope, src, zir.Inst.ElemPtr, .{ .array_ptr = array_ptr, .index = index }, .{})); } @@ -1420,13 +1484,13 @@ const CondKind = union(enum) { const cond_ptr = try expr(mod, &block_scope.base, .ref, cond_node); self.* = .{ .optional = cond_ptr }; const result = try addZIRUnOp(mod, &block_scope.base, src, .deref, cond_ptr); - return try addZIRUnOp(mod, &block_scope.base, src, .isnonnull, result); + return try addZIRUnOp(mod, &block_scope.base, src, .is_non_null, result); }, .err_union => { const err_ptr = try expr(mod, &block_scope.base, .ref, cond_node); self.* = .{ .err_union = err_ptr }; const result = try addZIRUnOp(mod, &block_scope.base, src, .deref, err_ptr); - return try addZIRUnOp(mod, &block_scope.base, src, .iserr, result); + return try addZIRUnOp(mod, &block_scope.base, src, .is_err, result); }, } } @@ -1456,7 +1520,7 @@ const CondKind = union(enum) { fn elseSubScope(self: CondKind, mod: *Module, else_scope: *Scope.GenZIR, src: usize, payload_node: ?*ast.Node) !*Scope { if (self != .err_union) return &else_scope.base; - const payload_ptr = try addZIRUnOp(mod, &else_scope.base, src, .unwrap_err_unsafe, self.err_union.?); + const payload_ptr = try addZIRUnOp(mod, &else_scope.base, src, .err_union_payload_unsafe_ptr, self.err_union.?); const payload = payload_node.?.castTag(.Payload).?; const ident_node = payload.error_symbol.castTag(.Identifier).?; @@ -2264,6 +2328,7 @@ fn identifier(mod: *Module, scope: *Scope, rl: ResultLoc, ident: *ast.Node.OneTo .local_ptr => { const local_ptr = s.cast(Scope.LocalPtr).?; if (mem.eql(u8, local_ptr.name, ident_name)) { + // TODO remove all callsites to rlWrapPtr return rlWrapPtr(mod, scope, rl, local_ptr.ptr); } s = local_ptr.parent; @@ -3047,6 +3112,7 @@ fn rlWrapVoid(mod: *Module, scope: *Scope, rl: ResultLoc, node: *ast.Node, resul /// TODO go over all the callsites and see where we can introduce "by-value" ZIR instructions /// to save ZIR memory. For example, see DeclVal vs DeclRef. +/// Do not add additional callsites to this function. fn rlWrapPtr(mod: *Module, scope: *Scope, rl: ResultLoc, ptr: *zir.Inst) InnerError!*zir.Inst { if (rl == .ref) return ptr; diff --git a/src/codegen.zig b/src/codegen.zig index 14572c2012..1ca2bb2abe 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -860,9 +860,12 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type { .dbg_stmt => return self.genDbgStmt(inst.castTag(.dbg_stmt).?), .floatcast => return self.genFloatCast(inst.castTag(.floatcast).?), .intcast => return self.genIntCast(inst.castTag(.intcast).?), - .isnonnull => return self.genIsNonNull(inst.castTag(.isnonnull).?), - .isnull => return self.genIsNull(inst.castTag(.isnull).?), - .iserr => return self.genIsErr(inst.castTag(.iserr).?), + .is_non_null => return self.genIsNonNull(inst.castTag(.is_non_null).?), + .is_non_null_ptr => return self.genIsNonNullPtr(inst.castTag(.is_non_null_ptr).?), + .is_null => return self.genIsNull(inst.castTag(.is_null).?), + .is_null_ptr => return self.genIsNullPtr(inst.castTag(.is_null_ptr).?), + .is_err => return self.genIsErr(inst.castTag(.is_err).?), + .is_err_ptr => return self.genIsErrPtr(inst.castTag(.is_err_ptr).?), .load => return self.genLoad(inst.castTag(.load).?), .loop => return self.genLoop(inst.castTag(.loop).?), .not => return self.genNot(inst.castTag(.not).?), @@ -874,7 +877,8 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type { .sub => return self.genSub(inst.castTag(.sub).?), .switchbr => return self.genSwitch(inst.castTag(.switchbr).?), .unreach => return MCValue{ .unreach = {} }, - .unwrap_optional => return self.genUnwrapOptional(inst.castTag(.unwrap_optional).?), + .optional_payload => return self.genOptionalPayload(inst.castTag(.optional_payload).?), + .optional_payload_ptr => return self.genOptionalPayloadPtr(inst.castTag(.optional_payload_ptr).?), .wrap_optional => return self.genWrapOptional(inst.castTag(.wrap_optional).?), .varptr => return self.genVarPtr(inst.castTag(.varptr).?), .xor => return self.genXor(inst.castTag(.xor).?), @@ -1118,12 +1122,21 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type { } } - fn genUnwrapOptional(self: *Self, inst: *ir.Inst.UnOp) !MCValue { + fn genOptionalPayload(self: *Self, inst: *ir.Inst.UnOp) !MCValue { // No side effects, so if it's unreferenced, do nothing. if (inst.base.isUnused()) return MCValue.dead; switch (arch) { - else => return self.fail(inst.base.src, "TODO implement unwrap optional for {}", .{self.target.cpu.arch}), + else => return self.fail(inst.base.src, "TODO implement .optional_payload for {}", .{self.target.cpu.arch}), + } + } + + fn genOptionalPayloadPtr(self: *Self, inst: *ir.Inst.UnOp) !MCValue { + // No side effects, so if it's unreferenced, do nothing. + if (inst.base.isUnused()) + return MCValue.dead; + switch (arch) { + else => return self.fail(inst.base.src, "TODO implement .optional_payload_ptr for {}", .{self.target.cpu.arch}), } } @@ -2306,6 +2319,10 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type { } } + fn genIsNullPtr(self: *Self, inst: *ir.Inst.UnOp) !MCValue { + return self.fail(inst.base.src, "TODO load the operand and call genIsNull", .{}); + } + fn genIsNonNull(self: *Self, inst: *ir.Inst.UnOp) !MCValue { // Here you can specialize this instruction if it makes sense to, otherwise the default // will call genIsNull and invert the result. @@ -2314,12 +2331,20 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type { } } + fn genIsNonNullPtr(self: *Self, inst: *ir.Inst.UnOp) !MCValue { + return self.fail(inst.base.src, "TODO load the operand and call genIsNonNull", .{}); + } + fn genIsErr(self: *Self, inst: *ir.Inst.UnOp) !MCValue { switch (arch) { else => return self.fail(inst.base.src, "TODO implement iserr for {}", .{self.target.cpu.arch}), } } + fn genIsErrPtr(self: *Self, inst: *ir.Inst.UnOp) !MCValue { + return self.fail(inst.base.src, "TODO load the operand and call genIsErr", .{}); + } + fn genLoop(self: *Self, inst: *ir.Inst.Loop) !MCValue { // A loop is a setup to be able to jump back to the beginning. const start_index = self.code.items.len; diff --git a/src/ir.zig b/src/ir.zig index e43397faba..89698bdd84 100644 --- a/src/ir.zig +++ b/src/ir.zig @@ -73,9 +73,18 @@ pub const Inst = struct { condbr, constant, dbg_stmt, - isnonnull, - isnull, - iserr, + // ?T => bool + is_null, + // ?T => bool (inverted logic) + is_non_null, + // *?T => bool + is_null_ptr, + // *?T => bool (inverted logic) + is_non_null_ptr, + // E!T => bool + is_err, + // *E!T => bool + is_err_ptr, booland, boolor, /// Read a value from a pointer. @@ -93,7 +102,10 @@ pub const Inst = struct { not, floatcast, intcast, - unwrap_optional, + // ?T => T + optional_payload, + // *?T => *T + optional_payload_ptr, wrap_optional, xor, switchbr, @@ -111,14 +123,18 @@ pub const Inst = struct { .ret, .bitcast, .not, - .isnonnull, - .isnull, - .iserr, + .is_non_null, + .is_non_null_ptr, + .is_null, + .is_null_ptr, + .is_err, + .is_err_ptr, .ptrtoint, .floatcast, .intcast, .load, - .unwrap_optional, + .optional_payload, + .optional_payload_ptr, .wrap_optional, => UnOp, diff --git a/src/test.zig b/src/test.zig index 1c9fb57f01..150b6496c1 100644 --- a/src/test.zig +++ b/src/test.zig @@ -696,7 +696,10 @@ pub const TestContext = struct { var all_errors = try comp.getAllErrorsAlloc(); defer all_errors.deinit(allocator); if (all_errors.list.len != 0) { - std.debug.print("\nErrors occurred updating the compilation:\n{s}\n", .{hr}); + std.debug.print( + "\nCase '{s}': unexpected errors at update_index={d}:\n{s}\n", + .{ case.name, update_index, hr }, + ); for (all_errors.list) |err_msg| { switch (err_msg) { .src => |src| { diff --git a/src/zir.zig b/src/zir.zig index 0e7b3a3520..be45538288 100644 --- a/src/zir.zig +++ b/src/zir.zig @@ -174,11 +174,17 @@ pub const Inst = struct { /// Make an integer type out of signedness and bit count. inttype, /// Return a boolean false if an optional is null. `x != null` - isnonnull, + is_non_null, /// Return a boolean true if an optional is null. `x == null` - isnull, + is_null, + /// Return a boolean false if an optional is null. `x.* != null` + is_non_null_ptr, + /// Return a boolean true if an optional is null. `x.* == null` + is_null_ptr, /// Return a boolean true if value is an error - iserr, + is_err, + /// Return a boolean true if dereferenced pointer is an error + is_err_ptr, /// A labeled block of code that loops forever. At the end of the body it is implied /// to repeat; no explicit "repeat" instruction terminates loop bodies. loop, @@ -278,16 +284,42 @@ pub const Inst = struct { optional_type, /// Create a union type. union_type, - /// Unwraps an optional value 'lhs.?' - unwrap_optional_safe, - /// Same as previous, but without safety checks. Used for orelse, if and while - unwrap_optional_unsafe, - /// Gets the payload of an error union - unwrap_err_safe, - /// Same as previous, but without safety checks. Used for orelse, if and while - unwrap_err_unsafe, - /// Gets the error code value of an error union - unwrap_err_code, + /// ?T => T with safety. + /// Given an optional value, returns the payload value, with a safety check that + /// the value is non-null. Used for `orelse`, `if` and `while`. + optional_payload_safe, + /// ?T => T without safety. + /// Given an optional value, returns the payload value. No safety checks. + optional_payload_unsafe, + /// *?T => *T with safety. + /// Given a pointer to an optional value, returns a pointer to the payload value, + /// with a safety check that the value is non-null. Used for `orelse`, `if` and `while`. + optional_payload_safe_ptr, + /// *?T => *T without safety. + /// Given a pointer to an optional value, returns a pointer to the payload value. + /// No safety checks. + optional_payload_unsafe_ptr, + /// E!T => T with safety. + /// Given an error union value, returns the payload value, with a safety check + /// that the value is not an error. Used for catch, if, and while. + err_union_payload_safe, + /// E!T => T without safety. + /// Given an error union value, returns the payload value. No safety checks. + err_union_payload_unsafe, + /// *E!T => *T with safety. + /// Given a pointer to an error union value, returns a pointer to the payload value, + /// with a safety check that the value is not an error. Used for catch, if, and while. + err_union_payload_safe_ptr, + /// *E!T => *T without safety. + /// Given a pointer to a error union value, returns a pointer to the payload value. + /// No safety checks. + err_union_payload_unsafe_ptr, + /// E!T => E without safety. + /// Given an error union value, returns the error code. No safety checks. + err_union_code, + /// *E!T => E without safety. + /// Given a pointer to an error union value, returns the error code. No safety checks. + err_union_code_ptr, /// Takes a *E!T and raises a compiler error if T != void ensure_err_payload_void, /// Create a enum literal, @@ -320,9 +352,12 @@ pub const Inst = struct { .compileerror, .deref, .@"return", - .isnull, - .isnonnull, - .iserr, + .is_null, + .is_non_null, + .is_null_ptr, + .is_non_null_ptr, + .is_err, + .is_err_ptr, .ptrtoint, .ensure_result_used, .ensure_result_non_error, @@ -341,11 +376,16 @@ pub const Inst = struct { .mut_slice_type, .const_slice_type, .optional_type, - .unwrap_optional_safe, - .unwrap_optional_unsafe, - .unwrap_err_safe, - .unwrap_err_unsafe, - .unwrap_err_code, + .optional_payload_safe, + .optional_payload_unsafe, + .optional_payload_safe_ptr, + .optional_payload_unsafe_ptr, + .err_union_payload_safe, + .err_union_payload_unsafe, + .err_union_payload_safe_ptr, + .err_union_payload_unsafe_ptr, + .err_union_code, + .err_union_code_ptr, .ensure_err_payload_void, .anyframe_type, .bitnot, @@ -495,9 +535,12 @@ pub const Inst = struct { .int, .intcast, .inttype, - .isnonnull, - .isnull, - .iserr, + .is_non_null, + .is_null, + .is_non_null_ptr, + .is_null_ptr, + .is_err, + .is_err_ptr, .mod_rem, .mul, .mulwrap, @@ -525,11 +568,16 @@ pub const Inst = struct { .typeof, .xor, .optional_type, - .unwrap_optional_safe, - .unwrap_optional_unsafe, - .unwrap_err_safe, - .unwrap_err_unsafe, - .unwrap_err_code, + .optional_payload_safe, + .optional_payload_unsafe, + .optional_payload_safe_ptr, + .optional_payload_unsafe_ptr, + .err_union_payload_safe, + .err_union_payload_unsafe, + .err_union_payload_safe_ptr, + .err_union_payload_unsafe_ptr, + .err_union_code, + .err_union_code_ptr, .ptr_type, .ensure_err_payload_void, .enum_literal, @@ -1540,14 +1588,18 @@ const DumpTzir = struct { .ret, .bitcast, .not, - .isnonnull, - .isnull, - .iserr, + .is_non_null, + .is_non_null_ptr, + .is_null, + .is_null_ptr, + .is_err, + .is_err_ptr, .ptrtoint, .floatcast, .intcast, .load, - .unwrap_optional, + .optional_payload, + .optional_payload_ptr, .wrap_optional, => { const un_op = inst.cast(ir.Inst.UnOp).?; @@ -1637,14 +1689,18 @@ const DumpTzir = struct { .ret, .bitcast, .not, - .isnonnull, - .isnull, - .iserr, + .is_non_null, + .is_null, + .is_non_null_ptr, + .is_null_ptr, + .is_err, + .is_err_ptr, .ptrtoint, .floatcast, .intcast, .load, - .unwrap_optional, + .optional_payload, + .optional_payload_ptr, .wrap_optional, => { const un_op = inst.cast(ir.Inst.UnOp).?; diff --git a/src/zir_sema.zig b/src/zir_sema.zig index 36eb5f4239..82772cac16 100644 --- a/src/zir_sema.zig +++ b/src/zir_sema.zig @@ -127,18 +127,26 @@ pub fn analyzeInst(mod: *Module, scope: *Scope, old_inst: *zir.Inst) InnerError! .cmp_gt => return analyzeInstCmp(mod, scope, old_inst.castTag(.cmp_gt).?, .gt), .cmp_neq => return analyzeInstCmp(mod, scope, old_inst.castTag(.cmp_neq).?, .neq), .condbr => return analyzeInstCondBr(mod, scope, old_inst.castTag(.condbr).?), - .isnull => return analyzeInstIsNonNull(mod, scope, old_inst.castTag(.isnull).?, true), - .isnonnull => return analyzeInstIsNonNull(mod, scope, old_inst.castTag(.isnonnull).?, false), - .iserr => return analyzeInstIsErr(mod, scope, old_inst.castTag(.iserr).?), + .is_null => return isNull(mod, scope, old_inst.castTag(.is_null).?, false), + .is_non_null => return isNull(mod, scope, old_inst.castTag(.is_non_null).?, true), + .is_null_ptr => return isNullPtr(mod, scope, old_inst.castTag(.is_null_ptr).?, false), + .is_non_null_ptr => return isNullPtr(mod, scope, old_inst.castTag(.is_non_null_ptr).?, true), + .is_err => return isErr(mod, scope, old_inst.castTag(.is_err).?), + .is_err_ptr => return isErrPtr(mod, scope, old_inst.castTag(.is_err_ptr).?), .boolnot => return analyzeInstBoolNot(mod, scope, old_inst.castTag(.boolnot).?), .typeof => return analyzeInstTypeOf(mod, scope, old_inst.castTag(.typeof).?), .typeof_peer => return analyzeInstTypeOfPeer(mod, scope, old_inst.castTag(.typeof_peer).?), .optional_type => return analyzeInstOptionalType(mod, scope, old_inst.castTag(.optional_type).?), - .unwrap_optional_safe => return analyzeInstUnwrapOptional(mod, scope, old_inst.castTag(.unwrap_optional_safe).?, true), - .unwrap_optional_unsafe => return analyzeInstUnwrapOptional(mod, scope, old_inst.castTag(.unwrap_optional_unsafe).?, false), - .unwrap_err_safe => return analyzeInstUnwrapErr(mod, scope, old_inst.castTag(.unwrap_err_safe).?, true), - .unwrap_err_unsafe => return analyzeInstUnwrapErr(mod, scope, old_inst.castTag(.unwrap_err_unsafe).?, false), - .unwrap_err_code => return analyzeInstUnwrapErrCode(mod, scope, old_inst.castTag(.unwrap_err_code).?), + .optional_payload_safe => return optionalPayload(mod, scope, old_inst.castTag(.optional_payload_safe).?, true), + .optional_payload_unsafe => return optionalPayload(mod, scope, old_inst.castTag(.optional_payload_unsafe).?, false), + .optional_payload_safe_ptr => return optionalPayloadPtr(mod, scope, old_inst.castTag(.optional_payload_safe_ptr).?, true), + .optional_payload_unsafe_ptr => return optionalPayloadPtr(mod, scope, old_inst.castTag(.optional_payload_unsafe_ptr).?, false), + .err_union_payload_safe => return errorUnionPayload(mod, scope, old_inst.castTag(.err_union_payload_safe).?, true), + .err_union_payload_unsafe => return errorUnionPayload(mod, scope, old_inst.castTag(.err_union_payload_unsafe).?, false), + .err_union_payload_safe_ptr => return errorUnionPayloadPtr(mod, scope, old_inst.castTag(.err_union_payload_safe_ptr).?, true), + .err_union_payload_unsafe_ptr => return errorUnionPayloadPtr(mod, scope, old_inst.castTag(.err_union_payload_unsafe_ptr).?, false), + .err_union_code => return errorUnionCode(mod, scope, old_inst.castTag(.err_union_code).?), + .err_union_code_ptr => return errorUnionCodePtr(mod, scope, old_inst.castTag(.err_union_code_ptr).?), .ensure_err_payload_void => return analyzeInstEnsureErrPayloadVoid(mod, scope, old_inst.castTag(.ensure_err_payload_void).?), .array_type => return analyzeInstArrayType(mod, scope, old_inst.castTag(.array_type).?), .array_type_sentinel => return analyzeInstArrayTypeSentinel(mod, scope, old_inst.castTag(.array_type_sentinel).?), @@ -1104,48 +1112,109 @@ fn analyzeInstEnumLiteral(mod: *Module, scope: *Scope, inst: *zir.Inst.EnumLiter }); } -fn analyzeInstUnwrapOptional(mod: *Module, scope: *Scope, unwrap: *zir.Inst.UnOp, safety_check: bool) InnerError!*Inst { +/// Pointer in, pointer out. +fn optionalPayloadPtr( + mod: *Module, + scope: *Scope, + unwrap: *zir.Inst.UnOp, + safety_check: bool, +) InnerError!*Inst { const tracy = trace(@src()); defer tracy.end(); - const operand = try resolveInst(mod, scope, unwrap.positionals.operand); - assert(operand.ty.zigTypeTag() == .Pointer); - const elem_type = operand.ty.elemType(); - if (elem_type.zigTypeTag() != .Optional) { - return mod.fail(scope, unwrap.base.src, "expected optional type, found {}", .{elem_type}); + const optional_ptr = try resolveInst(mod, scope, unwrap.positionals.operand); + assert(optional_ptr.ty.zigTypeTag() == .Pointer); + + const opt_type = optional_ptr.ty.elemType(); + if (opt_type.zigTypeTag() != .Optional) { + return mod.fail(scope, unwrap.base.src, "expected optional type, found {}", .{opt_type}); } - const child_type = try elem_type.optionalChildAlloc(scope.arena()); - const child_pointer = try mod.simplePtrType(scope, unwrap.base.src, child_type, operand.ty.isConstPtr(), .One); + const child_type = try opt_type.optionalChildAlloc(scope.arena()); + const child_pointer = try mod.simplePtrType(scope, unwrap.base.src, child_type, !optional_ptr.ty.isConstPtr(), .One); - if (operand.value()) |val| { + if (optional_ptr.value()) |pointer_val| { + const val = try pointer_val.pointerDeref(scope.arena()); if (val.isNull()) { return mod.fail(scope, unwrap.base.src, "unable to unwrap null", .{}); } + // The same Value represents the pointer to the optional and the payload. return mod.constInst(scope, unwrap.base.src, .{ .ty = child_pointer, + .val = pointer_val, + }); + } + + const b = try mod.requireRuntimeBlock(scope, unwrap.base.src); + if (safety_check and mod.wantSafety(scope)) { + const is_non_null = try mod.addUnOp(b, unwrap.base.src, Type.initTag(.bool), .is_non_null_ptr, optional_ptr); + try mod.addSafetyCheck(b, is_non_null, .unwrap_null); + } + return mod.addUnOp(b, unwrap.base.src, child_pointer, .optional_payload_ptr, optional_ptr); +} + +/// Value in, value out. +fn optionalPayload( + mod: *Module, + scope: *Scope, + unwrap: *zir.Inst.UnOp, + safety_check: bool, +) InnerError!*Inst { + const tracy = trace(@src()); + defer tracy.end(); + + const operand = try resolveInst(mod, scope, unwrap.positionals.operand); + const opt_type = operand.ty; + if (opt_type.zigTypeTag() != .Optional) { + return mod.fail(scope, unwrap.base.src, "expected optional type, found {}", .{opt_type}); + } + + const child_type = try opt_type.optionalChildAlloc(scope.arena()); + + if (operand.value()) |val| { + if (val.isNull()) { + return mod.fail(scope, unwrap.base.src, "unable to unwrap null", .{}); + } + return mod.constInst(scope, unwrap.base.src, .{ + .ty = child_type, .val = val, }); } const b = try mod.requireRuntimeBlock(scope, unwrap.base.src); if (safety_check and mod.wantSafety(scope)) { - const is_non_null = try mod.addUnOp(b, unwrap.base.src, Type.initTag(.bool), .isnonnull, operand); + const is_non_null = try mod.addUnOp(b, unwrap.base.src, Type.initTag(.bool), .is_non_null, operand); try mod.addSafetyCheck(b, is_non_null, .unwrap_null); } - return mod.addUnOp(b, unwrap.base.src, child_pointer, .unwrap_optional, operand); + return mod.addUnOp(b, unwrap.base.src, child_type, .optional_payload, operand); } -fn analyzeInstUnwrapErr(mod: *Module, scope: *Scope, unwrap: *zir.Inst.UnOp, safety_check: bool) InnerError!*Inst { +/// Value in, value out +fn errorUnionPayload(mod: *Module, scope: *Scope, unwrap: *zir.Inst.UnOp, safety_check: bool) InnerError!*Inst { const tracy = trace(@src()); defer tracy.end(); - return mod.fail(scope, unwrap.base.src, "TODO implement analyzeInstUnwrapErr", .{}); + return mod.fail(scope, unwrap.base.src, "TODO implement zir_sema.errorUnionPayload", .{}); } -fn analyzeInstUnwrapErrCode(mod: *Module, scope: *Scope, unwrap: *zir.Inst.UnOp) InnerError!*Inst { +/// Pointer in, pointer out +fn errorUnionPayloadPtr(mod: *Module, scope: *Scope, unwrap: *zir.Inst.UnOp, safety_check: bool) InnerError!*Inst { const tracy = trace(@src()); defer tracy.end(); - return mod.fail(scope, unwrap.base.src, "TODO implement analyzeInstUnwrapErrCode", .{}); + return mod.fail(scope, unwrap.base.src, "TODO implement zir_sema.errorUnionPayloadPtr", .{}); +} + +/// Value in, value out +fn errorUnionCode(mod: *Module, scope: *Scope, unwrap: *zir.Inst.UnOp) InnerError!*Inst { + const tracy = trace(@src()); + defer tracy.end(); + return mod.fail(scope, unwrap.base.src, "TODO implement zir_sema.errorUnionCode", .{}); +} + +/// Pointer in, value out +fn errorUnionCodePtr(mod: *Module, scope: *Scope, unwrap: *zir.Inst.UnOp) InnerError!*Inst { + const tracy = trace(@src()); + defer tracy.end(); + return mod.fail(scope, unwrap.base.src, "TODO implement zir_sema.errorUnionCodePtr", .{}); } fn analyzeInstEnsureErrPayloadVoid(mod: *Module, scope: *Scope, unwrap: *zir.Inst.UnOp) InnerError!*Inst { @@ -2074,20 +2143,36 @@ fn analyzeInstBoolOp(mod: *Module, scope: *Scope, inst: *zir.Inst.BinOp) InnerEr return mod.addBinOp(b, inst.base.src, bool_type, if (is_bool_or) .boolor else .booland, lhs, rhs); } -fn analyzeInstIsNonNull(mod: *Module, scope: *Scope, inst: *zir.Inst.UnOp, invert_logic: bool) InnerError!*Inst { +fn isNull(mod: *Module, scope: *Scope, inst: *zir.Inst.UnOp, invert_logic: bool) InnerError!*Inst { const tracy = trace(@src()); defer tracy.end(); const operand = try resolveInst(mod, scope, inst.positionals.operand); return mod.analyzeIsNull(scope, inst.base.src, operand, invert_logic); } -fn analyzeInstIsErr(mod: *Module, scope: *Scope, inst: *zir.Inst.UnOp) InnerError!*Inst { +fn isNullPtr(mod: *Module, scope: *Scope, inst: *zir.Inst.UnOp, invert_logic: bool) InnerError!*Inst { + const tracy = trace(@src()); + defer tracy.end(); + const ptr = try resolveInst(mod, scope, inst.positionals.operand); + const loaded = try mod.analyzeDeref(scope, inst.base.src, ptr, ptr.src); + return mod.analyzeIsNull(scope, inst.base.src, loaded, invert_logic); +} + +fn isErr(mod: *Module, scope: *Scope, inst: *zir.Inst.UnOp) InnerError!*Inst { const tracy = trace(@src()); defer tracy.end(); const operand = try resolveInst(mod, scope, inst.positionals.operand); return mod.analyzeIsErr(scope, inst.base.src, operand); } +fn isErrPtr(mod: *Module, scope: *Scope, inst: *zir.Inst.UnOp) InnerError!*Inst { + const tracy = trace(@src()); + defer tracy.end(); + const ptr = try resolveInst(mod, scope, inst.positionals.operand); + const loaded = try mod.analyzeDeref(scope, inst.base.src, ptr, ptr.src); + return mod.analyzeIsErr(scope, inst.base.src, loaded); +} + fn analyzeInstCondBr(mod: *Module, scope: *Scope, inst: *zir.Inst.CondBr) InnerError!*Inst { const tracy = trace(@src()); defer tracy.end(); diff --git a/test/stage2/test.zig b/test/stage2/test.zig index f2c0989b46..845b9b627d 100644 --- a/test/stage2/test.zig +++ b/test/stage2/test.zig @@ -1462,4 +1462,51 @@ pub fn addCases(ctx: *TestContext) !void { "", ); } + { + var case = ctx.exe("orelse at comptime", linux_x64); + case.addCompareOutput( + \\export fn _start() noreturn { + \\ const i: ?u64 = 0; + \\ const orelsed = i orelse 5; + \\ assert(orelsed == 0); + \\ exit(); + \\} + \\fn assert(b: bool) void { + \\ if (!b) unreachable; + \\} + \\fn exit() noreturn { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (231), + \\ [arg1] "{rdi}" (0) + \\ : "rcx", "r11", "memory" + \\ ); + \\ unreachable; + \\} + , + "", + ); + case.addCompareOutput( + \\export fn _start() noreturn { + \\ const i: ?u64 = null; + \\ const orelsed = i orelse 5; + \\ assert(orelsed == 5); + \\ exit(); + \\} + \\fn assert(b: bool) void { + \\ if (!b) unreachable; + \\} + \\fn exit() noreturn { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (231), + \\ [arg1] "{rdi}" (0) + \\ : "rcx", "r11", "memory" + \\ ); + \\ unreachable; + \\} + , + "", + ); + } } |
