1. 前言
在 STL 中,
unordered_set和unordered_map的底层都是 哈希表。它们的区别仅仅在于:
容器 存储单元 是否允许重复 键是否可变 unordered_set 单个值 否 否 unordered_map 键值对 否 键不可变,值可变 本文将展示:如何用同一份哈希表代码,通过模板参数差异,分别实现这两种容器。
2. 核心设计思想:模板参数萃取
这是 STL 最精妙的设计之一 👇
cpptemplate< class K, // 关键码类型 class T, // 节点中真正存储的数据 class KeyOfT, // 从 T 中提取 K 的方法 class Hash // 哈希函数 > class HashTable;
容器 T KeyOfT set K 返回自身 map pair<const K, V> 返回 first ✅ 一套代码,两种形态
3. 底层哈希表回顾(简化版)
cpptemplate<class K, class T, class KeyOfT, class Hash> class HashTable { public: using Node = HashNode<T>; using Iterator = __HashIterator<K, T, KeyOfT, Hash>; pair<Iterator, bool> Insert(const T& data) { /* ... */ } Iterator Find(const K& key) { /* ... */ } bool Erase(const K& key) { /* ... */ } private: vector<Node*> _tables; size_t _n = 0; };⚠️ 注意:哈希表只关心 T 和 如何从 T 取 K,不关心它是 set 还是 map。
4. 封装 unordered_set
4.1 设计要点
存储的是 单个值
值即键
值不可修改
4.2 代码实现
cppnamespace wxx { template<class K, class Hash = HashFunc<K>> class unordered_set { struct SetKeyOfT { const K& operator()(const K& key) const { return key; } }; public: using iterator = typename HashTable<K, K, SetKeyOfT, Hash>::Iterator; pair<iterator, bool> insert(const K& key) { return _ht.Insert(key); } iterator find(const K& key) { return _ht.Find(key); } bool erase(const K& key) { return _ht.Erase(key); } iterator begin() { return _ht.Begin(); } iterator end() { return _ht.End(); } private: HashTable<K, K, SetKeyOfT, Hash> _ht; }; } // namespace bit✅ set 的本质:一个"只有键"的哈希表
5. 封装 unordered_map
5.1 设计要点
存储的是
pair<const K, V>键来自
pair.first键不可修改,值可修改
5.2 代码实现
cppnamespace wxx { template<class K, class V, class Hash = HashFunc<K>> class unordered_map { struct MapKeyOfT { const K& operator()(const pair<const K, V>& kv) const { return kv.first; } }; public: using iterator = typename HashTable<K, pair<const K, V>, MapKeyOfT, Hash>::Iterator; pair<iterator, bool> insert(const pair<const K, V>& kv) { return _ht.Insert(kv); } iterator find(const K& key) { return _ht.Find(key); } bool erase(const K& key) { return _ht.Erase(key); } V& operator[](const K& key) { auto ret = _ht.Insert({key, V()}); return ret.first->second; } iterator begin() { return _ht.Begin(); } iterator end() { return _ht.End(); } private: HashTable<K, pair<const K, V>, MapKeyOfT, Hash> _ht; }; } // namespace wxx✅ map 的本质:一个"键值对"的哈希表
6. set 与 map 的差异对比
对比项 unordered_set unordered_map 存储类型 K pair<const K, V> KeyOfT 返回自身 返回 first operator\[\] ❌ ✅ 值是否可改 ❌ ✅(仅 second) 底层哈希表 同一套 同一套
7. 完整代码
HashTable.h
cpp#pragma once #include<vector> #include<string> #include<iostream> using namespace std; template<class K> struct HashFunc { size_t operator()(const K& key) { return (size_t)key; } }; template<> struct HashFunc<string> { // BKDR size_t operator()(const string& str) { size_t hash = 0; for (auto ch : str) { hash += ch; hash *= 131; } return hash; } }; inline unsigned long __stl_next_prime(unsigned long n) { // Note: assumes long is at least 32 bits. static const int __stl_num_primes = 28; static const unsigned long __stl_prime_list[__stl_num_primes] = { 53, 97, 193, 389, 769, 1543, 3079, 6151, 12289, 24593, 49157, 98317, 196613, 393241, 786433, 1572869, 3145739, 6291469, 12582917, 25165843, 50331653, 100663319, 201326611, 402653189, 805306457, 1610612741, 3221225473, 4294967291 }; const unsigned long* first = __stl_prime_list; const unsigned long* last = __stl_prime_list + __stl_num_primes; const unsigned long* pos = lower_bound(first, last, n); return pos == last ? *(last - 1) : *pos; } template<class T> struct HashNode { T _data; HashNode<T>* _next; HashNode(const T& data) :_data(data) , _next(nullptr) { } }; // 前置声明 template<class K, class T, class KeyOfT, class Hash> class HashTable; template<class K, class T, class Ref, class Ptr, class KeyOfT, class Hash> struct HTIterator { typedef HashNode<T> Node; typedef HashTable<K, T, KeyOfT, Hash> HT; typedef HTIterator<K, T, Ref, Ptr, KeyOfT, Hash> Self; Node* _node; const HT* _ht; HTIterator(Node* node, const HT* ht) :_node(node) , _ht(ht) { } Ref operator*() { return _node->_data; } Ptr operator->() { return &_node->_data; } Self& operator++() { if (_node->_next) // 当前还有节点 { _node = _node->_next; } else // 当前桶为空,找下一个不为空的桶的第一个 { size_t hashi = Hash()(KeyOfT()(_node->_data)) % _ht->_tables.size(); ++hashi; while (hashi != _ht->_tables.size()) { if (_ht->_tables[hashi]) { _node = _ht->_tables[hashi]; break; } hashi++; } // 最后一个桶的最后一个节点已经遍历结束,走到end()去,nullptr充当end() if (hashi == _ht->_tables.size()) { _node = nullptr; } } return *this; } bool operator!=(const Self& s) const { return _node != s._node; } bool operator==(const Self& s) const { return _node == s._node; } }; template<class K, class T, class KeyOfT, class Hash> class HashTable { // 友元声明 template<class K, class T, class Ref, class Ptr, class KeyOfT, class Hash> friend struct HTIterator; typedef HashNode<T> Node; public: typedef HTIterator<K, T, T&, T*, KeyOfT, Hash> Iterator; typedef HTIterator<K, T, const T&, const T*, KeyOfT, Hash> ConstIterator; Iterator Begin() { for (size_t i = 0; i < _tables.size(); i++) { if (_tables[i]) { return Iterator(_tables[i], this); } } return End(); } Iterator End() { return Iterator(nullptr, this); } ConstIterator Begin() const { for (size_t i = 0; i < _tables.size(); i++) { if (_tables[i]) { return ConstIterator(_tables[i], this); } } return End(); } ConstIterator End() const { return ConstIterator(nullptr, this); } HashTable() :_tables(__stl_next_prime(1), nullptr) , _n(0) { } ~HashTable() { for (size_t i = 0; i < _tables.size(); i++) { Node* cur = _tables[i]; // 当前桶的节点重新映射挂到新表 while (cur) { Node* next = cur->_next; delete cur; cur = next; } _tables[i] = nullptr; } } pair<Iterator, bool> Insert(const T& data) { KeyOfT kot; auto it = Find(kot(data)); if (it != End()) return { it, false }; Hash hs; // 负载因子==1扩容 if (_n == _tables.size()) { //HashTable<K, V> newHT; //newHT._tables.resize(_tables.size()*2); //// 遍历旧表将所有值映射到新表 //for (auto cur : _tables) //{ // while (cur) // { // newHT.Insert(cur->_kv); // cur = cur->_next; // } //} //_tables.swap(newHT._tables); vector<Node*> newtables(__stl_next_prime(_tables.size() + 1)); for (size_t i = 0; i < _tables.size(); i++) { Node* cur = _tables[i]; // 当前桶的节点重新映射挂到新表 while (cur) { Node* next = cur->_next; // 插入到新表 size_t hashi = hs(kot(cur->_data)) % newtables.size(); cur->_next = newtables[hashi]; newtables[hashi] = cur; cur = next; } _tables[i] = nullptr; } _tables.swap(newtables); } size_t hashi = hs(kot(data)) % _tables.size(); // 头插 Node* newNode = new Node(data); newNode->_next = _tables[hashi]; _tables[hashi] = newNode; ++_n; return { Iterator(newNode, this), true }; } Iterator Find(const K& key) { KeyOfT kot; Hash hs; size_t hashi = hs(key) % _tables.size(); Node* cur = _tables[hashi]; while (cur) { if (kot(cur->_data) == key) return { cur, this }; cur = cur->_next; } return End(); } bool Erase(const K& key) { KeyOfT kot; Hash hs; size_t hashi = hs(key) % _tables.size(); Node* prev = nullptr; Node* cur = _tables[hashi]; while (cur) { if (kot(cur->_data) == key) { if (prev == nullptr) { _tables[hashi] = cur->_next; } else { prev->_next = cur->_next; } delete cur; return true; } prev = cur; cur = cur->_next; } return false; } private: //vector<list<pair<K, V>>> _tables; vector<Node*> _tables; size_t _n = 0; // 实际存储的数据个数 };
Unordered_Set.h
cpp#include"HashTable.h" namespace bit { template<class K, class Hash = HashFunc<K>> class unordered_set { struct SetKeyOfT { const K& operator()(const K& key) { return key; } }; public: typedef typename HashTable<K, const K, SetKeyOfT, Hash>::Iterator iterator; typedef typename HashTable<K, const K, SetKeyOfT, Hash>::ConstIterator const_iterator; iterator begin() { return _t.Begin(); } iterator end() { return _t.End(); } const_iterator begin() const { return _t.Begin(); } const_iterator end() const { return _t.End(); } pair<iterator, bool> insert(const K& k) { return _t.Insert(k); } bool erase(const K& key) { return _t.Erase(key); } iterator find(const K& key) { return _t.Find(key); } private: HashTable<K, const K, SetKeyOfT, Hash> _t; }; void Func(const unordered_set<int>& s) { auto it1 = s.begin(); while (it1 != s.end()) { // *it1 = 1; cout << *it1 << " "; ++it1; } cout << endl; } struct Date { int _year; int _month; int _day; bool operator==(const Date& d) const { return _year == d._year && _month == d._month && _day == d._day; } }; struct DateHashFunc { // BKDR size_t operator()(const Date& d) { //2025 1 9 //2025 9 1 //2025 2 8 size_t hash = 0; hash += d._year; hash *= 131; hash += d._month; hash *= 131; hash += d._day; hash *= 131; return hash; } }; };
Unordered_Map.h
cpp#include<vector> #include<string> using namespace std; template<class K> struct HashFunc { size_t operator()(const K& key) { return (size_t)key; } }; template<> struct HashFunc<string> { // BKDR size_t operator()(const string& str) { size_t hash = 0; for (auto ch : str) { hash += ch; hash *= 131; } return hash; } }; inline unsigned long __stl_next_prime(unsigned long n) { // Note: assumes long is at least 32 bits. static const int __stl_num_primes = 28; static const unsigned long __stl_prime_list[__stl_num_primes] = { 53, 97, 193, 389, 769, 1543, 3079, 6151, 12289, 24593, 49157, 98317, 196613, 393241, 786433, 1572869, 3145739, 6291469, 12582917, 25165843, 50331653, 100663319, 201326611, 402653189, 805306457, 1610612741, 3221225473, 4294967291 }; const unsigned long* first = __stl_prime_list; const unsigned long* last = __stl_prime_list + __stl_num_primes; const unsigned long* pos = lower_bound(first, last, n); return pos == last ? *(last - 1) : *pos; } template<class T> struct HashNode { T _data; HashNode<T>* _next; HashNode(const T& data) :_data(data) , _next(nullptr) { } }; // 前置声明 template<class K, class T, class KeyOfT, class Hash> class HashTable; template<class K, class T, class Ref, class Ptr, class KeyOfT, class Hash> struct HTIterator { typedef HashNode<T> Node; typedef HashTable<K, T, KeyOfT, Hash> HT; typedef HTIterator<K, T, Ref, Ptr, KeyOfT, Hash> Self; Node* _node; const HT* _ht; HTIterator(Node* node, const HT* ht) :_node(node) , _ht(ht) { } Ref operator*() { return _node->_data; } Ptr operator->() { return &_node->_data; } Self& operator++() { if (_node->_next) // 当前还有节点 { _node = _node->_next; } else // 当前桶为空,找下一个不为空的桶的第一个 { size_t hashi = Hash()(KeyOfT()(_node->_data)) % _ht->_tables.size(); ++hashi; while (hashi != _ht->_tables.size()) { if (_ht->_tables[hashi]) { _node = _ht->_tables[hashi]; break; } hashi++; } // 最后一个桶的最后一个节点已经遍历结束,走到end()去,nullptr充当end() if (hashi == _ht->_tables.size()) { _node = nullptr; } } return *this; } bool operator!=(const Self& s) const { return _node != s._node; } bool operator==(const Self& s) const { return _node == s._node; } }; template<class K, class T, class KeyOfT, class Hash> class HashTable { // 友元声明 template<class K, class T, class Ref, class Ptr, class KeyOfT, class Hash> friend struct HTIterator; typedef HashNode<T> Node; public: typedef HTIterator<K, T, T&, T*, KeyOfT, Hash> Iterator; typedef HTIterator<K, T, const T&, const T*, KeyOfT, Hash> ConstIterator; Iterator Begin() { for (size_t i = 0; i < _tables.size(); i++) { if (_tables[i]) { return Iterator(_tables[i], this); } } return End(); } Iterator End() { return Iterator(nullptr, this); } ConstIterator Begin() const { for (size_t i = 0; i < _tables.size(); i++) { if (_tables[i]) { return ConstIterator(_tables[i], this); } } return End(); } ConstIterator End() const { return ConstIterator(nullptr, this); } HashTable() :_tables(__stl_next_prime(1), nullptr) , _n(0) { } ~HashTable() { for (size_t i = 0; i < _tables.size(); i++) { Node* cur = _tables[i]; // 当前桶的节点重新映射挂到新表 while (cur) { Node* next = cur->_next; delete cur; cur = next; } _tables[i] = nullptr; } } pair<Iterator, bool> Insert(const T& data) { KeyOfT kot; auto it = Find(kot(data)); if (it != End()) return { it, false }; Hash hs; // 负载因子==1扩容 if (_n == _tables.size()) { //HashTable<K, V> newHT; //newHT._tables.resize(_tables.size()*2); //// 遍历旧表将所有值映射到新表 //for (auto cur : _tables) //{ // while (cur) // { // newHT.Insert(cur->_kv); // cur = cur->_next; // } //} //_tables.swap(newHT._tables); vector<Node*> newtables(__stl_next_prime(_tables.size() + 1)); for (size_t i = 0; i < _tables.size(); i++) { Node* cur = _tables[i]; // 当前桶的节点重新映射挂到新表 while (cur) { Node* next = cur->_next; // 插入到新表 size_t hashi = hs(kot(cur->_data)) % newtables.size(); cur->_next = newtables[hashi]; newtables[hashi] = cur; cur = next; } _tables[i] = nullptr; } _tables.swap(newtables); } size_t hashi = hs(kot(data)) % _tables.size(); // 头插 Node* newNode = new Node(data); newNode->_next = _tables[hashi]; _tables[hashi] = newNode; ++_n; return { Iterator(newNode, this), true }; } Iterator Find(const K& key) { KeyOfT kot; Hash hs; size_t hashi = hs(key) % _tables.size(); Node* cur = _tables[hashi]; while (cur) { if (kot(cur->_data) == key) return { cur, this }; cur = cur->_next; } return End(); } bool Erase(const K& key) { KeyOfT kot; Hash hs; size_t hashi = hs(key) % _tables.size(); Node* prev = nullptr; Node* cur = _tables[hashi]; while (cur) { if (kot(cur->_data) == key) { if (prev == nullptr) { _tables[hashi] = cur->_next; } else { prev->_next = cur->_next; } delete cur; return true; } prev = cur; cur = cur->_next; } return false; } private: //vector<list<pair<K, V>>> _tables; vector<Node*> _tables; size_t _n = 0; // 实际存储的数据个数 };

