C++:封装红黑树实现map和set

封装红黑树实现map和set

框架

cpp 复制代码
//RBTree.h
enum Colour
{
	BLACK,
	RED
};

template<class T>
struct RBTreeNode
{
	T _data;
	RBTreeNode<T>* _left;
	RBTreeNode<T>* _right;
	RBTreeNode<T>* _parent;
	Colour _col;

	RBTreeNode(const T& data)
		:_data(data)
		,_left(nullptr)
		,_right(nullptr)
		,_parent(nullptr)
		,_col(RED)
	{ }
};

template<class T,class Ref,class Ptr>
struct RBTreeIterator
{
	typedef RBTreeNode<T> Node;
	typedef RBTreeIterator<T, Ref, Ptr> Self;

	Node* _node;
	Node* _root;//用在实现 迭代器--

	RBTreeIterator(Node* node, Node* root)
		:_node(node)
		,_root(root)
	{ }
    
    Self operator++()
    {}
    Self operator--()
    {}
    
    Ref operator*()//解引用
    {}
    Ptr operator->()//取pair里的值
    {}
    
    //两个迭代器作比较
    bool operator!=(const Self& s) const
    {}
    bool operator==(const Self& s) const
    {}    
};

template<class K,class T,class KeyOfT>
class RBTree
{
	typedef RBTreeNode<T> Node;
public:
	typedef RBTreeIterator<T, T&, T*> Iterator;
	typedef RBTreeIterator<T, const T&, const T*> ConstIterator;

	RBTree() = default;
	~RBTree()
	{
		Destroy(_root);
		_root = nullptr;
	}
    
    Iterator Begin()
    {}
    Iterator End()
    {}
    ConstIterator Begin() const
    {}
    ConstIterator End() const
    {}
    
    pair<Iterator, bool> Insert(const T& data)
    {}
    
    Iterator Find(const K& key)
    {}
    int Height()
    {}
    int Size()
    {}
private:
    int _HeightTree(Node* root)
    {}
    int _Size(Node* root)
    {}
    void Destroy(Node* root)
    {}
    void RotateR(Node* parent)
    {}
    void RotateL(Node* parent)
    {}
private:
	Node* _root = nullptr;
};

//Myset.h
#include"RBTree.h"

namespace mine
{
	template<class K>
	class set
	{
		struct SetOfT
		{
			const K& operator()(const K& key)
			{
				return key;
			}
		};
	public:
		typedef typename RBTree<K, const K, SetOfT>::Iterator iterator;
		typedef typename RBTree<K, const K, SetOfT>::ConstIterator const_iterator;

		iterator begin()
		{
			return _t.Begin();
		}
		iterator end()
		{
			return _t.End();
		}
		const_iterator begin() const
		{
			return _t.Begin();
		}
		const_iterator end() const
		{
			return _t.End();
		}

		pair<iterator, bool> insert(const K& key)
		{
			return _t.Insert(key);
		}

		iterator find(const K& key)
		{
			return _t.Find(key);
		}

	private:
		RBTree<K,const K, SetOfT> _t;
	};
}
    

//Mymap.h
#include"RBTree.h"

namespace mine
{
	template<class K,class V>
	class map
	{
		struct MapOfT
		{
			const K& operator()(const pair<K, V>& kv)
			{
				return kv.first;
			}
		};
	public:
		typedef typename RBTree<K, pair<const K, V>, MapOfT>::Iterator iterator;
		typedef typename RBTree<K, pair<const K, V>, MapOfT>::ConstIterator const_iterator;

		iterator begin()
		{
			return _t.Begin();
		}
		iterator end()
		{
			return _t.End();
		}
		const_iterator begin() const
		{
			return _t.Begin();
		}
		const_iterator end() const
		{
			return _t.End();
		}

		pair<iterator, bool> insert(const pair<K, V>& kv)
		{
			return _t.Insert(kv);
		}

		V& operator[](const K& key) 
		{
			pair<iterator, bool> ret = insert({ key,V() });
			return ret.first->second;
		}

		iterator find(const K& key)
		{
			return _t.Find(key);
		}

	private:
		RBTree<K, pair<const K, V>, MapOfT> _t;
	};
}

