diff options
| author | Andrew Kelley <superjoe30@gmail.com> | 2018-08-25 21:57:28 -0400 |
|---|---|---|
| committer | Andrew Kelley <superjoe30@gmail.com> | 2018-08-25 21:57:28 -0400 |
| commit | 7109035b78ee05302bbdaadc52013b430a030b69 (patch) | |
| tree | ae6d7202dc75f2c799f5fbcad72ccf8b02a954a4 /std/hash_map.zig | |
| parent | 6cf248ec0824c746fc796905144c8077ccab99cf (diff) | |
| parent | 526338b00fbe1cac19f64832176af3bdf2108a56 (diff) | |
| download | zig-7109035b78ee05302bbdaadc52013b430a030b69.tar.gz zig-7109035b78ee05302bbdaadc52013b430a030b69.zip | |
Merge remote-tracking branch 'origin/master' into llvm7
Diffstat (limited to 'std/hash_map.zig')
| -rw-r--r-- | std/hash_map.zig | 330 |
1 files changed, 273 insertions, 57 deletions
diff --git a/std/hash_map.zig b/std/hash_map.zig index cebd5272c0..9654d612a5 100644 --- a/std/hash_map.zig +++ b/std/hash_map.zig @@ -9,6 +9,10 @@ const builtin = @import("builtin"); const want_modification_safety = builtin.mode != builtin.Mode.ReleaseFast; const debug_u32 = if (want_modification_safety) u32 else void; +pub fn AutoHashMap(comptime K: type, comptime V: type) type { + return HashMap(K, V, getAutoHashFn(K), getAutoEqlFn(K)); +} + pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u32, comptime eql: fn (a: K, b: K) bool) type { return struct { entries: []Entry, @@ -20,13 +24,22 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3 const Self = this; - pub const Entry = struct { - used: bool, - distance_from_start_index: usize, + pub const KV = struct { key: K, value: V, }; + const Entry = struct { + used: bool, + distance_from_start_index: usize, + kv: KV, + }; + + pub const GetOrPutResult = struct { + kv: *KV, + found_existing: bool, + }; + pub const Iterator = struct { hm: *const Self, // how many items have we returned @@ -36,7 +49,7 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3 // used to detect concurrent modification initial_modification_count: debug_u32, - pub fn next(it: *Iterator) ?*Entry { + pub fn next(it: *Iterator) ?*KV { if (want_modification_safety) { assert(it.initial_modification_count == it.hm.modification_count); // concurrent modification } @@ -46,7 +59,7 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3 if (entry.used) { it.index += 1; it.count += 1; - return entry; + return &entry.kv; } } unreachable; // no next item @@ -71,7 +84,7 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3 }; } - pub fn deinit(hm: *const Self) void { + pub fn deinit(hm: Self) void { hm.allocator.free(hm.entries); } @@ -84,34 +97,65 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3 hm.incrementModificationCount(); } - pub fn count(hm: *const Self) usize { - return hm.size; + pub fn count(self: Self) usize { + return self.size; } - /// Returns the value that was already there. - pub fn put(hm: *Self, key: K, value: *const V) !?V { - if (hm.entries.len == 0) { - try hm.initCapacity(16); + /// If key exists this function cannot fail. + /// If there is an existing item with `key`, then the result + /// kv pointer points to it, and found_existing is true. + /// Otherwise, puts a new item with undefined value, and + /// the kv pointer points to it. Caller should then initialize + /// the data. + pub fn getOrPut(self: *Self, key: K) !GetOrPutResult { + // TODO this implementation can be improved - we should only + // have to hash once and find the entry once. + if (self.get(key)) |kv| { + return GetOrPutResult{ + .kv = kv, + .found_existing = true, + }; + } + self.incrementModificationCount(); + try self.ensureCapacity(); + const put_result = self.internalPut(key); + assert(put_result.old_kv == null); + return GetOrPutResult{ + .kv = &put_result.new_entry.kv, + .found_existing = false, + }; + } + + fn ensureCapacity(self: *Self) !void { + if (self.entries.len == 0) { + return self.initCapacity(16); } - hm.incrementModificationCount(); // if we get too full (60%), double the capacity - if (hm.size * 5 >= hm.entries.len * 3) { - const old_entries = hm.entries; - try hm.initCapacity(hm.entries.len * 2); + if (self.size * 5 >= self.entries.len * 3) { + const old_entries = self.entries; + try self.initCapacity(self.entries.len * 2); // dump all of the old elements into the new table for (old_entries) |*old_entry| { if (old_entry.used) { - _ = hm.internalPut(old_entry.key, old_entry.value); + self.internalPut(old_entry.kv.key).new_entry.kv.value = old_entry.kv.value; } } - hm.allocator.free(old_entries); + self.allocator.free(old_entries); } + } + + /// Returns the kv pair that was already there. + pub fn put(self: *Self, key: K, value: V) !?KV { + self.incrementModificationCount(); + try self.ensureCapacity(); - return hm.internalPut(key, value); + const put_result = self.internalPut(key); + put_result.new_entry.kv.value = value; + return put_result.old_kv; } - pub fn get(hm: *const Self, key: K) ?*Entry { + pub fn get(hm: *const Self, key: K) ?*KV { if (hm.entries.len == 0) { return null; } @@ -122,7 +166,7 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3 return hm.get(key) != null; } - pub fn remove(hm: *Self, key: K) ?*Entry { + pub fn remove(hm: *Self, key: K) ?*KV { if (hm.entries.len == 0) return null; hm.incrementModificationCount(); const start_index = hm.keyToIndex(key); @@ -134,7 +178,7 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3 if (!entry.used) return null; - if (!eql(entry.key, key)) continue; + if (!eql(entry.kv.key, key)) continue; while (roll_over < hm.entries.len) : (roll_over += 1) { const next_index = (start_index + roll_over + 1) % hm.entries.len; @@ -142,7 +186,7 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3 if (!next_entry.used or next_entry.distance_from_start_index == 0) { entry.used = false; hm.size -= 1; - return entry; + return &entry.kv; } entry.* = next_entry.*; entry.distance_from_start_index -= 1; @@ -163,6 +207,16 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3 }; } + pub fn clone(self: Self) !Self { + var other = Self.init(self.allocator); + try other.initCapacity(self.entries.len); + var it = self.iterator(); + while (it.next()) |entry| { + assert((try other.put(entry.key, entry.value)) == null); + } + return other; + } + fn initCapacity(hm: *Self, capacity: usize) !void { hm.entries = try hm.allocator.alloc(Entry, capacity); hm.size = 0; @@ -178,60 +232,81 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3 } } - /// Returns the value that was already there. - fn internalPut(hm: *Self, orig_key: K, orig_value: *const V) ?V { + const InternalPutResult = struct { + new_entry: *Entry, + old_kv: ?KV, + }; + + /// Returns a pointer to the new entry. + /// Asserts that there is enough space for the new item. + fn internalPut(self: *Self, orig_key: K) InternalPutResult { var key = orig_key; - var value = orig_value.*; - const start_index = hm.keyToIndex(key); + var value: V = undefined; + const start_index = self.keyToIndex(key); var roll_over: usize = 0; var distance_from_start_index: usize = 0; - while (roll_over < hm.entries.len) : ({ + var got_result_entry = false; + var result = InternalPutResult{ + .new_entry = undefined, + .old_kv = null, + }; + while (roll_over < self.entries.len) : ({ roll_over += 1; distance_from_start_index += 1; }) { - const index = (start_index + roll_over) % hm.entries.len; - const entry = &hm.entries[index]; + const index = (start_index + roll_over) % self.entries.len; + const entry = &self.entries[index]; - if (entry.used and !eql(entry.key, key)) { + if (entry.used and !eql(entry.kv.key, key)) { if (entry.distance_from_start_index < distance_from_start_index) { // robin hood to the rescue const tmp = entry.*; - hm.max_distance_from_start_index = math.max(hm.max_distance_from_start_index, distance_from_start_index); + self.max_distance_from_start_index = math.max(self.max_distance_from_start_index, distance_from_start_index); + if (!got_result_entry) { + got_result_entry = true; + result.new_entry = entry; + } entry.* = Entry{ .used = true, .distance_from_start_index = distance_from_start_index, - .key = key, - .value = value, + .kv = KV{ + .key = key, + .value = value, + }, }; - key = tmp.key; - value = tmp.value; + key = tmp.kv.key; + value = tmp.kv.value; distance_from_start_index = tmp.distance_from_start_index; } continue; } - var result: ?V = null; if (entry.used) { - result = entry.value; + result.old_kv = entry.kv; } else { // adding an entry. otherwise overwriting old value with // same key - hm.size += 1; + self.size += 1; } - hm.max_distance_from_start_index = math.max(distance_from_start_index, hm.max_distance_from_start_index); + self.max_distance_from_start_index = math.max(distance_from_start_index, self.max_distance_from_start_index); + if (!got_result_entry) { + result.new_entry = entry; + } entry.* = Entry{ .used = true, .distance_from_start_index = distance_from_start_index, - .key = key, - .value = value, + .kv = KV{ + .key = key, + .value = value, + }, }; return result; } unreachable; // put into a full map } - fn internalGet(hm: *const Self, key: K) ?*Entry { + fn internalGet(hm: Self, key: K) ?*KV { const start_index = hm.keyToIndex(key); { var roll_over: usize = 0; @@ -240,13 +315,13 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3 const entry = &hm.entries[index]; if (!entry.used) return null; - if (eql(entry.key, key)) return entry; + if (eql(entry.kv.key, key)) return &entry.kv; } } return null; } - fn keyToIndex(hm: *const Self, key: K) usize { + fn keyToIndex(hm: Self, key: K) usize { return usize(hash(key)) % hm.entries.len; } }; @@ -256,7 +331,7 @@ test "basic hash map usage" { var direct_allocator = std.heap.DirectAllocator.init(); defer direct_allocator.deinit(); - var map = HashMap(i32, i32, hash_i32, eql_i32).init(&direct_allocator.allocator); + var map = AutoHashMap(i32, i32).init(&direct_allocator.allocator); defer map.deinit(); assert((try map.put(1, 11)) == null); @@ -265,8 +340,19 @@ test "basic hash map usage" { assert((try map.put(4, 44)) == null); assert((try map.put(5, 55)) == null); - assert((try map.put(5, 66)).? == 55); - assert((try map.put(5, 55)).? == 66); + assert((try map.put(5, 66)).?.value == 55); + assert((try map.put(5, 55)).?.value == 66); + + const gop1 = try map.getOrPut(5); + assert(gop1.found_existing == true); + assert(gop1.kv.value == 55); + gop1.kv.value = 77; + assert(map.get(5).?.value == 77); + + const gop2 = try map.getOrPut(99); + assert(gop2.found_existing == false); + gop2.kv.value = 42; + assert(map.get(99).?.value == 42); assert(map.contains(2)); assert(map.get(2).?.value == 22); @@ -279,7 +365,7 @@ test "iterator hash map" { var direct_allocator = std.heap.DirectAllocator.init(); defer direct_allocator.deinit(); - var reset_map = HashMap(i32, i32, hash_i32, eql_i32).init(&direct_allocator.allocator); + var reset_map = AutoHashMap(i32, i32).init(&direct_allocator.allocator); defer reset_map.deinit(); assert((try reset_map.put(1, 11)) == null); @@ -287,14 +373,14 @@ test "iterator hash map" { assert((try reset_map.put(3, 33)) == null); var keys = []i32{ - 1, - 2, 3, + 2, + 1, }; var values = []i32{ - 11, - 22, 33, + 22, + 11, }; var it = reset_map.iterator(); @@ -322,10 +408,140 @@ test "iterator hash map" { assert(entry.value == values[0]); } -fn hash_i32(x: i32) u32 { - return @bitCast(u32, x); +pub fn getHashPtrAddrFn(comptime K: type) (fn (K) u32) { + return struct { + fn hash(key: K) u32 { + return getAutoHashFn(usize)(@ptrToInt(key)); + } + }.hash; } -fn eql_i32(a: i32, b: i32) bool { - return a == b; +pub fn getTrivialEqlFn(comptime K: type) (fn (K, K) bool) { + return struct { + fn eql(a: K, b: K) bool { + return a == b; + } + }.eql; +} + +pub fn getAutoHashFn(comptime K: type) (fn (K) u32) { + return struct { + fn hash(key: K) u32 { + comptime var rng = comptime std.rand.DefaultPrng.init(0); + return autoHash(key, &rng.random, u32); + } + }.hash; +} + +pub fn getAutoEqlFn(comptime K: type) (fn (K, K) bool) { + return struct { + fn eql(a: K, b: K) bool { + return autoEql(a, b); + } + }.eql; +} + +// TODO improve these hash functions +pub fn autoHash(key: var, comptime rng: *std.rand.Random, comptime HashInt: type) HashInt { + switch (@typeInfo(@typeOf(key))) { + builtin.TypeId.NoReturn, + builtin.TypeId.Opaque, + builtin.TypeId.Undefined, + builtin.TypeId.ArgTuple, + => @compileError("cannot hash this type"), + + builtin.TypeId.Void, + builtin.TypeId.Null, + => return 0, + + builtin.TypeId.Int => |info| { + const unsigned_x = @bitCast(@IntType(false, info.bits), key); + if (info.bits <= HashInt.bit_count) { + return HashInt(unsigned_x) ^ comptime rng.scalar(HashInt); + } else { + return @truncate(HashInt, unsigned_x ^ comptime rng.scalar(@typeOf(unsigned_x))); + } + }, + + builtin.TypeId.Float => |info| { + return autoHash(@bitCast(@IntType(false, info.bits), key), rng); + }, + builtin.TypeId.Bool => return autoHash(@boolToInt(key), rng), + builtin.TypeId.Enum => return autoHash(@enumToInt(key), rng), + builtin.TypeId.ErrorSet => return autoHash(@errorToInt(key), rng), + builtin.TypeId.Promise, builtin.TypeId.Fn => return autoHash(@ptrToInt(key), rng), + + builtin.TypeId.Namespace, + builtin.TypeId.Block, + builtin.TypeId.BoundFn, + builtin.TypeId.ComptimeFloat, + builtin.TypeId.ComptimeInt, + builtin.TypeId.Type, + => return 0, + + builtin.TypeId.Pointer => |info| switch (info.size) { + builtin.TypeInfo.Pointer.Size.One => @compileError("TODO auto hash for single item pointers"), + builtin.TypeInfo.Pointer.Size.Many => @compileError("TODO auto hash for many item pointers"), + builtin.TypeInfo.Pointer.Size.Slice => { + const interval = std.math.max(1, key.len / 256); + var i: usize = 0; + var h = comptime rng.scalar(HashInt); + while (i < key.len) : (i += interval) { + h ^= autoHash(key[i], rng, HashInt); + } + return h; + }, + }, + + builtin.TypeId.Optional => @compileError("TODO auto hash for optionals"), + builtin.TypeId.Array => @compileError("TODO auto hash for arrays"), + builtin.TypeId.Struct => @compileError("TODO auto hash for structs"), + builtin.TypeId.Union => @compileError("TODO auto hash for unions"), + builtin.TypeId.ErrorUnion => @compileError("TODO auto hash for unions"), + } +} + +pub fn autoEql(a: var, b: @typeOf(a)) bool { + switch (@typeInfo(@typeOf(a))) { + builtin.TypeId.NoReturn, + builtin.TypeId.Opaque, + builtin.TypeId.Undefined, + builtin.TypeId.ArgTuple, + => @compileError("cannot test equality of this type"), + builtin.TypeId.Void, + builtin.TypeId.Null, + => return true, + builtin.TypeId.Bool, + builtin.TypeId.Int, + builtin.TypeId.Float, + builtin.TypeId.ComptimeFloat, + builtin.TypeId.ComptimeInt, + builtin.TypeId.Namespace, + builtin.TypeId.Block, + builtin.TypeId.Promise, + builtin.TypeId.Enum, + builtin.TypeId.BoundFn, + builtin.TypeId.Fn, + builtin.TypeId.ErrorSet, + builtin.TypeId.Type, + => return a == b, + + builtin.TypeId.Pointer => |info| switch (info.size) { + builtin.TypeInfo.Pointer.Size.One => @compileError("TODO auto eql for single item pointers"), + builtin.TypeInfo.Pointer.Size.Many => @compileError("TODO auto eql for many item pointers"), + builtin.TypeInfo.Pointer.Size.Slice => { + if (a.len != b.len) return false; + for (a) |a_item, i| { + if (!autoEql(a_item, b[i])) return false; + } + return true; + }, + }, + + builtin.TypeId.Optional => @compileError("TODO auto eql for optionals"), + builtin.TypeId.Array => @compileError("TODO auto eql for arrays"), + builtin.TypeId.Struct => @compileError("TODO auto eql for structs"), + builtin.TypeId.Union => @compileError("TODO auto eql for unions"), + builtin.TypeId.ErrorUnion => @compileError("TODO auto eql for unions"), + } } |
