【C++】红黑树模拟实现STL中的map与set

红黑树里面具体存的是什么类型的元素,是由模板参数 T 来决定

  • 如果 T 是 Key 那么就是 set。

  • 如果 T 是 pair<const Key, V>,那么就是 map。

1、定义红黑树的节点结构

cpp 复制代码
// 定义红黑颜色
enum Colour
{
	RED,
	BLACK
};

template<class T>
struct RBTreeNode
{
	RBTreeNode<T>* _left;
	RBTreeNode<T>* _right;
	RBTreeNode<T>* _parent;

    T _data; // 数据域
	Colour _col; // 用来标记节点颜色

	RBTreeNode(const T& data) // 构造函数
		: _left(nullptr)
		, _right(nullptr)
		, _parent(nullptr)
		, _data(data)
	{}
};

2、 定义红黑树结构

定义红黑树结构(KV)

  • K:键值 key 的类型。
  • T:数据的类型,如果是 map,则为 pair<const K, V>;如果是 set,则为 K。

【改造前】
cpp 复制代码
template<class K, class T>
class RBTree
{
	typedef RBTreeNode<T> Node; // 红黑树节点
public:
	RBTree() :_root(nullptr) {}  // 构造函数
    bool Insert(const T& data);  // 插入节点
    // ...
private:
	Node* _root;
};

红黑树的插入节点接口中要先通过比较节点中数据的大小来查找适合插入的位置,但是红黑树并不知道数据 data 到底是 key 还是 pair。如果数据 data 是 key,那么直接取 key 来作比较;如果数据 data 是 pair,则需要取 first 来作比较。

【思考】这该如何去实现对传进来的不同类型的数据都能进行比较呢?

STL 源码是这样实现的,通过传给模板参数 KeyOfValue 的是 set 的仿函数还是 map 的仿函数来应对不同类型数据的比较:


【改造后】

改造后的红黑树结构(增加了仿函数类):(还没完善好,还差迭代器,insert 和 operator[] 接口还没实现)

我们自己写的代码,通过给红黑树增加一个模板参数 KeyOfT,KeyOfT 是一个仿函数类,把 map 和 set 中实现的仿函数传给 KeyOfT,根据传的不同数据类型 T (key / pair) 和该类型对应的仿函数 (SetKey / MapFirst),调用仿函数取出要比较的值(key / first),来进行比较。
红黑树的定义

  • K:键值key的类型。
  • T:数据的类型,如果是 map,则为 pair<const K, V>;如果是 set,则为 K。
  • KeyOfT:通过 T 的类型来获取 key 值的一个仿函数类。
cpp 复制代码
template<class K, class T, class KeyOfT>
class RBTree
{
	typedef RBTreeNode<T> Node; // 红黑树节点
public:
	RBTree() :_root(nullptr) {}  // 构造函数
    bool Insert(const T& data);  // 插入节点(接口返回值目前是bool,后续要改为pair)
    // ...
private:
	Node* _root;
};

【画图说明】

通过 T 的类型和对应的取 T 类型对象的值的仿函数,就可以进行不同类型数据的比较了:

cpp 复制代码
#pragma once

enum Colour
{
	RED,
	BLACK
};

template<class T>
struct RBTreeNode
{
	RBTreeNode<T>* _left;
	RBTreeNode<T>* _right;
	RBTreeNode<T>* _parent;

	T _data;
	Colour _col;

	RBTreeNode(const T& data)
		:_left(nullptr)
		, _right(nullptr)
		, _parent(nullptr)
		, _data(data)
	{}
};

template<class T, class Ref, class Ptr>
struct __RBTreeIterator
{
	typedef RBTreeNode<T> Node;
	typedef __RBTreeIterator<T, Ref, Ptr> Self;
	Node* _node;

	__RBTreeIterator(Node* node)
		:_node(node)
	{}

	Ref operator*()
	{
		return _node->_data;
	}

	Ptr operator->()
	{
		return &_node->_data;
	}

	bool operator!=(const Self& s) const
	{
		return _node != s._node;
	}

	bool operator==(const Self& s) const
	{
		return _node == s._node;
	}

