diff options
| author | Andrew Kelley <andrew@ziglang.org> | 2022-08-17 12:55:08 -0700 |
|---|---|---|
| committer | Andrew Kelley <andrew@ziglang.org> | 2022-08-17 13:02:57 -0700 |
| commit | c764640e92c9e4d32b89650ac774bebf1498be92 (patch) | |
| tree | 5b17ee88da0d6499007a36a3a4b701c5d8f44a81 /src | |
| parent | a12abc6d6c8b89a09befdcbd9019247ccc3bd641 (diff) | |
| download | zig-c764640e92c9e4d32b89650ac774bebf1498be92.tar.gz zig-c764640e92c9e4d32b89650ac774bebf1498be92.zip | |
Sema: fix generics with struct literal coerced to tagged union
The `Value.eql` function has to test for value equality *as-if* the lhs
value parameter is coerced into the type of the rhs. For tagged unions,
there was a problematic case when the lhs was an anonymous struct,
because in such case the value is empty_struct_value and the type
contains all the value information. But the only type available in the
function was the rhs type.
So the fix involved making `Value.eqlAdvanced` also accept the lhs type,
and then enhancing the logic to handle the case of the `.anon_struct` tag.
closes #12418
Tests run locally:
* test-behavior
* test-cases
Diffstat (limited to 'src')
| -rw-r--r-- | src/Sema.zig | 56 | ||||
| -rw-r--r-- | src/value.zig | 77 |
2 files changed, 91 insertions, 42 deletions
diff --git a/src/Sema.zig b/src/Sema.zig index d7d6994bcd..f1d140520c 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -1495,7 +1495,8 @@ pub fn resolveInst(sema: *Sema, zir_ref: Zir.Inst.Ref) !Air.Inst.Ref { // Finally, the last section of indexes refers to the map of ZIR=>AIR. const inst = sema.inst_map.get(@intCast(u32, i)).?; - if (sema.typeOf(inst).tag() == .generic_poison) return error.GenericPoison; + const ty = sema.typeOf(inst); + if (ty.tag() == .generic_poison) return error.GenericPoison; return inst; } @@ -5570,11 +5571,15 @@ const GenericCallAdapter = struct { generic_fn: *Module.Fn, precomputed_hash: u64, func_ty_info: Type.Payload.Function.Data, - /// Unlike comptime_args, the Type here is not always present. - /// .generic_poison is used to communicate non-anytype parameters. - comptime_tvs: []const TypedValue, + args: []const Arg, module: *Module, + const Arg = struct { + ty: Type, + val: Value, + is_anytype: bool, + }; + pub fn eql(ctx: @This(), adapted_key: void, other_key: *Module.Fn) bool { _ = adapted_key; // The generic function Decl is guaranteed to be the first dependency @@ -5585,10 +5590,10 @@ const GenericCallAdapter = struct { const other_comptime_args = other_key.comptime_args.?; for (other_comptime_args[0..ctx.func_ty_info.param_types.len]) |other_arg, i| { - const this_arg = ctx.comptime_tvs[i]; + const this_arg = ctx.args[i]; const this_is_comptime = this_arg.val.tag() != .generic_poison; const other_is_comptime = other_arg.val.tag() != .generic_poison; - const this_is_anytype = this_arg.ty.tag() != .generic_poison; + const this_is_anytype = this_arg.is_anytype; const other_is_anytype = other_key.isAnytypeParam(ctx.module, @intCast(u32, i)); if (other_is_anytype != this_is_anytype) return false; @@ -5607,7 +5612,17 @@ const GenericCallAdapter = struct { } } else if (this_is_comptime) { // Both are comptime parameters but not anytype parameters. - if (!this_arg.val.eql(other_arg.val, other_arg.ty, ctx.module)) { + // We assert no error is possible here because any lazy values must be resolved + // before inserting into the generic function hash map. + const is_eql = Value.eqlAdvanced( + this_arg.val, + this_arg.ty, + other_arg.val, + other_arg.ty, + ctx.module, + null, + ) catch unreachable; + if (!is_eql) { return false; } } @@ -6258,8 +6273,7 @@ fn instantiateGenericCall( var hasher = std.hash.Wyhash.init(0); std.hash.autoHash(&hasher, @ptrToInt(module_fn)); - const comptime_tvs = try sema.arena.alloc(TypedValue, func_ty_info.param_types.len); - + const generic_args = try sema.arena.alloc(GenericCallAdapter.Arg, func_ty_info.param_types.len); { var i: usize = 0; for (fn_info.param_body) |inst| { @@ -6283,8 +6297,9 @@ fn instantiateGenericCall( else => continue, } + const arg_ty = sema.typeOf(uncasted_args[i]); + if (is_comptime) { - const arg_ty = sema.typeOf(uncasted_args[i]); const arg_val = sema.analyzeGenericCallArgVal(block, .unneeded, uncasted_args[i]) catch |err| switch (err) { error.NeededSourceLocation => { const decl = sema.mod.declPtr(block.src_decl); @@ -6297,27 +6312,30 @@ fn instantiateGenericCall( arg_val.hash(arg_ty, &hasher, mod); if (is_anytype) { arg_ty.hashWithHasher(&hasher, mod); - comptime_tvs[i] = .{ + generic_args[i] = .{ .ty = arg_ty, .val = arg_val, + .is_anytype = true, }; } else { - comptime_tvs[i] = .{ - .ty = Type.initTag(.generic_poison), + generic_args[i] = .{ + .ty = arg_ty, .val = arg_val, + .is_anytype = false, }; } } else if (is_anytype) { - const arg_ty = sema.typeOf(uncasted_args[i]); arg_ty.hashWithHasher(&hasher, mod); - comptime_tvs[i] = .{ + generic_args[i] = .{ .ty = arg_ty, .val = Value.initTag(.generic_poison), + .is_anytype = true, }; } else { - comptime_tvs[i] = .{ - .ty = Type.initTag(.generic_poison), + generic_args[i] = .{ + .ty = arg_ty, .val = Value.initTag(.generic_poison), + .is_anytype = false, }; } @@ -6331,7 +6349,7 @@ fn instantiateGenericCall( .generic_fn = module_fn, .precomputed_hash = precomputed_hash, .func_ty_info = func_ty_info, - .comptime_tvs = comptime_tvs, + .args = generic_args, .module = mod, }; const gop = try mod.monomorphed_funcs.getOrPutAdapted(gpa, {}, adapter); @@ -30124,7 +30142,7 @@ fn valuesEqual( rhs: Value, ty: Type, ) CompileError!bool { - return Value.eqlAdvanced(lhs, rhs, ty, sema.mod, sema.kit(block, src)); + return Value.eqlAdvanced(lhs, ty, rhs, ty, sema.mod, sema.kit(block, src)); } /// Asserts the values are comparable vectors of type `ty`. diff --git a/src/value.zig b/src/value.zig index 677a459afe..9909cab5ce 100644 --- a/src/value.zig +++ b/src/value.zig @@ -2004,6 +2004,10 @@ pub const Value = extern union { return (try orderAgainstZeroAdvanced(lhs, sema_kit)).compare(op); } + pub fn eql(a: Value, b: Value, ty: Type, mod: *Module) bool { + return eqlAdvanced(a, ty, b, ty, mod, null) catch unreachable; + } + /// This function is used by hash maps and so treats floating-point NaNs as equal /// to each other, and not equal to other floating-point values. /// Similarly, it treats `undef` as a distinct value from all other values. @@ -2012,13 +2016,10 @@ pub const Value = extern union { /// for `a`. This function must act *as if* `a` has been coerced to `ty`. This complication /// is required in order to make generic function instantiation efficient - specifically /// the insertion into the monomorphized function table. - pub fn eql(a: Value, b: Value, ty: Type, mod: *Module) bool { - return eqlAdvanced(a, b, ty, mod, null) catch unreachable; - } - /// If `null` is provided for `sema_kit` then it is guaranteed no error will be returned. pub fn eqlAdvanced( a: Value, + a_ty: Type, b: Value, ty: Type, mod: *Module, @@ -2044,33 +2045,34 @@ pub const Value = extern union { const a_payload = a.castTag(.opt_payload).?.data; const b_payload = b.castTag(.opt_payload).?.data; var buffer: Type.Payload.ElemType = undefined; - return eqlAdvanced(a_payload, b_payload, ty.optionalChild(&buffer), mod, sema_kit); + const payload_ty = ty.optionalChild(&buffer); + return eqlAdvanced(a_payload, payload_ty, b_payload, payload_ty, mod, sema_kit); }, .slice => { const a_payload = a.castTag(.slice).?.data; const b_payload = b.castTag(.slice).?.data; - if (!(try eqlAdvanced(a_payload.len, b_payload.len, Type.usize, mod, sema_kit))) { + if (!(try eqlAdvanced(a_payload.len, Type.usize, b_payload.len, Type.usize, mod, sema_kit))) { return false; } var ptr_buf: Type.SlicePtrFieldTypeBuffer = undefined; const ptr_ty = ty.slicePtrFieldType(&ptr_buf); - return eqlAdvanced(a_payload.ptr, b_payload.ptr, ptr_ty, mod, sema_kit); + return eqlAdvanced(a_payload.ptr, ptr_ty, b_payload.ptr, ptr_ty, mod, sema_kit); }, .elem_ptr => { const a_payload = a.castTag(.elem_ptr).?.data; const b_payload = b.castTag(.elem_ptr).?.data; if (a_payload.index != b_payload.index) return false; - return eqlAdvanced(a_payload.array_ptr, b_payload.array_ptr, ty, mod, sema_kit); + return eqlAdvanced(a_payload.array_ptr, ty, b_payload.array_ptr, ty, mod, sema_kit); }, .field_ptr => { const a_payload = a.castTag(.field_ptr).?.data; const b_payload = b.castTag(.field_ptr).?.data; if (a_payload.field_index != b_payload.field_index) return false; - return eqlAdvanced(a_payload.container_ptr, b_payload.container_ptr, ty, mod, sema_kit); + return eqlAdvanced(a_payload.container_ptr, ty, b_payload.container_ptr, ty, mod, sema_kit); }, .@"error" => { const a_name = a.castTag(.@"error").?.data.name; @@ -2080,7 +2082,8 @@ pub const Value = extern union { .eu_payload => { const a_payload = a.castTag(.eu_payload).?.data; const b_payload = b.castTag(.eu_payload).?.data; - return eqlAdvanced(a_payload, b_payload, ty.errorUnionPayload(), mod, sema_kit); + const payload_ty = ty.errorUnionPayload(); + return eqlAdvanced(a_payload, payload_ty, b_payload, payload_ty, mod, sema_kit); }, .eu_payload_ptr => @panic("TODO: Implement more pointer eql cases"), .opt_payload_ptr => @panic("TODO: Implement more pointer eql cases"), @@ -2098,7 +2101,7 @@ pub const Value = extern union { const types = ty.tupleFields().types; assert(types.len == a_field_vals.len); for (types) |field_ty, i| { - if (!(try eqlAdvanced(a_field_vals[i], b_field_vals[i], field_ty, mod, sema_kit))) { + if (!(try eqlAdvanced(a_field_vals[i], field_ty, b_field_vals[i], field_ty, mod, sema_kit))) { return false; } } @@ -2109,7 +2112,7 @@ pub const Value = extern union { const fields = ty.structFields().values(); assert(fields.len == a_field_vals.len); for (fields) |field, i| { - if (!(try eqlAdvanced(a_field_vals[i], b_field_vals[i], field.ty, mod, sema_kit))) { + if (!(try eqlAdvanced(a_field_vals[i], field.ty, b_field_vals[i], field.ty, mod, sema_kit))) { return false; } } @@ -2120,7 +2123,7 @@ pub const Value = extern union { for (a_field_vals) |a_elem, i| { const b_elem = b_field_vals[i]; - if (!(try eqlAdvanced(a_elem, b_elem, elem_ty, mod, sema_kit))) { + if (!(try eqlAdvanced(a_elem, elem_ty, b_elem, elem_ty, mod, sema_kit))) { return false; } } @@ -2132,7 +2135,7 @@ pub const Value = extern union { switch (ty.containerLayout()) { .Packed, .Extern => { const tag_ty = ty.unionTagTypeHypothetical(); - if (!(try a_union.tag.eqlAdvanced(b_union.tag, tag_ty, mod, sema_kit))) { + if (!(try eqlAdvanced(a_union.tag, tag_ty, b_union.tag, tag_ty, mod, sema_kit))) { // In this case, we must disregard mismatching tags and compare // based on the in-memory bytes of the payloads. @panic("TODO comptime comparison of extern union values with mismatching tags"); @@ -2140,13 +2143,13 @@ pub const Value = extern union { }, .Auto => { const tag_ty = ty.unionTagTypeHypothetical(); - if (!(try a_union.tag.eqlAdvanced(b_union.tag, tag_ty, mod, sema_kit))) { + if (!(try eqlAdvanced(a_union.tag, tag_ty, b_union.tag, tag_ty, mod, sema_kit))) { return false; } }, } const active_field_ty = ty.unionFieldType(a_union.tag, mod); - return a_union.val.eqlAdvanced(b_union.val, active_field_ty, mod, sema_kit); + return eqlAdvanced(a_union.val, active_field_ty, b_union.val, active_field_ty, mod, sema_kit); }, else => {}, } else if (a_tag == .null_value or b_tag == .null_value) { @@ -2180,7 +2183,7 @@ pub const Value = extern union { const b_val = b.enumToInt(ty, &buf_b); var buf_ty: Type.Payload.Bits = undefined; const int_ty = ty.intTagType(&buf_ty); - return eqlAdvanced(a_val, b_val, int_ty, mod, sema_kit); + return eqlAdvanced(a_val, int_ty, b_val, int_ty, mod, sema_kit); }, .Array, .Vector => { const len = ty.arrayLen(); @@ -2191,17 +2194,44 @@ pub const Value = extern union { while (i < len) : (i += 1) { const a_elem = elemValueBuffer(a, mod, i, &a_buf); const b_elem = elemValueBuffer(b, mod, i, &b_buf); - if (!(try eqlAdvanced(a_elem, b_elem, elem_ty, mod, sema_kit))) { + if (!(try eqlAdvanced(a_elem, elem_ty, b_elem, elem_ty, mod, sema_kit))) { return false; } } return true; }, .Struct => { - // A tuple can be represented with .empty_struct_value, - // the_one_possible_value, .aggregate in which case we could - // end up here and the values are equal if the type has zero fields. - return ty.isTupleOrAnonStruct() and ty.structFieldCount() != 0; + // A struct can be represented with one of: + // .empty_struct_value, + // .the_one_possible_value, + // .aggregate, + // Note that we already checked above for matching tags, e.g. both .aggregate. + return ty.onePossibleValue() != null; + }, + .Union => { + // Here we have to check for value equality, as-if `a` has been coerced to `ty`. + if (ty.onePossibleValue() != null) { + return true; + } + if (a_ty.castTag(.anon_struct)) |payload| { + const tuple = payload.data; + if (tuple.values.len != 1) { + return false; + } + const field_name = tuple.names[0]; + const union_obj = ty.cast(Type.Payload.Union).?.data; + const field_index = union_obj.fields.getIndex(field_name) orelse return false; + const tag_and_val = b.castTag(.@"union").?.data; + var field_tag_buf: Value.Payload.U32 = .{ + .base = .{ .tag = .enum_field_index }, + .data = @intCast(u32, field_index), + }; + const field_tag = Value.initPayload(&field_tag_buf.base); + const tag_matches = tag_and_val.tag.eql(field_tag, union_obj.tag_ty, mod); + if (!tag_matches) return false; + return eqlAdvanced(tag_and_val.val, union_obj.tag_ty, tuple.values[0], tuple.types[0], mod, sema_kit); + } + return false; }, .Float => { switch (ty.floatBits(target)) { @@ -2230,7 +2260,8 @@ pub const Value = extern union { .base = .{ .tag = .opt_payload }, .data = a, }; - return eqlAdvanced(Value.initPayload(&buffer.base), b, ty, mod, sema_kit); + const opt_val = Value.initPayload(&buffer.base); + return eqlAdvanced(opt_val, ty, b, ty, mod, sema_kit); } }, else => {}, |
