【C++】 手撕哈希表:封装 unordered_set和unordered_map


1. 前言

在 STL 中,unordered_setunordered_map的底层都是 哈希表。它们的区别仅仅在于:

容器 存储单元 是否允许重复 键是否可变
unordered_set 单个值
unordered_map 键值对 键不可变,值可变

本文将展示:如何用同一份哈希表代码,通过模板参数差异,分别实现这两种容器。


2. 核心设计思想:模板参数萃取

这是 STL 最精妙的设计之一 👇

cpp 复制代码
template<
    class K,        // 关键码类型
    class T,        // 节点中真正存储的数据
    class KeyOfT,   // 从 T 中提取 K 的方法
    class Hash      // 哈希函数
>
class HashTable;
容器 T KeyOfT
set K 返回自身
map pair<const K, V> 返回 first

✅ 一套代码,两种形态


3. 底层哈希表回顾(简化版)

cpp 复制代码
template<class K, class T, class KeyOfT, class Hash>
class HashTable {
public:
    using Node = HashNode<T>;
    using Iterator = __HashIterator<K, T, KeyOfT, Hash>;

    pair<Iterator, bool> Insert(const T& data) { /* ... */ }
    Iterator Find(const K& key) { /* ... */ }
    bool Erase(const K& key) { /* ... */ }

private:
    vector<Node*> _tables;
    size_t _n = 0;
};

⚠️ 注意:哈希表只关心 T​ 和 如何从 T 取 K,不关心它是 set 还是 map。


4. 封装 unordered_set

4.1 设计要点

  • 存储的是 单个值

  • 值即键

  • 值不可修改

4.2 代码实现

cpp 复制代码
namespace wxx {

template<class K, class Hash = HashFunc<K>>
class unordered_set {
    struct SetKeyOfT {
        const K& operator()(const K& key) const {
            return key;
        }
    };

public:
    using iterator = typename HashTable<K, K, SetKeyOfT, Hash>::Iterator;

    pair<iterator, bool> insert(const K& key) {
        return _ht.Insert(key);
    }

    iterator find(const K& key) {
        return _ht.Find(key);
    }

    bool erase(const K& key) {
        return _ht.Erase(key);
    }

    iterator begin() { return _ht.Begin(); }
    iterator end() { return _ht.End(); }

private:
    HashTable<K, K, SetKeyOfT, Hash> _ht;
};

} // namespace bit

✅ set 的本质:一个"只有键"的哈希表


5. 封装 unordered_map

5.1 设计要点

  • 存储的是 pair<const K, V>

  • 键来自 pair.first

  • 键不可修改,值可修改

5.2 代码实现

cpp 复制代码
namespace wxx {

template<class K, class V, class Hash = HashFunc<K>>
class unordered_map {
    struct MapKeyOfT {
        const K& operator()(const pair<const K, V>& kv) const {
            return kv.first;
        }
    };

public:
    using iterator = typename HashTable<K, pair<const K, V>, MapKeyOfT, Hash>::Iterator;

    pair<iterator, bool> insert(const pair<const K, V>& kv) {
        return _ht.Insert(kv);
    }

    iterator find(const K& key) {
        return _ht.Find(key);
    }

    bool erase(const K& key) {
        return _ht.Erase(key);
    }

    V& operator[](const K& key) {
        auto ret = _ht.Insert({key, V()});
        return ret.first->second;
    }

    iterator begin() { return _ht.Begin(); }
    iterator end() { return _ht.End(); }

private:
    HashTable<K, pair<const K, V>, MapKeyOfT, Hash> _ht;
};

} // namespace wxx

✅ map 的本质:一个"键值对"的哈希表


6. set 与 map 的差异对比

对比项 unordered_set unordered_map
存储类型 K pair<const K, V>
KeyOfT 返回自身 返回 first
operator\[\]
值是否可改 ✅(仅 second)
底层哈希表 同一套 同一套

7. 完整代码

HashTable.h

cpp 复制代码
#pragma once
#include<vector>
#include<string>
#include<iostream>
using namespace std;
template<class K>
struct HashFunc
{
	size_t operator()(const K& key)
	{
		return (size_t)key;
	}
};