	Self& operator++()
	{
		if (_node->_right)
		{
			// 下一个就是右子树的最左节点
			Node* left = _node->_right;
			while (left->_left)
			{
				left = left->_left;
			}
			_node = left;
		}
		else
		{
			// 找祖先里面孩子不是祖先的右的那个
			Node* parent = _node->_parent;
			Node* cur = _node;
			while (parent && cur == parent->_right)
			{
				cur = cur->_parent;
				parent = parent->_parent;
			}
			_node = parent;
		}
		return *this;
	}

	Self& operator--()
	{
		if (_node->_left)
		{
			// 下一个是左子树的最右节点
			Node* right = _node->_left;
			while (right->_right)
			{
				right = right->_right;
			}
			_node = right;
		}
		else
		{
			// 孩子不是父亲的左的那个祖先
			Node* parent = _node->_parent;
			Node* cur = _node;
			while (parent && cur == parent->_left)
			{
				cur = cur->_parent;
				parent = parent->_parent;
			}
			_node = parent;
		}
		return *this;
	}
};

template<class K, class T, class KeyOfT>
struct RBTree
{
	typedef RBTreeNode<T> Node;
public:
	typedef __RBTreeIterator<T, T&, T*> iterator;

	iterator begin()
	{
		Node* left = _root;
		while (left && left->_left)
		{
			left = left->_left;
		}
		return iterator(left);
	}

	iterator end()
	{
		return iterator(nullptr);
	}

	pair<iterator, bool> Insert(const T& data)
	{
		KeyOfT kot;

		if (_root == nullptr)
		{
			_root = new Node(data);
			_root->_col = BLACK;
			return make_pair(iterator(_root), true);
		}

		Node* parent = nullptr;
		Node* cur = _root;
		while (cur)
		{
			if (kot(cur->_data) < kot(data))
			{
				parent = cur;
				cur = cur->_right;
			}
			else if (kot(cur->_data) > kot(data))
			{
				parent = cur;
				cur = cur->_left;
			}
			else
			{
				return make_pair(iterator(cur), false);
			}
		}

		cur = new Node(data);
		Node* newnode = cur;
		cur->_col = RED;

		if (kot(parent->_data) < kot(data))
		{
			parent->_right = cur;
		}
		else
		{
			parent->_left = cur;
		}

		cur->_parent = parent;

		while (parent && parent->_col == RED)
		{
			Node* grandfater = parent->_parent;
			assert(grandfater);
			assert(grandfater->_col == BLACK);
			// 关键看叔叔
			if (parent == grandfater->_left)
			{
				Node* uncle = grandfater->_right;
				// 情况1:uncle存在且为红,变色+继续往上处理
				if (uncle && uncle->_col == RED)
				{
					parent->_col = uncle->_col = BLACK;
					grandfater->_col = RED;
					// 继续往上处理
					cur = grandfater;
					parent = cur->_parent;
				}// 情况2+3:uncle不存在+存在且为黑
				else
				{
					// 情况2:右单旋+变色
					//     g 
					//   p   u
					// c
					if (cur == parent->_left)
					{
						RotateR(grandfater);
						parent->_col = BLACK;
						grandfater->_col = RED;
					}
					else
					{
						// 情况3:左右单旋+变色
						//     g 
						//   p   u
						//     c
						RotateL(parent);
						RotateR(grandfater);
						cur->_col = BLACK;
						grandfater->_col = RED;
					}

					break;
				}
			}
			else // (parent == grandfater->_right)
			{
				Node* uncle = grandfater->_left;
				// 情况一
				if (uncle && uncle->_col == RED)
				{
					parent->_col = uncle->_col = BLACK;
					grandfater->_col = RED;
					// 继续往上处理
					cur = grandfater;
					parent = cur->_parent;
				}
				else
				{
					// 情况2:左单旋+变色
					//     g 
					//   u   p
					//         c
					if (cur == parent->_right)
					{
						RotateL(grandfater);
						parent->_col = BLACK;
						grandfater->_col = RED;
					}
					else
					{
						// 情况3:右左单旋+变色
						//     g 
						//   u   p
						//     c
						RotateR(parent);
						RotateL(grandfater);
						cur->_col = BLACK;
						grandfater->_col = RED;
					}
					break;
				}
			}
		}
		_root->_col = BLACK;
		return make_pair(iterator(newnode), true);
	}

