C++ 手撕 STL 底层:红黑树封装 mymap/myset

前言

在 C++ 开发中,mapset 几乎是必备的关联式容器。它们自动排序、键唯一、查找效率极高,而这一切强大特性的底层,都来自红黑树

但绝大多数同学只停留在 "会用" 的层面:

  • 不知道 map/set 底层为什么是红黑树
  • 不理解 STL 如何用一棵红黑树同时支持 K 模型(set)KV 模型(map)
  • 看不懂源码里的 KeyOfValueidentityselect1st 是什么
  • 不明白为什么 set 不能修改键,map 只能改 value 不能改 key
  • 不知道 operator[] 到底是怎么实现的

本文基于 SGI STL 源码思想,不依赖任何库,纯手工用 C++ 实现:

  • 通用红黑树(支持泛型、迭代器、旋转、变色)
  • 封装 myset(K 模型,只读迭代器)
    • 封装 mymap(KV 模型,支持 operator [])

一、底层思想:STL 为什么这样设计 map/set?

1.1 map 和 set 的本质区别

  • set:纯 key 模型,只存一个值,用来排序、去重、查找。
  • map:key-value 模型,存键值对,key 用来排序,value 用来存数据。

它们的查找、插入、删除逻辑完全一样 ,都是按 key 比较。所以 STL 并没有给 map 和 set 各写一棵树,而是复用同一棵红黑树

1.2 核心设计:泛型 + 仿函数取键

红黑树本身不知道自己存的是:

  • 一个 key(set)
  • 还是 pair<const K, V>(map)

所以 STL 用了一个仿函数 KeyOfT ,告诉红黑树:"你不用管存的是什么,调用我就能拿到 key 进行比较。"

这就是 STL 最经典的复用设计

  • set 传:SetKeyOfT,直接返回 key
  • map 传:MapKeyOfT,返回 pair.first

1.3 必须遵守的规则

  1. key 不可修改,否则整棵树结构报废。
  2. set 迭代器是 const 迭代器
  3. map 的 pair 第一个参数必须是 const K
  4. 迭代器遍历是 中序遍历,结果有序。

二、整体结构一览

我们要实现 3 个文件:

  1. RBTree.h ------ 通用红黑树(迭代器 + 插入 + 旋转 + 验证)
  2. MySet.h ------ 封装 set(K 模型,const 迭代器)
  3. MyMap.h ------ 封装 map(KV 模型,支持 [])

外部使用和 STL 完全一致:

cpp 复制代码
bit::set<int> s;
bit::map<string, int> m;

三、第一步:实现通用红黑树 RBTree.h

这是最核心的底层结构,支持:

  • 泛型存储任意类型 T
  • 通过仿函数 KeyOfT 获取 key
  • 迭代器(++/--)
  • 左旋、右旋
  • 红黑树插入、变色、调整
  • 返回迭代器的 Insert 接口

3.1 颜色枚举

cpp 复制代码
#pragma once

enum Colour
{
    RED,
    BLACK
};

3.2 红黑树结点

cpp 复制代码
template<class T>
struct RBTreeNode
{
    T _data;
    RBTreeNode<T>* _left;
    RBTreeNode<T>* _right;
    RBTreeNode<T>* _parent;
    Colour _col;

    RBTreeNode(const T& data)
        : _data(data)
        , _left(nullptr)
        , _right(nullptr)
        , _parent(nullptr)
        , _col(RED)
    {}
};

3.3 红黑树迭代器(最难点)

迭代器本质是中序遍历

  • ++:找中序下一个结点
  • --:找中序上一个结点
  • end() 用 nullptr 表示
cpp 复制代码
template<class T, class Ref, class Ptr>
struct RBTreeIterator
{
    typedef RBTreeNode<T> Node;
    typedef RBTreeIterator<T, Ref, Ptr> Self;

    Node* _node;
    Node* _root;

    RBTreeIterator(Node* node, Node* root)
        :_node(node)
        ,_root(root)
    {}