template<>
struct HashFunc<string>
{
	// BKDR
	size_t operator()(const string& str)
	{
		size_t hash = 0;
		for (auto ch : str)
		{
			hash += ch;
			hash *= 131;
		}

		return hash;
	}
};


inline unsigned long __stl_next_prime(unsigned long n)
{
	// Note: assumes long is at least 32 bits.
	static const int __stl_num_primes = 28;
	static const unsigned long __stl_prime_list[__stl_num_primes] =
	{
		53, 97, 193, 389, 769,
		1543, 3079, 6151, 12289, 24593,
		49157, 98317, 196613, 393241, 786433,
		1572869, 3145739, 6291469, 12582917, 25165843,
		50331653, 100663319, 201326611, 402653189, 805306457,
		1610612741, 3221225473, 4294967291
	};
	const unsigned long* first = __stl_prime_list;
	const unsigned long* last = __stl_prime_list + __stl_num_primes;
	const unsigned long* pos = lower_bound(first, last, n);
	return pos == last ? *(last - 1) : *pos;
}


template<class T>
struct HashNode
{
	T _data;
	HashNode<T>* _next;

	HashNode(const T& data)
		:_data(data)
		, _next(nullptr)
	{
	}
};

// 前置声明
template<class K, class T, class KeyOfT, class Hash>
class HashTable;

template<class K, class T, class Ref, class Ptr, class KeyOfT, class Hash>
struct HTIterator
{
	typedef HashNode<T> Node;
	typedef HashTable<K, T, KeyOfT, Hash> HT;
	typedef HTIterator<K, T, Ref, Ptr, KeyOfT, Hash> Self;

	Node* _node;
	const HT* _ht;

	HTIterator(Node* node, const HT* ht)
		:_node(node)
		, _ht(ht)
	{
	}

	Ref operator*()
	{
		return _node->_data;
	}

	Ptr operator->()
	{
		return &_node->_data;
	}

	Self& operator++()
	{
		if (_node->_next)  // 当前还有节点
		{
			_node = _node->_next;
		}
		else  // 当前桶为空,找下一个不为空的桶的第一个
		{
			size_t hashi = Hash()(KeyOfT()(_node->_data)) % _ht->_tables.size();
			++hashi;
			while (hashi != _ht->_tables.size())
			{
				if (_ht->_tables[hashi])
				{
					_node = _ht->_tables[hashi];
					break;
				}

				hashi++;
			}

			// 最后一个桶的最后一个节点已经遍历结束,走到end()去,nullptr充当end()
			if (hashi == _ht->_tables.size())
			{
				_node = nullptr;
			}
		}

		return *this;
	}

	bool operator!=(const Self& s) const
	{
		return _node != s._node;
	}

	bool operator==(const Self& s) const
	{
		return _node == s._node;
	}
};


template<class K, class T, class KeyOfT, class Hash>
class HashTable
{
	// 友元声明
	template<class K, class T, class Ref, class Ptr, class KeyOfT, class Hash>
	friend struct HTIterator;

	typedef HashNode<T> Node;
public:
	typedef HTIterator<K, T, T&, T*, KeyOfT, Hash> Iterator;
	typedef HTIterator<K, T, const T&, const T*, KeyOfT, Hash> ConstIterator;

	Iterator Begin()
	{
		for (size_t i = 0; i < _tables.size(); i++)
		{
			if (_tables[i])
			{
				return Iterator(_tables[i], this);
			}
		}

		return End();
	}

	Iterator End()
	{
		return Iterator(nullptr, this);
	}

	ConstIterator Begin() const
	{
		for (size_t i = 0; i < _tables.size(); i++)
		{
			if (_tables[i])
			{
				return ConstIterator(_tables[i], this);
			}
		}

		return End();
	}

	ConstIterator End() const
	{
		return ConstIterator(nullptr, this);
	}

	HashTable()
		:_tables(__stl_next_prime(1), nullptr)
		, _n(0)
	{
	}

	~HashTable()
	{
		for (size_t i = 0; i < _tables.size(); i++)
		{
			Node* cur = _tables[i];
			// 当前桶的节点重新映射挂到新表
			while (cur)
			{
				Node* next = cur->_next;
				delete cur;
				cur = next;
			}

			_tables[i] = nullptr;
		}
	}