	void InOrder()
	{
		_InOrder(_root);
		cout << endl;
	}

	bool IsBalance()
	{
		if (_root == nullptr)
		{
			return true;
		}

		if (_root->_col == RED)
		{
			cout << "根节点不是黑色" << endl;
			return false;
		}

		// 黑色节点数量基准值
		int benchmark = 0;
		/*Node* cur = _root;
		while (cur)
		{
		if (cur->_col == BLACK)
		++benchmark;

		cur = cur->_left;
		}*/
		return PrevCheck(_root, 0, benchmark);
	}

private:
	bool PrevCheck(Node* root, int blackNum, int& benchmark)
	{
		if (root == nullptr)
		{
			//cout << blackNum << endl;
			//return;
			if (benchmark == 0)
			{
				benchmark = blackNum;
				return true;
			}

			if (blackNum != benchmark)
			{
				cout << "某条黑色节点的数量不相等" << endl;
				return false;
			}
			else
			{
				return true;
			}
		}

		if (root->_col == BLACK)
		{
			++blackNum;
		}

		if (root->_col == RED && root->_parent->_col == RED)
		{
			cout << "存在连续的红色节点" << endl;
			return false;
		}

		return PrevCheck(root->_left, blackNum, benchmark)
			&& PrevCheck(root->_right, blackNum, benchmark);
	}

	void _InOrder(Node* root)
	{
		if (root == nullptr)
		{
			return;
		}
		_InOrder(root->_left);
		cout << root->_kv.first << ":" << root->_kv.second << endl;
		_InOrder(root->_right);
	}

	void RotateL(Node* parent)
	{
		Node* subR = parent->_right;
		Node* subRL = subR->_left;

		parent->_right = subRL;
		if (subRL)
		{
			subRL->_parent = parent;
		}

		Node* ppNode = parent->_parent;

		subR->_left = parent;
		parent->_parent = subR;

		if (_root == parent)
		{
			_root = subR;
			subR->_parent = nullptr;
		}
		else
		{
			if (ppNode->_left == parent)
			{
				ppNode->_left = subR;
			}
			else
			{
				ppNode->_right = subR;
			}

			subR->_parent = ppNode;
		}
	}

	void RotateR(Node* parent)
	{
		Node* subL = parent->_left;
		Node* subLR = subL->_right;

		parent->_left = subLR;
		if (subLR)
		{
			subLR->_parent = parent;
		}

		Node* ppNode = parent->_parent;

		subL->_right = parent;
		parent->_parent = subL;

		if (_root == parent)
		{
			_root = subL;
			subL->_parent = nullptr;
		}
		else
		{
			if (ppNode->_left == parent)
			{
				ppNode->_left = subL;
			}
			else
			{
				ppNode->_right = subL;
			}

			subL->_parent = ppNode;
		}
	}

private:
	Node* _root = nullptr;
};

一、红黑树的迭代器

迭代器的好处是可以方便遍历,是数据结构的底层实现与用户透明。map 和 set 的迭代器是封装的红黑树的迭代器。

如果想要给红黑树增加迭代器,需要考虑以前问题:

1、begin() 与 end()

STL 明确规定,begin() 与 end() 代表的是一段前闭后开的区间,而对红黑树进行中序遍历 后,可以得到一个有序的序列。

SGI-STL源码中,红黑树有一个哨兵位的头节点,begin() 是放在红黑树中最小节点(即最左侧节点)的位置,end() 是放在 end() 放在头结点的位置。

begin() 可以放在红黑树中最小节点(即最左侧节点)的位置,end() 放在最大节点(最右侧节点)的下一个位置,关键是最大节点的下一个位置在哪块?

因为这里我们只是对它进行模拟,理解它的底层原理即可,为了不让代码太过复杂,我们这里的模拟实现就不设定 header 结点,直接让 end() 为 nullptr即可。

cpp 复制代码
template<class K, class T, class KeyOfT>
struct RBTree
{
	typedef RBTreeNode<T> Node; // 红黑树节点
public:
	typedef _RBTreeIterator<T, T&, T*> iterator; // 迭代器
    
