C++起始之路——封装红黑树实现map和set

💁‍♂️个人主页:进击的荆棘

👇作者其它专栏:

《数据结构与算法》《算法》《C++起始之路》


目录

1.源码及框架分析

2.模拟实现map和set


1.源码及框架分析

SGI-STL30版本源代码,map和set的源代码在map/set/stl_map.h/stl_set.h/stl_tree.h等几个头文件中。

map和set的实现结构框架核心部分截取出来如下:

cpp 复制代码
//set
#ifndef __SGI_STL_INTERNAL_TREE_H
#include <stl_tree.h>
#endif
#include <stl_set.h>
#include <stl_multiset.h>

//map
#ifndef __SGI_STL_INTERNAL_TREE_H
#include <stl_tree.h>
#endif
#include <stl_map.h>
#include <stl_multimap.h>

//stl_set.h
template<class Key,class Compare=less<Key>,class Alloc=alloc>
class set{
public:
    //typedefs:
    typedef Key key_type;
    typedef Key value_type;
private:
    typedef rb_tree<key_type,value_type,
                    identity<value_type>,key_compare,Alooc> rep_type;
    rep_type t;    //red-black tree representing set
};

//stl_map.h
template<class Key,class T,class Compare=less<Key>,class Alloc=alloc>
class map{
public:
    //typedefs:
    typedef Key key_type;
    typedef T mappped_type;
    typedef pair<const Key,T> value_type;
private:
    typedef rb_tree<key_type,value_type,
                    selectlst<value_type>,key_compare,Alloc> rep_type;
    rep_type t;    //red-black tree representing map
};

//stl_tree.h
struct __rb_tree_node_base{
    typedef __rb_tree_color_type color_type;
    typedef __rb_tree_node_base* base_ptr;

    color_type color;
    base_ptr parent;
    base_ptr left;
    base_ptr right;
};

//stl_tree.h
template<class Key,class Value,class KeyOfValue,class Compare,class Alloc=alloc>
class rb_tree{
protected:
    typedef void* void_pointer;
    typedef __rb_tree_node_base* base_ptr;
    typedef __rb_tree_node<Value> rb_tree_node;
    typedef rb_tree_node* link_type;
    typedef Key key_type;
    typedef Value value_type;
public:
    //insert用的是第二个模板参数左形参
    pair<iterator,bool> insert_unique(const value_type& x);

    //erase和find用第一个模板参数做形参
    size_type erase(const key_type& x);
    iterator find(const key_type& x);
protected:
    size_type node_count;    //keep track of size of tree
    link_type header;
};

template<class Value>
struct __rb_tree_node:public __rb_tree_node_base{
    typedef __rb_tree_node<Value>* link_type;
    Value value_field;
};
    

●通过下图对框架的分析,可以看到源码中rb_tree用了一个巧妙的泛型思想实现,rb_tree是实现key的搜索场景,还是key/value的搜索场景不是直接写死的,而是由第二个模板参数Value决定_rb_tree_node中存储的数据类型。

●set实例化rb_tree时第二个模板参数给的是key,map实例化rb_tree时第二个模板参数给的是pair<const key,T>,这样一颗红黑树既可以实现key搜索场景的set,也可以实现key/value搜索场景的map。

●要注意,源码里面模板参数是用T代表value,而内部写的value_type不是我们日常key/value场景中说的value,源码中的value_type反而是红黑树节点中存储的真实的数据的类型。

●rb_tree第二个模板参数Value已经控制了红黑树节点中存储的数据类型,为什么还要传第一个模板参数Key呢?尤其是set,两个模板参数是一样的。要注意的是对于map和set,find/erase时的函数参数都是Key,所以第一个模板参数是传给find/erase等函数做形参的类型的。对于set而言两个参数是一样的,但是对于map而言就完全不一样l,map insert的pair对象,但是find和erase的是Key对象。

●这里的源码命名风格比较乱,set模板参数用的是Key命名,map用的是Key和T命名,而rb_tree用的又是Key和Value。

2.模拟实现map和set

2.1实现出复用红黑树的框架,并支持insert

●参考源码框架,map和set复用之前我们实现的红黑树。