key参数用K,value参数用V,红黑树中的数据类型使用T。

RBTree实现了泛型,但不知道T参数是K,还是pair<K, V>。所以我们在map和set层分别实现⼀个MapOfT和SetOfT的仿函数传给

RBTree的KeyOfT,然后RBTree中通过KeyOfT仿函数取出T类型对象中的key,这样才方便比较。

实现iterator

iterator实现的大框架跟list的iterator思路是⼀致的,用一个类型封装结点的指针,再通过重载运算符实现,使迭代器像指针⼀样访问的行为。

迭代器++

迭代器++的核心逻辑就是不看全局,只看局部,只考虑当前中序局部要访问的下⼀个结点。

若set或map里的key是{1,2,3,4,5,6},迭代器it指向的是key为3的节点,迭代器++,即要让it指向key为4的节点。因为set或map里的数据是有序的,所以迭代器++就是要根据底层的红黑树,找到大于当前迭代器指向的key。

迭代器++时,如果it指向的结点的右子树不为空 ,要访问下一个结点是右子树的中序第一个 ,⼀棵树中序第⼀个是最左结点,所以直接找右子树的最左结点即可。因为默认小的在左,大的在右,++it即在右子树中找比现在的it指向的节点的key大的,右子树所有节点的key都比现在it的key大,应该找右子树中最小的key,即右子树的最左节点。

迭代器++时,如果it指向的结点的右子树为空,代表当前结点 及 当前结点所在的子树也访问完了,要访问的下一个结点在当前结点的祖先中,所以要沿着当前结点到根的路径向上找。

迭代器++,it指向的节点右子树为空,如何正确找到下一个要访问的节点?从前面的框架 可以找到it有2个成员变量,其中一个就是指向当前节点的指针_node,让cur为it-> _node,cur指向节点的父亲节点为parent,不断向上回溯,当parent的左孩子为cur时,parent就是++it后it所在的位置。因为it指向节点的右为空说明该节点的key已经是某一子树里最大的key了,回溯分析如下图。

下面这种情况也是同理。

本质 还是二叉搜索树的"小于根的去左边,大于根的去右边",以及set和map要求有序

end()

用nullptr作为end()

源码中在红黑树增加了一个哨兵位作头节点,这个哨兵位作为end(),哨兵位头结点和根互为⽗亲,左指向最左结点,右指向最右结点。

迭代器--

迭代器it为end(),这时--it即找整棵树的最右节点。

--it与++it的思路正好相反,++it是按照中序"左->根->右",--it即"右->根->左" 。当it不是end()时,如果it->_node的左子树不为空,就在左子树中找最右节点

it->_node的左子树为空,让cur指向当前节点,记录cur的父亲节点parent,用cur回溯,直到cur为parent的右。

**总结:**set和map要求有序,所以迭代器++/--就是找比当前key大的或小的,根据二叉搜索树和中序的性质去找,++要么找右子树最左节点,要么回溯找满足parent的左是cur的parent节点;--要么找直接找整棵树的最右节点,要么在左子树中找最右节点,再要么回溯找满足parent的右是cur的parent节点。

实现map的[]

map中的operator[]是用来修改键值对中的value,还兼有插入查找功能。在RBTree中实现的Insert的返回值是pair<Iterator,bool>,Iterator中的_node指向插入节点,用ret接收这个返回值,ret.first即迭代器,ret.first->second即ret.first调用operator->,返回ret.first中 _node的 _data的地址即pair<K,V>的地址,ret.first->second本质是ret.first.operator->()->second(编译器会自动补->),最终得到value。

总的代码

cpp 复制代码
//RBTree.h
enum Colour
{
	BLACK,
	RED
};

template<class T>
struct RBTreeNode
{
	T _data;
	RBTreeNode<T>* _left;
	RBTreeNode<T>* _right;
	RBTreeNode<T>* _parent;
	Colour _col;

	RBTreeNode(const T& data)
		:_data(data)
		,_left(nullptr)
		,_right(nullptr)
		,_parent(nullptr)
		,_col(RED)
	{ }
};

template<class T,class Ref,class Ptr>
struct RBTreeIterator
{
	typedef RBTreeNode<T> Node;
	typedef RBTreeIterator<T, Ref, Ptr> Self;