    iterator begin() // begin(): 指向红黑树的最左节点的迭代器
	{
		Node* left = _root;
		while (left && left->_left)
		{
			left = left->_left;
		}
		return iterator(left);
        // 注意:单参数的构造函数支持隐式类型转换,节点会被构造成迭代器
        // 所以也可以这样写:return left;
	}

	iterator end() // end(): 指向nullptr的迭代器
	{
		return iterator(nullptr);
	}
private:
	Node* _root = nullptr;
};

2、operator++()与operator--()

按照中序遍历(左子树 -- 根 -- 右子树) 来走,分为以下几种情况:

1、it 指向节点的右子树不为空

  • 则 it++ 要访问的节点是,右子树中序(左子树 -- 根 -- 右子树)的第一个节点,也就是右子树中的最左节点(即最大节点)。

2、it 指向节点的右子树为空(说明以 it 为根的子树已经访问完了),且 it 的父亲存在,it 是它父亲的右孩子(说明 it 被访问完之后,以 it 父亲为根的子树也就访问完了,此时该访问 it 父亲的父亲了)

  • 则 it++ 要访问的节点是:it 指向节点的父亲的父亲(即节点 13)。

3、it 指向节点的右子树为空,且 it 父亲存在,it 是它父亲的左孩子

  • 则 it++ 要访问的节点是:it 指向节点的父亲节点(即节点 17)。

【注意】

当 it 访问完最后一个节点后,最后一个节点右子树为空,此时整棵树已经访问完了,cur 和 parent 会一直迭代走到根节点,然后返回 _node = parent,parent为空,我们在红黑树中 end() 的值给的也是空,这样当 it 访问完最后一个节点后,就等于 end() 了。


T:数据的类型,如果是 map,则为 pair<const K, V>;如果是 set,则为 K。
  • Ref:数据的引用。
  • Ptr:数据的指针。
cpp 复制代码
template<class T, class Ref, class Ptr>
struct __RBTreeIterator
{
	typedef RBTreeNode<T> Node; // 红黑树节点
	typedef __RBTreeIterator<T, Ref, Ptr> Self; // 迭代器
	Node* _node; // 节点指针

    // 构造函数
	__RBTreeIterator(Node* node)
		:_node(node)
	{}

    // 运算符重载
	Ref operator*()
	{
		return _node->_data; // 返回当前迭代器指向节点中数据的引用
	}

	Ptr operator->()
	{
		return &_node->_data; // 返回当前迭代器指向节点中数据的地址
	}

    // 比较两个迭代器,即比较它们的节点指针,看是否指向同一节点
	bool operator!=(const Self& s) const
	{
		return _node != s._node;
	}

    // 比较两个迭代器,即比较它们的节点指针,看是否指向同一节点
	bool operator==(const Self& s) const
	{
		return _node == s._node;
	}

	Self& operator++() // 前置++
	{
        // 按照中序来走,分为两种情况:
        // 1、当前节点右子树不为空,则++访问右子树的最大节点
		if (_node->_right != nullptr)
		{
			// 下一个就是右子树的最左节点
			Node* left = _node->_right;
			while (left->_left)
			{
				left = left->_left;
			}
			_node = left;
		}
        // 2、当前节点右子树为空
		else
		{
			// 找祖先里面孩子不是祖先的右的那个
			Node* parent = _node->_parent;
			Node* cur = _node;

            // (1)cur父亲存在且cur是父亲的右孩子,则++访问cur的父亲的父亲
            // (2)cur父亲存在且cur是父亲的左孩子,则++访问cur的父亲
			while (parent && cur == parent->_right)
			{
				cur = cur->_parent;
				parent = parent->_parent;
			}
			_node = parent; // 现在的parent就是我们要访问的位置
		}
		return *this; // 返回下一个节点的迭代器的引用
	}

    Self& operator--() // 前置--
	{
		if (_node->_left)
		{
			// 下一个是左子树的最右节点
			Node* right = _node->_left;
			while (right->_right)
			{
				right = right->_right;
			}
			_node = right;
		}
		else
		{
			// 孩子不是父亲的左的那个祖先
			Node* parent = _node->_parent;
			Node* cur = _node;
			while (parent && cur == parent->_left)
			{
				cur = cur->_parent;
				parent = parent->_parent;
			}
			_node = parent;
		}
		return *this; // 返回上一个节点的迭代器的引用
	}
};

二、map和set的插入