	pair<Iterator, bool> Insert(const T& data)
	{
		KeyOfT kot;
		auto it = Find(kot(data));
		if (it != End())
			return { it, false };

		Hash hs;
		// 负载因子==1扩容
		if (_n == _tables.size())
		{
			//HashTable<K, V> newHT;
			//newHT._tables.resize(_tables.size()*2);
			//// 遍历旧表将所有值映射到新表
			//for (auto cur : _tables)
			//{
			//	while (cur)
			//	{
			//		newHT.Insert(cur->_kv);
			//		cur = cur->_next;
			//	}
			//}
			//_tables.swap(newHT._tables);

			vector<Node*> newtables(__stl_next_prime(_tables.size() + 1));
			for (size_t i = 0; i < _tables.size(); i++)
			{
				Node* cur = _tables[i];
				// 当前桶的节点重新映射挂到新表
				while (cur)
				{
					Node* next = cur->_next;

					// 插入到新表
					size_t hashi = hs(kot(cur->_data)) % newtables.size();
					cur->_next = newtables[hashi];
					newtables[hashi] = cur;

					cur = next;
				}

				_tables[i] = nullptr;
			}

			_tables.swap(newtables);
		}

		size_t hashi = hs(kot(data)) % _tables.size();
		// 头插
		Node* newNode = new Node(data);
		newNode->_next = _tables[hashi];
		_tables[hashi] = newNode;

		++_n;
		return { Iterator(newNode, this), true };
	}

	Iterator Find(const K& key)
	{
		KeyOfT kot;
		Hash hs;
		size_t hashi = hs(key) % _tables.size();
		Node* cur = _tables[hashi];
		while (cur)
		{
			if (kot(cur->_data) == key)
				return { cur, this };

			cur = cur->_next;
		}

		return End();
	}

	bool Erase(const K& key)
	{
		KeyOfT kot;
		Hash hs;
		size_t hashi = hs(key) % _tables.size();
		Node* prev = nullptr;
		Node* cur = _tables[hashi];
		while (cur)
		{
			if (kot(cur->_data) == key)
			{
				if (prev == nullptr)
				{
					_tables[hashi] = cur->_next;
				}
				else
				{
					prev->_next = cur->_next;
				}

				delete cur;

				return true;
			}

			prev = cur;
			cur = cur->_next;
		}

		return false;
	}

private:
	//vector<list<pair<K, V>>> _tables;
	vector<Node*> _tables;
	size_t _n = 0;  // 实际存储的数据个数
};

Unordered_Set.h

cpp 复制代码
#include"HashTable.h"

namespace bit
{
	template<class K, class Hash = HashFunc<K>>
	class unordered_set
	{
		struct SetKeyOfT
		{
			const K& operator()(const K& key)
			{
				return key;
			}
		};
	public:
		typedef typename HashTable<K, const K, SetKeyOfT, Hash>::Iterator iterator;
		typedef typename HashTable<K, const K, SetKeyOfT, Hash>::ConstIterator const_iterator;

		iterator begin()
		{
			return _t.Begin();
		}

		iterator end()
		{
			return _t.End();
		}

		const_iterator begin() const
		{
			return _t.Begin();
		}

		const_iterator end() const
		{
			return _t.End();
		}

		pair<iterator, bool> insert(const K& k)
		{
			return _t.Insert(k);
		}

		bool erase(const K& key)
		{
			return _t.Erase(key);
		}

		iterator find(const K& key)
		{
			return _t.Find(key);
		}

	private:
		HashTable<K, const K, SetKeyOfT, Hash> _t;
	};

	void Func(const unordered_set<int>& s)
	{
		auto it1 = s.begin();
		while (it1 != s.end())
		{
			// *it1 = 1;

			cout << *it1 << " ";
			++it1;
		}
		cout << endl;
	}
	struct Date
	{
		int _year;
		int _month;
		int _day;

		bool operator==(const Date& d) const
		{
			return _year == d._year
				&& _month == d._month
				&& _day == d._day;
		}
	};

