From d8c3c11c6c5288118b3529c6a99810a4c4add9e9 Mon Sep 17 00:00:00 2001 From: Jiacai Liu Date: Mon, 30 Jan 2023 06:00:14 +0800 Subject: std: add expectEqualDeep (#13995) --- lib/std/testing.zig | 246 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 246 insertions(+) (limited to 'lib/std/testing.zig') diff --git a/lib/std/testing.zig b/lib/std/testing.zig index 53877967c9..895a9a0973 100644 --- a/lib/std/testing.zig +++ b/lib/std/testing.zig @@ -670,6 +670,252 @@ pub fn expectStringEndsWith(actual: []const u8, expected_ends_with: []const u8) return error.TestExpectedEndsWith; } +/// This function is intended to be used only in tests. When the two values are not +/// deeply equal, prints diagnostics to stderr to show exactly how they are not equal, +/// then returns a test failure error. +/// `actual` is casted to the type of `expected`. +/// +/// Deeply equal is defined as follows: +/// Primitive types are deeply equal if they are equal using `==` operator. +/// Struct values are deeply equal if their corresponding fields are deeply equal. +/// Container types(like Array/Slice/Vector) deeply equal when their corresponding elements are deeply equal. +/// Pointer values are deeply equal if values they point to are deeply equal. +/// +/// Note: Self-referential structs are not supported (e.g. things like std.SinglyLinkedList) +pub fn expectEqualDeep(expected: anytype, actual: @TypeOf(expected)) !void { + switch (@typeInfo(@TypeOf(actual))) { + .NoReturn, + .Opaque, + .Frame, + .AnyFrame, + => @compileError("value of type " ++ @typeName(@TypeOf(actual)) ++ " encountered"), + + .Undefined, + .Null, + .Void, + => return, + + .Type => { + if (actual != expected) { + std.debug.print("expected type {s}, found type {s}\n", .{ @typeName(expected), @typeName(actual) }); + return error.TestExpectedEqual; + } + }, + + .Bool, + .Int, + .Float, + .ComptimeFloat, + .ComptimeInt, + .EnumLiteral, + .Enum, + .Fn, + .ErrorSet, + => { + if (actual != expected) { + std.debug.print("expected {}, found {}\n", .{ expected, actual }); + return error.TestExpectedEqual; + } + }, + + .Pointer => |pointer| { + switch (pointer.size) { + // We have no idea what is behind those pointers, so the best we can do is `==` check. + .C, .Many => { + if (actual != expected) { + std.debug.print("expected {*}, found {*}\n", .{ expected, actual }); + return error.TestExpectedEqual; + } + }, + .One => { + // Length of those pointers are runtime value, so the best we can do is `==` check. + switch (@typeInfo(pointer.child)) { + .Fn, .Opaque => { + if (actual != expected) { + std.debug.print("expected {*}, found {*}\n", .{ expected, actual }); + return error.TestExpectedEqual; + } + }, + else => try expectEqualDeep(expected.*, actual.*), + } + }, + .Slice => { + if (expected.len != actual.len) { + std.debug.print("Slice len not the same, expected {d}, found {d}\n", .{ expected.len, actual.len }); + return error.TestExpectedEqual; + } + var i: usize = 0; + while (i < expected.len) : (i += 1) { + expectEqualDeep(expected[i], actual[i]) catch |e| { + std.debug.print("index {d} incorrect. expected {any}, found {any}\n", .{ + i, expected[i], actual[i], + }); + return e; + }; + } + }, + } + }, + + .Array => |_| { + if (expected.len != actual.len) { + std.debug.print("Array len not the same, expected {d}, found {d}\n", .{ expected.len, actual.len }); + return error.TestExpectedEqual; + } + var i: usize = 0; + while (i < expected.len) : (i += 1) { + expectEqualDeep(expected[i], actual[i]) catch |e| { + std.debug.print("index {d} incorrect. expected {any}, found {any}\n", .{ + i, expected[i], actual[i], + }); + return e; + }; + } + }, + + .Vector => |info| { + if (info.len != @typeInfo(@TypeOf(actual)).Vector.len) { + std.debug.print("Vector len not the same, expected {d}, found {d}\n", .{ info.len, @typeInfo(@TypeOf(actual)).Vector.len }); + return error.TestExpectedEqual; + } + var i: usize = 0; + while (i < info.len) : (i += 1) { + expectEqualDeep(expected[i], actual[i]) catch |e| { + std.debug.print("index {d} incorrect. expected {any}, found {any}\n", .{ + i, expected[i], actual[i], + }); + return e; + }; + } + }, + + .Struct => |structType| { + inline for (structType.fields) |field| { + expectEqualDeep(@field(expected, field.name), @field(actual, field.name)) catch |e| { + std.debug.print("Field {s} incorrect. expected {any}, found {any}\n", .{ field.name, @field(expected, field.name), @field(actual, field.name) }); + return e; + }; + } + }, + + .Union => |union_info| { + if (union_info.tag_type == null) { + @compileError("Unable to compare untagged union values"); + } + + const Tag = std.meta.Tag(@TypeOf(expected)); + + const expectedTag = @as(Tag, expected); + const actualTag = @as(Tag, actual); + + try expectEqual(expectedTag, actualTag); + + // we only reach this loop if the tags are equal + switch (expected) { + inline else => |val, tag| { + try expectEqualDeep(val, @field(actual, @tagName(tag))); + }, + } + }, + + .Optional => { + if (expected) |expected_payload| { + if (actual) |actual_payload| { + try expectEqualDeep(expected_payload, actual_payload); + } else { + std.debug.print("expected {any}, found null\n", .{expected_payload}); + return error.TestExpectedEqual; + } + } else { + if (actual) |actual_payload| { + std.debug.print("expected null, found {any}\n", .{actual_payload}); + return error.TestExpectedEqual; + } + } + }, + + .ErrorUnion => { + if (expected) |expected_payload| { + if (actual) |actual_payload| { + try expectEqualDeep(expected_payload, actual_payload); + } else |actual_err| { + std.debug.print("expected {any}, found {any}\n", .{ expected_payload, actual_err }); + return error.TestExpectedEqual; + } + } else |expected_err| { + if (actual) |actual_payload| { + std.debug.print("expected {any}, found {any}\n", .{ expected_err, actual_payload }); + return error.TestExpectedEqual; + } else |actual_err| { + try expectEqualDeep(expected_err, actual_err); + } + } + }, + } +} + +test "expectEqualDeep primitive type" { + try expectEqualDeep(1, 1); + try expectEqualDeep(true, true); + try expectEqualDeep(1.5, 1.5); + try expectEqualDeep(u8, u8); + try expectEqualDeep(error.Bad, error.Bad); + + // optional + { + const foo: ?u32 = 1; + const bar: ?u32 = 1; + try expectEqualDeep(foo, bar); + try expectEqualDeep(?u32, ?u32); + } + // function type + { + const fnType = struct { + fn foo() void { + unreachable; + } + }.foo; + try expectEqualDeep(fnType, fnType); + } +} + +test "expectEqualDeep pointer" { + const a = 1; + const b = 1; + try expectEqualDeep(&a, &b); +} + +test "expectEqualDeep composite type" { + try expectEqualDeep("abc", "abc"); + const s1: []const u8 = "abc"; + const s2 = "abcd"; + const s3: []const u8 = s2[0..3]; + try expectEqualDeep(s1, s3); + + const TestStruct = struct { s: []const u8 }; + try expectEqualDeep(TestStruct{ .s = "abc" }, TestStruct{ .s = "abc" }); + try expectEqualDeep([_][]const u8{ "a", "b", "c" }, [_][]const u8{ "a", "b", "c" }); + + // vector + try expectEqualDeep(@splat(4, @as(u32, 4)), @splat(4, @as(u32, 4))); + + // nested array + { + const a = [2][2]f32{ + [_]f32{ 1.0, 0.0 }, + [_]f32{ 0.0, 1.0 }, + }; + + const b = [2][2]f32{ + [_]f32{ 1.0, 0.0 }, + [_]f32{ 0.0, 1.0 }, + }; + + try expectEqualDeep(a, b); + try expectEqualDeep(&a, &b); + } +} + fn printIndicatorLine(source: []const u8, indicator_index: usize) void { const line_begin_index = if (std.mem.lastIndexOfScalar(u8, source[0..indicator_index], '\n')) |line_begin| line_begin + 1 -- cgit v1.2.3