From 7c9979a02e830a4383995e66ff623a7d07cac091 Mon Sep 17 00:00:00 2001 From: Veikka Tuominen Date: Thu, 11 Aug 2022 22:45:15 +0300 Subject: stage2: generate a switch for `@errSetCast` safety --- src/codegen/llvm.zig | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) (limited to 'src/codegen/llvm.zig') diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 38f7d285ae..5da3e7e327 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -4247,6 +4247,7 @@ pub const FuncGen = struct { .prefetch => try self.airPrefetch(inst), .is_named_enum_value => try self.airIsNamedEnumValue(inst), + .error_set_has_value => try self.airErrorSetHasValue(inst), .reduce => try self.airReduce(inst, false), .reduce_optimized => try self.airReduce(inst, true), @@ -7983,6 +7984,53 @@ pub const FuncGen = struct { } } + fn airErrorSetHasValue(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value { + if (self.liveness.isUnused(inst)) return null; + + const ty_op = self.air.instructions.items(.data)[inst].ty_op; + const operand = try self.resolveInst(ty_op.operand); + const error_set_ty = self.air.getRefType(ty_op.ty); + + const names = error_set_ty.errorSetNames(); + const valid_block = self.dg.context.appendBasicBlock(self.llvm_func, "Valid"); + const invalid_block = self.dg.context.appendBasicBlock(self.llvm_func, "Invalid"); + const end_block = self.context.appendBasicBlock(self.llvm_func, "End"); + const switch_instr = self.builder.buildSwitch(operand, invalid_block, @intCast(c_uint, names.len)); + + for (names) |name| { + const err_int = self.dg.module.global_error_set.get(name).?; + const this_tag_int_value = int: { + var tag_val_payload: Value.Payload.U64 = .{ + .base = .{ .tag = .int_u64 }, + .data = err_int, + }; + break :int try self.dg.lowerValue(.{ + .ty = Type.u16, + .val = Value.initPayload(&tag_val_payload.base), + }); + }; + switch_instr.addCase(this_tag_int_value, valid_block); + } + self.builder.positionBuilderAtEnd(valid_block); + _ = self.builder.buildBr(end_block); + + self.builder.positionBuilderAtEnd(invalid_block); + _ = self.builder.buildBr(end_block); + + self.builder.positionBuilderAtEnd(end_block); + + const llvm_type = self.dg.context.intType(1); + const incoming_values: [2]*const llvm.Value = .{ + llvm_type.constInt(1, .False), llvm_type.constInt(0, .False), + }; + const incoming_blocks: [2]*const llvm.BasicBlock = .{ + valid_block, invalid_block, + }; + const phi_node = self.builder.buildPhi(llvm_type, ""); + phi_node.addIncoming(&incoming_values, &incoming_blocks, 2); + return phi_node; + } + fn airIsNamedEnumValue(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value { if (self.liveness.isUnused(inst)) return null; -- cgit v1.2.3 From 09f273136c74639d18695f63ba9c0fdc68ad678c Mon Sep 17 00:00:00 2001 From: Veikka Tuominen Date: Thu, 11 Aug 2022 23:06:50 +0300 Subject: stage2: check for zero in `@intToError` safety --- src/Sema.zig | 22 ++++++++++++---------- src/codegen/llvm.zig | 2 +- src/type.zig | 2 ++ test/cases/safety/zero casted to error.zig | 19 +++++++++++++++++++ 4 files changed, 34 insertions(+), 11 deletions(-) create mode 100644 test/cases/safety/zero casted to error.zig (limited to 'src/codegen/llvm.zig') diff --git a/src/Sema.zig b/src/Sema.zig index 3e62344a04..55b519d900 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -6804,11 +6804,10 @@ fn zirErrorToInt(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = extra.node }; const uncasted_operand = try sema.resolveInst(extra.operand); const operand = try sema.coerce(block, Type.anyerror, uncasted_operand, operand_src); - const result_ty = Type.u16; if (try sema.resolveMaybeUndefVal(block, src, operand)) |val| { if (val.isUndef()) { - return sema.addConstUndef(result_ty); + return sema.addConstUndef(Type.err_int); } switch (val.tag()) { .@"error" => { @@ -6817,14 +6816,14 @@ fn zirErrorToInt(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat .base = .{ .tag = .int_u64 }, .data = (try sema.mod.getErrorValue(val.castTag(.@"error").?.data.name)).value, }; - return sema.addConstant(result_ty, Value.initPayload(&payload.base)); + return sema.addConstant(Type.err_int, Value.initPayload(&payload.base)); }, // This is not a valid combination with the type `anyerror`. .the_only_possible_value => unreachable, // Assume it's already encoded as an integer. - else => return sema.addConstant(result_ty, val), + else => return sema.addConstant(Type.err_int, val), } } @@ -6833,14 +6832,14 @@ fn zirErrorToInt(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat if (!op_ty.isAnyError()) { const names = op_ty.errorSetNames(); switch (names.len) { - 0 => return sema.addConstant(result_ty, Value.zero), - 1 => return sema.addIntUnsigned(result_ty, sema.mod.global_error_set.get(names[0]).?), + 0 => return sema.addConstant(Type.err_int, Value.zero), + 1 => return sema.addIntUnsigned(Type.err_int, sema.mod.global_error_set.get(names[0]).?), else => {}, } } try sema.requireRuntimeBlock(block, src, operand_src); - return block.addBitCast(result_ty, operand); + return block.addBitCast(Type.err_int, operand); } fn zirIntToError(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref { @@ -6851,7 +6850,7 @@ fn zirIntToError(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat const src = LazySrcLoc.nodeOffset(extra.node); const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = extra.node }; const uncasted_operand = try sema.resolveInst(extra.operand); - const operand = try sema.coerce(block, Type.u16, uncasted_operand, operand_src); + const operand = try sema.coerce(block, Type.err_int, uncasted_operand, operand_src); const target = sema.mod.getTarget(); if (try sema.resolveDefinedValue(block, operand_src, operand)) |value| { @@ -6868,7 +6867,10 @@ fn zirIntToError(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat try sema.requireRuntimeBlock(block, src, operand_src); if (block.wantSafety()) { const is_lt_len = try block.addUnOp(.cmp_lt_errors_len, operand); - try sema.addSafetyCheck(block, is_lt_len, .invalid_error_code); + const zero_val = try sema.addConstant(Type.err_int, Value.zero); + const is_non_zero = try block.addBinOp(.cmp_neq, operand, zero_val); + const ok = try block.addBinOp(.bit_and, is_lt_len, is_non_zero); + try sema.addSafetyCheck(block, ok, .invalid_error_code); } return block.addInst(.{ .tag = .bitcast, @@ -17360,7 +17362,7 @@ fn zirErrSetCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat try sema.requireRuntimeBlock(block, src, operand_src); if (block.wantSafety() and !dest_ty.isAnyError() and sema.mod.comp.bin_file.options.use_llvm) { - const err_int_inst = try block.addBitCast(Type.u16, operand); + const err_int_inst = try block.addBitCast(Type.err_int, operand); const ok = try block.addTyOp(.error_set_has_value, dest_ty, err_int_inst); try sema.addSafetyCheck(block, ok, .invalid_error_code); } diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 5da3e7e327..0586c99432 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -8005,7 +8005,7 @@ pub const FuncGen = struct { .data = err_int, }; break :int try self.dg.lowerValue(.{ - .ty = Type.u16, + .ty = Type.err_int, .val = Value.initPayload(&tag_val_payload.base), }); }; diff --git a/src/type.zig b/src/type.zig index 1b71f4e9b1..582ea230ef 100644 --- a/src/type.zig +++ b/src/type.zig @@ -6303,6 +6303,8 @@ pub const Type = extern union { pub const @"anyopaque" = initTag(.anyopaque); pub const @"null" = initTag(.@"null"); + pub const err_int = Type.u16; + pub fn ptr(arena: Allocator, mod: *Module, data: Payload.Pointer.Data) !Type { const target = mod.getTarget(); diff --git a/test/cases/safety/zero casted to error.zig b/test/cases/safety/zero casted to error.zig new file mode 100644 index 0000000000..3a2edf834a --- /dev/null +++ b/test/cases/safety/zero casted to error.zig @@ -0,0 +1,19 @@ +const std = @import("std"); + +pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn { + _ = stack_trace; + if (std.mem.eql(u8, message, "invalid error code")) { + std.process.exit(0); + } + std.process.exit(1); +} +pub fn main() !void { + bar(0) catch {}; + return error.TestFailed; +} +fn bar(x: u16) anyerror { + return @intToError(x); +} +// run +// backend=llvm +// target=native -- cgit v1.2.3