●以下调整,key参数用K代替,value参数用V代替,红黑树中的数据类型,使用T。

●因为RBTree实现了泛型不知道T参数导致是K,还是pair<K,V>,那么insert内部进行插入逻辑比较时,就没办法进行比较,因为pair的默认支持的是key和value一起参入比较,需要时的任何时候只比较key,所以在map和set层分别实现一个MapKeyOfT和SetKeyOfT的仿函数传给RBTree的KeyOfT,然后RBTree中通过KeyOfT仿函数取出T类型对象中的key,再进行比较。

cpp 复制代码
namespace Achieve{
    template<class K>
    class set{
        struct SetKeyOfT{
            const K& operator()(const K& key){
                return key;
            }
        };
    public:
        bool insert(const K& key){
            return _t.Insert(key);
        }
    private:
        RBTree<K,K,SetKeyOfT> _t;
    };
}

namespace Achieve{
    template<class K,class V>
    class map{
        struct MapKeyOfT{
            const K& operator()(const pair<K,V>& key){
                return key.first;
            }
        };
    public:
        bool insert(const pair<K,V>& kv){
            return _t.Insert(kv);
        }
    private:
        RBTree<K,pair<K,V>,MapKeyOfT> _t;
    };
}

enum Color{
    RED,BLACK
};

template<class T>
struct RBTreeNode{
    //需要parent指针
    T _data;
    RBTreeNode<T>* _left;
    RBTreeNode<T>* _right;
    RBTreeNode<T>* _parent;
    //记录红黑
    Color _color;

    RBTreeNode(const T& data)
        :_data(data)
        ,_left(nullptr)
        ,_right(nullptr)
        ,_parent(nullptr)
    {}
};

//实现步骤:
//1.实现红黑树
//2.封装map和set框架,解决KeyOfT
//3.iterator
//4.const_iterator
//5.key不支持修改的问题
//6.operator[]

template<class k,class T,class KeyOfT>
class RBTree{
    typedef RBTreeNode<T> Node;
public:
    bool Insert(const T& data){
        if(!_root){
            _root=new Node(data);
            //根节点必须为黑色
            _root->_color=BLACK;
            return true;
        }
        KeyOfT kot;
        Node* parent=nullptr;
        Node* cur=_root;
        while(cur){
            if(kot(cur->_data)<kot(data)){
                parent=cur;
                cur=cur->_right;
            }
            else if(kot(cur->_data)>kot(data)){
                parent=cur;
                cur=cur->_left;
            }
            else return false;
        }
        //开始插入
        cur=new Node(data);
        //提前记录当前节点
        Node* newNode=cur;
        cur->_color=RED;
        if(kot(cur->_data)<kot(parent->_data))
            parent->_left=cur;
        else parent->_right=cur;
        //父指针指向父节点
        cur->_parent=parent;
        //当父节点存在,且与新插入节点构成连续红色时
        while(parent&&parent->_color==RED){
            Node* grandfather=parent->_parent;
            //若父节点在爷节点的左边
            if(parent==grandfather->_left){
                //   g
                //p      u
                Node* uncle=grandfather->_right;
                //若u存在且为红色,变色
                if(uncle&&uncle->_color==RED){
                    parent->_color=uncle->_color=BLACK;
                    grandfather->_color=RED;
                    //将c更新到g,继续操作
                    cur=grandfather;
                    parent=cur->_parent;
                }
                //此时u不存在或为黑色
                else{
                    //当插入的cur在parent的左边
                    if(cur==parent->_left){
                        //     g
                        //  p     u
                        //c
                        RotateR(grandfather);
                        parent->_color=BLACK;
                        grandfather->_color=RED;
                    }
                    //当插入的cur在parent的右边
                    else {
                        //     g
                        //  p     u
                        //    c
                        RotateL(parent);
                        RotateR(grandfather);
                        parent->_color=BLACK;
                        grandfather->_color=RED;
                    }
                    break;
                }
            }
            //当父节点在爷爷的右边
            else{
                //   g
                //u      p
                Node* uncle=grandfather->_left;
                //若u存在且为红色,变色
                if(uncle&&uncle->_color==RED){
                    parent->_color=uncle->_color=BLACK;
                    grandfather->_color=RED;
                    //将c更新到g,继续操作
                    cur=grandfather;
                    parent=cur->_parent;
                }
                //此时u不存在或为黑色
                else{
                    //当插入的cur在parent的右边
                    if(cur==parent->_right){
                        //     g
                        //  u     p
                        //           c
                        RotateL(grandfather);
                        parent->_color=BLACK;
                        grandfather->_color=RED;
                    }
                    //当插入的cur在parent的左边
                    else {
                        //     g
                        //  u     p
                        //      c
                        RotateR(parent);
                        RotateL(grandfather);
                        parent->_color=BLACK;
                        grandfather->_color=RED;
                    }
                    break;
                }
            }
        }
        //根节点必须为黑色
        _root->_color=BLACK;
        return true;
    }
    //右单旋 
    void RotateR(Node* parent){
        Node* subL=parent->_left;
        Node* subLR=subL->_right;
        
        parent->_left=subLR;
        //链接父节点
        if(subLR)
            subLR->_parent=parent;
        //防止找不到父结点的父结点
        Node* pparent=parent->_parent;
        subL->_right=parent;
        parent->_parent=subL;

        if(parent==_root){
            _root=subL;
            subL->_parent=nullptr;
        }
        else{
            if(pparent->_left==parent)
                pparent->_left=subL;
            else pparent->_right=subL;
            subL->_parent=pparent;
        }
    }
    //左单旋
    void RotateL(Node* parent){
        Node* subR=parent->_right;
        Node* subRL=subR->_left;
        
        parent->_right=subRL;
        //链接父节点
        if(subRL)
            subRL->_parent=parent;
        //防止找不到父结点的父结点
        Node* pparent=parent->_parent;
        subR->_left=parent;
        parent->_parent=subR;

        if(pparent==nullptr){
            _root=subR;
            subR->_parent=nullptr;
        }
        else{
            if(parent==pparent->_left)
                pparent->_left=subR;
            else pparent->_right=subR;
            subR->_parent=pparent;
        }
    }
private:
    Node* _root=nullptr;
};