	struct DateHashFunc
	{
		// BKDR
		size_t operator()(const Date& d)
		{
			//2025 1 9
			//2025 9 1
			//2025 2 8
			size_t hash = 0;
			hash += d._year;
			hash *= 131;

			hash += d._month;
			hash *= 131;

			hash += d._day;
			hash *= 131;

			return hash;
		}
	};
};

Unordered_Map.h

cpp 复制代码
#include<vector>
#include<string>
using namespace std;
template<class K>
struct HashFunc
{
	size_t operator()(const K& key)
	{
		return (size_t)key;
	}
};


template<>
struct HashFunc<string>
{
	// BKDR
	size_t operator()(const string& str)
	{
		size_t hash = 0;
		for (auto ch : str)
		{
			hash += ch;
			hash *= 131;
		}

		return hash;
	}
};


inline unsigned long __stl_next_prime(unsigned long n)
{
	// Note: assumes long is at least 32 bits.
	static const int __stl_num_primes = 28;
	static const unsigned long __stl_prime_list[__stl_num_primes] =
	{
		53, 97, 193, 389, 769,
		1543, 3079, 6151, 12289, 24593,
		49157, 98317, 196613, 393241, 786433,
		1572869, 3145739, 6291469, 12582917, 25165843,
		50331653, 100663319, 201326611, 402653189, 805306457,
		1610612741, 3221225473, 4294967291
	};
	const unsigned long* first = __stl_prime_list;
	const unsigned long* last = __stl_prime_list + __stl_num_primes;
	const unsigned long* pos = lower_bound(first, last, n);
	return pos == last ? *(last - 1) : *pos;
}


template<class T>
struct HashNode
{
	T _data;
	HashNode<T>* _next;

	HashNode(const T& data)
		:_data(data)
		, _next(nullptr)
	{
	}
};

// 前置声明
template<class K, class T, class KeyOfT, class Hash>
class HashTable;

template<class K, class T, class Ref, class Ptr, class KeyOfT, class Hash>
struct HTIterator
{
	typedef HashNode<T> Node;
	typedef HashTable<K, T, KeyOfT, Hash> HT;
	typedef HTIterator<K, T, Ref, Ptr, KeyOfT, Hash> Self;

	Node* _node;
	const HT* _ht;

	HTIterator(Node* node, const HT* ht)
		:_node(node)
		, _ht(ht)
	{
	}

	Ref operator*()
	{
		return _node->_data;
	}

	Ptr operator->()
	{
		return &_node->_data;
	}

	Self& operator++()
	{
		if (_node->_next)  // 当前还有节点
		{
			_node = _node->_next;
		}
		else  // 当前桶为空,找下一个不为空的桶的第一个
		{
			size_t hashi = Hash()(KeyOfT()(_node->_data)) % _ht->_tables.size();
			++hashi;
			while (hashi != _ht->_tables.size())
			{
				if (_ht->_tables[hashi])
				{
					_node = _ht->_tables[hashi];
					break;
				}

				hashi++;
			}

			// 最后一个桶的最后一个节点已经遍历结束,走到end()去,nullptr充当end()
			if (hashi == _ht->_tables.size())
			{
				_node = nullptr;
			}
		}

		return *this;
	}

	bool operator!=(const Self& s) const
	{
		return _node != s._node;
	}

	bool operator==(const Self& s) const
	{
		return _node == s._node;
	}
};


template<class K, class T, class KeyOfT, class Hash>
class HashTable
{
	// 友元声明
	template<class K, class T, class Ref, class Ptr, class KeyOfT, class Hash>
	friend struct HTIterator;

	typedef HashNode<T> Node;
public:
	typedef HTIterator<K, T, T&, T*, KeyOfT, Hash> Iterator;
	typedef HTIterator<K, T, const T&, const T*, KeyOfT, Hash> ConstIterator;

	Iterator Begin()
	{
		for (size_t i = 0; i < _tables.size(); i++)
		{
			if (_tables[i])
			{
				return Iterator(_tables[i], this);
			}
		}

		return End();
	}

	Iterator End()
	{
		return Iterator(nullptr, this);
	}

	ConstIterator Begin() const
	{
		for (size_t i = 0; i < _tables.size(); i++)
		{
			if (_tables[i])
			{
				return ConstIterator(_tables[i], this);
			}
		}

		return End();
	}

