aboutsummaryrefslogtreecommitdiff
path: root/lib/std/hash_map.zig
diff options
context:
space:
mode:
authorAndrew Kelley <andrew@ziglang.org>2019-09-26 01:54:45 -0400
committerGitHub <noreply@github.com>2019-09-26 01:54:45 -0400
commit68bb3945708c43109c48bda3664176307d45b62c (patch)
treeafb9731e10cef9d192560b52cd9ae2cf179775c4 /lib/std/hash_map.zig
parent6128bc728d1e1024a178c16c2149f5b1a167a013 (diff)
parent4637e8f9699af9c3c6cf4df50ef5bb67c7a318a4 (diff)
downloadzig-68bb3945708c43109c48bda3664176307d45b62c.tar.gz
zig-68bb3945708c43109c48bda3664176307d45b62c.zip
Merge pull request #3315 from ziglang/mv-std-lib
Move std/ to lib/std/
Diffstat (limited to 'lib/std/hash_map.zig')
-rw-r--r--lib/std/hash_map.zig552
1 files changed, 552 insertions, 0 deletions
diff --git a/lib/std/hash_map.zig b/lib/std/hash_map.zig
new file mode 100644
index 0000000000..4ffe88067b
--- /dev/null
+++ b/lib/std/hash_map.zig
@@ -0,0 +1,552 @@
+const std = @import("std.zig");
+const debug = std.debug;
+const assert = debug.assert;
+const testing = std.testing;
+const math = std.math;
+const mem = std.mem;
+const meta = std.meta;
+const autoHash = std.hash.autoHash;
+const Wyhash = std.hash.Wyhash;
+const Allocator = mem.Allocator;
+const builtin = @import("builtin");
+
+const want_modification_safety = builtin.mode != builtin.Mode.ReleaseFast;
+const debug_u32 = if (want_modification_safety) u32 else void;
+
+pub fn AutoHashMap(comptime K: type, comptime V: type) type {
+ return HashMap(K, V, getAutoHashFn(K), getAutoEqlFn(K));
+}
+
+/// Builtin hashmap for strings as keys.
+pub fn StringHashMap(comptime V: type) type {
+ return HashMap([]const u8, V, hashString, eqlString);
+}
+
+pub fn eqlString(a: []const u8, b: []const u8) bool {
+ return mem.eql(u8, a, b);
+}
+
+pub fn hashString(s: []const u8) u32 {
+ return @truncate(u32, std.hash.Wyhash.hash(0, s));
+}
+
+pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u32, comptime eql: fn (a: K, b: K) bool) type {
+ return struct {
+ entries: []Entry,
+ size: usize,
+ max_distance_from_start_index: usize,
+ allocator: *Allocator,
+ // this is used to detect bugs where a hashtable is edited while an iterator is running.
+ modification_count: debug_u32,
+
+ const Self = @This();
+
+ pub const KV = struct {
+ key: K,
+ value: V,
+ };
+
+ const Entry = struct {
+ used: bool,
+ distance_from_start_index: usize,
+ kv: KV,
+ };
+
+ pub const GetOrPutResult = struct {
+ kv: *KV,
+ found_existing: bool,
+ };
+
+ pub const Iterator = struct {
+ hm: *const Self,
+ // how many items have we returned
+ count: usize,
+ // iterator through the entry array
+ index: usize,
+ // used to detect concurrent modification
+ initial_modification_count: debug_u32,
+
+ pub fn next(it: *Iterator) ?*KV {
+ if (want_modification_safety) {
+ assert(it.initial_modification_count == it.hm.modification_count); // concurrent modification
+ }
+ if (it.count >= it.hm.size) return null;
+ while (it.index < it.hm.entries.len) : (it.index += 1) {
+ const entry = &it.hm.entries[it.index];
+ if (entry.used) {
+ it.index += 1;
+ it.count += 1;
+ return &entry.kv;
+ }
+ }
+ unreachable; // no next item
+ }
+
+ // Reset the iterator to the initial index
+ pub fn reset(it: *Iterator) void {
+ it.count = 0;
+ it.index = 0;
+ // Resetting the modification count too
+ it.initial_modification_count = it.hm.modification_count;
+ }
+ };
+
+ pub fn init(allocator: *Allocator) Self {
+ return Self{
+ .entries = [_]Entry{},
+ .allocator = allocator,
+ .size = 0,
+ .max_distance_from_start_index = 0,
+ .modification_count = if (want_modification_safety) 0 else {},
+ };
+ }
+
+ pub fn deinit(hm: Self) void {
+ hm.allocator.free(hm.entries);
+ }
+
+ pub fn clear(hm: *Self) void {
+ for (hm.entries) |*entry| {
+ entry.used = false;
+ }
+ hm.size = 0;
+ hm.max_distance_from_start_index = 0;
+ hm.incrementModificationCount();
+ }
+
+ pub fn count(self: Self) usize {
+ return self.size;
+ }
+
+ /// If key exists this function cannot fail.
+ /// If there is an existing item with `key`, then the result
+ /// kv pointer points to it, and found_existing is true.
+ /// Otherwise, puts a new item with undefined value, and
+ /// the kv pointer points to it. Caller should then initialize
+ /// the data.
+ pub fn getOrPut(self: *Self, key: K) !GetOrPutResult {
+ // TODO this implementation can be improved - we should only
+ // have to hash once and find the entry once.
+ if (self.get(key)) |kv| {
+ return GetOrPutResult{
+ .kv = kv,
+ .found_existing = true,
+ };
+ }
+ self.incrementModificationCount();
+ try self.autoCapacity();
+ const put_result = self.internalPut(key);
+ assert(put_result.old_kv == null);
+ return GetOrPutResult{
+ .kv = &put_result.new_entry.kv,
+ .found_existing = false,
+ };
+ }
+
+ pub fn getOrPutValue(self: *Self, key: K, value: V) !*KV {
+ const res = try self.getOrPut(key);
+ if (!res.found_existing)
+ res.kv.value = value;
+
+ return res.kv;
+ }
+
+ fn optimizedCapacity(expected_count: usize) usize {
+ // ensure that the hash map will be at most 60% full if
+ // expected_count items are put into it
+ var optimized_capacity = expected_count * 5 / 3;
+ // an overflow here would mean the amount of memory required would not
+ // be representable in the address space
+ return math.ceilPowerOfTwo(usize, optimized_capacity) catch unreachable;
+ }
+
+ /// Increases capacity so that the hash map will be at most
+ /// 60% full when expected_count items are put into it
+ pub fn ensureCapacity(self: *Self, expected_count: usize) !void {
+ const optimized_capacity = optimizedCapacity(expected_count);
+ return self.ensureCapacityExact(optimized_capacity);
+ }
+
+ /// Sets the capacity to the new capacity if the new
+ /// capacity is greater than the current capacity.
+ /// New capacity must be a power of two.
+ fn ensureCapacityExact(self: *Self, new_capacity: usize) !void {
+ // capacity must always be a power of two to allow for modulo
+ // optimization in the constrainIndex fn
+ assert(math.isPowerOfTwo(new_capacity));
+
+ if (new_capacity <= self.entries.len) {
+ return;
+ }
+
+ const old_entries = self.entries;
+ try self.initCapacity(new_capacity);
+ self.incrementModificationCount();
+ if (old_entries.len > 0) {
+ // dump all of the old elements into the new table
+ for (old_entries) |*old_entry| {
+ if (old_entry.used) {
+ self.internalPut(old_entry.kv.key).new_entry.kv.value = old_entry.kv.value;
+ }
+ }
+ self.allocator.free(old_entries);
+ }
+ }
+
+ /// Returns the kv pair that was already there.
+ pub fn put(self: *Self, key: K, value: V) !?KV {
+ try self.autoCapacity();
+ return putAssumeCapacity(self, key, value);
+ }
+
+ /// Calls put() and asserts that no kv pair is clobbered.
+ pub fn putNoClobber(self: *Self, key: K, value: V) !void {
+ assert((try self.put(key, value)) == null);
+ }
+
+ pub fn putAssumeCapacity(self: *Self, key: K, value: V) ?KV {
+ assert(self.count() < self.entries.len);
+ self.incrementModificationCount();
+
+ const put_result = self.internalPut(key);
+ put_result.new_entry.kv.value = value;
+ return put_result.old_kv;
+ }
+
+ pub fn get(hm: *const Self, key: K) ?*KV {
+ if (hm.entries.len == 0) {
+ return null;
+ }
+ return hm.internalGet(key);
+ }
+
+ pub fn getValue(hm: *const Self, key: K) ?V {
+ return if (hm.get(key)) |kv| kv.value else null;
+ }
+
+ pub fn contains(hm: *const Self, key: K) bool {
+ return hm.get(key) != null;
+ }
+
+ /// Returns any kv pair that was removed.
+ pub fn remove(hm: *Self, key: K) ?KV {
+ if (hm.entries.len == 0) return null;
+ hm.incrementModificationCount();
+ const start_index = hm.keyToIndex(key);
+ {
+ var roll_over: usize = 0;
+ while (roll_over <= hm.max_distance_from_start_index) : (roll_over += 1) {
+ const index = hm.constrainIndex(start_index + roll_over);
+ var entry = &hm.entries[index];
+
+ if (!entry.used) return null;
+
+ if (!eql(entry.kv.key, key)) continue;
+
+ const removed_kv = entry.kv;
+ while (roll_over < hm.entries.len) : (roll_over += 1) {
+ const next_index = hm.constrainIndex(start_index + roll_over + 1);
+ const next_entry = &hm.entries[next_index];
+ if (!next_entry.used or next_entry.distance_from_start_index == 0) {
+ entry.used = false;
+ hm.size -= 1;
+ return removed_kv;
+ }
+ entry.* = next_entry.*;
+ entry.distance_from_start_index -= 1;
+ entry = next_entry;
+ }
+ unreachable; // shifting everything in the table
+ }
+ }
+ return null;
+ }
+
+ /// Calls remove(), asserts that a kv pair is removed, and discards it.
+ pub fn removeAssertDiscard(hm: *Self, key: K) void {
+ assert(hm.remove(key) != null);
+ }
+
+ pub fn iterator(hm: *const Self) Iterator {
+ return Iterator{
+ .hm = hm,
+ .count = 0,
+ .index = 0,
+ .initial_modification_count = hm.modification_count,
+ };
+ }
+
+ pub fn clone(self: Self) !Self {
+ var other = Self.init(self.allocator);
+ try other.initCapacity(self.entries.len);
+ var it = self.iterator();
+ while (it.next()) |entry| {
+ try other.putNoClobber(entry.key, entry.value);
+ }
+ return other;
+ }
+
+ fn autoCapacity(self: *Self) !void {
+ if (self.entries.len == 0) {
+ return self.ensureCapacityExact(16);
+ }
+ // if we get too full (60%), double the capacity
+ if (self.size * 5 >= self.entries.len * 3) {
+ return self.ensureCapacityExact(self.entries.len * 2);
+ }
+ }
+
+ fn initCapacity(hm: *Self, capacity: usize) !void {
+ hm.entries = try hm.allocator.alloc(Entry, capacity);
+ hm.size = 0;
+ hm.max_distance_from_start_index = 0;
+ for (hm.entries) |*entry| {
+ entry.used = false;
+ }
+ }
+
+ fn incrementModificationCount(hm: *Self) void {
+ if (want_modification_safety) {
+ hm.modification_count +%= 1;
+ }
+ }
+
+ const InternalPutResult = struct {
+ new_entry: *Entry,
+ old_kv: ?KV,
+ };
+
+ /// Returns a pointer to the new entry.
+ /// Asserts that there is enough space for the new item.
+ fn internalPut(self: *Self, orig_key: K) InternalPutResult {
+ var key = orig_key;
+ var value: V = undefined;
+ const start_index = self.keyToIndex(key);
+ var roll_over: usize = 0;
+ var distance_from_start_index: usize = 0;
+ var got_result_entry = false;
+ var result = InternalPutResult{
+ .new_entry = undefined,
+ .old_kv = null,
+ };
+ while (roll_over < self.entries.len) : ({
+ roll_over += 1;
+ distance_from_start_index += 1;
+ }) {
+ const index = self.constrainIndex(start_index + roll_over);
+ const entry = &self.entries[index];
+
+ if (entry.used and !eql(entry.kv.key, key)) {
+ if (entry.distance_from_start_index < distance_from_start_index) {
+ // robin hood to the rescue
+ const tmp = entry.*;
+ self.max_distance_from_start_index = math.max(self.max_distance_from_start_index, distance_from_start_index);
+ if (!got_result_entry) {
+ got_result_entry = true;
+ result.new_entry = entry;
+ }
+ entry.* = Entry{
+ .used = true,
+ .distance_from_start_index = distance_from_start_index,
+ .kv = KV{
+ .key = key,
+ .value = value,
+ },
+ };
+ key = tmp.kv.key;
+ value = tmp.kv.value;
+ distance_from_start_index = tmp.distance_from_start_index;
+ }
+ continue;
+ }
+
+ if (entry.used) {
+ result.old_kv = entry.kv;
+ } else {
+ // adding an entry. otherwise overwriting old value with
+ // same key
+ self.size += 1;
+ }
+
+ self.max_distance_from_start_index = math.max(distance_from_start_index, self.max_distance_from_start_index);
+ if (!got_result_entry) {
+ result.new_entry = entry;
+ }
+ entry.* = Entry{
+ .used = true,
+ .distance_from_start_index = distance_from_start_index,
+ .kv = KV{
+ .key = key,
+ .value = value,
+ },
+ };
+ return result;
+ }
+ unreachable; // put into a full map
+ }
+
+ fn internalGet(hm: Self, key: K) ?*KV {
+ const start_index = hm.keyToIndex(key);
+ {
+ var roll_over: usize = 0;
+ while (roll_over <= hm.max_distance_from_start_index) : (roll_over += 1) {
+ const index = hm.constrainIndex(start_index + roll_over);
+ const entry = &hm.entries[index];
+
+ if (!entry.used) return null;
+ if (eql(entry.kv.key, key)) return &entry.kv;
+ }
+ }
+ return null;
+ }
+
+ fn keyToIndex(hm: Self, key: K) usize {
+ return hm.constrainIndex(usize(hash(key)));
+ }
+
+ fn constrainIndex(hm: Self, i: usize) usize {
+ // this is an optimization for modulo of power of two integers;
+ // it requires hm.entries.len to always be a power of two
+ return i & (hm.entries.len - 1);
+ }
+ };
+}
+
+test "basic hash map usage" {
+ var map = AutoHashMap(i32, i32).init(std.heap.direct_allocator);
+ defer map.deinit();
+
+ testing.expect((try map.put(1, 11)) == null);
+ testing.expect((try map.put(2, 22)) == null);
+ testing.expect((try map.put(3, 33)) == null);
+ testing.expect((try map.put(4, 44)) == null);
+
+ try map.putNoClobber(5, 55);
+ testing.expect((try map.put(5, 66)).?.value == 55);
+ testing.expect((try map.put(5, 55)).?.value == 66);
+
+ const gop1 = try map.getOrPut(5);
+ testing.expect(gop1.found_existing == true);
+ testing.expect(gop1.kv.value == 55);
+ gop1.kv.value = 77;
+ testing.expect(map.get(5).?.value == 77);
+
+ const gop2 = try map.getOrPut(99);
+ testing.expect(gop2.found_existing == false);
+ gop2.kv.value = 42;
+ testing.expect(map.get(99).?.value == 42);
+
+ const gop3 = try map.getOrPutValue(5, 5);
+ testing.expect(gop3.value == 77);
+
+ const gop4 = try map.getOrPutValue(100, 41);
+ testing.expect(gop4.value == 41);
+
+ testing.expect(map.contains(2));
+ testing.expect(map.get(2).?.value == 22);
+ testing.expect(map.getValue(2).? == 22);
+
+ const rmv1 = map.remove(2);
+ testing.expect(rmv1.?.key == 2);
+ testing.expect(rmv1.?.value == 22);
+ testing.expect(map.remove(2) == null);
+ testing.expect(map.get(2) == null);
+ testing.expect(map.getValue(2) == null);
+
+ map.removeAssertDiscard(3);
+}
+
+test "iterator hash map" {
+ var reset_map = AutoHashMap(i32, i32).init(std.heap.direct_allocator);
+ defer reset_map.deinit();
+
+ try reset_map.putNoClobber(1, 11);
+ try reset_map.putNoClobber(2, 22);
+ try reset_map.putNoClobber(3, 33);
+
+ // TODO this test depends on the hashing algorithm, because it assumes the
+ // order of the elements in the hashmap. This should not be the case.
+ var keys = [_]i32{
+ 1,
+ 3,
+ 2,
+ };
+ var values = [_]i32{
+ 11,
+ 33,
+ 22,
+ };
+
+ var it = reset_map.iterator();
+ var count: usize = 0;
+ while (it.next()) |next| {
+ testing.expect(next.key == keys[count]);
+ testing.expect(next.value == values[count]);
+ count += 1;
+ }
+
+ testing.expect(count == 3);
+ testing.expect(it.next() == null);
+ it.reset();
+ count = 0;
+ while (it.next()) |next| {
+ testing.expect(next.key == keys[count]);
+ testing.expect(next.value == values[count]);
+ count += 1;
+ if (count == 2) break;
+ }
+
+ it.reset();
+ var entry = it.next().?;
+ testing.expect(entry.key == keys[0]);
+ testing.expect(entry.value == values[0]);
+}
+
+test "ensure capacity" {
+ var map = AutoHashMap(i32, i32).init(std.heap.direct_allocator);
+ defer map.deinit();
+
+ try map.ensureCapacity(20);
+ const initialCapacity = map.entries.len;
+ testing.expect(initialCapacity >= 20);
+ var i: i32 = 0;
+ while (i < 20) : (i += 1) {
+ testing.expect(map.putAssumeCapacity(i, i + 10) == null);
+ }
+ // shouldn't resize from putAssumeCapacity
+ testing.expect(initialCapacity == map.entries.len);
+}
+
+pub fn getHashPtrAddrFn(comptime K: type) (fn (K) u32) {
+ return struct {
+ fn hash(key: K) u32 {
+ return getAutoHashFn(usize)(@ptrToInt(key));
+ }
+ }.hash;
+}
+
+pub fn getTrivialEqlFn(comptime K: type) (fn (K, K) bool) {
+ return struct {
+ fn eql(a: K, b: K) bool {
+ return a == b;
+ }
+ }.eql;
+}
+
+pub fn getAutoHashFn(comptime K: type) (fn (K) u32) {
+ return struct {
+ fn hash(key: K) u32 {
+ var hasher = Wyhash.init(0);
+ autoHash(&hasher, key);
+ return @truncate(u32, hasher.final());
+ }
+ }.hash;
+}
+
+pub fn getAutoEqlFn(comptime K: type) (fn (K, K) bool) {
+ return struct {
+ fn eql(a: K, b: K) bool {
+ return meta.eql(a, b);
+ }
+ }.eql;
+}