From 3264abe3d8f658e1b7275d2be80e43eddfc098dc Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 22 May 2022 21:51:44 -0700 Subject: stage2: fixes for error union semantics * Sema: avoid unnecessary safety checks when an error set is empty. * Sema: make zirErrorToInt handle comptime errors that are represented as integers. * Sema: make empty error sets properly integrate with typeHasOnePossibleValue. * Type: correct the ABI alignment and size of error unions which have both zero-bit error set and zero-bit payload. The previous code did not account for the fact that we still need to store a bit for whether there is an error. * LLVM: lower error unions possibly with the payload first or with the error code first, depending on alignment. Previously it always put the error code first and used a padding array. * LLVM: lower functions which have an empty error set as the return type the same as anyerror, so that they can be used where fn()anyerror function pointers are expected. In such functions, Zig will lower ret to returning zero instead of void. As a result, one more behavior test is passing. --- src/Sema.zig | 63 ++++++++++++---- src/codegen/llvm.zig | 200 +++++++++++++++++++++++++++++++++++---------------- src/type.zig | 197 ++++++++++++++++++++++++++++++++++++++------------ 3 files changed, 340 insertions(+), 120 deletions(-) (limited to 'src') diff --git a/src/Sema.zig b/src/Sema.zig index d3fca6d2b2..b718912a38 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -5899,12 +5899,22 @@ fn zirErrorToInt(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError! if (val.isUndef()) { return sema.addConstUndef(result_ty); } - const payload = try sema.arena.create(Value.Payload.U64); - payload.* = .{ - .base = .{ .tag = .int_u64 }, - .data = (try sema.mod.getErrorValue(val.castTag(.@"error").?.data.name)).value, - }; - return sema.addConstant(result_ty, Value.initPayload(&payload.base)); + switch (val.tag()) { + .@"error" => { + const payload = try sema.arena.create(Value.Payload.U64); + payload.* = .{ + .base = .{ .tag = .int_u64 }, + .data = (try sema.mod.getErrorValue(val.castTag(.@"error").?.data.name)).value, + }; + return sema.addConstant(result_ty, Value.initPayload(&payload.base)); + }, + + // This is not a valid combination with the type `anyerror`. + .the_only_possible_value => unreachable, + + // Assume it's already encoded as an integer. + else => return sema.addConstant(result_ty, val), + } } try sema.requireRuntimeBlock(block, src); @@ -6261,19 +6271,24 @@ fn zirErrUnionPayload( }); } + const result_ty = operand_ty.errorUnionPayload(); if (try sema.resolveDefinedValue(block, src, operand)) |val| { if (val.getError()) |name| { return sema.fail(block, src, "caught unexpected error '{s}'", .{name}); } const data = val.castTag(.eu_payload).?.data; - const result_ty = operand_ty.errorUnionPayload(); return sema.addConstant(result_ty, data); } + try sema.requireRuntimeBlock(block, src); - if (safety_check and block.wantSafety()) { + + // If the error set has no fields then no safety check is needed. + if (safety_check and block.wantSafety() and + operand_ty.errorUnionSet().errorSetCardinality() != .zero) + { try sema.panicUnwrapError(block, src, operand, .unwrap_errunion_err, .is_non_err); } - const result_ty = operand_ty.errorUnionPayload(); + return block.addTyOp(.unwrap_errunion_payload, result_ty, operand); } @@ -6311,7 +6326,8 @@ fn analyzeErrUnionPayloadPtr( }); } - const payload_ty = operand_ty.elemType().errorUnionPayload(); + const err_union_ty = operand_ty.elemType(); + const payload_ty = err_union_ty.errorUnionPayload(); const operand_pointer_ty = try Type.ptr(sema.arena, sema.mod, .{ .pointee_type = payload_ty, .mutable = !operand_ty.isConstPtr(), @@ -6351,9 +6367,14 @@ fn analyzeErrUnionPayloadPtr( } try sema.requireRuntimeBlock(block, src); - if (safety_check and block.wantSafety()) { + + // If the error set has no fields then no safety check is needed. + if (safety_check and block.wantSafety() and + err_union_ty.errorUnionSet().errorSetCardinality() != .zero) + { try sema.panicUnwrapError(block, src, operand, .unwrap_errunion_err_ptr, .is_non_err_ptr); } + const air_tag: Air.Inst.Tag = if (initializing) .errunion_payload_ptr_set else @@ -23301,10 +23322,7 @@ pub fn typeHasOnePossibleValue( .enum_literal, .anyerror_void_error_union, .error_union, - .error_set, - .error_set_single, .error_set_inferred, - .error_set_merged, .@"opaque", .var_args_param, .manyptr_u8, @@ -23333,6 +23351,23 @@ pub fn typeHasOnePossibleValue( .bound_fn, => return null, + .error_set_single => { + const name = ty.castTag(.error_set_single).?.data; + return try Value.Tag.@"error".create(sema.arena, .{ .name = name }); + }, + .error_set => { + const err_set_obj = ty.castTag(.error_set).?.data; + const names = err_set_obj.names.keys(); + if (names.len > 1) return null; + return try Value.Tag.@"error".create(sema.arena, .{ .name = names[0] }); + }, + .error_set_merged => { + const name_map = ty.castTag(.error_set_merged).?.data; + const names = name_map.keys(); + if (names.len > 1) return null; + return try Value.Tag.@"error".create(sema.arena, .{ .name = names[0] }); + }, + .@"struct" => { const resolved_ty = try sema.resolveTypeFields(block, src, ty); const s = resolved_ty.castTag(.@"struct").?.data; diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index ef33f39f55..95d12dff3a 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -2451,20 +2451,22 @@ pub const DeclGen = struct { .ErrorUnion => { const error_type = t.errorUnionSet(); const payload_type = t.errorUnionPayload(); - const llvm_error_type = try dg.llvmType(error_type); + if (error_type.errorSetCardinality() == .zero) { + return dg.llvmType(payload_type); + } if (!payload_type.hasRuntimeBitsIgnoreComptime()) { - return llvm_error_type; + return try dg.llvmType(Type.anyerror); } + const llvm_error_type = try dg.llvmType(error_type); const llvm_payload_type = try dg.llvmType(payload_type); const payload_align = payload_type.abiAlignment(target); - const error_size = error_type.abiSize(target); - if (payload_align > error_size) { - const pad_type = dg.context.intType(8).arrayType(@intCast(u32, payload_align - error_size)); - const fields: [3]*const llvm.Type = .{ llvm_error_type, pad_type, llvm_payload_type }; + const error_align = Type.anyerror.abiAlignment(target); + if (error_align > payload_align) { + const fields: [2]*const llvm.Type = .{ llvm_error_type, llvm_payload_type }; return dg.context.structType(&fields, fields.len, .False); } else { - const fields: [2]*const llvm.Type = .{ llvm_error_type, llvm_payload_type }; + const fields: [2]*const llvm.Type = .{ llvm_payload_type, llvm_error_type }; return dg.context.structType(&fields, fields.len, .False); } }, @@ -3103,6 +3105,10 @@ pub const DeclGen = struct { .ErrorUnion => { const error_type = tv.ty.errorUnionSet(); const payload_type = tv.ty.errorUnionPayload(); + if (error_type.errorSetCardinality() == .zero) { + const payload_val = tv.val.castTag(.eu_payload).?.data; + return dg.genTypedValue(.{ .ty = payload_type, .val = payload_val }); + } const is_pl = tv.val.errorUnionIsPayload(); if (!payload_type.hasRuntimeBitsIgnoreComptime()) { @@ -3110,28 +3116,24 @@ pub const DeclGen = struct { const err_val = if (!is_pl) tv.val else Value.initTag(.zero); return dg.genTypedValue(.{ .ty = error_type, .val = err_val }); } - var len: u8 = 2; - var fields: [3]*const llvm.Value = .{ - try dg.genTypedValue(.{ - .ty = error_type, - .val = if (is_pl) Value.initTag(.zero) else tv.val, - }), - try dg.genTypedValue(.{ - .ty = payload_type, - .val = if (tv.val.castTag(.eu_payload)) |pl| pl.data else Value.initTag(.undef), - }), - undefined, - }; const payload_align = payload_type.abiAlignment(target); - const error_size = error_type.abiSize(target); - if (payload_align > error_size) { - fields[2] = fields[1]; - const pad_type = dg.context.intType(8).arrayType(@intCast(u32, payload_align - error_size)); - fields[1] = pad_type.getUndef(); - len += 1; + const error_align = Type.anyerror.abiAlignment(target); + const llvm_error_value = try dg.genTypedValue(.{ + .ty = error_type, + .val = if (is_pl) Value.initTag(.zero) else tv.val, + }); + const llvm_payload_value = try dg.genTypedValue(.{ + .ty = payload_type, + .val = if (tv.val.castTag(.eu_payload)) |pl| pl.data else Value.initTag(.undef), + }); + if (error_align > payload_align) { + const fields: [2]*const llvm.Value = .{ llvm_error_value, llvm_payload_value }; + return dg.context.constStruct(&fields, fields.len, .False); + } else { + const fields: [2]*const llvm.Value = .{ llvm_payload_value, llvm_error_value }; + return dg.context.constStruct(&fields, fields.len, .False); } - return dg.context.constStruct(&fields, len, .False); }, .Struct => { const llvm_struct_ty = try dg.llvmType(tv.ty); @@ -4338,11 +4340,19 @@ pub const FuncGen = struct { _ = self.builder.buildRetVoid(); return null; } + const fn_info = self.dg.decl.ty.fnInfo(); if (!ret_ty.hasRuntimeBitsIgnoreComptime()) { - _ = self.builder.buildRetVoid(); + if (fn_info.return_type.isError()) { + // Functions with an empty error set are emitted with an error code + // return type and return zero so they can be function pointers coerced + // to functions that return anyerror. + const err_int = try self.dg.llvmType(Type.anyerror); + _ = self.builder.buildRet(err_int.constInt(0, .False)); + } else { + _ = self.builder.buildRetVoid(); + } return null; } - const fn_info = self.dg.decl.ty.fnInfo(); const abi_ret_ty = try lowerFnRetTy(self.dg, fn_info); const operand = try self.resolveInst(un_op); const llvm_ret_ty = operand.typeOf(); @@ -4369,13 +4379,25 @@ pub const FuncGen = struct { const un_op = self.air.instructions.items(.data)[inst].un_op; const ptr_ty = self.air.typeOf(un_op); const ret_ty = ptr_ty.childType(); - if (!ret_ty.hasRuntimeBitsIgnoreComptime() or self.ret_ptr != null) { + const fn_info = self.dg.decl.ty.fnInfo(); + if (!ret_ty.hasRuntimeBitsIgnoreComptime()) { + if (fn_info.return_type.isError()) { + // Functions with an empty error set are emitted with an error code + // return type and return zero so they can be function pointers coerced + // to functions that return anyerror. + const err_int = try self.dg.llvmType(Type.anyerror); + _ = self.builder.buildRet(err_int.constInt(0, .False)); + } else { + _ = self.builder.buildRetVoid(); + } + return null; + } + if (self.ret_ptr != null) { _ = self.builder.buildRetVoid(); return null; } const ptr = try self.resolveInst(un_op); const target = self.dg.module.getTarget(); - const fn_info = self.dg.decl.ty.fnInfo(); const abi_ret_ty = try lowerFnRetTy(self.dg, fn_info); const llvm_ret_ty = try self.dg.llvmType(ret_ty); const casted_ptr = if (abi_ret_ty == llvm_ret_ty) ptr else p: { @@ -5433,18 +5455,30 @@ pub const FuncGen = struct { const err_set_ty = try self.dg.llvmType(Type.initTag(.anyerror)); const zero = err_set_ty.constNull(); + if (err_union_ty.errorUnionSet().errorSetCardinality() == .zero) { + const llvm_i1 = self.context.intType(1); + switch (op) { + .EQ => return llvm_i1.constInt(1, .False), // 0 == 0 + .NE => return llvm_i1.constInt(0, .False), // 0 != 0 + else => unreachable, + } + } + if (!payload_ty.hasRuntimeBitsIgnoreComptime()) { const loaded = if (operand_is_ptr) self.builder.buildLoad(operand, "") else operand; return self.builder.buildICmp(op, loaded, zero, ""); } + const target = self.dg.module.getTarget(); + const err_field_index = errUnionErrorOffset(payload_ty, target); + if (operand_is_ptr or isByRef(err_union_ty)) { - const err_field_ptr = self.builder.buildStructGEP(operand, 0, ""); + const err_field_ptr = self.builder.buildStructGEP(operand, err_field_index, ""); const loaded = self.builder.buildLoad(err_field_ptr, ""); return self.builder.buildICmp(op, loaded, zero, ""); } - const loaded = self.builder.buildExtractValue(operand, 0, ""); + const loaded = self.builder.buildExtractValue(operand, err_field_index, ""); return self.builder.buildICmp(op, loaded, zero, ""); } @@ -5544,11 +5578,17 @@ pub const FuncGen = struct { const ty_op = self.air.instructions.items(.data)[inst].ty_op; const operand = try self.resolveInst(ty_op.operand); - const result_ty = self.air.getRefType(ty_op.ty); + const operand_ty = self.air.typeOf(ty_op.operand); + const error_union_ty = if (operand_is_ptr) operand_ty.childType() else operand_ty; + // If the error set has no fields, then the payload and the error + // union are the same value. + if (error_union_ty.errorUnionSet().errorSetCardinality() == .zero) { + return operand; + } + const result_ty = self.air.typeOfIndex(inst); const payload_ty = if (operand_is_ptr) result_ty.childType() else result_ty; - const target = self.dg.module.getTarget(); - const offset: u8 = if (payload_ty.abiAlignment(target) > Type.anyerror.abiSize(target)) 2 else 1; + const offset = errUnionPayloadOffset(payload_ty, target); if (!payload_ty.hasRuntimeBitsIgnoreComptime()) { if (!operand_is_ptr) return null; @@ -5574,54 +5614,70 @@ pub const FuncGen = struct { const ty_op = self.air.instructions.items(.data)[inst].ty_op; const operand = try self.resolveInst(ty_op.operand); const operand_ty = self.air.typeOf(ty_op.operand); - const err_set_ty = if (operand_is_ptr) operand_ty.childType() else operand_ty; + const err_union_ty = if (operand_is_ptr) operand_ty.childType() else operand_ty; + if (err_union_ty.errorUnionSet().errorSetCardinality() == .zero) { + const err_llvm_ty = try self.dg.llvmType(Type.anyerror); + if (operand_is_ptr) { + return self.builder.buildBitCast(operand, err_llvm_ty.pointerType(0), ""); + } else { + return err_llvm_ty.constInt(0, .False); + } + } - const payload_ty = err_set_ty.errorUnionPayload(); + const payload_ty = err_union_ty.errorUnionPayload(); if (!payload_ty.hasRuntimeBitsIgnoreComptime()) { if (!operand_is_ptr) return operand; return self.builder.buildLoad(operand, ""); } - if (operand_is_ptr or isByRef(err_set_ty)) { - const err_field_ptr = self.builder.buildStructGEP(operand, 0, ""); + const target = self.dg.module.getTarget(); + const offset = errUnionErrorOffset(payload_ty, target); + + if (operand_is_ptr or isByRef(err_union_ty)) { + const err_field_ptr = self.builder.buildStructGEP(operand, offset, ""); return self.builder.buildLoad(err_field_ptr, ""); } - return self.builder.buildExtractValue(operand, 0, ""); + return self.builder.buildExtractValue(operand, offset, ""); } fn airErrUnionPayloadPtrSet(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value { const ty_op = self.air.instructions.items(.data)[inst].ty_op; const operand = try self.resolveInst(ty_op.operand); - const error_set_ty = self.air.typeOf(ty_op.operand).childType(); + const error_union_ty = self.air.typeOf(ty_op.operand).childType(); - const error_ty = error_set_ty.errorUnionSet(); - const payload_ty = error_set_ty.errorUnionPayload(); + const error_ty = error_union_ty.errorUnionSet(); + if (error_ty.errorSetCardinality() == .zero) { + // TODO: write undefined bytes through the pointer here + return operand; + } + const payload_ty = error_union_ty.errorUnionPayload(); const non_error_val = try self.dg.genTypedValue(.{ .ty = error_ty, .val = Value.zero }); if (!payload_ty.hasRuntimeBitsIgnoreComptime()) { - // We have a pointer to a i1. We need to set it to 1 and then return the same pointer. _ = self.builder.buildStore(non_error_val, operand); return operand; } const index_type = self.context.intType(32); + const target = self.dg.module.getTarget(); { + const error_offset = errUnionErrorOffset(payload_ty, target); // First set the non-error value. const indices: [2]*const llvm.Value = .{ index_type.constNull(), // dereference the pointer - index_type.constNull(), // first field is the payload + index_type.constInt(error_offset, .False), }; const non_null_ptr = self.builder.buildInBoundsGEP(operand, &indices, indices.len, ""); - _ = self.builder.buildStore(non_error_val, non_null_ptr); + const store_inst = self.builder.buildStore(non_error_val, non_null_ptr); + store_inst.setAlignment(Type.anyerror.abiAlignment(target)); } // Then return the payload pointer (only if it is used). if (self.liveness.isUnused(inst)) return null; - const target = self.dg.module.getTarget(); - const payload_offset: u8 = if (payload_ty.abiAlignment(target) > Type.anyerror.abiSize(target)) 2 else 1; + const payload_offset = errUnionPayloadOffset(payload_ty, target); const indices: [2]*const llvm.Value = .{ index_type.constNull(), // dereference the pointer - index_type.constInt(payload_offset, .False), // second field is the payload + index_type.constInt(payload_offset, .False), }; return self.builder.buildInBoundsGEP(operand, &indices, indices.len, ""); } @@ -5669,21 +5725,26 @@ pub const FuncGen = struct { if (self.liveness.isUnused(inst)) return null; const ty_op = self.air.instructions.items(.data)[inst].ty_op; - const payload_ty = self.air.typeOf(ty_op.operand); + const inst_ty = self.air.typeOfIndex(inst); const operand = try self.resolveInst(ty_op.operand); + if (inst_ty.errorUnionSet().errorSetCardinality() == .zero) { + return operand; + } + const payload_ty = self.air.typeOf(ty_op.operand); if (!payload_ty.hasRuntimeBitsIgnoreComptime()) { return operand; } - const inst_ty = self.air.typeOfIndex(inst); - const ok_err_code = self.context.intType(16).constNull(); + const ok_err_code = (try self.dg.llvmType(Type.anyerror)).constNull(); const err_un_llvm_ty = try self.dg.llvmType(inst_ty); const target = self.dg.module.getTarget(); - const payload_offset: u8 = if (payload_ty.abiAlignment(target) > Type.anyerror.abiSize(target)) 2 else 1; + const payload_offset = errUnionPayloadOffset(payload_ty, target); + const error_offset = errUnionErrorOffset(payload_ty, target); if (isByRef(inst_ty)) { const result_ptr = self.buildAlloca(err_un_llvm_ty); - const err_ptr = self.builder.buildStructGEP(result_ptr, 0, ""); - _ = self.builder.buildStore(ok_err_code, err_ptr); + const err_ptr = self.builder.buildStructGEP(result_ptr, error_offset, ""); + const store_inst = self.builder.buildStore(ok_err_code, err_ptr); + store_inst.setAlignment(Type.anyerror.abiAlignment(target)); const payload_ptr = self.builder.buildStructGEP(result_ptr, payload_offset, ""); var ptr_ty_payload: Type.Payload.ElemType = .{ .base = .{ .tag = .single_mut_pointer }, @@ -5694,7 +5755,7 @@ pub const FuncGen = struct { return result_ptr; } - const partial = self.builder.buildInsertValue(err_un_llvm_ty.getUndef(), ok_err_code, 0, ""); + const partial = self.builder.buildInsertValue(err_un_llvm_ty.getUndef(), ok_err_code, error_offset, ""); return self.builder.buildInsertValue(partial, operand, payload_offset, ""); } @@ -5711,11 +5772,13 @@ pub const FuncGen = struct { const err_un_llvm_ty = try self.dg.llvmType(err_un_ty); const target = self.dg.module.getTarget(); - const payload_offset: u8 = if (payload_ty.abiAlignment(target) > Type.anyerror.abiSize(target)) 2 else 1; + const payload_offset = errUnionPayloadOffset(payload_ty, target); + const error_offset = errUnionErrorOffset(payload_ty, target); if (isByRef(err_un_ty)) { const result_ptr = self.buildAlloca(err_un_llvm_ty); - const err_ptr = self.builder.buildStructGEP(result_ptr, 0, ""); - _ = self.builder.buildStore(operand, err_ptr); + const err_ptr = self.builder.buildStructGEP(result_ptr, error_offset, ""); + const store_inst = self.builder.buildStore(operand, err_ptr); + store_inst.setAlignment(Type.anyerror.abiAlignment(target)); const payload_ptr = self.builder.buildStructGEP(result_ptr, payload_offset, ""); var ptr_ty_payload: Type.Payload.ElemType = .{ .base = .{ .tag = .single_mut_pointer }, @@ -5728,7 +5791,7 @@ pub const FuncGen = struct { return result_ptr; } - const partial = self.builder.buildInsertValue(err_un_llvm_ty.getUndef(), operand, 0, ""); + const partial = self.builder.buildInsertValue(err_un_llvm_ty.getUndef(), operand, error_offset, ""); // TODO set payload bytes to undef return partial; } @@ -8546,7 +8609,14 @@ fn firstParamSRet(fn_info: Type.Payload.Function.Data, target: std.Target) bool /// be effectively bitcasted to the actual return type. fn lowerFnRetTy(dg: *DeclGen, fn_info: Type.Payload.Function.Data) !*const llvm.Type { if (!fn_info.return_type.hasRuntimeBitsIgnoreComptime()) { - return dg.context.voidType(); + // If the return type is an error set or an error union, then we make this + // anyerror return type instead, so that it can be coerced into a function + // pointer type which has anyerror as the return type. + if (fn_info.return_type.isError()) { + return dg.llvmType(Type.anyerror); + } else { + return dg.context.voidType(); + } } const target = dg.module.getTarget(); switch (fn_info.cc) { @@ -8991,3 +9061,11 @@ fn buildAllocaInner( return builder.buildAlloca(llvm_ty, ""); } + +fn errUnionPayloadOffset(payload_ty: Type, target: std.Target) u1 { + return @boolToInt(Type.anyerror.abiAlignment(target) > payload_ty.abiAlignment(target)); +} + +fn errUnionErrorOffset(payload_ty: Type, target: std.Target) u1 { + return @boolToInt(Type.anyerror.abiAlignment(target) <= payload_ty.abiAlignment(target)); +} diff --git a/src/type.zig b/src/type.zig index ea65cc8916..4b8a41915f 100644 --- a/src/type.zig +++ b/src/type.zig @@ -2317,10 +2317,7 @@ pub const Type = extern union { .const_slice_u8_sentinel_0, .array_u8_sentinel_0, .anyerror_void_error_union, - .error_set, - .error_set_single, .error_set_inferred, - .error_set_merged, .manyptr_u8, .manyptr_const_u8, .manyptr_const_u8_sentinel_0, @@ -2361,8 +2358,20 @@ pub const Type = extern union { .fn_void_no_args, .fn_naked_noreturn_no_args, .fn_ccc_void_no_args, + .error_set_single, => return false, + .error_set => { + const err_set_obj = ty.castTag(.error_set).?.data; + const names = err_set_obj.names.keys(); + return names.len > 1; + }, + .error_set_merged => { + const name_map = ty.castTag(.error_set_merged).?.data; + const names = name_map.keys(); + return names.len > 1; + }, + // These types have more than one possible value, so the result is the same as // asking whether they are comptime-only types. .anyframe_T, @@ -2388,6 +2397,21 @@ pub const Type = extern union { } }, + .error_union => { + // This code needs to be kept in sync with the equivalent switch prong + // in abiSizeAdvanced. + const data = ty.castTag(.error_union).?.data; + if (data.error_set.errorSetCardinality() == .zero) { + return hasRuntimeBitsAdvanced(data.payload, ignore_comptime_only, sema_kit); + } else if (ignore_comptime_only) { + return true; + } else if (sema_kit) |sk| { + return !(try sk.sema.typeRequiresComptime(sk.block, sk.src, ty)); + } else { + return !comptimeOnly(ty); + } + }, + .@"struct" => { const struct_obj = ty.castTag(.@"struct").?.data; if (sema_kit) |sk| { @@ -2467,12 +2491,6 @@ pub const Type = extern union { .int_signed, .int_unsigned => return ty.cast(Payload.Bits).?.data != 0, - .error_union => { - const payload = ty.castTag(.error_union).?.data; - return (try payload.error_set.hasRuntimeBitsAdvanced(ignore_comptime_only, sema_kit)) or - (try payload.payload.hasRuntimeBitsAdvanced(ignore_comptime_only, sema_kit)); - }, - .tuple, .anon_struct => { const tuple = ty.tupleFields(); for (tuple.types) |field_ty, i| { @@ -2852,13 +2870,30 @@ pub const Type = extern union { else => unreachable, }, - .error_set, - .error_set_single, + // TODO revisit this when we have the concept of the error tag type .anyerror_void_error_union, .anyerror, .error_set_inferred, - .error_set_merged, - => return AbiAlignmentAdvanced{ .scalar = 2 }, // TODO revisit this when we have the concept of the error tag type + => return AbiAlignmentAdvanced{ .scalar = 2 }, + + .error_set => { + const err_set_obj = ty.castTag(.error_set).?.data; + const names = err_set_obj.names.keys(); + if (names.len <= 1) { + return AbiAlignmentAdvanced{ .scalar = 0 }; + } else { + return AbiAlignmentAdvanced{ .scalar = 2 }; + } + }, + .error_set_merged => { + const name_map = ty.castTag(.error_set_merged).?.data; + const names = name_map.keys(); + if (names.len <= 1) { + return AbiAlignmentAdvanced{ .scalar = 0 }; + } else { + return AbiAlignmentAdvanced{ .scalar = 2 }; + } + }, .array, .array_sentinel => return ty.elemType().abiAlignmentAdvanced(target, strat), @@ -2900,31 +2935,29 @@ pub const Type = extern union { }, .error_union => { + // This code needs to be kept in sync with the equivalent switch prong + // in abiSizeAdvanced. const data = ty.castTag(.error_union).?.data; + if (data.error_set.errorSetCardinality() == .zero) { + return abiAlignmentAdvanced(data.payload, target, strat); + } + const code_align = abiAlignment(Type.anyerror, target); switch (strat) { .eager, .sema_kit => { - if (!(try data.error_set.hasRuntimeBitsAdvanced(false, sema_kit))) { - return data.payload.abiAlignmentAdvanced(target, strat); - } else if (!(try data.payload.hasRuntimeBitsAdvanced(false, sema_kit))) { - return data.error_set.abiAlignmentAdvanced(target, strat); + if (!(try data.payload.hasRuntimeBitsAdvanced(false, sema_kit))) { + return AbiAlignmentAdvanced{ .scalar = code_align }; } return AbiAlignmentAdvanced{ .scalar = @maximum( + code_align, (try data.payload.abiAlignmentAdvanced(target, strat)).scalar, - (try data.error_set.abiAlignmentAdvanced(target, strat)).scalar, ) }; }, .lazy => |arena| { switch (try data.payload.abiAlignmentAdvanced(target, strat)) { .scalar => |payload_align| { - if (payload_align == 0) { - return data.error_set.abiAlignmentAdvanced(target, strat); - } - switch (try data.error_set.abiAlignmentAdvanced(target, strat)) { - .scalar => |err_set_align| { - return AbiAlignmentAdvanced{ .scalar = @maximum(payload_align, err_set_align) }; - }, - .val => {}, - } + return AbiAlignmentAdvanced{ + .scalar = @maximum(code_align, payload_align), + }; }, .val => {}, } @@ -3018,6 +3051,7 @@ pub const Type = extern union { .@"undefined", .enum_literal, .type_info, + .error_set_single, => return AbiAlignmentAdvanced{ .scalar = 0 }, .noreturn, @@ -3136,6 +3170,7 @@ pub const Type = extern union { .empty_struct_literal, .empty_struct, .void, + .error_set_single, => return AbiSizeAdvanced{ .scalar = 0 }, .@"struct", .tuple, .anon_struct => switch (ty.containerLayout()) { @@ -3291,14 +3326,30 @@ pub const Type = extern union { }, // TODO revisit this when we have the concept of the error tag type - .error_set, - .error_set_single, .anyerror_void_error_union, .anyerror, .error_set_inferred, - .error_set_merged, => return AbiSizeAdvanced{ .scalar = 2 }, + .error_set => { + const err_set_obj = ty.castTag(.error_set).?.data; + const names = err_set_obj.names.keys(); + if (names.len <= 1) { + return AbiSizeAdvanced{ .scalar = 0 }; + } else { + return AbiSizeAdvanced{ .scalar = 2 }; + } + }, + .error_set_merged => { + const name_map = ty.castTag(.error_set_merged).?.data; + const names = name_map.keys(); + if (names.len <= 1) { + return AbiSizeAdvanced{ .scalar = 0 }; + } else { + return AbiSizeAdvanced{ .scalar = 2 }; + } + }, + .i16, .u16 => return AbiSizeAdvanced{ .scalar = intAbiSize(16, target) }, .i32, .u32 => return AbiSizeAdvanced{ .scalar = intAbiSize(32, target) }, .i64, .u64 => return AbiSizeAdvanced{ .scalar = intAbiSize(64, target) }, @@ -3325,24 +3376,42 @@ pub const Type = extern union { }, .error_union => { + // This code needs to be kept in sync with the equivalent switch prong + // in abiAlignmentAdvanced. const data = ty.castTag(.error_union).?.data; - if (!data.error_set.hasRuntimeBits() and !data.payload.hasRuntimeBits()) { - return AbiSizeAdvanced{ .scalar = 0 }; - } else if (!data.error_set.hasRuntimeBits()) { - return AbiSizeAdvanced{ .scalar = data.payload.abiSize(target) }; - } else if (!data.payload.hasRuntimeBits()) { - return AbiSizeAdvanced{ .scalar = data.error_set.abiSize(target) }; + // Here we need to care whether or not the error set is *empty* or whether + // it only has *one possible value*. In the former case, it means there + // cannot possibly be an error, meaning the ABI size is equivalent to the + // payload ABI size. In the latter case, we need to account for the "tag" + // because even if both the payload type and the error set type of an + // error union have no runtime bits, an error union still has + // 1 bit of data which is whether or not the value is an error. + // Zig still uses the error code encoding at runtime, even when only 1 bit + // would suffice. This prevents coercions from needing to branch. + if (data.error_set.errorSetCardinality() == .zero) { + return abiSizeAdvanced(data.payload, target, strat); + } + const code_size = abiSize(Type.anyerror, target); + if (!data.payload.hasRuntimeBits()) { + // Same as anyerror. + return AbiSizeAdvanced{ .scalar = code_size }; } - const code_align = abiAlignment(data.error_set, target); + const code_align = abiAlignment(Type.anyerror, target); const payload_align = abiAlignment(data.payload, target); - const big_align = @maximum(code_align, payload_align); const payload_size = abiSize(data.payload, target); var size: u64 = 0; - size += abiSize(data.error_set, target); - size = std.mem.alignForwardGeneric(u64, size, payload_align); - size += payload_size; - size = std.mem.alignForwardGeneric(u64, size, big_align); + if (code_align > payload_align) { + size += code_size; + size = std.mem.alignForwardGeneric(u64, size, payload_align); + size += payload_size; + size = std.mem.alignForwardGeneric(u64, size, code_align); + } else { + size += payload_size; + size = std.mem.alignForwardGeneric(u64, size, code_align); + size += code_size; + size = std.mem.alignForwardGeneric(u64, size, payload_align); + } return AbiSizeAdvanced{ .scalar = size }; }, } @@ -4166,6 +4235,35 @@ pub const Type = extern union { }; } + const ErrorSetCardinality = enum { zero, one, many }; + + pub fn errorSetCardinality(ty: Type) ErrorSetCardinality { + switch (ty.tag()) { + .anyerror => return .many, + .error_set_inferred => return .many, + .error_set_single => return .one, + .error_set => { + const err_set_obj = ty.castTag(.error_set).?.data; + const names = err_set_obj.names.keys(); + switch (names.len) { + 0 => return .zero, + 1 => return .one, + else => return .many, + } + }, + .error_set_merged => { + const name_map = ty.castTag(.error_set_merged).?.data; + const names = name_map.keys(); + switch (names.len) { + 0 => return .zero, + 1 => return .one, + else => return .many, + } + }, + else => unreachable, + } + } + /// Returns true if it is an error set that includes anyerror, false otherwise. /// Note that the result may be a false negative if the type did not get error set /// resolution prior to this call. @@ -4664,10 +4762,7 @@ pub const Type = extern union { .enum_literal, .anyerror_void_error_union, .error_union, - .error_set, - .error_set_single, .error_set_inferred, - .error_set_merged, .@"opaque", .var_args_param, .manyptr_u8, @@ -4696,6 +4791,18 @@ pub const Type = extern union { .bound_fn, => return null, + .error_set_single => return Value.initTag(.the_only_possible_value), + .error_set => { + const err_set_obj = ty.castTag(.error_set).?.data; + if (err_set_obj.names.count() > 1) return null; + return Value.initTag(.the_only_possible_value); + }, + .error_set_merged => { + const name_map = ty.castTag(.error_set_merged).?.data; + if (name_map.count() > 1) return null; + return Value.initTag(.the_only_possible_value); + }, + .@"struct" => { const s = ty.castTag(.@"struct").?.data; assert(s.haveFieldTypes()); -- cgit v1.2.3