    // ++:中序下一个
    Self& operator++()
    {
        if (_node->_right)
        {
            Node* leftMost = _node->_right;
            while (leftMost->_left)
                leftMost = leftMost->_left;

            _node = leftMost;
        }
        else
        {
            Node* cur = _node;
            Node* parent = cur->_parent;
            while (parent && cur == parent->_right)
            {
                cur = parent;
                parent = cur->_parent;
            }
            _node = parent;
        }
        return *this;
    }

    // --:中序上一个
    Self& operator--()
    {
        if (_node == nullptr)
        {
            Node* rightMost = _root;
            while (rightMost && rightMost->_right)
                rightMost = rightMost->_right;

            _node = rightMost;
        }
        else if (_node->_left)
        {
            Node* rightMost = _node->_left;
            while (rightMost->_right)
                rightMost = rightMost->_right;

            _node = rightMost;
        }
        else
        {
            Node* cur = _node;
            Node* parent = cur->_parent;
            while (parent && cur == parent->_left)
            {
                cur = parent;
                parent = cur->_parent;
            }
            _node = parent;
        }
        return *this;
    }

    Ref operator*()
    {
        return _node->_data;
    }

    Ptr operator->()
    {
        return &_node->_data;
    }

    bool operator!= (const Self& s) const
    {
        return _node != s._node;
    }

    bool operator== (const Self& s) const
    {
        return _node == s._node;
    }
};

3.4 完整红黑树(含旋转 + 插入)

cpp 复制代码
template<class K, class T, class KeyOfT>
class RBTree
{
    typedef RBTreeNode<T> Node;
public:
    typedef RBTreeIterator<T, T&, T*> Iterator;
    typedef RBTreeIterator<T, const T&, const T*> ConstIterator;

    Iterator Begin()
    {
        Node* leftMost = _root;
        while (leftMost && leftMost->_left)
            leftMost = leftMost->_left;

        return Iterator(leftMost, _root);
    }

    Iterator End()
    {
        return Iterator(nullptr, _root);
    }

    ConstIterator Begin() const
    {
        Node* leftMost = _root;
        while (leftMost && leftMost->_left)
            leftMost = leftMost->_left;

        return ConstIterator(leftMost, _root);
    }

    ConstIterator End() const
    {
        return ConstIterator(nullptr, _root);
    }

    pair<Iterator, bool> Insert(const T& data)
    {
        if (_root == nullptr)
        {
            _root = new Node(data);
            _root->_col = BLACK;
            return make_pair(Iterator(_root, _root), true);
        }

        KeyOfT kot;
        Node* parent = nullptr;
        Node* cur = _root;

        while (cur)
        {
            if (kot(cur->_data) < kot(data))
            {
                parent = cur;
                cur = cur->_right;
            }
            else if (kot(cur->_data) > kot(data))
            {
                parent = cur;
                cur = cur->_left;
            }
            else
            {
                return make_pair(Iterator(cur, _root), false);
            }
        }

        cur = new Node(data);
        Node* newnode = cur;
        cur->_col = RED;

        if (kot(parent->_data) < kot(data))
            parent->_right = cur;
        else
            parent->_left = cur;

        cur->_parent = parent;

        // 红黑树调整
        while (parent && parent->_col == RED)
        {
            Node* grandfather = parent->_parent;
            if (parent == grandfather->_left)
            {
                Node* uncle = grandfather->_right;
                if (uncle && uncle->_col == RED)
                {
                    parent->_col = BLACK;
                    uncle->_col = BLACK;
                    grandfather->_col = RED;

                    cur = grandfather;
                    parent = cur->_parent;
                }
                else
                {
                    if (cur == parent->_left)
                    {
                        RotateR(grandfather);
                        parent->_col = BLACK;
                        grandfather->_col = RED;
                    }
                    else
                    {
                        RotateL(parent);
                        RotateR(grandfather);
                        cur->_col = BLACK;
                        grandfather->_col = RED;
                    }
                    break;
                }
            }
            else
            {
                Node* uncle = grandfather->_left;
                if (uncle && uncle->_col == RED)
                {
                    parent->_col = BLACK;
                    uncle->_col = BLACK;
                    grandfather->_col = RED;

                    cur = grandfather;
                    parent = cur->_parent;
                }
                else
                {
                    if (cur == parent->_right)
                    {
                        RotateL(grandfather);
                        parent->_col = BLACK;
                        grandfather->_col = RED;
                    }
                    else
                    {
                        RotateR(parent);
                        RotateL(grandfather);
                        cur->_col = BLACK;
                        grandfather->_col = RED;
                    }
                    break;
                }
            }
        }

        _root->_col = BLACK;
        return make_pair(Iterator(newnode, _root), true);
    }