map 和 set 的 insert 是封装的红黑树的插入节点接口,所以我们先要模拟实现红黑树的插入。

功能:插入元素时,先通过该元素的 key 查找并判断该元素是否已在树中:

  • 如果在,返回:pair<指向该元素的迭代器, false>。
  • 如果不在,先插入节点,再返回:pair<指向该元素的迭代器, true>。
  • T:数据的类型,如果是 map,则为 pair<const K, V>;如果是 set,则为 K。
cpp 复制代码
pair<iterator, bool> Insert(const T& data)
{
    //KeyOfT kot; // 实例化仿函数对象

    if (_root == nullptr)
    {
        _root = new Node(data); // 插入新节点
        _root->_col = BLACK; // 根节点为黑色

        return make_pair(iterator(_root), true); // 返回<指向插入节点的迭代器, true>
    }

    Node* parent = nullptr;
    Node* cur = _root; // 记录当前节点和它的父节点
    
    while (cur) // cur为空时,说明找到插入位置了
    {
        if (KeyOfT()(data) > KeyOfT()(cur->_data)) // 键值大于当前节点
        {
            parent = cur;
            cur = cur->_right;
        }
        else if (KeyOfT()(data) < KeyOfT()(cur->_data)) // 键值小于当前节点
        {
            parent = cur;
            cur = cur->_left;
        }
        else // (KeyOfT()(data) == KeyOfT()(cur->_data)) // 键值等于当前节点
        {
            // 不允许数据冗余
            return make_pair(iterator(cur), false); // 返回<指向已有节点的迭代器, false>
        }
    }

    // 插入新节点,颜色为红色(可能会破坏性质3,产生两个连续红色节点)
    cur = new Node(data);
    Node* newnode = cur; // 保存下插入的新节点的位置
    cur->_col = RED;

    // 判断新节点是其父亲的左孩子还是右孩子
    if (KeyOfT()(cur->_data) > KeyOfT()(parent->_data))
    {
        parent->_right = cur;
    }
    else
    {
        parent->_left = cur; 
    }
    cur->_parent = parent; // 更新cur的双亲指针

    while (parent && parent->_col == RED)
	{
		Node* grandfater = parent->_parent;
		assert(grandfater);
		assert(grandfater->_col == BLACK);
		// 关键看叔叔
		if (parent == grandfater->_left)
		{
			Node* uncle = grandfater->_right;
			// 情况1:uncle存在且为红,变色+继续往上处理
			if (uncle && uncle->_col == RED)
			{
				parent->_col = uncle->_col = BLACK;
				grandfater->_col = RED;
				// 继续往上处理
				cur = grandfater;
				parent = cur->_parent;
			}// 情况2+3:uncle不存在+存在且为黑
			else
			{
				// 情况2:右单旋+变色
				//     g 
				//   p   u
				// c
				if (cur == parent->_left)
				{
					RotateR(grandfater);
					parent->_col = BLACK;
					grandfater->_col = RED;
				}
				else
				{
					// 情况3:左右单旋+变色
					//     g 
					//   p   u
					//     c
					RotateL(parent);
					RotateR(grandfater);
					cur->_col = BLACK;
					grandfater->_col = RED;
				}

				break;
			}
		}
		else // (parent == grandfater->_right)
		{
			Node* uncle = grandfater->_left;
			// 情况一
			if (uncle && uncle->_col == RED)
			{
				parent->_col = uncle->_col = BLACK;
				grandfater->_col = RED;
				// 继续往上处理
				cur = grandfater;
				parent = cur->_parent;
			}
			else
			{
				// 情况2:左单旋+变色
				//     g 
				//   u   p
				//         c
				if (cur == parent->_right)
				{
					RotateL(grandfater);
					parent->_col = BLACK;
					grandfater->_col = RED;
				}
				else
				{
					// 情况3:右左单旋+变色
					//     g 
					//   u   p
					//     c
					RotateR(parent);
					RotateL(grandfater);
					cur->_col = BLACK;
					grandfater->_col = RED;
				}
				break;
			}
		}
	}
    _root->_col = BLACK;
    return make_pair(iterator(newnode), true); // 返回<指向插入节点的迭代器, true>
}

三、map 的模拟实现