	Node* _node;
	Node* _root;//用在实现 迭代器--

	RBTreeIterator(Node* node, Node* root)
		:_node(node)
		,_root(root)
	{ }

	Self operator++()
	{//中序:左 根 右
		if (_node->_right)//当前_node的右不为空,左、根已访问,
		{                                 //       |
			Node* min = _node->_right;    //     _node
			while (min->_left)            //    /     \
			                              //          node
			{	                          //        min
				min = min->_left;
			}
			_node = min;
		}
		else //当前_node右为空,用cur回溯找祖先,直到祖先的左为cur
		{
			Node* cur = _node;
			Node* parent = cur->_parent;
			while (parent && cur == parent->_right)
			{
				cur = parent;
				parent = cur->_parent;
			}
			_node = parent;
		}
		return *this;
	}
	Self operator--()
	{
		if (_node == nullptr)
		{
			//--End()
			Node* rightMost = _root;
			while (rightMost && rightMost->_right)
			{
				rightMost = rightMost->_right;
			}
			_node = rightMost;
		}
		else if (_node->_left)
		{
			//左子树不为空,左子树中序遍历的最后一个,即一直往右找
			Node* rightMost = _node->_left;
			while (rightMost->_right)
			{
				rightMost = rightMost->_right;
			}
			_node = rightMost;
		}
		else
		{
			//_node不为空且左子树为空,用cur回溯找祖先,直到cur为祖先的右
			Node* cur = _node;
			Node* parent = cur->_parent;
			while (parent && cur == parent->_left)
			{
				cur = parent;
				parent = cur->_parent;
			}
			_node = parent;
		}
		return *this;
	}

	Ref operator*()//解引用
	{
		return _node->_data;
	}
	Ptr operator->()//取pair里的值
	{
		return &_node->_data;
	}

	//两个迭代器作比较
	bool operator!=(const Self& s) const
	{
		return _node != s._node;
	}
	bool operator==(const Self& s) const
	{
		return _node == s._node;
	}
};

template<class K,class T,class KeyOfT>
class RBTree
{
	typedef RBTreeNode<T> Node;
public:
	typedef RBTreeIterator<T, T&, T*> Iterator;
	typedef RBTreeIterator<T, const T&, const T*> ConstIterator;

	RBTree() = default;
	~RBTree()
	{
		Destroy(_root);
		_root = nullptr;
	}

	Iterator Begin()
	{
		Node* cur = _root;
		while (cur && cur->_left)
		{
			cur = cur->_left;
		}
		return Iterator(cur, _root);
	}
	Iterator End()
	{
		return Iterator(nullptr, _root);
	}
	ConstIterator Begin() const
	{
		Node* cur = _root;
		while (cur && cur->_left)
		{
			cur = cur->_left;
		}
		return ConstIterator(cur, _root);
	}
	ConstIterator End() const
	{
		return ConstIterator(nullptr, _root);
	}

