aboutsummaryrefslogtreecommitdiff
path: root/std/hash_map.zig
diff options
context:
space:
mode:
authorAndrew Kelley <superjoe30@gmail.com>2018-08-25 21:57:28 -0400
committerAndrew Kelley <superjoe30@gmail.com>2018-08-25 21:57:28 -0400
commit7109035b78ee05302bbdaadc52013b430a030b69 (patch)
treeae6d7202dc75f2c799f5fbcad72ccf8b02a954a4 /std/hash_map.zig
parent6cf248ec0824c746fc796905144c8077ccab99cf (diff)
parent526338b00fbe1cac19f64832176af3bdf2108a56 (diff)
downloadzig-7109035b78ee05302bbdaadc52013b430a030b69.tar.gz
zig-7109035b78ee05302bbdaadc52013b430a030b69.zip
Merge remote-tracking branch 'origin/master' into llvm7
Diffstat (limited to 'std/hash_map.zig')
-rw-r--r--std/hash_map.zig330
1 files changed, 273 insertions, 57 deletions
diff --git a/std/hash_map.zig b/std/hash_map.zig
index cebd5272c0..9654d612a5 100644
--- a/std/hash_map.zig
+++ b/std/hash_map.zig
@@ -9,6 +9,10 @@ const builtin = @import("builtin");
const want_modification_safety = builtin.mode != builtin.Mode.ReleaseFast;
const debug_u32 = if (want_modification_safety) u32 else void;
+pub fn AutoHashMap(comptime K: type, comptime V: type) type {
+ return HashMap(K, V, getAutoHashFn(K), getAutoEqlFn(K));
+}
+
pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u32, comptime eql: fn (a: K, b: K) bool) type {
return struct {
entries: []Entry,
@@ -20,13 +24,22 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3
const Self = this;
- pub const Entry = struct {
- used: bool,
- distance_from_start_index: usize,
+ pub const KV = struct {
key: K,
value: V,
};
+ const Entry = struct {
+ used: bool,
+ distance_from_start_index: usize,
+ kv: KV,
+ };
+
+ pub const GetOrPutResult = struct {
+ kv: *KV,
+ found_existing: bool,
+ };
+
pub const Iterator = struct {
hm: *const Self,
// how many items have we returned
@@ -36,7 +49,7 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3
// used to detect concurrent modification
initial_modification_count: debug_u32,
- pub fn next(it: *Iterator) ?*Entry {
+ pub fn next(it: *Iterator) ?*KV {
if (want_modification_safety) {
assert(it.initial_modification_count == it.hm.modification_count); // concurrent modification
}
@@ -46,7 +59,7 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3
if (entry.used) {
it.index += 1;
it.count += 1;
- return entry;
+ return &entry.kv;
}
}
unreachable; // no next item
@@ -71,7 +84,7 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3
};
}
- pub fn deinit(hm: *const Self) void {
+ pub fn deinit(hm: Self) void {
hm.allocator.free(hm.entries);
}
@@ -84,34 +97,65 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3
hm.incrementModificationCount();
}
- pub fn count(hm: *const Self) usize {
- return hm.size;
+ pub fn count(self: Self) usize {
+ return self.size;
}
- /// Returns the value that was already there.
- pub fn put(hm: *Self, key: K, value: *const V) !?V {
- if (hm.entries.len == 0) {
- try hm.initCapacity(16);
+ /// If key exists this function cannot fail.
+ /// If there is an existing item with `key`, then the result
+ /// kv pointer points to it, and found_existing is true.
+ /// Otherwise, puts a new item with undefined value, and
+ /// the kv pointer points to it. Caller should then initialize
+ /// the data.
+ pub fn getOrPut(self: *Self, key: K) !GetOrPutResult {
+ // TODO this implementation can be improved - we should only
+ // have to hash once and find the entry once.
+ if (self.get(key)) |kv| {
+ return GetOrPutResult{
+ .kv = kv,
+ .found_existing = true,
+ };
+ }
+ self.incrementModificationCount();
+ try self.ensureCapacity();
+ const put_result = self.internalPut(key);
+ assert(put_result.old_kv == null);
+ return GetOrPutResult{
+ .kv = &put_result.new_entry.kv,
+ .found_existing = false,
+ };
+ }
+
+ fn ensureCapacity(self: *Self) !void {
+ if (self.entries.len == 0) {
+ return self.initCapacity(16);
}
- hm.incrementModificationCount();
// if we get too full (60%), double the capacity
- if (hm.size * 5 >= hm.entries.len * 3) {
- const old_entries = hm.entries;
- try hm.initCapacity(hm.entries.len * 2);
+ if (self.size * 5 >= self.entries.len * 3) {
+ const old_entries = self.entries;
+ try self.initCapacity(self.entries.len * 2);
// dump all of the old elements into the new table
for (old_entries) |*old_entry| {
if (old_entry.used) {
- _ = hm.internalPut(old_entry.key, old_entry.value);
+ self.internalPut(old_entry.kv.key).new_entry.kv.value = old_entry.kv.value;
}
}
- hm.allocator.free(old_entries);
+ self.allocator.free(old_entries);
}
+ }
+
+ /// Returns the kv pair that was already there.
+ pub fn put(self: *Self, key: K, value: V) !?KV {
+ self.incrementModificationCount();
+ try self.ensureCapacity();
- return hm.internalPut(key, value);
+ const put_result = self.internalPut(key);
+ put_result.new_entry.kv.value = value;
+ return put_result.old_kv;
}
- pub fn get(hm: *const Self, key: K) ?*Entry {
+ pub fn get(hm: *const Self, key: K) ?*KV {
if (hm.entries.len == 0) {
return null;
}
@@ -122,7 +166,7 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3
return hm.get(key) != null;
}
- pub fn remove(hm: *Self, key: K) ?*Entry {
+ pub fn remove(hm: *Self, key: K) ?*KV {
if (hm.entries.len == 0) return null;
hm.incrementModificationCount();
const start_index = hm.keyToIndex(key);
@@ -134,7 +178,7 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3
if (!entry.used) return null;
- if (!eql(entry.key, key)) continue;
+ if (!eql(entry.kv.key, key)) continue;
while (roll_over < hm.entries.len) : (roll_over += 1) {
const next_index = (start_index + roll_over + 1) % hm.entries.len;
@@ -142,7 +186,7 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3
if (!next_entry.used or next_entry.distance_from_start_index == 0) {
entry.used = false;
hm.size -= 1;
- return entry;
+ return &entry.kv;
}
entry.* = next_entry.*;
entry.distance_from_start_index -= 1;
@@ -163,6 +207,16 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3
};
}
+ pub fn clone(self: Self) !Self {
+ var other = Self.init(self.allocator);
+ try other.initCapacity(self.entries.len);
+ var it = self.iterator();
+ while (it.next()) |entry| {
+ assert((try other.put(entry.key, entry.value)) == null);
+ }
+ return other;
+ }
+
fn initCapacity(hm: *Self, capacity: usize) !void {
hm.entries = try hm.allocator.alloc(Entry, capacity);
hm.size = 0;
@@ -178,60 +232,81 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3
}
}
- /// Returns the value that was already there.
- fn internalPut(hm: *Self, orig_key: K, orig_value: *const V) ?V {
+ const InternalPutResult = struct {
+ new_entry: *Entry,
+ old_kv: ?KV,
+ };
+
+ /// Returns a pointer to the new entry.
+ /// Asserts that there is enough space for the new item.
+ fn internalPut(self: *Self, orig_key: K) InternalPutResult {
var key = orig_key;
- var value = orig_value.*;
- const start_index = hm.keyToIndex(key);
+ var value: V = undefined;
+ const start_index = self.keyToIndex(key);
var roll_over: usize = 0;
var distance_from_start_index: usize = 0;
- while (roll_over < hm.entries.len) : ({
+ var got_result_entry = false;
+ var result = InternalPutResult{
+ .new_entry = undefined,
+ .old_kv = null,
+ };
+ while (roll_over < self.entries.len) : ({
roll_over += 1;
distance_from_start_index += 1;
}) {
- const index = (start_index + roll_over) % hm.entries.len;
- const entry = &hm.entries[index];
+ const index = (start_index + roll_over) % self.entries.len;
+ const entry = &self.entries[index];
- if (entry.used and !eql(entry.key, key)) {
+ if (entry.used and !eql(entry.kv.key, key)) {
if (entry.distance_from_start_index < distance_from_start_index) {
// robin hood to the rescue
const tmp = entry.*;
- hm.max_distance_from_start_index = math.max(hm.max_distance_from_start_index, distance_from_start_index);
+ self.max_distance_from_start_index = math.max(self.max_distance_from_start_index, distance_from_start_index);
+ if (!got_result_entry) {
+ got_result_entry = true;
+ result.new_entry = entry;
+ }
entry.* = Entry{
.used = true,
.distance_from_start_index = distance_from_start_index,
- .key = key,
- .value = value,
+ .kv = KV{
+ .key = key,
+ .value = value,
+ },
};
- key = tmp.key;
- value = tmp.value;
+ key = tmp.kv.key;
+ value = tmp.kv.value;
distance_from_start_index = tmp.distance_from_start_index;
}
continue;
}
- var result: ?V = null;
if (entry.used) {
- result = entry.value;
+ result.old_kv = entry.kv;
} else {
// adding an entry. otherwise overwriting old value with
// same key
- hm.size += 1;
+ self.size += 1;
}
- hm.max_distance_from_start_index = math.max(distance_from_start_index, hm.max_distance_from_start_index);
+ self.max_distance_from_start_index = math.max(distance_from_start_index, self.max_distance_from_start_index);
+ if (!got_result_entry) {
+ result.new_entry = entry;
+ }
entry.* = Entry{
.used = true,
.distance_from_start_index = distance_from_start_index,
- .key = key,
- .value = value,
+ .kv = KV{
+ .key = key,
+ .value = value,
+ },
};
return result;
}
unreachable; // put into a full map
}
- fn internalGet(hm: *const Self, key: K) ?*Entry {
+ fn internalGet(hm: Self, key: K) ?*KV {
const start_index = hm.keyToIndex(key);
{
var roll_over: usize = 0;
@@ -240,13 +315,13 @@ pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u3
const entry = &hm.entries[index];
if (!entry.used) return null;
- if (eql(entry.key, key)) return entry;
+ if (eql(entry.kv.key, key)) return &entry.kv;
}
}
return null;
}
- fn keyToIndex(hm: *const Self, key: K) usize {
+ fn keyToIndex(hm: Self, key: K) usize {
return usize(hash(key)) % hm.entries.len;
}
};
@@ -256,7 +331,7 @@ test "basic hash map usage" {
var direct_allocator = std.heap.DirectAllocator.init();
defer direct_allocator.deinit();
- var map = HashMap(i32, i32, hash_i32, eql_i32).init(&direct_allocator.allocator);
+ var map = AutoHashMap(i32, i32).init(&direct_allocator.allocator);
defer map.deinit();
assert((try map.put(1, 11)) == null);
@@ -265,8 +340,19 @@ test "basic hash map usage" {
assert((try map.put(4, 44)) == null);
assert((try map.put(5, 55)) == null);
- assert((try map.put(5, 66)).? == 55);
- assert((try map.put(5, 55)).? == 66);
+ assert((try map.put(5, 66)).?.value == 55);
+ assert((try map.put(5, 55)).?.value == 66);
+
+ const gop1 = try map.getOrPut(5);
+ assert(gop1.found_existing == true);
+ assert(gop1.kv.value == 55);
+ gop1.kv.value = 77;
+ assert(map.get(5).?.value == 77);
+
+ const gop2 = try map.getOrPut(99);
+ assert(gop2.found_existing == false);
+ gop2.kv.value = 42;
+ assert(map.get(99).?.value == 42);
assert(map.contains(2));
assert(map.get(2).?.value == 22);
@@ -279,7 +365,7 @@ test "iterator hash map" {
var direct_allocator = std.heap.DirectAllocator.init();
defer direct_allocator.deinit();
- var reset_map = HashMap(i32, i32, hash_i32, eql_i32).init(&direct_allocator.allocator);
+ var reset_map = AutoHashMap(i32, i32).init(&direct_allocator.allocator);
defer reset_map.deinit();
assert((try reset_map.put(1, 11)) == null);
@@ -287,14 +373,14 @@ test "iterator hash map" {
assert((try reset_map.put(3, 33)) == null);
var keys = []i32{
- 1,
- 2,
3,
+ 2,
+ 1,
};
var values = []i32{
- 11,
- 22,
33,
+ 22,
+ 11,
};
var it = reset_map.iterator();
@@ -322,10 +408,140 @@ test "iterator hash map" {
assert(entry.value == values[0]);
}
-fn hash_i32(x: i32) u32 {
- return @bitCast(u32, x);
+pub fn getHashPtrAddrFn(comptime K: type) (fn (K) u32) {
+ return struct {
+ fn hash(key: K) u32 {
+ return getAutoHashFn(usize)(@ptrToInt(key));
+ }
+ }.hash;
}
-fn eql_i32(a: i32, b: i32) bool {
- return a == b;
+pub fn getTrivialEqlFn(comptime K: type) (fn (K, K) bool) {
+ return struct {
+ fn eql(a: K, b: K) bool {
+ return a == b;
+ }
+ }.eql;
+}
+
+pub fn getAutoHashFn(comptime K: type) (fn (K) u32) {
+ return struct {
+ fn hash(key: K) u32 {
+ comptime var rng = comptime std.rand.DefaultPrng.init(0);
+ return autoHash(key, &rng.random, u32);
+ }
+ }.hash;
+}
+
+pub fn getAutoEqlFn(comptime K: type) (fn (K, K) bool) {
+ return struct {
+ fn eql(a: K, b: K) bool {
+ return autoEql(a, b);
+ }
+ }.eql;
+}
+
+// TODO improve these hash functions
+pub fn autoHash(key: var, comptime rng: *std.rand.Random, comptime HashInt: type) HashInt {
+ switch (@typeInfo(@typeOf(key))) {
+ builtin.TypeId.NoReturn,
+ builtin.TypeId.Opaque,
+ builtin.TypeId.Undefined,
+ builtin.TypeId.ArgTuple,
+ => @compileError("cannot hash this type"),
+
+ builtin.TypeId.Void,
+ builtin.TypeId.Null,
+ => return 0,
+
+ builtin.TypeId.Int => |info| {
+ const unsigned_x = @bitCast(@IntType(false, info.bits), key);
+ if (info.bits <= HashInt.bit_count) {
+ return HashInt(unsigned_x) ^ comptime rng.scalar(HashInt);
+ } else {
+ return @truncate(HashInt, unsigned_x ^ comptime rng.scalar(@typeOf(unsigned_x)));
+ }
+ },
+
+ builtin.TypeId.Float => |info| {
+ return autoHash(@bitCast(@IntType(false, info.bits), key), rng);
+ },
+ builtin.TypeId.Bool => return autoHash(@boolToInt(key), rng),
+ builtin.TypeId.Enum => return autoHash(@enumToInt(key), rng),
+ builtin.TypeId.ErrorSet => return autoHash(@errorToInt(key), rng),
+ builtin.TypeId.Promise, builtin.TypeId.Fn => return autoHash(@ptrToInt(key), rng),
+
+ builtin.TypeId.Namespace,
+ builtin.TypeId.Block,
+ builtin.TypeId.BoundFn,
+ builtin.TypeId.ComptimeFloat,
+ builtin.TypeId.ComptimeInt,
+ builtin.TypeId.Type,
+ => return 0,
+
+ builtin.TypeId.Pointer => |info| switch (info.size) {
+ builtin.TypeInfo.Pointer.Size.One => @compileError("TODO auto hash for single item pointers"),
+ builtin.TypeInfo.Pointer.Size.Many => @compileError("TODO auto hash for many item pointers"),
+ builtin.TypeInfo.Pointer.Size.Slice => {
+ const interval = std.math.max(1, key.len / 256);
+ var i: usize = 0;
+ var h = comptime rng.scalar(HashInt);
+ while (i < key.len) : (i += interval) {
+ h ^= autoHash(key[i], rng, HashInt);
+ }
+ return h;
+ },
+ },
+
+ builtin.TypeId.Optional => @compileError("TODO auto hash for optionals"),
+ builtin.TypeId.Array => @compileError("TODO auto hash for arrays"),
+ builtin.TypeId.Struct => @compileError("TODO auto hash for structs"),
+ builtin.TypeId.Union => @compileError("TODO auto hash for unions"),
+ builtin.TypeId.ErrorUnion => @compileError("TODO auto hash for unions"),
+ }
+}
+
+pub fn autoEql(a: var, b: @typeOf(a)) bool {
+ switch (@typeInfo(@typeOf(a))) {
+ builtin.TypeId.NoReturn,
+ builtin.TypeId.Opaque,
+ builtin.TypeId.Undefined,
+ builtin.TypeId.ArgTuple,
+ => @compileError("cannot test equality of this type"),
+ builtin.TypeId.Void,
+ builtin.TypeId.Null,
+ => return true,
+ builtin.TypeId.Bool,
+ builtin.TypeId.Int,
+ builtin.TypeId.Float,
+ builtin.TypeId.ComptimeFloat,
+ builtin.TypeId.ComptimeInt,
+ builtin.TypeId.Namespace,
+ builtin.TypeId.Block,
+ builtin.TypeId.Promise,
+ builtin.TypeId.Enum,
+ builtin.TypeId.BoundFn,
+ builtin.TypeId.Fn,
+ builtin.TypeId.ErrorSet,
+ builtin.TypeId.Type,
+ => return a == b,
+
+ builtin.TypeId.Pointer => |info| switch (info.size) {
+ builtin.TypeInfo.Pointer.Size.One => @compileError("TODO auto eql for single item pointers"),
+ builtin.TypeInfo.Pointer.Size.Many => @compileError("TODO auto eql for many item pointers"),
+ builtin.TypeInfo.Pointer.Size.Slice => {
+ if (a.len != b.len) return false;
+ for (a) |a_item, i| {
+ if (!autoEql(a_item, b[i])) return false;
+ }
+ return true;
+ },
+ },
+
+ builtin.TypeId.Optional => @compileError("TODO auto eql for optionals"),
+ builtin.TypeId.Array => @compileError("TODO auto eql for arrays"),
+ builtin.TypeId.Struct => @compileError("TODO auto eql for structs"),
+ builtin.TypeId.Union => @compileError("TODO auto eql for unions"),
+ builtin.TypeId.ErrorUnion => @compileError("TODO auto eql for unions"),
+ }
}