map 的底层结构就是红黑树,因此在 map 中直接封装一棵红黑树,然后将其接口包装下即可。

cpp 复制代码
namespace xyl
{
    template<class K, class V>
    class map
    {
        // 作用:将value中的key提取出来
        struct MapKeyOfT
		{
			const K& operator()(const pair<K, V>& kv)
			{
				return kv.first; // 返回pair对象中的key
			}
		};

    public:
        // 编译到这里的时候,类模板RBTree<K, pair<const K, V>, MapFirst>可能还没有实例化
        // 那么编译器就不认识这个类模板,更别说去它里面找iterator了
		// 所以要加typename,告诉编译器这是个类型,等它实例化了再去找它
        typedef typename RBTree<K, pair<K, V>, MapKeyOfT>::iterator iterator;

        iterator begin()
		{
			return _t.begin();
		}

		iterator end()
		{
			return _t.end();
		}

        pair<iterator, bool> insert(const pair<K, V>& kv)
		{
			return _t.Insert(kv); // 底层调用红黑树的接口
		}

        // 功能:传入键值key,通过该元素的key查找并判断是否在map中:
		// 1、在map中,返回key对应的映射值的引用
		// 2、不在map中,插入pair<key, value()>,再返回key对应映射值的引用
        V& operator[](const K& key)
		{
            // 注意:这里的V()是缺省值,调用V类型的默认构造函数去构造一个匿名对象
			pair<iterator, bool> ret = insert(make_pair(key, V()));

			return ret.first->second; // 返回key对应映射值的引用
		}

    private:
        RBTree<K, pair<K, V>, MapKeyOfT> _t; // 红黑树
    };
}

四、set 的模拟实现

set 的底层为红黑树,因此只需在 set 内部封装一棵红黑树,即可将该容器实现出来(具体实现可参考 map)。

cpp 复制代码
namespace xyl
{
    template<class K>
    class set
    {
        // 作用是:将value中的key提取出来
        struct SetKeyOfT
		{
			const K& operator()(const K& key)
			{
				return key; 返回key对象中的Key
			}
		};

    public:
        // 编译到这里的时候,类模板RBTree<K, K, SetKey>可能还没有实例化
		// 那么编译器就不认识这个类模板,更别说去它里面找iterator了
		// 所以要加typename,告诉编译器这是个类型,等它实例化了再去找它
        typedef typename RBTree<K, K, SetKeyOfT>::iterator iterator; // 红黑树类型重命名

        iterator begin()
		{
			return _t.begin();
		}

		iterator end()
		{
			return _t.end();
		}

        // 插入元素(key)
		pair<iterator, bool> insert(const K& key)
		{
			return _t.Insert(key); // 底层调用红黑树的接口
		}

        // 中序遍历
		void inorder()
		{
			_t.InOrder();
		}

    private:
        RBTree<K, K, SetKeyOfT> _t; // 红黑树
    };
}
相关推荐
Lenyiin1 分钟前
第146场双周赛:统计符合条件长度为3的子数组数目、统计异或值为给定值的路径数目、判断网格图能否被切割成块、唯一中间众数子序列 Ⅰ
c++·算法·leetcode·周赛·lenyiin
yuanbenshidiaos2 小时前
c++---------数据类型
java·jvm·c++
十年一梦实验室2 小时前
【C++】sophus : sim_details.hpp 实现了矩阵函数 W、其导数,以及其逆 (十七)
开发语言·c++·线性代数·矩阵
taoyong0012 小时前
代码随想录算法训练营第十一天-239.滑动窗口最大值
c++·算法
这是我582 小时前
C++打小怪游戏
c++·其他·游戏·visual studio·小怪·大型·怪物
fpcc2 小时前
跟我学c++中级篇——C++中的缓存利用
c++·缓存
呆萌很3 小时前
C++ 集合 list 使用
c++
诚丞成4 小时前
计算世界之安生:C++继承的文水和智慧(上)
开发语言·c++
东风吹柳4 小时前
观察者模式(sigslot in C++)
c++·观察者模式·信号槽·sigslot
A懿轩A5 小时前
C/C++ 数据结构与算法【栈和队列】 栈+队列详细解析【日常学习,考研必备】带图+详细代码
c语言·数据结构·c++·学习·考研·算法·栈和队列