    Iterator Find(const K& key)
    {
        KeyOfT kot;
        Node* cur = _root;
        while (cur)
        {
            if (kot(cur->_data) < key)
                cur = cur->_right;
            else if (kot(cur->_data) > key)
                cur = cur->_left;
            else
                return Iterator(cur, _root);
        }
        return End();
    }

private:
    // 左旋
    void RotateL(Node* parent)
    {
        Node* subR = parent->_right;
        Node* subRL = subR->_left;

        parent->_right = subRL;
        if (subRL)
            subRL->_parent = parent;

        Node* pp = parent->_parent;

        subR->_left = parent;
        parent->_parent = subR;

        if (pp == nullptr)
        {
            _root = subR;
            subR->_parent = nullptr;
        }
        else
        {
            if (pp->_left == parent)
                pp->_left = subR;
            else
                pp->_right = subR;

            subR->_parent = pp;
        }
    }

    // 右旋
    void RotateR(Node* parent)
    {
        Node* subL = parent->_left;
        Node* subLR = subL->_right;

        parent->_left = subLR;
        if (subLR)
            subLR->_parent = parent;

        Node* pp = parent->_parent;

        subL->_right = parent;
        parent->_parent = subL;

        if (pp == nullptr)
        {
            _root = subL;
            subL->_parent = nullptr;
        }
        else
        {
            if (pp->_left == parent)
                pp->_left = subL;
            else
                pp->_right = subL;

            subL->_parent = pp;
        }
    }

private:
    Node* _root = nullptr;
};

四、第二步:封装 MySet.h

set 特点:

  • K 模型
  • 键唯一
  • 迭代器不可修改(const)
  • 底层存储:RBTree<K, const K, SetKeyOfT>
cpp 复制代码
#pragma once
#include "RBTree.h"

namespace bit
{
    template<class K>
    class set
    {
        struct SetKeyOfT
        {
            const K& operator()(const K& key)
            {
                return key;
            }
        };

    public:
        typedef typename RBTree<K, const K, SetKeyOfT>::ConstIterator iterator;
        typedef typename RBTree<K, const K, SetKeyOfT>::ConstIterator const_iterator;

        iterator begin()
        {
            return _t.Begin();
        }

        iterator end()
        {
            return _t.End();
        }

        const_iterator begin() const
        {
            return _t.Begin();
        }

        const_iterator end() const
        {
            return _t.End();
        }

        pair<iterator, bool> insert(const K& key)
        {
            auto ret = _t.Insert(key);
            return make_pair(ret.first, ret.second);
        }

        iterator find(const K& key)
        {
            return _t.Find(key);
        }

    private:
        RBTree<K, const K, SetKeyOfT> _t;
    };
}

五、第三步:封装 MyMap.h

map 特点:

  • KV 模型
  • key 不可修改,value 可修改
  • 支持 operator[]
  • 底层存储:RBTree<K, pair<const K, V>, MapKeyOfT>
cpp 复制代码
#pragma once
#include "RBTree.h"

namespace bit
{
    template<class K, class V>
    class map
    {
        struct MapKeyOfT
        {
            const K& operator()(const pair<K, V>& kv)
            {
                return kv.first;
            }
        };

