aboutsummaryrefslogtreecommitdiff
path: root/src/hash_map.hpp
blob: 51ec352eda6bbed2883e3c9709a34fcb9489a3de (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
/*
 * Copyright (c) 2015 Andrew Kelley
 *
 * This file is part of zig, which is MIT licensed.
 * See http://opensource.org/licenses/MIT
 */

#ifndef ZIG_HASH_MAP_HPP
#define ZIG_HASH_MAP_HPP

#include "util.hpp"

#include <stdint.h>

template<typename K, typename V, uint32_t (*HashFunction)(K key), bool (*EqualFn)(K a, K b)>
class HashMap {
public:
    void init(int capacity) {
        init_capacity(capacity);
    }
    void deinit(void) {
        free(_entries);
    }

    struct Entry {
        bool used;
        int distance_from_start_index;
        K key;
        V value;
    };

    void clear() {
        for (int i = 0; i < _capacity; i += 1) {
            _entries[i].used = false;
        }
        _size = 0;
        _max_distance_from_start_index = 0;
        _modification_count += 1;
    }

    int size() const {
        return _size;
    }

    void put(const K &key, const V &value) {
        _modification_count += 1;
        internal_put(key, value);

        // if we get too full (60%), double the capacity
        if (_size * 5 >= _capacity * 3) {
            Entry *old_entries = _entries;
            int old_capacity = _capacity;
            init_capacity(_capacity * 2);
            // dump all of the old elements into the new table
            for (int i = 0; i < old_capacity; i += 1) {
                Entry *old_entry = &old_entries[i];
                if (old_entry->used)
                    internal_put(old_entry->key, old_entry->value);
            }
            free(old_entries);
        }
    }

    Entry *put_unique(const K &key, const V &value) {
        // TODO make this more efficient
        Entry *entry = internal_get(key);
        if (entry)
            return entry;
        put(key, value);
        return nullptr;
    }

    const V &get(const K &key) const {
        Entry *entry = internal_get(key);
        if (!entry)
            zig_panic("key not found");
        return entry->value;
    }

    Entry *maybe_get(const K &key) const {
        return internal_get(key);
    }

    void maybe_remove(const K &key) {
        if (maybe_get(key)) {
            remove(key);
        }
    }

    void remove(const K &key) {
        _modification_count += 1;
        int start_index = key_to_index(key);
        for (int roll_over = 0; roll_over <= _max_distance_from_start_index; roll_over += 1) {
            int index = (start_index + roll_over) % _capacity;
            Entry *entry = &_entries[index];

            if (!entry->used)
                zig_panic("key not found");

            if (!EqualFn(entry->key, key))
                continue;

            for (; roll_over < _capacity; roll_over += 1) {
                int next_index = (start_index + roll_over + 1) % _capacity;
                Entry *next_entry = &_entries[next_index];
                if (!next_entry->used || next_entry->distance_from_start_index == 0) {
                    entry->used = false;
                    _size -= 1;
                    return;
                }
                *entry = *next_entry;
                entry->distance_from_start_index -= 1;
                entry = next_entry;
            }
            zig_panic("shifting everything in the table");
        }
        zig_panic("key not found");
    }

    class Iterator {
    public:
        Entry *next() {
            if (_inital_modification_count != _table->_modification_count)
                zig_panic("concurrent modification");
            if (_count >= _table->size())
                return NULL;
            for (; _index < _table->_capacity; _index += 1) {
                Entry *entry = &_table->_entries[_index];
                if (entry->used) {
                    _index += 1;
                    _count += 1;
                    return entry;
                }
            }
            zig_panic("no next item");
        }
    private:
        const HashMap * _table;
        // how many items have we returned
        int _count = 0;
        // iterator through the entry array
        int _index = 0;
        // used to detect concurrent modification
        uint32_t _inital_modification_count;
        Iterator(const HashMap * table) :
                _table(table), _inital_modification_count(table->_modification_count) {
        }
        friend HashMap;
    };

    // you must not modify the underlying HashMap while this iterator is still in use
    Iterator entry_iterator() const {
        return Iterator(this);
    }

private:

    Entry *_entries;
    int _capacity;
    int _size;
    int _max_distance_from_start_index;
    // this is used to detect bugs where a hashtable is edited while an iterator is running.
    uint32_t _modification_count;

    void init_capacity(int capacity) {
        _capacity = capacity;
        _entries = allocate<Entry>(_capacity);
        _size = 0;
        _max_distance_from_start_index = 0;
        for (int i = 0; i < _capacity; i += 1) {
            _entries[i].used = false;
        }
    }

    void internal_put(K key, V value) {
        int start_index = key_to_index(key);
        for (int roll_over = 0, distance_from_start_index = 0;
                roll_over < _capacity; roll_over += 1, distance_from_start_index += 1)
        {
            int index = (start_index + roll_over) % _capacity;
            Entry *entry = &_entries[index];

            if (entry->used && !EqualFn(entry->key, key)) {
                if (entry->distance_from_start_index < distance_from_start_index) {
                    // robin hood to the rescue
                    Entry tmp = *entry;
                    if (distance_from_start_index > _max_distance_from_start_index)
                        _max_distance_from_start_index = distance_from_start_index;
                    *entry = {
                        true,
                        distance_from_start_index,
                        key,
                        value,
                    };
                    key = tmp.key;
                    value = tmp.value;
                    distance_from_start_index = tmp.distance_from_start_index;
                }
                continue;
            }

            if (!entry->used) {
                // adding an entry. otherwise overwriting old value with
                // same key
                _size += 1;
            }

            if (distance_from_start_index > _max_distance_from_start_index)
                _max_distance_from_start_index = distance_from_start_index;
            *entry = {
                true,
                distance_from_start_index,
                key,
                value,
            };
            return;
        }
        zig_panic("put into a full HashMap");
    }


    Entry *internal_get(const K &key) const {
        int start_index = key_to_index(key);
        for (int roll_over = 0; roll_over <= _max_distance_from_start_index; roll_over += 1) {
            int index = (start_index + roll_over) % _capacity;
            Entry *entry = &_entries[index];

            if (!entry->used)
                return NULL;

            if (EqualFn(entry->key, key))
                return entry;
        }
        return NULL;
    }

    int key_to_index(const K &key) const {
        return (int)(HashFunction(key) % ((uint32_t)_capacity));
    }
};

#endif