2.2支持iterator的实现

2.2.1iterator核心源代码

cpp 复制代码
struct __rb_base_iterator{
    typedef __rb_tree_node_base::base_ptr base_ptr;
    base_ptr node;
    
    void increment(){
        if(node->right!=){
            node=node->right;
            while(node->left!=0)
                node=node->left;
        }
        else{
            base_ptr y=node->parent;
            while(node==y->right){
                node=y;
                y=y->parent;
            }
            if(node->right!=y)
                node=y;
        }
    }

    void decrement(){
        if(node->color==__rb_tree_red&&node->parent->parent==node)
            node=node->right;
        else if(node->left!=0){
            base_ptr y=node->left;
            while(y->right!=0)
                y=y->right;
            node=y;
        }
        else {
            base_ptr y=node->parent;
            while(node==y->left){
                node=y;
                y=y->parent;
            }
            node=y;
        }
    }
};

2.2.2iterator实现思路分析

●iterator实现的大框架跟list的iterator思路是一致的,用一个类型封装节点的指针,再通过重载运算符实现,迭代器像指针一样访问的行为。

●operator++和operator--的实现。map和set的迭代器走的是中序遍历,左子树->根节点->右子树,那么begin()会返回中序第一个节点的iterator也就是10所在节点的迭代器。

●迭代器的++核心逻辑就是不看全局,只看局部,只考虑当前中序局部要访问的下一个节点(左中右)。

●迭代器++时,若it指向的节点的右子树不为空,代表当前节点已经访问完了,要访问下一个节点是右子树的中序第一个,一棵树中序第一个是最左节点,所以直接找右子树的最左节点即可。

●迭代器++时,若it指向的节点的右子树为空,代表当前节点已经访问完了且当前节点所在的子树也访问完了,要访问的下一个节点在当前节点祖先里面,所以要沿当前节点到根的祖先路径向上找。

●若当前节点是父亲的左,根据中序左子树->根节点->右子树,则下一个访问的节点就是当前节点的父亲;若下图:it指向25,25右为空,25是30的左,所以下一个访问的节点就是30。