	pair<Iterator, bool> Insert(const T& data)
	{
		if (_root == nullptr)
		{
			_root = new Node(data);
			_root->_col = BLACK;
			return { Iterator(_root,_root),true };
		}

		KeyOfT kot;
		Node* cur = _root;
		Node* parent = nullptr;

		while (cur)
		{
			if (kot(cur->_data) > kot(data))
			{
				parent = cur;
				cur = cur->_left;
			}
			else if (kot(cur->_data) < kot(data))
			{
				parent = cur;
				cur = cur->_right;
			}
			else
				return { Iterator(cur,_root),false };
		}

		cur = new Node(data);
		Node* newnode = cur;
		cur->_col = RED;

		if (kot(parent->_data) > kot(cur->_data))//???
			parent->_left = cur;
		else
			parent->_right = cur;
		cur->_parent = parent;

		while (parent && parent->_col == RED)
		{
			Node* grandfather = parent->_parent;

			//    g
			//  p   u
 			if (parent == grandfather->_left)
			{
				Node* uncle = grandfather->_right;

				if (uncle && uncle->_col == RED)
				{
					parent->_col = uncle->_col = BLACK;
					grandfather->_col = RED;

					cur = grandfather;
					parent = cur->_parent;
				}
				else
				{   //     g
					//   p   u
					//  c
					if (cur == parent->_left)
					{
						RotateR(grandfather);
						parent->_col = BLACK;
						grandfather->_col = RED;
					}
					//    g
					//  p   u
					//    c
					else
					{
						RotateL(parent);
						RotateR(grandfather);
						grandfather->_col = RED;
						cur->_col = BLACK;
					}
					break;
				}
			}
			//  g
			//u   p
			else
			{
				Node* uncle = grandfather->_left;
				if (uncle && uncle->_col == RED)
				{
					parent->_col = uncle->_col = BLACK;
					grandfather->_col = RED;

					cur = grandfather;
					parent = cur->_parent;
				}
				else
				{   //    g
					//  u   p
					//       c
					if (cur == parent->_right)
					{
						RotateL(grandfather);
						parent->_col = BLACK;
						grandfather->_col = RED;
					}
					//    g
					//  u   p
					//     c
					else
					{
						RotateR(parent);
						RotateL(grandfather);
						cur->_col = BLACK;
						grandfather->_col = RED;
					}
					break;
				}
			}
		}
		_root->_col = BLACK;
		return { Iterator(newnode,_root),true };
	}

	Iterator Find(const K& key)
	{
		Node* cur = _root;
		KeyOfT kot;
		while (cur)
		{
			if (kot(cur->_data) < key)
				cur = cur->_right;
			else if (kot(cur->_data) > key)
				cur = cur->_left;
			else
				return Iterator(cur, _root);
		}
		return End();
	}
	int Height()
	{
		return _HeightTree(_root);
	}
	int Size()
	{
		return _Size(_root);
	}
private:
	int _HeightTree(Node* root)
	{
		if (root == nullptr)
			return 0;
		int leftHeight = _HeightTree(root->_left);
		int rightHeight = _HeightTree(root->_right);
		return leftHeight > rightHeight ? leftHeight + 1 : rightHeight + 1;
	}
	int _Size(Node* root)
	{
		if (root == nullptr)
			return 0;
		return _Size(root->_left) + _Size(root->_right) + 1;
	}

	void Destroy(Node* root)//后序删除
	{
		if (root == nullptr)
			return;
		Destroy(root->_left);
		Destroy(root->_right);
		delete root;
	}

	void RotateR(Node* parent)
	{
		Node* subL = parent->_left;
		if (subL == nullptr)
			return;
		Node* subLR = subL->_right;

		parent->_left = subLR;
		if (subLR)
			subLR->_parent = parent;

		Node* pParent = parent->_parent;
		subL->_right = parent;
		parent->_parent = subL;
		if (parent == _root)
		{
			subL->_parent = nullptr;
			_root = subL;
		}
		else
		{
			if (parent == pParent->_left)
				pParent->_left = subL;
			else
				pParent->_right = subL;
			subL->_parent = pParent;
		}
	}
	void RotateL(Node* parent)
	{
		Node* subR = parent->_right;
		if (subR == nullptr)
			return;
		Node* subRL = subR->_left;

		parent->_right = subRL;
		if (subRL)
			subRL->_parent = parent;

		Node* pParent = parent->_parent;
		subR->_left = parent;
		parent->_parent = subR;
		if (parent == _root)
		{
			subR->_parent = nullptr;
			_root = subR;
		}
		else
		{
			if (parent == pParent->_left)
				pParent->_left = subR;
			else
				pParent->_right = subR;
			subR->_parent = pParent;
		}
	}

private:
	Node* _root = nullptr;
};

//Myset.h
#include"RBTree.h"

namespace mine
{
	template<class K>
	class set
	{
		struct SetOfT
		{
			const K& operator()(const K& key)
			{
				return key;
			}
		};
	public:
		typedef typename RBTree<K, const K, SetOfT>::Iterator iterator;
		typedef typename RBTree<K, const K, SetOfT>::ConstIterator const_iterator;

		iterator begin()
		{
			return _t.Begin();
		}
		iterator end()
		{
			return _t.End();
		}
		const_iterator begin() const
		{
			return _t.Begin();
		}
		const_iterator end() const
		{
			return _t.End();
		}

