diff options
| author | Travis Staloch <1562827+travisstaloch@users.noreply.github.com> | 2024-04-20 23:14:39 -0700 |
|---|---|---|
| committer | Andrew Kelley <andrew@ziglang.org> | 2024-04-22 15:31:41 -0700 |
| commit | 8af59d1f98266bd70b3afb44d196bbd151cedf22 (patch) | |
| tree | 64b0c48f2b2d222629acbd5698c1f5310fdc708f /lib/std/static_string_map.zig | |
| parent | fefdbca6e62145a20777789961262f15c2bf6cbe (diff) | |
| download | zig-8af59d1f98266bd70b3afb44d196bbd151cedf22.tar.gz zig-8af59d1f98266bd70b3afb44d196bbd151cedf22.zip | |
ComptimeStringMap: return a regular struct and optimize
this patch renames ComptimeStringMap to StaticStringMap, makes it
accept only a single type parameter, and return a known struct type
instead of an anonymous struct. initial motivation for these changes
was to reduce the 'very long type names' issue described here
https://github.com/ziglang/zig/pull/19682.
this breaks the previous API. users will now need to write:
`const map = std.StaticStringMap(T).initComptime(kvs_list);`
* move `kvs_list` param from type param to an `initComptime()` param
* new public methods
* `keys()`, `values()` helpers
* `init(allocator)`, `deinit(allocator)` for runtime data
* `getLongestPrefix(str)`, `getLongestPrefixIndex(str)` - i'm not sure
these belong but have left in for now incase they are deemed useful
* performance notes:
* i posted some benchmarking results here:
https://github.com/travisstaloch/comptime-string-map-revised/issues/1
* i noticed a speedup reducing the size of the struct from 48 to 32
bytes and thus use u32s instead of usize for all length fields
* i noticed speedup storing KVs as a struct of arrays
* latest benchmark shows these wall_time improvements for
debug/safe/small/fast builds: -6.6% / -10.2% / -19.1% / -8.9%. full
output in link above.
Diffstat (limited to 'lib/std/static_string_map.zig')
| -rw-r--r-- | lib/std/static_string_map.zig | 502 |
1 files changed, 502 insertions, 0 deletions
diff --git a/lib/std/static_string_map.zig b/lib/std/static_string_map.zig new file mode 100644 index 0000000000..1136bde587 --- /dev/null +++ b/lib/std/static_string_map.zig @@ -0,0 +1,502 @@ +const std = @import("std.zig"); +const mem = std.mem; + +/// Static string map optimized for small sets of disparate string keys. +/// Works by separating the keys by length at initialization and only checking +/// strings of equal length at runtime. +pub fn StaticStringMap(comptime V: type) type { + return StaticStringMapWithEql(V, defaultEql); +} + +/// Like `std.mem.eql`, but takes advantage of the fact that the lengths +/// of `a` and `b` are known to be equal. +pub fn defaultEql(a: []const u8, b: []const u8) bool { + if (a.ptr == b.ptr) return true; + for (a, b) |a_elem, b_elem| { + if (a_elem != b_elem) return false; + } + return true; +} + +/// Like `std.ascii.eqlIgnoreCase` but takes advantage of the fact that +/// the lengths of `a` and `b` are known to be equal. +pub fn eqlAsciiIgnoreCase(a: []const u8, b: []const u8) bool { + if (a.ptr == b.ptr) return true; + for (a, b) |a_c, b_c| { + if (std.ascii.toLower(a_c) != std.ascii.toLower(b_c)) return false; + } + return true; +} + +/// StaticStringMap, but accepts an equality function (`eql`). +/// The `eql` function is only called to determine the equality +/// of equal length strings. Any strings that are not equal length +/// are never compared using the `eql` function. +pub fn StaticStringMapWithEql( + comptime V: type, + comptime eql: fn (a: []const u8, b: []const u8) bool, +) type { + return struct { + kvs: *const KVs = &empty_kvs, + len_indexes: [*]const u32 = &empty_len_indexes, + len_indexes_len: u32 = 0, + min_len: u32 = std.math.maxInt(u32), + max_len: u32 = 0, + + pub const KV = struct { + key: []const u8, + value: V, + }; + + const Self = @This(); + const KVs = struct { + keys: [*]const []const u8, + values: [*]const V, + len: u32, + }; + const empty_kvs = KVs{ + .keys = &empty_keys, + .values = &empty_vals, + .len = 0, + }; + const empty_len_indexes = [0]u32{}; + const empty_keys = [0][]const u8{}; + const empty_vals = [0]V{}; + + /// Returns a map backed by static, comptime allocated memory. + /// + /// `kvs_list` must be either a list of `struct { []const u8, V }` + /// (key-value pair) tuples, or a list of `struct { []const u8 }` + /// (only keys) tuples if `V` is `void`. + pub inline fn initComptime(comptime kvs_list: anytype) Self { + comptime { + @setEvalBranchQuota(30 * kvs_list.len); + var self = Self{}; + if (kvs_list.len == 0) + return self; + + var sorted_keys: [kvs_list.len][]const u8 = undefined; + var sorted_vals: [kvs_list.len]V = undefined; + + self.initSortedKVs(kvs_list, &sorted_keys, &sorted_vals); + const final_keys = sorted_keys; + const final_vals = sorted_vals; + self.kvs = &.{ + .keys = &final_keys, + .values = &final_vals, + .len = @intCast(kvs_list.len), + }; + + var len_indexes: [self.max_len + 1]u32 = undefined; + self.initLenIndexes(&len_indexes); + const final_len_indexes = len_indexes; + self.len_indexes = &final_len_indexes; + self.len_indexes_len = @intCast(len_indexes.len); + return self; + } + } + + /// Returns a map backed by memory allocated with `allocator`. + /// + /// Handles `kvs_list` the same way as `initComptime()`. + pub fn init(kvs_list: anytype, allocator: mem.Allocator) !Self { + var self = Self{}; + if (kvs_list.len == 0) + return self; + + const sorted_keys = try allocator.alloc([]const u8, kvs_list.len); + errdefer allocator.free(sorted_keys); + const sorted_vals = try allocator.alloc(V, kvs_list.len); + errdefer allocator.free(sorted_vals); + const kvs = try allocator.create(KVs); + errdefer allocator.destroy(kvs); + + self.initSortedKVs(kvs_list, sorted_keys, sorted_vals); + kvs.* = .{ + .keys = sorted_keys.ptr, + .values = sorted_vals.ptr, + .len = kvs_list.len, + }; + self.kvs = kvs; + + const len_indexes = try allocator.alloc(u32, self.max_len + 1); + self.initLenIndexes(len_indexes); + self.len_indexes = len_indexes.ptr; + self.len_indexes_len = @intCast(len_indexes.len); + return self; + } + + /// this method should only be used with init() and not with initComptime(). + pub fn deinit(self: Self, allocator: mem.Allocator) void { + allocator.free(self.len_indexes[0..self.len_indexes_len]); + allocator.free(self.kvs.keys[0..self.kvs.len]); + allocator.free(self.kvs.values[0..self.kvs.len]); + allocator.destroy(self.kvs); + } + + const SortContext = struct { + keys: [][]const u8, + vals: []V, + + pub fn lessThan(ctx: @This(), a: usize, b: usize) bool { + return ctx.keys[a].len < ctx.keys[b].len; + } + + pub fn swap(ctx: @This(), a: usize, b: usize) void { + std.mem.swap([]const u8, &ctx.keys[a], &ctx.keys[b]); + std.mem.swap(V, &ctx.vals[a], &ctx.vals[b]); + } + }; + + fn initSortedKVs( + self: *Self, + kvs_list: anytype, + sorted_keys: [][]const u8, + sorted_vals: []V, + ) void { + for (kvs_list, 0..) |kv, i| { + sorted_keys[i] = kv.@"0"; + sorted_vals[i] = if (V == void) {} else kv.@"1"; + self.min_len = @intCast(@min(self.min_len, kv.@"0".len)); + self.max_len = @intCast(@max(self.max_len, kv.@"0".len)); + } + mem.sortUnstableContext(0, sorted_keys.len, SortContext{ + .keys = sorted_keys, + .vals = sorted_vals, + }); + } + + fn initLenIndexes(self: Self, len_indexes: []u32) void { + var len: usize = 0; + var i: u32 = 0; + while (len <= self.max_len) : (len += 1) { + // find the first keyword len == len + while (len > self.kvs.keys[i].len) { + i += 1; + } + len_indexes[len] = i; + } + } + + /// Checks if the map has a value for the key. + pub fn has(self: Self, str: []const u8) bool { + return self.get(str) != null; + } + + /// Returns the value for the key if any, else null. + pub fn get(self: Self, str: []const u8) ?V { + if (self.kvs.len == 0) + return null; + + return self.kvs.values[self.getIndex(str) orelse return null]; + } + + pub fn getIndex(self: Self, str: []const u8) ?usize { + const kvs = self.kvs.*; + if (kvs.len == 0) + return null; + + if (str.len < self.min_len or str.len > self.max_len) + return null; + + var i = self.len_indexes[str.len]; + while (true) { + const key = kvs.keys[i]; + if (key.len != str.len) + return null; + if (eql(key, str)) + return i; + i += 1; + if (i >= kvs.len) + return null; + } + } + + /// Returns the longest key, value pair where key is a prefix of `str` + /// else null. + pub fn getLongestPrefix(self: Self, str: []const u8) ?KV { + if (self.kvs.len == 0) + return null; + const i = self.getLongestPrefixIndex(str) orelse return null; + const kvs = self.kvs.*; + return .{ + .key = kvs.keys[i], + .value = kvs.values[i], + }; + } + + pub fn getLongestPrefixIndex(self: Self, str: []const u8) ?usize { + if (self.kvs.len == 0) + return null; + + if (str.len < self.min_len) + return null; + + var len = @min(self.max_len, str.len); + while (len >= self.min_len) : (len -= 1) { + if (self.getIndex(str[0..len])) |i| + return i; + } + return null; + } + + pub fn keys(self: Self) []const []const u8 { + const kvs = self.kvs.*; + return kvs.keys[0..kvs.len]; + } + + pub fn values(self: Self) []const V { + const kvs = self.kvs.*; + return kvs.values[0..kvs.len]; + } + }; +} + +const TestEnum = enum { A, B, C, D, E }; +const TestMap = StaticStringMap(TestEnum); +const TestKV = struct { []const u8, TestEnum }; +const TestMapVoid = StaticStringMap(void); +const TestKVVoid = struct { []const u8 }; +const TestMapWithEql = StaticStringMapWithEql(TestEnum, eqlAsciiIgnoreCase); +const testing = std.testing; +const test_alloc = testing.allocator; + +test "list literal of list literals" { + const slice = [_]TestKV{ + .{ "these", .D }, + .{ "have", .A }, + .{ "nothing", .B }, + .{ "incommon", .C }, + .{ "samelen", .E }, + }; + const map = TestMap.initComptime(slice); + try testMap(map); + // Default comparison is case sensitive + try testing.expect(null == map.get("NOTHING")); + + // runtime init(), deinit() + const map_rt = try TestMap.init(slice, test_alloc); + defer map_rt.deinit(test_alloc); + try testMap(map_rt); + // Default comparison is case sensitive + try testing.expect(null == map_rt.get("NOTHING")); +} + +test "array of structs" { + const slice = [_]TestKV{ + .{ "these", .D }, + .{ "have", .A }, + .{ "nothing", .B }, + .{ "incommon", .C }, + .{ "samelen", .E }, + }; + + try testMap(TestMap.initComptime(slice)); +} + +test "slice of structs" { + const slice = [_]TestKV{ + .{ "these", .D }, + .{ "have", .A }, + .{ "nothing", .B }, + .{ "incommon", .C }, + .{ "samelen", .E }, + }; + + try testMap(TestMap.initComptime(slice)); +} + +fn testMap(map: anytype) !void { + try testing.expectEqual(TestEnum.A, map.get("have").?); + try testing.expectEqual(TestEnum.B, map.get("nothing").?); + try testing.expect(null == map.get("missing")); + try testing.expectEqual(TestEnum.D, map.get("these").?); + try testing.expectEqual(TestEnum.E, map.get("samelen").?); + + try testing.expect(!map.has("missing")); + try testing.expect(map.has("these")); + + try testing.expect(null == map.get("")); + try testing.expect(null == map.get("averylongstringthathasnomatches")); +} + +test "void value type, slice of structs" { + const slice = [_]TestKVVoid{ + .{"these"}, + .{"have"}, + .{"nothing"}, + .{"incommon"}, + .{"samelen"}, + }; + const map = TestMapVoid.initComptime(slice); + try testSet(map); + // Default comparison is case sensitive + try testing.expect(null == map.get("NOTHING")); +} + +test "void value type, list literal of list literals" { + const slice = [_]TestKVVoid{ + .{"these"}, + .{"have"}, + .{"nothing"}, + .{"incommon"}, + .{"samelen"}, + }; + + try testSet(TestMapVoid.initComptime(slice)); +} + +fn testSet(map: TestMapVoid) !void { + try testing.expectEqual({}, map.get("have").?); + try testing.expectEqual({}, map.get("nothing").?); + try testing.expect(null == map.get("missing")); + try testing.expectEqual({}, map.get("these").?); + try testing.expectEqual({}, map.get("samelen").?); + + try testing.expect(!map.has("missing")); + try testing.expect(map.has("these")); + + try testing.expect(null == map.get("")); + try testing.expect(null == map.get("averylongstringthathasnomatches")); +} + +fn testStaticStringMapWithEql(map: TestMapWithEql) !void { + try testMap(map); + try testing.expectEqual(TestEnum.A, map.get("HAVE").?); + try testing.expectEqual(TestEnum.E, map.get("SameLen").?); + try testing.expect(null == map.get("SameLength")); + try testing.expect(map.has("ThESe")); +} + +test "StaticStringMapWithEql" { + const slice = [_]TestKV{ + .{ "these", .D }, + .{ "have", .A }, + .{ "nothing", .B }, + .{ "incommon", .C }, + .{ "samelen", .E }, + }; + + try testStaticStringMapWithEql(TestMapWithEql.initComptime(slice)); +} + +test "empty" { + const m1 = StaticStringMap(usize).initComptime(.{}); + try testing.expect(null == m1.get("anything")); + + const m2 = StaticStringMapWithEql(usize, eqlAsciiIgnoreCase).initComptime(.{}); + try testing.expect(null == m2.get("anything")); + + const m3 = try StaticStringMap(usize).init(.{}, test_alloc); + try testing.expect(null == m3.get("anything")); + + const m4 = try StaticStringMapWithEql(usize, eqlAsciiIgnoreCase).init(.{}, test_alloc); + try testing.expect(null == m4.get("anything")); +} + +test "redundant entries" { + const slice = [_]TestKV{ + .{ "redundant", .D }, + .{ "theNeedle", .A }, + .{ "redundant", .B }, + .{ "re" ++ "dundant", .C }, + .{ "redun" ++ "dant", .E }, + }; + const map = TestMap.initComptime(slice); + + // No promises about which one you get: + try testing.expect(null != map.get("redundant")); + + // Default map is not case sensitive: + try testing.expect(null == map.get("REDUNDANT")); + + try testing.expectEqual(TestEnum.A, map.get("theNeedle").?); +} + +test "redundant insensitive" { + const slice = [_]TestKV{ + .{ "redundant", .D }, + .{ "theNeedle", .A }, + .{ "redundanT", .B }, + .{ "RE" ++ "dundant", .C }, + .{ "redun" ++ "DANT", .E }, + }; + + const map = TestMapWithEql.initComptime(slice); + + // No promises about which result you'll get ... + try testing.expect(null != map.get("REDUNDANT")); + try testing.expect(null != map.get("ReDuNdAnT")); + try testing.expectEqual(TestEnum.A, map.get("theNeedle").?); +} + +test "comptime-only value" { + const map = StaticStringMap(type).initComptime(.{ + .{ "a", struct { + pub const foo = 1; + } }, + .{ "b", struct { + pub const foo = 2; + } }, + .{ "c", struct { + pub const foo = 3; + } }, + }); + + try testing.expect(map.get("a").?.foo == 1); + try testing.expect(map.get("b").?.foo == 2); + try testing.expect(map.get("c").?.foo == 3); + try testing.expect(map.get("d") == null); +} + +test "getLongestPrefix" { + const slice = [_]TestKV{ + .{ "a", .A }, + .{ "aa", .B }, + .{ "aaa", .C }, + .{ "aaaa", .D }, + }; + + const map = TestMap.initComptime(slice); + + try testing.expectEqual(null, map.getLongestPrefix("")); + try testing.expectEqual(null, map.getLongestPrefix("bar")); + try testing.expectEqualStrings("aaaa", map.getLongestPrefix("aaaabar").?.key); + try testing.expectEqualStrings("aaa", map.getLongestPrefix("aaabar").?.key); +} + +test "getLongestPrefix2" { + const slice = [_]struct { []const u8, u8 }{ + .{ "one", 1 }, + .{ "two", 2 }, + .{ "three", 3 }, + .{ "four", 4 }, + .{ "five", 5 }, + .{ "six", 6 }, + .{ "seven", 7 }, + .{ "eight", 8 }, + .{ "nine", 9 }, + }; + const map = StaticStringMap(u8).initComptime(slice); + + try testing.expectEqual(1, map.get("one")); + try testing.expectEqual(null, map.get("o")); + try testing.expectEqual(null, map.get("onexxx")); + try testing.expectEqual(9, map.get("nine")); + try testing.expectEqual(null, map.get("n")); + try testing.expectEqual(null, map.get("ninexxx")); + try testing.expectEqual(null, map.get("xxx")); + + try testing.expectEqual(1, map.getLongestPrefix("one").?.value); + try testing.expectEqual(1, map.getLongestPrefix("onexxx").?.value); + try testing.expectEqual(null, map.getLongestPrefix("o")); + try testing.expectEqual(null, map.getLongestPrefix("on")); + try testing.expectEqual(9, map.getLongestPrefix("nine").?.value); + try testing.expectEqual(9, map.getLongestPrefix("ninexxx").?.value); + try testing.expectEqual(null, map.getLongestPrefix("n")); + try testing.expectEqual(null, map.getLongestPrefix("xxx")); +} + +test "long kvs_list doesn't exceed @setEvalBranchQuota" { + _ = TestMapVoid.initComptime([1]TestKVVoid{.{"x"}} ** 1_000); +} |