●若当前节点是父亲的右,根据中序左子树->根节点->右子树,当前节点所在的子树访问完了,当前节点所在父亲的子树也访问完了,那么下一个访问的需要继续往根的祖先中去找,直到找到孩子是父亲左的那个祖先就是中序要访问的下一个节点。如下图:it指向15,15右为空,15是10的右,15所在子树访问完了,10所在子树也访问完了,继续向上找,10是18的左,那么下一个访问的节点就是18.

●end()如何表示?如下图:当it指向50时,++it时,50是40的右,40是30的右,30是18的右,18到根没有父亲,没有找到孩子是父亲左的那个祖先,这是父亲为空了,那我们就把it中的节点指针置为nullptr,用nullptr去充当end。需要注意的是stl源码空时,红黑树增加了一个哨兵位头节点作为end(),这哨兵位头节点和根互为父亲,左指向最左节点,右指向最右节点。相比用nullptr作为end(),差别不大。只是--end()判断到节点时空,特殊处理以下,让迭代器节点指向最右节点。

2.3map支持[]

●map要支持[]主要修改insert返回值,修改RBTree中的Insert返回值为pair<Iterator,bool> Insert(const T& data)

2.4Achieve::map和Achieve::set代码实现

cpp 复制代码
namespace Achieve{
    template<class K>
    class set{
        struct SetKeyOfT{
            const K& operator()(const K& key){
                return key;
            }
        };
    public:
        //typename明确Iterator是一个类型
        typedef typename RBTree<K,const K,SetKeyOfT>::Iterator iterator;//Key不能修改,否则树会出错
        typedef typename RBTree<K,const K,SetKeyOfT>::ConstIterator const_iterator;
        iterator begin(){
            return _t.Begin();
        }
        iterator end(){
            return _t.End();
        }
        const_iterator begin() const{
            return _t.Begin();
        }
        const_iterator end() const{
            return _t.End();
        }
        pair<iterator,bool> insert(const K& key){
            return _t.Insert(key);
        }
    private:
        RBTree<K,const K,SetKeyOfT> _t;
    };
}

namespace Achieve{
    template<class K,class V>
    class map{
        struct MapKeyOfT{
            const K& operator()(const pair<K,V>& key){
                return key.first;
            }
        };
    public:
        //typename明确Iterator是一个类型
        typedef typename RBTree<K,pair<const K,V>,MapKeyOfT>::Iterator iterator;//Key不能修改,否则树会出错
        typedef typename RBTree<K,pair<const K,V>,MapKeyOfT>::ConstIterator const_iterator;
        iterator begin(){
            return _t.Begin();
        }
        iterator end(){
            return _t.End();
        }
        const_iterator begin() const{
            return _t.Begin();
        }
        const_iterator end() const{
            return _t.End();
        }
        pair<iterator,bool> insert(const pair<K,V>& kv){
            return _t.Insert(kv);
        }
        V& operator[](const K& key){
            pair<iterator,bool> ret=insert({key,V()});
            return ret.first->second;
        }
    private:
        RBTree<K,pair<const K,V>,MapKeyOfT> _t;
    };
}

enum Color{
    RED,BLACK
};
template<class T>
struct RBTreeNode{
    //需要parent指针
    T _data;
    RBTreeNode<T>* _left;
    RBTreeNode<T>* _right;
    RBTreeNode<T>* _parent;
    //记录红黑
    Color _color;

    RBTreeNode(const T& data)
        :_data(data)
        ,_left(nullptr)
        ,_right(nullptr)
        ,_parent(nullptr)
    {}
};
template<class T,class Ref,class Ptr>
struct RBIterator{
    typedef RBTreeNode<T> Node;
    typedef RBIterator<T,Ref,Ptr> Self;
    Node* _node;
    Node* _root;

