symbolic
hash_set.h
1 
10 #ifndef SYMBOLIC_UTILS_HASH_SET_H_
11 #define SYMBOLIC_UTILS_HASH_SET_H_
12 
13 #include <algorithm> // std::max, std::swap
14 #include <iostream> // TODO: remove
15 #include <vector> // std::vector
16 
17 #include "symbolic/utils/unique_vector.h"
18 
19 namespace symbolic {
20 
21 constexpr int HASH_SET_INITIAL_SIZE = 1;
22 
26 template <typename T>
27 class HashSet {
28  public:
29  class iterator;
30  class const_iterator;
31  // using iterator = const_iterator;
32 
33  HashSet() : buckets_(HASH_SET_INITIAL_SIZE){};
34 
35  HashSet(std::initializer_list<T> l) : buckets_(HASH_SET_INITIAL_SIZE) {
36  for (const T& element : l) {
37  insert(element);
38  }
39  }
40 
41  iterator begin() {
42  iterator it(buckets_, 0, 0);
43  it.FindNextElement();
44  return it;
45  }
46  iterator end() { return iterator(buckets_, buckets_.size(), 0); }
47 
48  const_iterator begin() const {
49  const_iterator it(buckets_, 0, 0);
50  it.FindNextElement();
51  return it;
52  }
53  const_iterator end() const {
54  return const_iterator(buckets_, buckets_.size(), 0);
55  }
56  // iterator rend() const { return iterator(buckets_, -1, -1); }
57 
58  bool empty() const { return size() == 0; }
59  size_t size() const { return size_; }
60 
61  size_t bucket_count() const { return buckets_.size(); }
62 
63  template <typename T_query>
64  bool contains(const T_query& element) const {
65  return GetBucket(element).contains(element);
66  }
67 
68  template <typename T_query>
69  bool insert(const T_query& element) {
70  const bool inserted = GetBucket(element).insert(element);
71  if (inserted) {
72  size_++;
73  if (size() > buckets_.size()) Rehash(UpperBound());
74  }
75  return inserted;
76  }
77 
78  bool insert(T&& element) {
79  const bool inserted = GetBucket(element).insert(std::move(element));
80  if (inserted) {
81  size_++;
82  if (size() > buckets_.size()) Rehash(UpperBound());
83  }
84  return inserted;
85  }
86 
87  template <typename T_query>
88  bool erase(const T_query& element) {
89  const bool erased = GetBucket(element).erase(element);
90  if (erased) {
91  size_--;
92  if (size() <= LowerBound()) Rehash(LowerBound());
93  }
94  return erased;
95  }
96 
97  friend bool operator==(const HashSet<T>& lhs, const HashSet<T>& rhs) {
98  return lhs.buckets_ == rhs.buckets_;
99  }
100  friend bool operator!=(const HashSet<T>& lhs, const HashSet<T>& rhs) {
101  return !(lhs == rhs);
102  }
103 
104  friend bool operator<(const HashSet<T>& lhs, const HashSet<T>& rhs) {
105  return lhs.buckets_ < rhs.buckets_;
106  }
107 
108  private:
109  size_t UpperBound() const { return 2 * buckets_.size() + 1; }
110  size_t LowerBound() const {
111  return std::max(HASH_SET_INITIAL_SIZE, (static_cast<int>(buckets_.size()) - 1) / 2);
112  }
113 
114  template <typename T_query>
115  UniqueVector<T>& GetBucket(const T_query& element) {
116  const size_t idx_bucket = std::hash<T_query>{}(element) % buckets_.size();
117  return buckets_[idx_bucket];
118  }
119  template <typename T_query>
120  const UniqueVector<T>& GetBucket(const T_query& element) const {
121  const size_t idx_bucket = std::hash<T_query>{}(element) % buckets_.size();
122  return buckets_[idx_bucket];
123  }
124 
125  void Rehash(size_t num_buckets) {
126  if (num_buckets == buckets_.size()) return;
127 
128  // Create new buckets.
129  std::vector<UniqueVector<T>> old_buckets(num_buckets);
130  std::swap(buckets_, old_buckets);
131 
132  // Iterate over old buckets.
133  for (UniqueVector<T>& bucket : old_buckets) {
134  // Move elements from old bucket.
135  for (T& element : bucket) {
136  GetBucket(element).insert(std::move(element));
137  }
138  }
139  }
140 
141  std::vector<UniqueVector<T>> buckets_;
142  size_t size_ = 0;
143 
144  public:
146  public:
147  // Iterator traits
148  using iterator_category = std::bidirectional_iterator_tag;
149  using value_type = T;
150  using difference_type = ptrdiff_t;
151  using pointer = const T*;
152  using reference = const T&;
153 
154  // Constructor
155  const_iterator(const std::vector<UniqueVector<T>>& buckets,
156  const int idx_bucket, const int idx_in_bucket)
157  : buckets_(&buckets),
158  idx_bucket_(idx_bucket),
159  idx_in_bucket_(idx_in_bucket) {}
160 
161  // Forward iterator
162  const_iterator& operator++() {
163  idx_in_bucket_++;
164  FindNextElement();
165  return *this;
166  }
167 
168  const_iterator operator++(int) {
169  const_iterator it = *this;
170  operator++();
171  return it;
172  }
173 
174  reference operator*() const {
175  return (*buckets_)[idx_bucket_][idx_in_bucket_];
176  }
177 
178  pointer operator->() const {
179  return &(*buckets_)[idx_bucket_][idx_in_bucket_];
180  }
181 
182  bool operator==(const const_iterator& rhs) const {
183  return idx_bucket_ == rhs.idx_bucket_ &&
184  idx_in_bucket_ == rhs.idx_in_bucket_;
185  }
186 
187  bool operator!=(const const_iterator& rhs) const { return !(*this == rhs); }
188 
189  // Bidirectional iterator
190  const_iterator& operator--() {
191  idx_in_bucket_--;
192  FindPreviousElement();
193  return *this;
194  }
195 
196  const_iterator operator--(int) {
197  const_iterator it = *this;
198  operator--();
199  return it;
200  }
201 
202  protected:
203  friend HashSet<T>;
204 
205  void FindNextElement() {
206  // Find next occupied bucket.
207  if (idx_bucket_ >= buckets_->size()) {
208  idx_in_bucket_ = 0;
209  return;
210  }
211  const UniqueVector<T>* bucket = &(*buckets_)[idx_bucket_];
212  while (idx_in_bucket_ >= bucket->size()) {
213  idx_bucket_++;
214  idx_in_bucket_ = 0;
215  if (idx_bucket_ == buckets_->size()) return;
216  bucket = &(*buckets_)[idx_bucket_];
217  }
218  }
219  void FindPreviousElement() {
220  // Find previous occupied bucket.
221  while (idx_in_bucket_ < 0) {
222  idx_bucket_--;
223  if (idx_bucket_ < 0) return;
224  const UniqueVector<T>& bucket = (*buckets_)[idx_bucket_];
225  idx_in_bucket_ = bucket.size() - 1;
226  }
227  }
228 
229  const std::vector<UniqueVector<T>>* buckets_ = nullptr;
230  int idx_bucket_ = 0;
231  int idx_in_bucket_ = 0;
232  };
233  class iterator {
234  public:
235  // Iterator traits
236  using iterator_category = std::bidirectional_iterator_tag;
237  using value_type = T;
238  using difference_type = ptrdiff_t;
239  using pointer = T*;
240  using reference = T&;
241 
242  // Constructor
243  iterator(std::vector<UniqueVector<T>>& buckets, const int idx_bucket,
244  const int idx_in_bucket)
245  : buckets_(&buckets),
246  idx_bucket_(idx_bucket),
247  idx_in_bucket_(idx_in_bucket) {}
248 
249  // Forward iterator
250  iterator& operator++() {
251  idx_in_bucket_++;
252  FindNextElement();
253  return *this;
254  }
255 
256  iterator operator++(int) {
257  iterator it = *this;
258  operator++();
259  return it;
260  }
261 
262  reference operator*() const {
263  return (*buckets_)[idx_bucket_][idx_in_bucket_];
264  }
265 
266  pointer operator->() const {
267  return &(*buckets_)[idx_bucket_][idx_in_bucket_];
268  }
269 
270  bool operator==(const iterator& rhs) const {
271  return idx_bucket_ == rhs.idx_bucket_ &&
272  idx_in_bucket_ == rhs.idx_in_bucket_;
273  }
274 
275  bool operator!=(const iterator& rhs) const { return !(*this == rhs); }
276 
277  // Bidirectional iterator
278  iterator& operator--() {
279  idx_in_bucket_--;
280  FindPreviousElement();
281  return *this;
282  }
283 
284  iterator operator--(int) {
285  iterator it = *this;
286  operator--();
287  return it;
288  }
289 
290  protected:
291  friend HashSet<T>;
292 
293  void FindNextElement() {
294  // Find next occupied bucket.
295  if (idx_bucket_ >= buckets_->size()) {
296  idx_in_bucket_ = 0;
297  return;
298  }
299  const UniqueVector<T>* bucket = &(*buckets_)[idx_bucket_];
300  while (idx_in_bucket_ >= bucket->size()) {
301  idx_bucket_++;
302  idx_in_bucket_ = 0;
303  if (idx_bucket_ == buckets_->size()) return;
304  bucket = &(*buckets_)[idx_bucket_];
305  }
306  }
307 
308  void FindPreviousElement() {
309  // Find previous occupied bucket.
310  while (idx_in_bucket_ < 0) {
311  idx_bucket_--;
312  if (idx_bucket_ < 0) return;
313  const UniqueVector<T>& bucket = (*buckets_)[idx_bucket_];
314  idx_in_bucket_ = bucket.size() - 1;
315  }
316  }
317 
318  std::vector<UniqueVector<T>>* buckets_ = nullptr;
319  int idx_bucket_ = 0;
320  int idx_in_bucket_ = 0;
321  };
322 };
323 
324 } // namespace symbolic
325 
326 #endif // SYMBOLIC_UTILS_HASH_SET_H_
symbolic
Definition: action.cc:329
symbolic::HashSet::const_iterator
Definition: hash_set.h:145
symbolic::HashSet
Definition: hash_set.h:27
symbolic::HashSet::iterator
Definition: hash_set.h:233
symbolic::UniqueVector
Definition: unique_vector.h:23