		pair<iterator, bool> insert(const K& key)
		{
			return _t.Insert(key);
		}

		iterator find(const K& key)
		{
			return _t.Find(key);
		}

	private:
		RBTree<K,const K, SetOfT> _t;
	};
}

//Mymap.h
#include"RBTree.h"

namespace mine
{
	template<class K,class V>
	class map
	{
		struct MapOfT
		{
			const K& operator()(const pair<K, V>& kv)
			{
				return kv.first;
			}
		};
	public:
		typedef typename RBTree<K, pair<const K, V>, MapOfT>::Iterator iterator;
		typedef typename RBTree<K, pair<const K, V>, MapOfT>::ConstIterator const_iterator;

		iterator begin()
		{
			return _t.Begin();
		}
		iterator end()
		{
			return _t.End();
		}
		const_iterator begin() const
		{
			return _t.Begin();
		}
		const_iterator end() const
		{
			return _t.End();
		}

		pair<iterator, bool> insert(const pair<K, V>& kv)
		{
			return _t.Insert(kv);
		}

		V& operator[](const K& key) 
		{
			pair<iterator, bool> ret = insert({ key,V() });
			return ret.first->second;
		}

		iterator find(const K& key)
		{
			return _t.Find(key);
		}

	private:
		RBTree<K, pair<const K, V>, MapOfT> _t;
	};
}

测试

cpp 复制代码
#define _CRT_SECURE_NO_WARNINGS
#include<iostream>
#include<vector>
using namespace std;//要放在set和map前,这样pair前无std::才是对的

#include"MyMap.h"
#include"Myset.h"

int main()
{
	mine::set<int> s;
	s.insert(5);
	s.insert(1);
	s.insert(3);
	s.insert(2);
	s.insert(6);
	mine::set<int>::iterator it = s.begin();
	while (it != s.end())
	{
		//*it += 10;
		cout << *it << " ";
		++it;
	}
	cout << endl;
	for (auto& e : s)
	{
		cout << e << " ";
	}
	cout << endl;
	auto It = s.end();
	while (It != s.begin())
	{
		--It;
		cout << *It << " ";
	}
	cout << endl;

	cout << endl;
	mine::map<string, string> dict;
	dict.insert({ "sort","排序" });
	dict.insert({ "left","左" });
	dict.insert({ "right","右" });
	auto mit = dict.begin();
	while (mit != dict.end())
	{
		/*mit->first += 'x';*/
		//mit->second += "xxx";
		cout << mit->first << ":" << mit->second << endl;
		++mit;
	}
	cout << endl;
	dict["left"] = "左,剩余";
	dict["insert"] = "插入";
	dict["string"];
	for (auto& e : dict)
	{
		cout << e.first << ":" << e.second << endl;
	}
	cout << endl;

	auto dit = dict.find("sort");
	cout << dit->first << ":" << dit->second << endl;

	return 0;
}
相关推荐
汉克老师2 小时前
GESP2024年9月认证C++二级( 第一部分选择题(1-8))
c++·算法·循环结构·分支结构·gesp二级·gesp2级
阿猿收手吧!2 小时前
【C++】syncstream:多线程安全输出的终极方案
c++·安全
xyq20242 小时前
Bootstrap5 滚动监听
开发语言
沐知全栈开发2 小时前
SQL PRIMARY KEY(主键)
开发语言
「QT(C++)开发工程师」2 小时前
Day 1:信号槽原理深入
c++·qt
元亓亓亓2 小时前
考研408--数据结构--day11--最小生成树&最短路径问题&拓扑排序&关键路径
数据结构·考研·最小生成树·拓扑排序·最短路径·关键路径
桂花很香,旭很美2 小时前
[7天实战入门Go语言后端] Day 0:预习——关键词、语法与常用包一览
开发语言·后端·golang
gfdgd xi2 小时前
GXDE OS 25.3.1 更新了!修复更多 bug 了!
linux·c++·操作系统·bug·deepin
Trouvaille ~2 小时前
【Linux】TCP vs UDP深度对比:如何选择与用UDP实现可靠传输
linux·网络·c++·tcp/ip·udp·操作系统