    RBIterator(Node* node,Node* root)
        :_node(node)
        ,_root(root)
    {}
    Self operator++(){
        //中序遍历,下一个节点为右子树的最小节点
        if(_node->_right){
            Node* min=_node->_right;
            while(min->_left){
                min=min->_left;
            }
            _node=min;
        }
        else{
            //右为空,祖先里面孩子是父亲左的那个祖先
            Node* cur=_node;
            Node* parent=cur->_parent;
            while(parent&&cur==parent->_right){
                cur=parent;
                parent=parent->_parent;
            }
            _node=parent;
        }
        return *this;
    }
    Self operator--(){
        if(_node==nullptr){//--end()
            //--end(),特殊处理,走到中序最后一个节点,树的最右节点
            Node* rightMost=_root;
            while(rightMost&&rightMost->_right){
                rightMost=rightMost->_right;
            }
            _node=rightMost;
        }
        //中序遍历,上一个节点为左子树的最右节点
        if(_node->_left){
            Node* max=_node->_left;
            while(max->_right){
                max=max->_right;
            }
            _node=max;
        }
        else{
            //孩子是父亲右的那个祖先
            Node* cur=_node;
            Node* parent=cur->_parent;
            while(parent&&cur==parent->_left){
                cur=parent;
                parent=parent->_parent;
            }
            _node=parent;
        }
        return *this;
    }
    Ref operator*(){
        return _node->_data;
    }
    Ptr operator->(){
        return &_node->_data;
    }
    bool operator==(const Self& s) const{
        return _node==s._node;
    }
    bool operator!=(const Self& s) const{
        return _node!=s._node;
    }
};
template<class k,class T,class KeyOfT>
class RBTree{
    typedef RBTreeNode<T> Node;
public:
    typedef RBIterator<T,T&,T*> Iterator;
    typedef RBIterator<T,const T&,const T*> ConstIterator;
    Iterator Begin(){
        Node* cur=_root;
        while(cur&&cur->_left){
            cur=cur->_left;
        }
        return Iterator(cur,_root);
    }
    ConstIterator Begin() const{
        Node* cur=_root;
        while(cur&&cur->_left){
            cur=cur->_left;
        }
        return ConstIterator(cur,_root);
    }
    Iterator End(){
        return Iterator(nullptr,_root);
    }
    ConstIterator End() const{
        return ConstIterator(nullptr,_root);
    }
    RBTree()=default;
    ~RBTree(){
        Destroy(_root);
        _root=nullptr;
    }
    pair<Iterator,bool> Insert(const T& data){
        if(!_root){
            _root=new Node(data);
            //根节点必须为黑色
            _root->_color=BLACK;
            return {Iterator(_root,_root),true};
        }
        KeyOfT kot;
        Node* parent=nullptr;
        Node* cur=_root;
        while(cur){
            if(kot(cur->_data)<kot(data)){
                parent=cur;
                cur=cur->_right;
            }
            else if(kot(cur->_data)>kot(data)){
                parent=cur;
                cur=cur->_left;
            }
            else return {Iterator(cur,_root),false};
        }
        //开始插入
        cur=new Node(data);
        //提前记录当前节点
        Node* newNode=cur;
        cur->_color=RED;
        if(kot(cur->_data)<kot(parent->_data))
            parent->_left=cur;
        else parent->_right=cur;
        //父指针指向父节点
        cur->_parent=parent;
        //当父节点存在,且与新插入节点构成连续红色时
        while(parent&&parent->_color==RED){
            Node* grandfather=parent->_parent;
            //若父节点在爷节点的左边
            if(parent==grandfather->_left){
                //   g
                //p      u
                Node* uncle=grandfather->_right;
                //若u存在且为红色,变色
                if(uncle&&uncle->_color==RED){
                    parent->_color=uncle->_color=BLACK;
                    grandfather->_color=RED;
                    //将c更新到g,继续操作
                    cur=grandfather;
                    parent=cur->_parent;
                }
                //此时u不存在或为黑色
                else{
                    //当插入的cur在parent的左边
                    if(cur==parent->_left){
                        //     g
                        //  p     u
                        //c
                        RotateR(grandfather);
                        parent->_color=BLACK;
                        grandfather->_color=RED;
                    }
                    //当插入的cur在parent的右边
                    else {
                        //     g
                        //  p     u
                        //    c
                        RotateL(parent);
                        RotateR(grandfather);
                        parent->_color=BLACK;
                        grandfather->_color=RED;
                    }
                    break;
                }
            }
            //当父节点在爷爷的右边
            else{
                //   g
                //u      p
                Node* uncle=grandfather->_left;
                //若u存在且为红色,变色
                if(uncle&&uncle->_color==RED){
                    parent->_color=uncle->_color=BLACK;
                    grandfather->_color=RED;
                    //将c更新到g,继续操作
                    cur=grandfather;
                    parent=cur->_parent;
                }
                //此时u不存在或为黑色
                else{
                    //当插入的cur在parent的右边
                    if(cur==parent->_right){
                        //     g
                        //  u     p
                        //           c
                        RotateL(grandfather);
                        parent->_color=BLACK;
                        grandfather->_color=RED;
                    }
                    //当插入的cur在parent的左边
                    else {
                        //     g
                        //  u     p
                        //      c
                        RotateR(parent);
                        RotateL(grandfather);
                        parent->_color=BLACK;
                        grandfather->_color=RED;
                    }
                    break;
                }
            }
        }
        //根节点必须为黑色
        _root->_color=BLACK;
        return {Iterator(newNode,_root),true};
    }
    //右单旋 
    void RotateR(Node* parent){
        Node* subL=parent->_left;
        Node* subLR=subL->_right;
        
        parent->_left=subLR;
        //链接父节点
        if(subLR)
            subLR->_parent=parent;
        //防止找不到父结点的父结点
        Node* pparent=parent->_parent;
        subL->_right=parent;
        parent->_parent=subL;

        if(parent==_root){
            _root=subL;
            subL->_parent=nullptr;
        }
        else{
            if(pparent->_left==parent)
                pparent->_left=subL;
            else pparent->_right=subL;
            subL->_parent=pparent;
        }
    }
    //左单旋
    void RotateL(Node* parent){
        Node* subR=parent->_right;
        Node* subRL=subR->_left;
        
        parent->_right=subRL;
        //链接父节点
        if(subRL)
            subRL->_parent=parent;
        //防止找不到父结点的父结点
        Node* pparent=parent->_parent;
        subR->_left=parent;
        parent->_parent=subR;

        if(pparent==nullptr){
            _root=subR;
            subR->_parent=nullptr;
        }
        else{
            if(parent==pparent->_left)
                pparent->_left=subR;
            else pparent->_right=subR;
            subR->_parent=pparent;
        }
    }
	int Height()
	{
		return _Height(_root);
	}
	int Size()
	{
		return _Size(_root);
	}
    Node* Find(const k& key){
        Node* cur=_root;
        KeyOfT kot;
        while(cur){
            if(kot(cur->_data)<key){
                cur=cur->_right;
            }
            else if(kot(cur->_data>key)){
                cur=cur->_left;
            }
            else return cur;
        }
        return nullptr;
    }
private:
	int _Height(Node* root)
	{
		if (root == nullptr)
			return 0;
		int leftHeight = _Height(root->_left);
		int rightHeight = _Height(root->_right);
		return leftHeight > rightHeight ? leftHeight + 1 : rightHeight + 1;
	}

	int _Size(Node* root)
	{
		if (root == nullptr)
			return 0;

		return _Size(root->_left) + _Size(root->_right) + 1;
	}
    void Destroy(Node* root){
        if(!root)
            return ;
        Destory(root->_left);
        Destory(root->_right);
        delete root;
    }
private:
    Node* _root=nullptr;
};
相关推荐
云深麋鹿2 小时前
C++ | 模板
开发语言·c++
t***54410 小时前
Clang 编译器在 Orwell Dev-C++ 中的局限性
开发语言·c++
yolo_guo11 小时前
redis++使用: hmset 与 hmget
c++·redis
handler0112 小时前
拒绝权限报错!三分钟掌握 Linux 权限管理
linux·c语言·c++·笔记·学习
t***54413 小时前
如何在Dev-C++中选择Clang编译器
开发语言·c++
汉克老师13 小时前
GESP2023年9月认证C++三级( 第一部分选择题(9-15))
c++·gesp三级·gesp3级
Wave84516 小时前
C++继承详解
开发语言·c++·算法
Tairitsu_H17 小时前
C++类基础概念:定义、实例化和this指针
开发语言·c++
不想写代码的星星17 小时前
C++17 string_view 观察报告:好用,但有点费命
c++