From 569870ca41e73c64d8dc9f1eccfef3529caf2266 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Tue, 8 Mar 2022 20:44:58 -0800 Subject: stage2: error_set_merged type equality This implements type equality for error sets. This is done through element-wise error set comparison. Inferred error sets are always distinct types and other error sets are always sorted. See #11022. --- src/Module.zig | 14 ++++++++++++- src/Sema.zig | 4 ++++ src/type.zig | 66 +++++++++++++++++++++++++++++++++++++++------------------- src/value.zig | 10 +++++++++ 4 files changed, 72 insertions(+), 22 deletions(-) (limited to 'src') diff --git a/src/Module.zig b/src/Module.zig index 93e4b87d5b..693cc3b5a0 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -824,7 +824,7 @@ pub const ErrorSet = struct { /// Offset from Decl node index, points to the error set AST node. node_offset: i32, /// The string bytes are stored in the owner Decl arena. - /// They are in the same order they appear in the AST. + /// These must be in sorted order. See sortNames. names: NameMap, pub const NameMap = std.StringArrayHashMapUnmanaged(void); @@ -836,6 +836,18 @@ pub const ErrorSet = struct { .lazy = .{ .node_offset = self.node_offset }, }; } + + /// sort the NameMap. This should be called whenever the map is modified. + /// alloc should be the allocator used for the NameMap data. + pub fn sortNames(names: *NameMap) void { + const Context = struct { + keys: [][]const u8, + pub fn lessThan(ctx: @This(), a_index: usize, b_index: usize) bool { + return std.mem.lessThan(u8, ctx.keys[a_index], ctx.keys[b_index]); + } + }; + names.sort(Context{ .keys = names.keys() }); + } }; pub const RequiresComptime = enum { no, yes, unknown, wip }; diff --git a/src/Sema.zig b/src/Sema.zig index 195a0ef274..f74fa1e0bf 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -2212,6 +2212,10 @@ fn zirErrorSetDecl( return sema.fail(block, src, "duplicate error set field {s}", .{name}); } } + + // names must be sorted. + Module.ErrorSet.sortNames(&names); + error_set.* = .{ .owner_decl = new_decl, .node_offset = inst_data.src_node, diff --git a/src/type.zig b/src/type.zig index b0ff9e59e9..9126f1213b 100644 --- a/src/type.zig +++ b/src/type.zig @@ -564,27 +564,30 @@ pub const Type = extern union { => { if (b.zigTypeTag() != .ErrorSet) return false; - // TODO: revisit the language specification for how to evaluate equality - // for error set types. - - if (a.tag() == .anyerror and b.tag() == .anyerror) { - return true; + // inferred error sets are only equal if both are inferred + // and they originate from the exact same function. + if (a.castTag(.error_set_inferred)) |a_pl| { + if (b.castTag(.error_set_inferred)) |b_pl| { + return a_pl.data.func == b_pl.data.func; + } + return false; } - - if (a.tag() == .error_set and b.tag() == .error_set) { - return a.castTag(.error_set).?.data.owner_decl == b.castTag(.error_set).?.data.owner_decl; + if (b.tag() == .error_set_inferred) return false; + + // anyerror matches exactly. + const a_is_any = a.isAnyError(); + const b_is_any = b.isAnyError(); + if (a_is_any or b_is_any) return a_is_any and b_is_any; + + // two resolved sets match if their error set names match. + const a_set = a.errorSetNames(); + const b_set = b.errorSetNames(); + if (a_set.len != b_set.len) return false; + for (b_set) |b_val| { + if (!a.errorSetHasField(b_val)) return false; } - if (a.tag() == .error_set_inferred and b.tag() == .error_set_inferred) { - return a.castTag(.error_set_inferred).?.data == b.castTag(.error_set_inferred).?.data; - } - - if (a.tag() == .error_set_single and b.tag() == .error_set_single) { - const a_data = a.castTag(.error_set_single).?.data; - const b_data = b.castTag(.error_set_single).?.data; - return std.mem.eql(u8, a_data, b_data); - } - return false; + return true; }, .@"opaque" => { @@ -961,12 +964,30 @@ pub const Type = extern union { .error_set, .error_set_single, - .anyerror, - .error_set_inferred, .error_set_merged, => { + // all are treated like an "error set" for hashing + std.hash.autoHash(hasher, std.builtin.TypeId.ErrorSet); + std.hash.autoHash(hasher, Tag.error_set); + + const names = ty.errorSetNames(); + std.hash.autoHash(hasher, names.len); + assert(std.sort.isSorted([]const u8, names, u8, std.mem.lessThan)); + for (names) |name| hasher.update(name); + }, + + .anyerror => { + // anyerror is distinct from other error sets std.hash.autoHash(hasher, std.builtin.TypeId.ErrorSet); - // TODO implement this after revisiting Type.Eql for error sets + std.hash.autoHash(hasher, Tag.anyerror); + }, + + .error_set_inferred => { + // inferred error sets are compared using their data pointer + const data = ty.castTag(.error_set_inferred).?.data.func; + std.hash.autoHash(hasher, std.builtin.TypeId.ErrorSet); + std.hash.autoHash(hasher, Tag.error_set_inferred); + std.hash.autoHash(hasher, data); }, .@"opaque" => { @@ -4365,6 +4386,9 @@ pub const Type = extern union { try names.put(arena, name, {}); } + // names must be sorted + Module.ErrorSet.sortNames(&names); + return try Tag.error_set_merged.create(arena, names); } diff --git a/src/value.zig b/src/value.zig index 121e380bd9..502de64348 100644 --- a/src/value.zig +++ b/src/value.zig @@ -1870,6 +1870,16 @@ pub const Value = extern union { return eql(a_payload.container_ptr, b_payload.container_ptr, ty); }, + .@"error" => { + const a_name = a.castTag(.@"error").?.data.name; + const b_name = b.castTag(.@"error").?.data.name; + return std.mem.eql(u8, a_name, b_name); + }, + .eu_payload => { + const a_payload = a.castTag(.eu_payload).?.data; + const b_payload = b.castTag(.eu_payload).?.data; + return eql(a_payload, b_payload, ty.errorUnionPayload()); + }, .eu_payload_ptr => @panic("TODO: Implement more pointer eql cases"), .opt_payload_ptr => @panic("TODO: Implement more pointer eql cases"), .array => { -- cgit v1.2.3