	ConstIterator End() const
	{
		return ConstIterator(nullptr, this);
	}

	HashTable()
		:_tables(__stl_next_prime(1), nullptr)
		, _n(0)
	{
	}

	~HashTable()
	{
		for (size_t i = 0; i < _tables.size(); i++)
		{
			Node* cur = _tables[i];
			// 当前桶的节点重新映射挂到新表
			while (cur)
			{
				Node* next = cur->_next;
				delete cur;
				cur = next;
			}

			_tables[i] = nullptr;
		}
	}

	pair<Iterator, bool> Insert(const T& data)
	{
		KeyOfT kot;
		auto it = Find(kot(data));
		if (it != End())
			return { it, false };

		Hash hs;
		// 负载因子==1扩容
		if (_n == _tables.size())
		{
			//HashTable<K, V> newHT;
			//newHT._tables.resize(_tables.size()*2);
			//// 遍历旧表将所有值映射到新表
			//for (auto cur : _tables)
			//{
			//	while (cur)
			//	{
			//		newHT.Insert(cur->_kv);
			//		cur = cur->_next;
			//	}
			//}
			//_tables.swap(newHT._tables);

			vector<Node*> newtables(__stl_next_prime(_tables.size() + 1));
			for (size_t i = 0; i < _tables.size(); i++)
			{
				Node* cur = _tables[i];
				// 当前桶的节点重新映射挂到新表
				while (cur)
				{
					Node* next = cur->_next;

					// 插入到新表
					size_t hashi = hs(kot(cur->_data)) % newtables.size();
					cur->_next = newtables[hashi];
					newtables[hashi] = cur;

					cur = next;
				}

				_tables[i] = nullptr;
			}

			_tables.swap(newtables);
		}

		size_t hashi = hs(kot(data)) % _tables.size();
		// 头插
		Node* newNode = new Node(data);
		newNode->_next = _tables[hashi];
		_tables[hashi] = newNode;

		++_n;
		return { Iterator(newNode, this), true };
	}

	Iterator Find(const K& key)
	{
		KeyOfT kot;
		Hash hs;
		size_t hashi = hs(key) % _tables.size();
		Node* cur = _tables[hashi];
		while (cur)
		{
			if (kot(cur->_data) == key)
				return { cur, this };

			cur = cur->_next;
		}

		return End();
	}

	bool Erase(const K& key)
	{
		KeyOfT kot;
		Hash hs;
		size_t hashi = hs(key) % _tables.size();
		Node* prev = nullptr;
		Node* cur = _tables[hashi];
		while (cur)
		{
			if (kot(cur->_data) == key)
			{
				if (prev == nullptr)
				{
					_tables[hashi] = cur->_next;
				}
				else
				{
					prev->_next = cur->_next;
				}

				delete cur;

				return true;
			}

			prev = cur;
			cur = cur->_next;
		}

		return false;
	}

private:
	//vector<list<pair<K, V>>> _tables;
	vector<Node*> _tables;
	size_t _n = 0;  // 实际存储的数据个数
};

相关推荐
Rookie Linux1 小时前
使用Qt6 QML以及第三方库FluentUI、PCapPlusPlus开发一个自定义抓包软件
网络·c++·qt·cmake·qml
江屿风2 小时前
C++图论基础拓扑排序算法流食般投喂
开发语言·c++·笔记·算法·排序算法
郝学胜-神的一滴2 小时前
Qt 高级开发 030:QListWidget 右键菜单全解,从策略配置到精准删除的优雅实现
开发语言·c++·qt·程序人生·用户界面
黄金龙PLUS2 小时前
基于ARX结构的新型序列密码算法FlashLight
算法·网络安全·密码学·哈希算法·同态加密
码上有光2 小时前
map与set的使用讲解
c++·set·map·平衡二叉搜索树·关联式容器
Irissgwe2 小时前
C++ STL unordered系列关联式容器详解
开发语言·c++·stl·关联式容器
fqbqrr10 小时前
2606C++,C++构的多态
开发语言·c++
小欣加油11 小时前
leetcode56 合并区间
c++·算法·leetcode·职场和发展
Yolo_TvT12 小时前
C++:析构函数
c++