    public:
        typedef typename RBTree<K, pair<const K, V>, MapKeyOfT>::Iterator iterator;
        typedef typename RBTree<K, pair<const K, V>, MapKeyOfT>::ConstIterator const_iterator;

        iterator begin()
        {
            return _t.Begin();
        }

        iterator end()
        {
            return _t.End();
        }

        const_iterator begin() const
        {
            return _t.Begin();
        }

        const_iterator end() const
        {
            return _t.End();
        }

        pair<iterator, bool> insert(const pair<K, V>& kv)
        {
            return _t.Insert(kv);
        }

        iterator find(const K& key)
        {
            return _t.Find(key);
        }

        V& operator[](const K& key)
        {
            auto ret = insert(make_pair(key, V()));
            return ret.first->second;
        }

    private:
        RBTree<K, pair<const K, V>, MapKeyOfT> _t;
    };
}

六、最精彩的设计点讲解(面试必考)

6.1 一棵红黑树如何支持 set 和 map?

通过三层泛型 + 仿函数

  • set:T=K,取键就是自己
  • map:T=pair<const K,V>,取键取 first
  • 红黑树不关心数据类型,只通过仿函数获取 key

6.2 为什么 set 迭代器不能修改?

因为 set 迭代器被我们定义为:RBTree<K, const K, SetKeyOfT>::ConstIterator迭代器返回的是 const 引用,无法修改。

6.3 为什么 map 不能改 key 只能改 value?

因为 map 存储的是:pair<const K, V>first 被 const 修饰,无法修改。

6.4 map 的 operator [] 如何实现?

cpp 复制代码
V& operator[](const K& key)
{
    auto ret = insert(make_pair(key, V()));
    return ret.first->second;
}

逻辑:

  1. 尝试插入 key,若不存在则默认构造 value
  2. 返回已存在 / 新插入结点的 value 引用

6.5 迭代器 ++/-- 为什么是中序?

因为红黑树是二叉搜索树,中序遍历结果有序。这也是 map/set 迭代器遍历出来是有序的原因。


七、测试代码

cpp 复制代码
#include <iostream>
#include <string>
#include "MySet.h"
#include "MyMap.h"

using namespace std;
using namespace bit;

void test_set()
{
    set<int> s;
    s.insert(3);
    s.insert(1);
    s.insert(5);
    s.insert(2);
    s.insert(4);

    for (auto e : s)
        cout << e << " ";
    cout << endl;
}

void test_map()
{
    map<string, string> dict;
    dict["sort"] = "排序";
    dict["insert"] = "插入";
    dict["erase"] = "删除";

    for (auto& kv : dict)
        cout << kv.first << " : " << kv.second << endl;

    cout << dict["insert"] << endl;
}

int main()
{
    test_set();
    test_map();
    return 0;
}
相关推荐
求学的小高2 小时前
数据结构Day9(图的遍历、图应用及相关算法)
数据结构·笔记·考研
小卓(friendhan2005)2 小时前
基于Qt的音乐播放器项目
数据库·c++·qt
tankeven2 小时前
贪心算法(Greedy Algorithm)详解:从理论到C++实践
c++·算法
Hesionberger2 小时前
LeetCode72.编辑距离(多维动态规划)
java·开发语言·c++·python·算法
lwf0061642 小时前
逻辑回归学习笔记-梯度下降求解回归方程
算法·机器学习·逻辑回归
郝学胜-神的一滴2 小时前
从底层看透Linux高性能服务器:epoll自定义封装与超时清理实战
linux·服务器·c++·网络协议·tcp/ip·unix
人道领域2 小时前
【LeetCode刷题日记】1047:双栈法与双指针法巧妙消除相邻重复字符
java·算法·leetcode·职场和发展
切糕师学AI2 小时前
布隆过滤器(Bloom Filter)技术详解
数学·算法
礼拜天没时间.2 小时前
力扣热题100实战 | 第33期:搜索旋转排序数组——二分查找的变体艺术
算法·leetcode·职场和发展·旋转数组·搜索旋转排序数组