set、multiset、map、multimap 类的简单模拟实现与源码展示
- 一、前言
- 二、框架重构
- [三、multi 多重设计](#三、multi 多重设计)
-
- [find 多重设计](#find 多重设计)
- [find 多重设计源码展示](#find 多重设计源码展示)
- [insert 多重设计](#insert 多重设计)
- [insert 多重设计源码展示](#insert 多重设计源码展示)
- 四、迭代器实现
-
- [iterator 实现思路分析](#iterator 实现思路分析)
- 迭代器实现代码展示
- 五、基础实现
- 六、源码展示
-
- 红黑树源码
- [set 与 multiset](#set 与 multiset)
- [map 与 multimap](#map 与 multimap)
- [七、准确性测试(VS2022 release)](#七、准确性测试(VS2022 release))
-
- [set 测试](#set 测试)
- [multiset 测试](#multiset 测试)
- [map 测试](#map 测试)
- [multimap 测试](#multimap 测试)
- [find 测试(以 multimap 测试)](#find 测试(以 multimap 测试))
一、前言
set、multiset、map、multimap 的底层都使用的是红黑树这个自平衡二叉搜索树,关于红黑树的讲解与源码获取请参考 :红黑树介绍、实现与封装
由于这四种类关联性强,核心部分会放在一起讲解。
以下代码环境为 VS2022 C++。
二、框架重构
注意 STL 源码中的红黑树增加了一个哨兵位头结点做为 end(),这个哨兵位头结点和根互为父亲,左指向最左结点,右指向最右结点。相比我们之前博客讲解的红黑树用 nullptr 作为 end(),差别不大,但这里我们以这种架构为准。
我们需要对一些函数进行修改,对于边界情况的修改比较简单,这里省略。
插入处理
注意插入非关联节点与向上调整时的 4 种旋转并不会影响最右最左节点相对于整个红黑树的位置。
这里只需在插入后加入检查即可:
cpp
if (_sentry->_left->_left != nullptr) // 插入后只需检查最左是否有节点,有就更新
{
_sentry->_left = _sentry->_left->_left;
}
if (_sentry->_right->_right != nullptr) // 插入后只需检查最右是否有节点,有就更新
{
_sentry->_right = _sentry->_right->_right;
}
删除处理
同理于插入,删除无关节点与向上调整时的 4 种旋转并不会影响最右最左节点相对于整个红黑树的位置。
但要注意若最右(最左)节点左边(右边)有节点,需要寻找其中的最左节点:
cpp
if (_sentry->_left == aim) // 被删节点为 sentry 最左节点
{
if (aim->_parent == _sentry) // 若 aim 为根
{
_sentry->_left = aim->_right == nullptr ? _sentry : aim->_right;
}
else if (aim->_parent->_left != nullptr) // aim->_parent->_left 原来是 aim->_right
{ // 这里寻找 aim->_right 里的最左节点
pNode getMostLeft = aim->_parent->_left;
pNode getMostLeftPrev = getMostLeft;
while (getMostLeft != nullptr)
{
getMostLeftPrev = getMostLeft;
getMostLeft = getMostLeft->_left;
}
_sentry->_left = getMostLeftPrev;
}
else // 不然父亲节点就是最左节点
{
_sentry->_left = aim->_parent;
}
}
if (_sentry->_right == aim) // 被删节点为 sentry 最右节点(注意只有根节点时,最左最右节点是一个)
{
if (aim->_parent == _sentry) // 若 aim 为根
{
_sentry->_right = aim->_left == nullptr ? _sentry : aim->_left;
}
else if (aim->_parent->_right != nullptr) // aim->_parent->_right 原来是 aim->_left
{
pNode getMostRight = aim->_parent->_right;
pNode getMostRightPrev = getMostRight;
while (getMostRight != nullptr)
{
getMostRightPrev = getMostRight;
getMostRight = getMostRight->_right;
}
_sentry->_right = getMostRightPrev;
}
else // 不然父亲节点就是最右节点
{
_sentry->_right = aim->_parent;
}
}
三、multi 多重设计
multiset 与 multimap 支持相同 key 值重复,我们只需要修改红黑树的 _find 与 _insert 底层部分逻辑即可。
当然,这需要分开非多重的与多重的执行函数。这里用的是仿函数处理,红黑树采取模版使用仿函数,用来针对 multi 与 非 multi 情况。
find 多重设计
cpp
pNode _find(const Key& key) const
{
pNode cur = _sentry->_parent;
while (cur != nullptr)
{
if (_comKey(_getKey(cur->_data), key) < 0)
{
cur = cur->_right;
}
else if (_comKey(_getKey(cur->_data), key) > 0)
{
cur = cur->_left;
}
else
{
// multi 与 非 multi 仿函数
//...
return _mulFindSolve(cur, _getKey);
}
}
return cur;
}
非 multi 遇到 key 值直接返回,对于 multi 需要找到 key 相同且未被删除的第一个插入的节点 ,也可以说是中序中排在第一个的节点。
旋转会导致节点移动,但总是有一个规律,保证相同节点插入的中序顺序不变,并且不用考虑红黑树使用的是升序还是降序。
这里我们用 multimap 取一个复杂的情况,我们要找的 key 为 9,为观察方便 value 表示相同的 key 第几次插入:
最后我们返回的是 key 为 9,value 为 1 的节点(如果它在查找前被删除,那返回的是 key 为 9,value 为 2 的节点),所以必须使用循环或递归来处理。
find 多重设计源码展示
cpp
struct IsMultiFindSolve
{
template<class pNode, class GetKey>
pNode operator()(pNode cur, const GetKey& _getKey) const
{
pNode getValue = cur; // 记录节点,假设为 key 相同的第一个插入的节点
pNode next = cur->_left; // 左下寻找
while (next != nullptr) // 为空退出
{ // 相同向左下寻找
while (next != nullptr && _getKey(getValue->_data) == _getKey(next->_data))
{
getValue = next; // 记录节点,假设为 key 相同的第一个插入的节点
cur = next;
next = next->_left;
} // 不相同向右下寻找
while (next != nullptr && _getKey(getValue->_data) != _getKey(next->_data))
{
cur = next;
next = next->_right;
}
}
return getValue;
}
};
struct NoMultiFindSolve
{
template<class pNode, class GetKey>
pNode operator()(pNode cur, const GetKey& _getKey) const
{
return cur; // 非 multi 没有重复 key 节点,直接返回
}
};
insert 多重设计
insert 多重与 find 相反,找到之前插入 key 相同的最后一个节点:
-
找到 key 相同向右下移动;
-
找到 key 不同向左下移动;
-
找到 空节点 退出。
但 insert 会使用到节点,注意传引用。
insert 多重设计源码展示
cpp
struct IsMultiInsertSolve
{
template<class pNode, class GetKey>
bool operator()(pNode& cur, pNode& parent, const GetKey& _getKey) const
{ // 引用修改
pNode getValue = parent;
while (cur != nullptr) // 为空退出
{ // 相同向右下寻找
while (cur != nullptr && _getKey(getValue->_data) == _getKey(cur->_data))
{
parent = cur;
cur = cur->_right;
} // 不相同向左下寻找
while (cur != nullptr && _getKey(getValue->_data) != _getKey(cur->_data))
{
parent = cur;
cur = cur->_left;
}
}
return true; // 表示为 multi 类
}
};
struct NoMultiInsertSolve
{
template<class pNode, class GetKey>
bool operator()(pNode& cur, pNode& parent, const GetKey& _getKey) const
{
return false; // 表示为 非 multi 类,不允许重复值插入
}
};
四、迭代器实现
iterator 实现思路分析
-
iterator 实现的大框架跟 list 的 iterator 思路是一致的,用一个类型封装结点的指针,再通过重载运算符实现,迭代器像指针一样访问的行为。这里的难点是 operator++ 和 operator-- 的实现。这 4 种类的迭代器走的都是中序遍历,左子树-> 根结点 -> 右子树,那么 begin() 会返回中序第一个结点的 iterator 也就是 哨兵节点的左孩子结点 的迭代器。
-
迭代器++ 的核心逻辑就是不看全局,只看局部,只考虑当前中序局部要访问的下一个结点。
-
迭代器++ 时,如果 it 指向的结点的右子树不为空 ,代表当前结点已经访问完了,要访问下一个结点是右子树的中序第一个,一棵树中序第一个是最左结点,所以直接找右子树的最左结点即可。
-
迭代器++ 时,如果 it 指向的结点的右子树为空 ,代表当前结点已经访问完了且当前结点所在的子树也访问完了,要访问的下一个结点在当前结点的祖先里面,所以要沿着当前结点到根的路径向上找。
-
迭代器++ 时,如果 it 指向的结点在父亲的左边 ,根据中序 左子树 -> 根结点 -> 右子树,那么下一个访问的结点就是当前结点的父亲。
-
迭代器++ 时,如果 it 指向的结点在父亲的右边 ,根据中序 左子树 -> 根结点 ->右子树,当前当前结点所在的子树访问完,当前结点所在父亲的子树也访问完,那么下一个访问的需要继续往根的祖先中去找,直到找到孩子是父亲左边的那个祖先,就是中序要遍历的下一个结点。
-
当遍历到 end() 结束,也就是 哨兵位节点。
-
迭代器-- 的实现跟 ++ 的思路完全类似,逻辑正好反过来即可,因为它访问顺序是右子树 -> 根结点 -> 左子树。
支持完整的迭代器还有很多细节需要修改,具体参考下面题的代码:
迭代器实现代码展示
cpp
template<class Type, class Reference, class Pointer>
class Iterator
{
typedef RBTreeNode<Type> Node;
typedef Node* pNode;
typedef Iterator<Type, Reference, Pointer> Self;
pNode _point_node = nullptr;
pNode _rb_sentry; // 方便对比
template<class ConstKeyType, class Key, class GetKey, class MultiFindSolve, class MultiInsertSolve, class CompareKey>
friend class RBTree_base; // 红黑树使用 iterator 删除节点
private:
pNode get_pNode() const // 删除节点传指针
{
return _point_node;
}
public:
Iterator(pNode pointer, pNode rb_sentry)
:_point_node(pointer)
, _rb_sentry(rb_sentry)
{
;
}
Iterator(const Iterator& it)
{
_point_node = it._point_node;
_rb_sentry = it._rb_sentry;
}
bool operator==(const Iterator& it) const
{
return _point_node == it._point_node;
}
bool operator!=(const Iterator& it) const
{
return _point_node != it._point_node;
}
Reference operator*() const
{
return _point_node->_data;
}
Pointer operator->() const
{
return &_point_node->_data;
}
Self& operator++()
{
if (_point_node == _rb_sentry) // ++end() 等于回到最左节点
{
_point_node = _point_node->_left;
}
else if (_point_node->_right != nullptr) // 1. 结点的右子树不为空
{
pNode prev = _point_node;
pNode cur = _point_node->_right;
while (cur != nullptr)
{
prev = cur;
cur = cur->_left;
}
_point_node = prev;
}
else // 2. 结点的右子树为空
{
pNode cur = _point_node;
while (cur == cur->_parent->_right) // 4. 结点在父亲的右边
{
cur = cur->_parent;
}
if (cur == _rb_sentry) // 访问到 sentry 也就是 end() 结束
{
_point_node = _rb_sentry;
}
else
{
_point_node = cur->_parent; // 3. 结点在父亲的左边
}
}
return *this;
}
Self& operator--()
{
if (_point_node == _rb_sentry) // --end() 等于回到最右节点
{
_point_node = _point_node->_right;
}
else if (_point_node->_left != nullptr) // 1. 结点的左子树不为空
{
pNode prev = _point_node;
pNode cur = _point_node->_left;
while (cur != nullptr)
{
prev = cur;
cur = cur->_right;
}
_point_node = prev;
}
else // 2. 结点的左子树为空
{
pNode cur = _point_node;
while (cur == cur->_parent->_left) // 4. 结点在父亲的左边
{
cur = cur->_parent;
}
if (cur == _rb_sentry) // 访问到 sentry 也就是 end() 结束
{
_point_node = _rb_sentry;
}
else
{
_point_node = cur->_parent; // 3. 结点在父亲的右边
}
}
return *this;
}
};
template<class Type, class Reference, class Pointer>
class ReverseIterator
{
typedef ReverseIterator<Type, Reference, Pointer> Self;
typedef Iterator<Type, Reference, Pointer> iterator;
typedef RBTreeNode<Type> Node;
typedef Node* pNode;
iterator _it;
public:
ReverseIterator(pNode pointer, pNode rb_sentry)
:_it(pointer, rb_sentry)
{
;
}
bool operator==(const ReverseIterator& rit) const
{
return _it == rit._it;
}
bool operator!=(const ReverseIterator& rit) const
{
return _it != rit._it;
}
Reference operator*() const
{
return *_it;
}
Pointer operator->() const
{
return this->_it;
}
Self& operator++()
{
--_it; // 反向++ 为 --
return *this;
}
Self& operator--()
{
++_it; // 反向-- 为 ++
return *this;
}
};
五、基础实现
const key 处理
key 与 key / value 中的迭代器不允许修改 key 的值,否则会破坏红黑树的平衡结构。
这里有两种方法。
整体 ConstKeyType
在模版参数中的 Type 加上 const 属性,变为 ConstKeyType,即 key 带有 const 属性的红黑树,迭代器返回 key 时用户就不能进行修改。
节点替换法
但要注意,我们之前使用的删除替换法是值替换,ConstKeyType 不管是 const Key,还是 pair<const Key, Value> 都不能进行值交换,这里需要改为节点替换:
由于哨兵位和其他因素,节点替换的难度是远远大于值替换的,有以下几种情况:
-
被删除节点的 parent 为 _sentry
-
被删除节点的 right 为 prev
-
替换节点的 right 为空
这里以最基础的交换慢慢解决上述情况:
cpp
// 1. aim 的 parent 为 _sentry
// 2. aim 的 right 为 prev
// 3. prev 的 right 为空
std::swap(aim->_left->_parent, prev->_left->_parent); // 交换两者左孩子父亲
std::swap(aim->_left, prev->_left); // 交换两者左孩子指向
std::swap(aim->_right->_parent, prev->_right->_parent); // .......右孩子父亲
std::swap(aim->_right, prev->_right); // .......右孩子指向
if (aim->_parent->_left == aim) // 更新被删除节点父亲指向
{
aim->_parent->_left = prev;
}
else
{
aim->_parent->_right = prev;
}
if (prev->_parent->_left == prev) // 更新被替换节点父亲指向
{
prev->_parent->_left = aim;
}
else
{
prev->_parent->_right = aim;
}
std::swap(aim->_parent, prev->_parent); // 交换两者父亲指向
parent = aim->_parent; // parent 更新
std::swap(aim->_color, prev->_color); // 颜色交换
第 3 种情况 替换节点的 right 为空 处理:
cpp
if (prev->_right != nullptr) // 3. 解决
{
std::swap(aim->_right->_parent, prev->_right->_parent);
}
else
{
aim->_right->_parent = prev;
}
std::swap(aim->_right, prev->_right); // 若 prev == aim->_right 会导致 prev 右自旋
第 1 种情况 被删除节点的 parent 为 _sentry 处理:
cpp
if (aim->_parent == _sentry) // 1. 解决
{
_sentry->_parent = prev;
}
else if (aim->_parent->_left == aim)
{
aim->_parent->_left = prev;
}
else
{
aim->_parent->_right = prev;
}
第 2 种情况 被删除节点的 right 为 prev(替换节点) 处理:
cpp
bool self = false; // 2. 解决, prev 右自旋
if (prev->_right == prev)
{
prev->_parent = aim->_parent; // 替换节点 parent 指向 aim->_parent
prev->_right = aim; // 替换节点 right 指向 aim
self = true;
}
else if (prev->_parent->_left == prev)
{
prev->_parent->_left = aim;
}
else
{
prev->_parent->_right = aim;
}
if (self == false)
{
std::swap(aim->_parent, prev->_parent);
}
else
{
aim->_parent = prev;
}
节点替换源码
cpp
void replace(pNode& parent, pNode& aim, pNode& checkParent, pNode& checkChild)
{
pNode prev = aim;
pNode cur = aim->_right; // 从右孩子找最小值
while (cur != nullptr)
{
prev = cur;
cur = cur->_left;
}
// 1. aim 的 parent 为 _sentry
// 2. aim 的 right 为 prev
// 3. prev 的 right 为空
aim->_left->_parent = prev; // prev 左节点一定为空
std::swap(aim->_left, prev->_left);
if (prev->_right != nullptr) // 3. 解决
{
std::swap(aim->_right->_parent, prev->_right->_parent);
}
else
{
aim->_right->_parent = prev;
}
std::swap(aim->_right, prev->_right); // 若 prev == aim->_right 会导致 prev 右自旋
if (aim->_parent == _sentry) // 1. 解决
{
_sentry->_parent = prev;
}
else if (aim->_parent->_left == aim)
{
aim->_parent->_left = prev;
}
else
{
aim->_parent->_right = prev;
}
bool self = false; // 2. 解决, prev 右自旋
if (prev->_right == prev)
{
prev->_parent = aim->_parent; // 替换节点 parent 指向 aim->_parent
prev->_right = aim; // 替换节点 right 指向 aim
self = true;
}
else if (prev->_parent->_left == prev)
{
prev->_parent->_left = aim;
}
else
{
prev->_parent->_right = aim;
}
if (self == false) // 防止 prev->_parent == prev,即上自旋
{
std::swap(aim->_parent, prev->_parent);
}
else
{
aim->_parent = prev;
}
parent = aim->_parent; // parent 更新
std::swap(aim->_color, prev->_color); // 颜色交换
if (aim->_right != nullptr) // 不为空走第 2 种情况,反之第 1 种情况
{
erase_NH_or_HN(parent, aim, aim->_right);
}
else
{
eraseNN(parent, aim, checkParent, checkChild);
}
}
迭代器 const key
迭代器 const key 不用让红黑树的 key 加上 const 属性,可以使用 值替换法。
但由于要考虑 ConstKeyType 与 NoConstKeyType 区别,我们还要在红黑树模版参数上加上 NoConstKeyType。
在迭代器传模版参数时,第一个为无 const , 第二、三个需要传:
cpp
// Type const Key 或 pair<const Key, Value> 引用与指针
typedef Iterator <NoConstKeyType, ConstKeyType&, ConstKeyType*> iterator;
typedef Iterator <NoConstKeyType, const ConstKeyType&, const ConstKeyType*> const_iterator;
typedef ReverseIterator <NoConstKeyType, ConstKeyType&, ConstKeyType*> reverse_iterator;
typedef ReverseIterator <NoConstKeyType, const ConstKeyType&, const ConstKeyType*> const_reverse_iterator;
我们只是考虑迭代器不能修改 key,则迭代器返回类型直接强转为 ConstKeyType 即可:
cpp
Reference operator*() const
{
return (Reference)(_point_node->_data); // _point_node->_data 类型为 NoConstKeyType&
} // Reference 类型为 ConstKeyType&,需要强转
Pointer operator->() const
{
return (Pointer)(&_point_node->_data); // &_point_node->_data 类型为 NoConstKeyType*
} // Pointer 类型为 ConstKeyType*,需要强转
但是注意,multi 删除 key 会连续删除,右替换法替换删除会将下一个节点数据删除,不方便后续我们实现连续删除。
右孩子替换法处理
这里将右替换法改为左替换即可:
cpp
void replace(pNode& parent, pNode& aim, pNode& checkParent, pNode& checkChild)
{
//pNode prev = aim;
//pNode cur = aim->_right; // 从右孩子找最左边
//while (cur != nullptr) // 右替换法
//{
// prev = cur;
// cur = cur->_left;
//}
pNode prev = aim;
pNode cur = aim->_left; // 从左孩子找最右边
while (cur != nullptr) // 左替换法
{
prev = cur;
cur = cur->_right;
}
std::swap(prev->_data, aim->_data); // 值替换值
parent = prev->_parent; // parent 更新
std::swap(prev, aim); // 指针交换
if (aim->_right != nullptr) // 不为空走第 2 种情况,反之第 1 种情况
{
erase_NH_or_HN(parent, aim, aim->_right);
}
else
{
eraseNN(parent, aim, checkParent, checkChild);
}
}
两种方法红黑树框架不一样,下面使用的结构都是 整体 ConstKeyType 的方法。
红黑树内部基本函数实现
四个类的大部分函数都是一样的,我们先在红黑树内部实现相同的,之后再具体处理。
迭代器实现
迭代器实现比较简单,这里省略。
cpp
typedef Iterator <ConstKeyType, ConstKeyType&, ConstKeyType*> iterator;
typedef Iterator <ConstKeyType, const ConstKeyType&, const ConstKeyType*> const_iterator;
typedef ReverseIterator <ConstKeyType, ConstKeyType&, ConstKeyType*> reverse_iterator;
typedef ReverseIterator <ConstKeyType, const ConstKeyType&, const ConstKeyType*> const_reverse_iterator;
iterator begin()
{
assert(size());
return iterator(_sentry->_left, _sentry);
}
iterator end()
{
return iterator(_sentry, _sentry);
}
const_iterator begin() const
{
assert(size());
return const_iterator(_sentry->_left, _sentry);
}
const_iterator end() const
{
return const_iterator(_sentry, _sentry);
}
reverse_iterator rbegin()
{
assert(size());
return reverse_iterator(_sentry->_right, _sentry);
}
reverse_iterator rend()
{
return reverse_iterator(_sentry, _sentry);
}
const_reverse_iterator rbegin() const
{
assert(size());
return const_reverse_iterator(_sentry->_right, _sentry);
}
const_reverse_iterator rend() const
{
return const_reverse_iterator(_sentry, _sentry);
}
删除实现
删除实现请参考:
std::set::erase
std::multiset::erase
std::map::erase
std::multimap::erase
可以看到它们除了 key_type 与 value_type,都一样,而 key_type 是 set 的 key,value_type 是 map 的 key,这里可以放在红黑树内部处理,它们四个类使用时直接调用即可。
cpp
size_t erase(const Key& key) // 为兼容 multi 返回值传 size_t
{
if (_size == 0)
{
return false;
}
iterator eraseElement = find(key);
iterator next = eraseElement;
++next;
size_t count = 0;
bool sameValue = true; // 兼容 multi 连续删除 key 相同的节点
while (eraseElement != end() && sameValue == true)
{
sameValue = (_getKey(*eraseElement) == _getKey(*next));
_erase(eraseElement.get_pNode()); // 传地址删除
eraseElement = next;
++next; // 若 next == end(),++end() 会移动到最左节点
++count;
}
_size -= count;
return count;
}
iterator erase(const_iterator position)
{
iterator next = iterator(position.get_pNode(), _sentry);
++next;
_erase(position.get_pNode()); // 传地址删除
--_size;
return next;
}
iterator erase(const_iterator first, const_iterator last)
{
iterator next = iterator(last.get_pNode(), _sentry);
size_t count = 0;
while (first != last)
{
++count;
const_iterator temp = first;
++first;
_erase(temp.get_pNode()); // 传地址删除
}
_size -= count;
return next;
}
查找实现
查找实现请参考:
std::set::find
std::multiset::find
std::map::find
std::multimap::find
key_type 与 value_type 这里也是一样的:
cpp
const_iterator find(const Key& key) const
{
const_iterator it = const_iterator(_find(key), _sentry);
return it == const_iterator(nullptr, _sentry) ? end() : it;
}
iterator find(const Key& key)
{
iterator it = iterator(_find(key), _sentry);
return it == iterator(nullptr, _sentry) ? end() : it;
}
插入实现
插入实现请参考:
std::set::insert
std::multiset::insert
std::map::insert
std::multimap::insert
这里只实现 3 个,包括 迭代器插入实现、初始化列表实现、正常插入。
迭代器插入实现、初始化列表实现可以直接在红黑树内部实现。正常插入我们可以先在红黑树中实现返回值为 std::pair<iterator, bool> 的 insert 来兼容 非 multi 类型,在 四个类中 再分别进行适配处理。
cpp
std::pair<iterator, bool> insert(const ConstKeyType& data)
{
std::pair<iterator, bool> access = _insert(data);
if (access.second == true)
{
++_size;
}
return access;
}
template <class InputIterator>
void insert(InputIterator first, InputIterator last)
{
while (first != last)
{
insert(*first);
++first;
}
}
void insert(std::initializer_list<ConstKeyType> list)
{
for (const auto& e : list)
{
insert(e);
}
}
计数实现
计数实现请参考:
std::set::count
std::multiset::count
std::map::count
std::multimap::count
为兼容 multi 的计数,红黑树中实现的计数返回的是 size_t 类型:
cpp
size_t count(const Key& key) const
{
const_iterator it = find(key);
size_t thecount = 0;
while (it != end() && _getKey(*it) == key)
{
++thecount;
++it;
}
return thecount;
}
map operator[] 处理
map operator[] 实现请参考:
std::map::operator[]
map 的 operator[] 比较特殊,当在 map 中找到对应的 key 会返回对应的 value ,如果没有对应的 key 会插入 key 值,其 value 会使用其默认构造函数构造。
四个类中只有它需要使用,这里拿出来单独在 map 中实现:
cpp
Value& operator[](const Key& key)
{
iterator it = this->find(key);
if (it == this->end()) // 没有则插入
{
auto get = this->_base.insert({ key, Value() });
it = get.first;
}
// 正常情况需要调用 it 的 operator->() 返回 pair 的地址,
// 再使用 -> 对 first 或 second 访问
// 这里为了 可阅读性 编译器做了特殊处理,
//return it.operator->()->second; // 也可以使用这种,但是可阅读性差
return it->second; // 特殊处理
}
set、multiset、map、multimap 类设计
我们可以设计一个层次结构来具体处理上面遗留的问题:
-
红黑树使用模版将 key、key / value 与 multi、非 multi 核心问题交给下一级处理。
-
我们使用 set_base 和 map_base 与红黑树组合,处理 key、key / value 问题 与 函数复用,将 multi、非 multi 问题交给下一级处理。
-
set、multiset 继承 set_base,解决 multi、非 multi 问题 与 部分函数不同,map、multimap 继承 map_base,解决 multi、非 multi 问题 与 部分函数不同。
用图来表示更清晰:
源码放在下一部分。
六、源码展示
红黑树源码
在 RBTree.hpp 中:
cpp
#pragma once
#include <iostream>
#include <cassert>
#include <utility>
namespace my
{
enum Color
{
RED,
BLACK
};
template<class Type>
struct RBTreeNode
{
RBTreeNode(const Type& data = Type())
:_data(data)
, _left(nullptr)
, _right(nullptr)
, _parent(nullptr)
, _color(RED)
{
;
}
RBTreeNode(const Type& data, bool is_sentry)
:_data(data)
, _left(this)
, _right(this)
, _parent(nullptr)
, _color(RED)
{
;
}
RBTreeNode* _left;
RBTreeNode* _right;
RBTreeNode* _parent;
Type _data;
Color _color;
};
template<class Type, class Reference, class Pointer>
class Iterator
{
typedef RBTreeNode<Type> Node;
typedef Node* pNode;
typedef Iterator<Type, Reference, Pointer> Self;
pNode _point_node = nullptr;
pNode _rb_sentry; // 方便对比
template<class ConstKeyType, class Key, class GetKey, class MultiFindSolve, class MultiInsertSolve, class CompareKey>
friend class RBTree_base; // 红黑树使用 iterator 删除节点
private:
pNode get_pNode() const // 删除节点传指针
{
return _point_node;
}
public:
Iterator(pNode pointer, pNode rb_sentry)
:_point_node(pointer)
, _rb_sentry(rb_sentry)
{
;
}
Iterator(const Iterator& it)
{
_point_node = it._point_node;
_rb_sentry = it._rb_sentry;
}
bool operator==(const Iterator& it) const
{
return _point_node == it._point_node;
}
bool operator!=(const Iterator& it) const
{
return _point_node != it._point_node;
}
Reference operator*() const
{
return _point_node->_data;
}
Pointer operator->() const
{
return &_point_node->_data;
}
Self& operator++()
{
if (_point_node == _rb_sentry) // ++end() 等于回到最左节点
{
_point_node = _point_node->_left;
}
else if (_point_node->_right != nullptr) // 1. 结点的右子树不为空
{
pNode prev = _point_node;
pNode cur = _point_node->_right;
while (cur != nullptr)
{
prev = cur;
cur = cur->_left;
}
_point_node = prev;
}
else // 2. 结点的右子树为空
{
pNode cur = _point_node;
while (cur == cur->_parent->_right) // 4. 结点在父亲的右边
{
cur = cur->_parent;
}
if (cur == _rb_sentry) // 访问到 sentry 也就是 end() 结束
{
_point_node = _rb_sentry;
}
else
{
_point_node = cur->_parent; // 3. 结点在父亲的左边
}
}
return *this;
}
Self& operator--()
{
if (_point_node == _rb_sentry) // --end() 等于回到最右节点
{
_point_node = _point_node->_right;
}
else if (_point_node->_left != nullptr) // 1. 结点的左子树不为空
{
pNode prev = _point_node;
pNode cur = _point_node->_left;
while (cur != nullptr)
{
prev = cur;
cur = cur->_right;
}
_point_node = prev;
}
else // 2. 结点的左子树为空
{
pNode cur = _point_node;
while (cur == cur->_parent->_left) // 4. 结点在父亲的左边
{
cur = cur->_parent;
}
if (cur == _rb_sentry) // 访问到 sentry 也就是 end() 结束
{
_point_node = _rb_sentry;
}
else
{
_point_node = cur->_parent; // 3. 结点在父亲的右边
}
}
return *this;
}
};
template<class Type, class Reference, class Pointer>
class ReverseIterator
{
typedef ReverseIterator<Type, Reference, Pointer> Self;
typedef Iterator<Type, Reference, Pointer> iterator;
typedef RBTreeNode<Type> Node;
typedef Node* pNode;
iterator _it;
public:
ReverseIterator(pNode pointer, pNode rb_sentry)
:_it(pointer, rb_sentry)
{
;
}
bool operator==(const ReverseIterator& rit) const
{
return _it == rit._it;
}
bool operator!=(const ReverseIterator& rit) const
{
return _it != rit._it;
}
Reference operator*() const
{
return *_it;
}
Pointer operator->() const
{
return this->_it;
}
Self& operator++()
{
--_it; // 反向++ 为 --
return *this;
}
Self& operator--()
{
++_it; // 反向-- 为 ++
return *this;
}
};
template<class ConstKeyType, class Key, class GetKey, class MultiFindSolve, class MultiInsertSolve, class CompareKey>
class RBTree_base
{
typedef RBTreeNode<ConstKeyType> Node;
typedef Node* pNode;
static constexpr const GetKey _getKey = GetKey();
static constexpr const CompareKey _comKey = CompareKey();
static constexpr const MultiFindSolve _mulFindSolve = MultiFindSolve();
static constexpr const MultiInsertSolve _mulInsertSolve = MultiInsertSolve();
pNode _sentry = new Node(ConstKeyType(), true);
size_t _size = 0;
private:
void rotateR(pNode cur)
{
pNode subL = cur->_left;
pNode subLR = subL->_right;
pNode parent = cur->_parent;
cur->_left = subLR;
if (subLR != nullptr)
{
subLR->_parent = cur;
}
subL->_right = cur;
cur->_parent = subL;
subL->_parent = parent;
if (parent == _sentry)
{
_sentry->_parent = subL;
}
else if (parent->_left == cur)
{
parent->_left = subL;
}
else
{
parent->_right = subL;
}
}
void rotateL(pNode cur)
{
pNode subR = cur->_right;
pNode subRL = subR->_left;
pNode parent = cur->_parent;
cur->_right = subRL;
if (subRL != nullptr)
{
subRL->_parent = cur;
}
subR->_left = cur;
cur->_parent = subR;
subR->_parent = parent;
if (parent == _sentry)
{
_sentry->_parent = subR;
}
else if (parent->_left == cur)
{
parent->_left = subR;
}
else
{
parent->_right = subR;
}
}
private:
static void _treeCopy(pNode& des, pNode src)
{
if (src == nullptr)
{
return;
}
des = new Node(src);
des->_color = src->_color;
_treeCopy(des->_left, src->_left);
_treeCopy(des->_right, des->_right);
if (des->_left != nullptr)
{
des->_left->_parent = des;
}
if (des->_right != nullptr)
{
des->_right->_parent = des;
}
}
static void _treeDestory(pNode root)
{
if (root == nullptr)
{
return;
}
_treeDestory(root->_left);
_treeDestory(root->_right);
delete root;
}
static size_t _high(pNode root)
{
if (root == nullptr)
{
return 0;
}
size_t left = _high(root->_left);
size_t right = _high(root->_right);
return (right > left ? right : left) + 1;
}
private:
void erase_NH_or_HN(pNode parent, pNode aim, pNode aimChild)
{
if (parent == _sentry) // 为根节点时
{
_sentry->_parent = aimChild;
_sentry->_parent->_parent = _sentry;
}
else if (parent->_left == aim) // 在左边
{
parent->_left = aimChild;
parent->_left->_parent = parent;
parent->_left->_color = BLACK;
}
else // 在右边
{
parent->_right = aimChild;
parent->_right->_parent = parent;
parent->_right->_color = BLACK;
}
}
void replace(pNode& parent, pNode& aim, pNode& checkParent, pNode& checkChild)
{
pNode prev = aim;
pNode cur = aim->_right; // 从右孩子找最小值
while (cur != nullptr)
{
prev = cur;
cur = cur->_left;
}
// 1. aim 的 parent 为 _sentry
// 2. aim 的 right 为 prev
// 3. prev 的 right 为空
aim->_left->_parent = prev; // prev 左节点一定为空
std::swap(aim->_left, prev->_left);
if (prev->_right != nullptr) // 3. 解决
{
std::swap(aim->_right->_parent, prev->_right->_parent);
}
else
{
aim->_right->_parent = prev;
}
std::swap(aim->_right, prev->_right); // 若 prev == aim->_right 会导致 prev 右自旋
if (aim->_parent == _sentry) // 1. 解决
{
_sentry->_parent = prev;
}
else if (aim->_parent->_left == aim)
{
aim->_parent->_left = prev;
}
else
{
aim->_parent->_right = prev;
}
bool self = false; // 2. 解决, prev 右自旋
if (prev->_right == prev)
{
prev->_parent = aim->_parent; // 替换节点 parent 指向 aim->_parent
prev->_right = aim; // 替换节点 right 指向 aim
self = true;
}
else if (prev->_parent->_left == prev)
{
prev->_parent->_left = aim;
}
else
{
prev->_parent->_right = aim;
}
if (self == false) // 防止 prev->_parent == prev,即上自旋
{
std::swap(aim->_parent, prev->_parent);
}
else
{
aim->_parent = prev;
}
parent = aim->_parent; // parent 更新
std::swap(aim->_color, prev->_color); // 颜色交换
if (aim->_right != nullptr) // 不为空走第 2 种情况,反之第 1 种情况
{
erase_NH_or_HN(parent, aim, aim->_right);
}
else
{
eraseNN(parent, aim, checkParent, checkChild);
}
}
void eraseNN(pNode& parent, pNode& aim, pNode& checkParent, pNode& checkChild)
{
if (parent == _sentry)
{
_sentry->_parent = nullptr;
//_sentry->_left = _sentry; // 下面已经处理
//_sentry->_right = _sentry;
}
else if (aim->_color == RED) // 为红色直接置空即可
{
if (parent->_left == aim)
{
parent->_left = nullptr;
}
else
{
parent->_right = nullptr;
}
}
else if (parent->_color == RED) // parent 为红色的枚举
{
if (parent->_left == aim) // 在左边
{
parent->_left = nullptr;
pNode uncle = parent->_right; // 此时 uncle 一定为黑色
if (uncle->_left == nullptr && uncle->_right == nullptr)
{
parent->_color = BLACK;
uncle->_color = RED;
}
else if (uncle->_left == nullptr)
{
rotateL(parent);
}
else if (uncle->_right == nullptr)
{
rotateR(uncle);
rotateL(parent);
parent->_color = BLACK;
}
else
{
rotateL(parent);
uncle->_color = RED;
parent->_color = uncle->_right->_color = BLACK;
}
}
else // 在右边
{
parent->_right = nullptr;
pNode uncle = parent->_left; // 此时 uncle 一定为黑色
if (uncle->_left == nullptr && uncle->_right == nullptr)
{
parent->_color = BLACK;
uncle->_color = RED;
}
else if (uncle->_left == nullptr)
{
rotateL(uncle);
rotateR(parent);
parent->_color = BLACK;
}
else if (uncle->_right == nullptr)
{
rotateR(parent);
}
else
{
rotateR(parent);
uncle->_color = RED;
parent->_color = uncle->_left->_color = BLACK;
}
}
}
else // parent 为黑色的枚举
{
if (parent->_left == aim) // 在左边
{
pNode uncle = parent->_right;
parent->_left = nullptr;
if (uncle->_left == nullptr && uncle->_right == nullptr)
{
uncle->_color = RED;
// 需要向上调整
checkChild = parent;
checkParent = checkChild->_parent;
}
else if (uncle->_left == nullptr)
{
rotateL(parent);
uncle->_right->_color = BLACK;
}
else if (uncle->_right == nullptr)
{
rotateR(uncle);
rotateL(parent);
uncle->_parent->_color = BLACK;
}
else if (uncle->_color == BLACK)
{
rotateL(parent);
uncle->_right->_color = BLACK;
}
else // 局部向上调整
{
parent->_left = aim;
rotateL(parent);
std::swap(parent->_color, uncle->_color);
eraseNN(parent, aim, checkParent, checkChild);
}
}
else // 在右边
{
pNode uncle = parent->_left;
parent->_right = nullptr;
if (uncle->_left == nullptr && uncle->_right == nullptr)
{
uncle->_color = RED;
// 需要向上调整
checkChild = parent;
checkParent = checkChild->_parent;
}
else if (uncle->_left == nullptr)
{
rotateL(uncle);
rotateR(parent);
uncle->_parent->_color = BLACK;
}
else if (uncle->_right == nullptr)
{
rotateR(parent);
uncle->_left->_color = BLACK;
}
else if (uncle->_color == BLACK)
{
rotateR(parent);
uncle->_left->_color = BLACK;
}
else // 局部向上调整
{
parent->_right = aim;
rotateR(parent);
std::swap(parent->_color, uncle->_color);
eraseNN(parent, aim, checkParent, checkChild);
}
}
}
}
void parentIsRedAdjust_constSolve(pNode parent, pNode child)
{
if (parent->_left == child) // 在左边
{
pNode brother = parent->_right;
if (brother->_left->_color == BLACK && brother->_right->_color == BLACK)
{
parent->_color = BLACK;
brother->_color = RED;
}
else if (brother->_left->_color == BLACK)
{
rotateL(parent);
}
else if (brother->_right->_color == BLACK)
{
std::swap(brother->_left->_color, brother->_color);
rotateR(brother);
rotateL(parent);
}
else
{
rotateR(brother);
rotateL(parent);
parent->_color = BLACK;
}
}
else // 在右边
{
pNode brother = parent->_left;
if (brother->_left->_color == BLACK && brother->_right->_color == BLACK)
{
parent->_color = BLACK;
brother->_color = RED;
}
else if (brother->_left->_color == BLACK)
{
std::swap(brother->_right->_color, brother->_color);
rotateL(brother);
rotateR(parent);
}
else if (brother->_right->_color == BLACK)
{
rotateR(parent);
}
else
{
rotateL(brother);
rotateR(parent);
parent->_color = BLACK;
}
}
}
bool parentIsBlackAdjust_solve(pNode parent, pNode child)
{
if (parent->_left == child) // 在左边
{
pNode brother = parent->_right;
if (brother->_color == RED)
{
std::swap(brother->_color, parent->_color);
rotateL(parent);
parentIsRedAdjust_constSolve(parent, child); // 局部套用
}
else if (brother->_left->_color == BLACK && brother->_right->_color == BLACK)
{
brother->_color = RED;
return true; // 表示需要再次向上调整
}
else if (brother->_left->_color == BLACK)
{
brother->_right->_color = BLACK;
rotateL(parent);
}
else if (brother->_right->_color == BLACK)
{
brother->_left->_color = BLACK;
rotateR(brother);
rotateL(parent);
}
else
{
brother->_left->_color = BLACK;
rotateR(brother);
rotateL(parent);
}
}
else // 在右边
{
pNode brother = parent->_left;
if (brother->_color == RED)
{
std::swap(brother->_color, parent->_color);
rotateR(parent);
parentIsRedAdjust_constSolve(parent, child); // 局部套用
}
else if (brother->_left->_color == BLACK && brother->_right->_color == BLACK)
{
brother->_color = RED;
return true; // 表示需要再次向上调整
}
else if (brother->_left->_color == BLACK)
{
brother->_right->_color = BLACK;
rotateL(brother);
rotateR(parent);
}
else if (brother->_right->_color == BLACK)
{
brother->_left->_color = BLACK;
rotateR(parent);
}
else
{
brother->_right->_color = BLACK;
rotateL(brother);
rotateR(parent);
}
}
return false; // 表示不需要向上调整
}
public:
typedef Iterator <ConstKeyType, ConstKeyType&, ConstKeyType*> iterator;
typedef Iterator <ConstKeyType, const ConstKeyType&, const ConstKeyType*> const_iterator;
typedef ReverseIterator <ConstKeyType, ConstKeyType&, ConstKeyType*> reverse_iterator;
typedef ReverseIterator <ConstKeyType, const ConstKeyType&, const ConstKeyType*> const_reverse_iterator;
iterator begin()
{
assert(size());
return iterator(_sentry->_left, _sentry);
}
iterator end()
{
return iterator(_sentry, _sentry);
}
const_iterator begin() const
{
assert(size());
return const_iterator(_sentry->_left, _sentry);
}
const_iterator end() const
{
return const_iterator(_sentry, _sentry);
}
reverse_iterator rbegin()
{
assert(size());
return reverse_iterator(_sentry->_right, _sentry);
}
reverse_iterator rend()
{
return reverse_iterator(_sentry, _sentry);
}
const_reverse_iterator rbegin() const
{
assert(size());
return const_reverse_iterator(_sentry->_right, _sentry);
}
const_reverse_iterator rend() const
{
return const_reverse_iterator(_sentry, _sentry);
}
private:
pNode _find(const Key& key) const
{
pNode cur = _sentry->_parent;
while (cur != nullptr)
{
if (_comKey(_getKey(cur->_data), key) < 0)
{
cur = cur->_right;
}
else if (_comKey(_getKey(cur->_data), key) > 0)
{
cur = cur->_left;
}
else
{
// multi 与 非 multi 仿函数
//...
return _mulFindSolve(cur, _getKey);
}
}
return cur;
}
std::pair<iterator, bool> _insert(const ConstKeyType& data)
{
if (_sentry->_parent == nullptr) // 无根节点时
{
_sentry->_parent = new Node(data);
_sentry->_parent->_parent = _sentry;
_sentry->_left = _sentry->_right = _sentry->_parent;
_sentry->_parent->_color = BLACK;
return { iterator(_sentry->_parent, _sentry), true };
}
pNode parent = _sentry->_parent;
pNode cur = _sentry->_parent;
while (cur != nullptr)
{
parent = cur;
if (_comKey(_getKey(cur->_data), _getKey(data)) < 0)
{
cur = cur->_right;
}
else if (_comKey(_getKey(cur->_data), _getKey(data)) > 0)
{
cur = cur->_left;
}
else
{
// multi 与 非 multi 仿函数
//......
if (_mulInsertSolve(cur, parent, _getKey) == false)
{
return { iterator(cur, _sentry), false };
}
}
}
cur = new Node(data);
pNode newNodePointer = cur;
if (_comKey(_getKey(parent->_data), _getKey(data)) <= 0)
{
parent->_right = cur;
}
else
{
parent->_left = cur;
}
cur->_parent = parent;
pNode grandparent = parent->_parent;
while (grandparent != _sentry && parent != _sentry)
{
if (parent->_color == BLACK)
{
break;
}
if (grandparent->_left == parent)
{
pNode uncle = grandparent->_right;
if (uncle && uncle->_color == RED)
{
grandparent->_color = RED;
uncle->_color = parent->_color = BLACK;
}
else
{
if (parent->_left == cur)
{
rotateR(grandparent);
grandparent->_color = RED;
parent->_color = BLACK;
}
else
{
rotateL(parent);
rotateR(grandparent);
grandparent->_color = RED;
cur->_color = BLACK;
}
break;
}
}
else
{
pNode uncle = grandparent->_left;
if (uncle && uncle->_color == RED)
{
grandparent->_color = RED;
uncle->_color = parent->_color = BLACK;
}
else
{
if (parent->_right == cur)
{
rotateL(grandparent);
grandparent->_color = RED;
parent->_color = BLACK;
}
else
{
rotateR(parent);
rotateL(grandparent);
grandparent->_color = RED;
cur->_color = BLACK;
}
break;
}
}
cur = grandparent;
parent = cur->_parent;
grandparent = parent->_parent;
}
_sentry->_parent->_color = BLACK;
if (_sentry->_left->_left != nullptr) // 插入后只需检查最左是否有节点,有就更新
{
_sentry->_left = _sentry->_left->_left;
}
if (_sentry->_right->_right != nullptr) // 插入后只需检查最右是否有节点,有就更新
{
_sentry->_right = _sentry->_right->_right;
}
return { iterator(newNodePointer, _sentry), true };
}
bool _erase(pNode aim)
{
if (aim == nullptr || aim == _sentry)
{
return false;
}
pNode parent = aim->_parent;
pNode checkParent = _sentry;
pNode checkChild = _sentry;
if (aim->_left == nullptr && aim->_right == nullptr)
{
eraseNN(parent, aim, checkParent, checkChild); // 可能会向上调整
}
else if (aim->_left == nullptr)
{
erase_NH_or_HN(parent, aim, aim->_right);
}
else if (aim->_right == nullptr)
{
erase_NH_or_HN(parent, aim, aim->_left);
}
else
{
replace(parent, aim, checkParent, checkChild); // 第 1 种情况可能会向上调整
}
while (checkParent != _sentry) // 当遍历完红黑树的根节点时就会退出
{
if (checkParent->_color == RED) // checkParent 为红色,只需常数次的修改,修改后退出
{
parentIsRedAdjust_constSolve(checkParent, checkChild);
break;
}
bool need_to_up = parentIsBlackAdjust_solve(checkParent, checkChild);
checkChild = checkParent; // false 表示不需要向上调整,反之需要
checkParent = (need_to_up == false ? _sentry : checkParent->_parent);
}
if (_sentry->_left == aim) // 被删节点为 sentry 最左节点
{
if (aim->_parent == _sentry) // 若 aim 为根
{
_sentry->_left = aim->_right == nullptr ? _sentry : aim->_right;
}
else if (aim->_parent->_left != nullptr) // aim->_parent->_left 原来是 aim->_right
{ // 这里寻找 aim->_right 里的最左节点
pNode getMostLeft = aim->_parent->_left;
pNode getMostLeftPrev = getMostLeft;
while (getMostLeft != nullptr)
{
getMostLeftPrev = getMostLeft;
getMostLeft = getMostLeft->_left;
}
_sentry->_left = getMostLeftPrev;
}
else // 不然父亲节点就是最左节点
{
_sentry->_left = aim->_parent;
}
}
if (_sentry->_right == aim) // 被删节点为 sentry 最右节点(注意只有根节点时,最左最右节点是一个)
{
if (aim->_parent == _sentry) // 若 aim 为根
{
_sentry->_right = aim->_left == nullptr ? _sentry : aim->_left;
}
else if (aim->_parent->_right != nullptr) // aim->_parent->_right 原来是 aim->_left
{
pNode getMostRight = aim->_parent->_right;
pNode getMostRightPrev = getMostRight;
while (getMostRight != nullptr)
{
getMostRightPrev = getMostRight;
getMostRight = getMostRight->_right;
}
_sentry->_right = getMostRightPrev;
}
else // 不然父亲节点就是最右节点
{
_sentry->_right = aim->_parent;
}
}
if (_sentry->_parent != nullptr)
{
_sentry->_parent->_color = BLACK;
}
delete aim; // 删除节点
return true;
}
public:
RBTree_base() = default;
RBTree_base(const RBTree_base& tree)
{
_treeCopy(_sentry->_parent, tree._sentry->_parent);
_size = tree._size;
}
RBTree_base(RBTree_base&& tree) noexcept
{
swap(tree);
}
RBTree_base& operator=(RBTree_base tree)
{
swap(tree);
return *this;
}
RBTree_base& operator=(RBTree_base&& tree)
{
swap(tree);
return *this;
}
~RBTree_base()
{
clear();
delete _sentry;
_sentry = nullptr;
}
template<class InputIterator>
RBTree_base(InputIterator begin, InputIterator end)
{
while (begin != end)
{
insert(*begin);
++begin;
}
}
RBTree_base(std::initializer_list<ConstKeyType> list)
{
for (const ConstKeyType& e : list)
{
insert(e);
}
}
public:
void clear()
{
if (_sentry->_parent != nullptr)
{
_treeDestory(_sentry->_parent);
}
_sentry->_parent = nullptr;
_sentry->_left = _sentry;
_sentry->_right = _sentry;
_size = 0;
}
void swap(const RBTree_base& tree)
{
std::swap(_sentry->_parent, tree._sentry->_parent);
std::swap(_size, tree._size);
}
const_iterator find(const Key& key) const
{
const_iterator it = const_iterator(_find(key), _sentry);
return it == const_iterator(nullptr, _sentry) ? end() : it;
}
iterator find(const Key& key)
{
iterator it = iterator(_find(key), _sentry);
return it == iterator(nullptr, _sentry) ? end() : it;
}
std::pair<iterator, bool> insert(const ConstKeyType& data)
{
std::pair<iterator, bool> access = _insert(data);
if (access.second == true)
{
++_size;
}
return access;
}
template <class InputIterator>
void insert(InputIterator first, InputIterator last)
{
while (first != last)
{
insert(*first);
++first;
}
}
void insert(std::initializer_list<ConstKeyType> list)
{
for (const auto& e : list)
{
insert(e);
}
}
size_t erase(const Key& key) // 为兼容 multi 返回值传 size_t
{
if (_size == 0)
{
return false;
}
iterator eraseElement = find(key);
iterator next = eraseElement;
++next;
size_t count = 0;
bool sameValue = true; // 兼容 multi 连续删除 key 相同的节点
while (eraseElement != end() && sameValue == true)
{
sameValue = (_getKey(*eraseElement) == _getKey(*next));
_erase(eraseElement.get_pNode()); // 传地址删除
eraseElement = next;
++next; // 若 next == end(),++end() 会移动到最左节点
++count;
}
_size -= count;
return count;
}
iterator erase(const_iterator position)
{
iterator next = iterator(position.get_pNode(), _sentry);
++next;
_erase(position.get_pNode()); // 传地址删除
--_size;
return next;
}
iterator erase(const_iterator first, const_iterator last)
{
iterator next = iterator(last.get_pNode(), _sentry);
size_t count = 0;
while (first != last)
{
++count;
const_iterator temp = first;
++first;
_erase(temp.get_pNode()); // 传地址删除
}
_size -= count;
return next;
}
size_t count(const Key& key) const
{
const_iterator it = find(key);
size_t thecount = 0;
while (it != end() && _getKey(*it) == key)
{
++thecount;
++it;
}
return thecount;
}
size_t high() const
{
return _high(_sentry->_parent);
}
size_t size() const
{
return _size;
}
};
template<class Key>
struct less
{
int operator()(const Key& one, const Key& two) const
{
if (one < two)
{
return -1;
}
else if (one > two)
{
return 1;
}
return 0;
}
};
template<class Key>
struct greater
{
int operator()(const Key& one, const Key& two) const
{
if (one < two)
{
return 1;
}
else if (one > two)
{
return -1;
}
return 0;
}
};
struct IsMultiFindSolve
{
template<class pNode, class GetKey>
pNode operator()(pNode cur, const GetKey& _getKey) const
{
pNode getValue = cur; // 记录节点,假设为 key 相同的第一个插入的节点
pNode next = cur->_left; // 左下寻找
while (next != nullptr) // 为空退出
{ // 相同向左下寻找
while (next != nullptr && _getKey(getValue->_data) == _getKey(next->_data))
{
getValue = next; // 记录节点,假设为 key 相同的第一个插入的节点
cur = next;
next = next->_left;
} // 不相同向右下寻找
while (next != nullptr && _getKey(getValue->_data) != _getKey(next->_data))
{
cur = next;
next = next->_right;
}
}
return getValue;
}
};
struct NoMultiFindSolve
{
template<class pNode, class GetKey>
pNode operator()(pNode cur, const GetKey& _getKey) const
{
return cur; // 非 multi 没有重复 key 节点,直接返回
}
};
struct IsMultiInsertSolve
{
template<class pNode, class GetKey>
bool operator()(pNode& cur, pNode& parent, const GetKey& _getKey) const
{ // 引用修改
pNode getValue = parent;
while (cur != nullptr) // 为空退出
{ // 相同向右下寻找
while (cur != nullptr && _getKey(getValue->_data) == _getKey(cur->_data))
{
parent = cur;
cur = cur->_right;
} // 不相同向左下寻找
while (cur != nullptr && _getKey(getValue->_data) != _getKey(cur->_data))
{
parent = cur;
cur = cur->_left;
}
}
return true; // 表示为 multi 类
}
};
struct NoMultiInsertSolve
{
template<class pNode, class GetKey>
bool operator()(pNode& cur, pNode& parent, const GetKey& _getKey) const
{
return false; // 表示为 非 multi 类,不允许重复值插入
}
};
}
set 与 multiset
在 set.hpp 中:
cpp
#pragma once
#include <utility>
#include "RBTree.hpp"
namespace my
{
template<class T, class MultiFindSolve, class MultiInsertSolve, class CompareKey>
class set_base
{
protected:
typedef const T ConstKeyType; // key
typedef T Key;
typedef T Value;
struct GetSetKey
{
const Key& operator()(const ConstKeyType& data) const
{
return data;
}
};
typedef RBTree_base<ConstKeyType, Key, GetSetKey, MultiFindSolve, MultiInsertSolve, CompareKey> RBTree;
RBTree _base; // 组合
public:
typedef Iterator<ConstKeyType, ConstKeyType&, ConstKeyType*> iterator;
typedef Iterator<ConstKeyType, const ConstKeyType&, const ConstKeyType*> const_iterator;
typedef ReverseIterator<ConstKeyType, ConstKeyType&, ConstKeyType*> reverse_iterator;
typedef ReverseIterator<ConstKeyType, const ConstKeyType&, const ConstKeyType*> const_reverse_iterator;
iterator begin()
{
return _base.begin();
}
iterator end()
{
return _base.end();
}
const_iterator begin() const
{
return _base.begin();
}
const_iterator end() const
{
return _base.end();
}
reverse_iterator rbegin()
{
return _base.rbegin();
}
reverse_iterator rend()
{
return _base.rend();
}
const_reverse_iterator rbegin() const
{
return _base.rbegin();
}
const_reverse_iterator rend() const
{
return _base.rend();
}
const_iterator cbegin() const
{
return begin();
}
const_iterator cend() const
{
return end();
}
const_reverse_iterator crbegin() const
{
return rbegin();
}
const_reverse_iterator crend() const
{
return rend();
}
public:
set_base() = default;
template<class InputIterator>
set_base(InputIterator begin, InputIterator end)
:_base(begin, end)
{
;
}
set_base(std::initializer_list<ConstKeyType> list)
:_base(list)
{
;
}
public:
bool empty() const
{
return _base.size() == 0;
}
size_t size() const
{
return _base.size();
}
size_t count(const Key& key) const
{
return _base.count(key);
}
const_iterator find(const Key& key) const
{
return const_iterator(_base.find(key));
}
iterator find(const Key& key)
{
return iterator(_base.find(key));
}
template <class InputIterator>
void insert(InputIterator first, InputIterator last)
{
return _base.insert(first, last);
}
void insert(std::initializer_list<ConstKeyType> list)
{
return _base.insert(list);
}
size_t erase(const Key& key)
{
return _base.erase(key);
}
iterator erase(const_iterator first, const_iterator last)
{
return _base.erase(first, last);
}
void clear()
{
_base.clear();
}
};
template<class T, class CompareKey = less<T>>
class set : public set_base<T, NoMultiFindSolve, NoMultiInsertSolve, CompareKey>
{ // 继承解决是否 multi
typedef set_base<T, NoMultiFindSolve, NoMultiInsertSolve, CompareKey> set_base;
typedef const T ConstKeyType;
public:
typedef typename set_base::iterator iterator;
set() = default;
template<class InputIterator>
set(InputIterator begin, InputIterator end)
:set_base(begin, end)
{
;
}
set(std::initializer_list<ConstKeyType> list)
:set_base(list)
{
;
}
public:
std::pair<iterator, bool> insert(const ConstKeyType& data) // 返回值不同问题
{
return this->_base.insert(data);
}
};
template<class T, class CompareKey = less<T>>
class multiset : public set_base<T, IsMultiFindSolve, IsMultiInsertSolve, CompareKey>
{ // 继承解决是否 multi
typedef set_base<T, IsMultiFindSolve, IsMultiInsertSolve, CompareKey> set_base;
typedef const T ConstKeyType;
public:
typedef typename set_base::iterator iterator;
multiset() = default;
template<class InputIterator>
multiset(InputIterator begin, InputIterator end)
:set_base(begin, end)
{
;
}
multiset(std::initializer_list<ConstKeyType> list)
:set_base(list)
{
;
}
public:
iterator insert(const ConstKeyType& data) // 返回值不同问题
{
std::pair<iterator, bool> get = this->_base.insert(data);
return get.first;
}
};
}
map 与 multimap
在 map.hpp 中:
cpp
#pragma once
#include <utility>
#include "RBTree.hpp"
namespace my
{
template<class Key, class Value, class MultiFindSolve, class MultiInsertSolve, class CompareKey = less<Key>>
class map_base
{
protected:
typedef std::pair<const Key, Value> ConstKeyType; // key / value
struct GetMapKey
{
const Key& operator()(const ConstKeyType& data) const
{
return data.first;
}
};
typedef RBTree_base<ConstKeyType, Key, GetMapKey, MultiFindSolve, MultiInsertSolve, CompareKey> RBTree;
RBTree _base; // 组合
public:
map_base() = default;
template<class InputIterator>
map_base(InputIterator begin, InputIterator end)
:_base(begin, end)
{
;
}
map_base(std::initializer_list<ConstKeyType> list)
:_base(list)
{
;
}
public:
typedef Iterator<ConstKeyType, ConstKeyType&, ConstKeyType*> iterator;
typedef Iterator<ConstKeyType, const ConstKeyType&, const ConstKeyType*> const_iterator;
typedef ReverseIterator<ConstKeyType, ConstKeyType&, ConstKeyType*> reverse_iterator;
typedef ReverseIterator<ConstKeyType, const ConstKeyType&, const ConstKeyType*> const_reverse_iterator;
iterator begin()
{
return iterator(_base.begin());
}
iterator end()
{
return iterator(_base.end());
}
const_iterator begin() const
{
return const_iterator(_base.begin());
}
const_iterator end() const
{
return const_iterator(_base.end());
}
reverse_iterator rbegin()
{
return reverse_iterator(_base.rbegin());
}
reverse_iterator rend()
{
return reverse_iterator(_base.rend());
}
const_reverse_iterator rbegin() const
{
return const_reverse_iterator(_base.rbegin());
}
const_reverse_iterator rend() const
{
return const_reverse_iterator(_base.rend());
}
const_iterator cbegin() const
{
return begin();
}
const_iterator cend() const
{
return end();
}
const_reverse_iterator crbegin() const
{
return rbegin();
}
const_reverse_iterator crend() const
{
return rend();
}
public:
bool empty() const
{
return _base.size() == 0;
}
size_t size() const
{
return _base.size();
}
size_t count(const Key& key) const
{
return _base.count(key);
}
const_iterator find(const Key& key) const
{
return _base.find(key);
}
iterator find(const Key& key)
{
return _base.find(key);
}
template<class InputIterator>
void insert(InputIterator first, InputIterator last)
{
return _base.insert(first, last);
}
void insert(std::initializer_list<ConstKeyType> list)
{
return _base.insert(list);
}
size_t erase(const Key& key)
{
return _base.erase(key);
}
iterator erase(const_iterator first, const_iterator last)
{
return _base.erase(first, last);
}
void clear()
{
_base.clear();
}
};
template<class Key, class Value, class CompareKey = less<Key>>
class map : public map_base<Key, Value, NoMultiFindSolve, NoMultiInsertSolve, CompareKey>
{ // 继承解决是否 multi
typedef map_base<Key, Value, NoMultiFindSolve, NoMultiInsertSolve, CompareKey> map_base;
typedef std::pair<const Key, Value> ConstKeyType;
public:
typedef typename map_base::iterator iterator;
map() = default;
template<class InputIterator>
map(InputIterator begin, InputIterator end)
:map_base(begin, end)
{
;
}
map(std::initializer_list<ConstKeyType> list)
:map_base(list)
{
;
}
public:
Value& operator[](const Key& key)
{
iterator it = this->find(key);
if (it == this->end()) // 没有则插入
{
auto get = this->_base.insert({ key, Value() });
it = get.first;
}
// 正常情况需要调用 it 的 operator->() 返回 pair 的地址,
// 再使用 -> 对 first 或 second 访问
// 这里为了 可阅读性 编译器做了特殊处理,
//return it.operator->()->second; // 也可以使用这种,但是可阅读性差
return it->second; // 特殊处理
}
std::pair<iterator, bool> insert(const ConstKeyType& data) // 返回值不同问题
{
return this->_base.insert(data);
}
};
template<class Key, class Value, class CompareKey = less<Key>>
class multimap : public map_base<Key, Value, IsMultiFindSolve, IsMultiInsertSolve, CompareKey>
{ // 继承解决是否 multi
typedef map_base<Key, Value, IsMultiFindSolve, IsMultiInsertSolve, CompareKey> map_base;
typedef std::pair<const Key, Value> ConstKeyType;
public:
typedef typename map_base::iterator iterator;
multimap() = default;
template<class InputIterator>
multimap(InputIterator begin, InputIterator end)
:map_base(begin, end)
{
;
}
multimap(std::initializer_list<ConstKeyType> list)
:map_base(list)
{
;
}
public:
iterator insert(const ConstKeyType& data) // 返回值不同问题
{
std::pair<iterator, bool> get = this->_base.insert(data);
return get.first;
}
};
}
七、准确性测试(VS2022 release)
这里与 VS2022 C++ 实现的进行对比:
set 测试
测试源码:
cpp
void testSet()
{
srand((unsigned int)time(NULL));
vector<int> v;
const int N = 1000000; // 一百万个随机数
v.reserve(N);
for (int i = 0; i < N; ++i)
{
int random = rand() % 10000 * 10000 + rand() % 10000;
//int random = rand() % 10000;
v.push_back(random);
}
std::set<int> t1;
my::set<int> t2;
int getInsert1 = 0;
int getErase1 = 0;
int getInsert2 = 0;
int getErase2 = 0;
int begin1 = clock();
for (auto& e : v)
{
auto get = t1.insert(e);
getInsert1 += get.second;
}
int end1 = clock();
cout << "std::set insert time: " << end1 - begin1 << endl;
int begin3 = clock();
for (auto& e : v)
{
auto get = t2.insert(e);
getInsert2 += get.second;
}
int end3 = clock();
cout << "my::set insert time: " << end3 - begin3 << endl << endl;
int begin2 = clock();
for (auto& e : v)
{
auto get = t1.erase(e);
getErase1 += get;
}
int end2 = clock();
cout << "std::set erase time: " << end2 - begin3 << endl;
int begin4 = clock();
for (auto& e : v)
{
auto get = t2.erase(e);
getErase2 += get;
}
int end4 = clock();
cout << "my::set erase time: " << end4 - begin4 << endl << endl;
cout << "插入与删除个数对比:" << endl;
cout << "std::set insert == " << getInsert1 << endl;
cout << "std::set erase == " << getErase1 << endl;
cout << "my::set insert == " << getInsert2 << endl;
cout << "my::set erase == " << getErase2 << endl;
}
multiset 测试
测试源码:
cpp
void testMultiset()
{
srand((unsigned int)time(NULL));
vector<int> v;
const int N = 1000000; // 一百万个重复概率较高的随机数
v.reserve(N);
for (int i = 0; i < N; ++i)
{
//int random = rand() % 10000 * 10000 + rand() % 10000;
int random = rand() % 10000;
v.push_back(random);
}
std::multiset<int> t1;
my::multiset<int> t2;
int getInsert1 = 0;
int getErase1 = 0;
int getInsert2 = 0;
int getErase2 = 0;
int begin1 = clock();
for (auto& e : v)
{
t1.insert(e);
++getInsert1;
}
int end1 = clock();
cout << "std::multiset insert time: " << end1 - begin1 << endl;
int begin3 = clock();
for (auto& e : v)
{
t2.insert(e);
++getInsert2;
}
int end3 = clock();
cout << "my::multiset insert time: " << end3 - begin3 << endl << endl;
int begin2 = clock();
for (auto& e : v)
{
auto get = t1.erase(e);
getErase1 += get;
}
int end2 = clock();
cout << "std::multiset erase time: " << end2 - begin3 << endl;
int begin4 = clock();
for (auto& e : v)
{
auto get = t2.erase(e);
getErase2 += get;
}
int end4 = clock();
cout << "my::multiset erase time: " << end4 - begin4 << endl << endl;
cout << "插入与删除个数对比:" << endl;
cout << "std::multiset insert == " << getInsert1 << endl;
cout << "std::multiset erase == " << getErase1 << endl;
cout << "my::multiset insert == " << getInsert2 << endl;
cout << "my::multiset erase == " << getErase2 << endl;
}
map 测试
测试源码:
cpp
void testMap()
{
srand((unsigned int)time(NULL));
vector<int> v;
const int N = 1000000; // 一百万个随机数
v.reserve(N);
for (int i = 0; i < N; ++i)
{
int random = rand() % 10000 * 10000 + rand() % 10000;
//int random = rand() % 10000;
v.push_back(random);
}
std::map<int, int> t1;
my::map<int, int> t2;
int getInsert1 = 0;
int getErase1 = 0;
int getInsert2 = 0;
int getErase2 = 0;
int begin1 = clock();
for (auto& e : v)
{
auto get = t1.insert({ e, e });
getInsert1 += get.second;
}
int end1 = clock();
cout << "std::map insert time: " << end1 - begin1 << endl;
int begin3 = clock();
for (auto& e : v)
{
auto get = t2.insert({ e, e });
getInsert2 += get.second;
}
int end3 = clock();
cout << "my::map insert time: " << end3 - begin3 << endl << endl;
int begin2 = clock();
for (auto& e : v)
{
auto get = t1.erase(e);
getErase1 += get;
}
int end2 = clock();
cout << "std::map erase time: " << end2 - begin3 << endl;
int begin4 = clock();
for (auto& e : v)
{
auto get = t2.erase(e);
getErase2 += get;
}
int end4 = clock();
cout << "my::map erase time: " << end4 - begin4 << endl << endl;
cout << "插入与删除个数对比:" << endl;
cout << "std::map insert == " << getInsert1 << endl;
cout << "std::map erase == " << getErase1 << endl;
cout << "my::map insert == " << getInsert2 << endl;
cout << "my::map erase == " << getErase2 << endl;
}
multimap 测试
测试源码:
cpp
void testMultimap()
{
srand((unsigned int)time(NULL));
vector<int> v;
const int N = 1000000; // 一百万个重复概率高的随机数
v.reserve(N);
for (int i = 0; i < N; ++i)
{
//int random = rand() % 10000 * 10000 + rand() % 10000;
int random = rand() % 10000;
v.push_back(random);
}
std::multimap<int, int> t1;
my::multimap<int, int> t2;
int getInsert1 = 0;
int getErase1 = 0;
int getInsert2 = 0;
int getErase2 = 0;
int begin1 = clock();
for (auto& e : v)
{
t1.insert({ e, e });
++getInsert1;
}
int end1 = clock();
cout << "std::multimap insert time: " << end1 - begin1 << endl;
int begin3 = clock();
for (auto& e : v)
{
t2.insert({ e, e });
++getInsert2;
}
int end3 = clock();
cout << "my::multimap insert time: " << end3 - begin3 << endl << endl;
int begin2 = clock();
for (auto& e : v)
{
auto get = t1.erase(e);
getErase1 += get;
}
int end2 = clock();
cout << "std::multimap erase time: " << end2 - begin3 << endl;
int begin4 = clock();
for (auto& e : v)
{
auto get = t2.erase(e);
getErase2 += get;
}
int end4 = clock();
cout << "my::multimap erase time: " << end4 - begin4 << endl << endl;
cout << "插入与删除个数对比:" << endl;
cout << "std::multimap insert == " << getInsert1 << endl;
cout << "std::multimap erase == " << getErase1 << endl;
cout << "my::multimap insert == " << getInsert2 << endl;
cout << "my::multimap erase == " << getErase2 << endl;
}
find 测试(以 multimap 测试)
我们使用 5 个不同 key 的数据,插入 2000 次,value 值随插入的个数同步增长,来对比 find 查找时是否返回的是该 key 类型第一个插入的:
测试源码:
cpp
void testFind()
{
srand((unsigned int)time(NULL));
vector<int> v;
const int N = 10000; // 一万次插入值
v.reserve(N);
std::set<int> st; // 去重
int k = 5; // 插入不同 key 的个数
int tk = N / k; // 重复个数
while (k--)
{
int random = rand() % 1000;
auto ans = st.insert(random);
if (ans.second == false)
{
continue;
}
for (int i = 0; i < tk; ++i)
{
v.push_back(random);
}
}
std::multimap<int, int> t1;
my::multimap<int, int> t2;
int getInsert1 = 0;
int getErase1 = 0;
int getInsert2 = 0;
int getErase2 = 0;
int num1 = 1;
int num2 = 1;
int begin1 = clock();
for (auto& e : v)
{
t1.insert({ e, num1++ });
++getInsert1;
}
int end1 = clock();
cout << "std::multimap insert time: " << end1 - begin1 << endl;
int begin3 = clock();
for (auto& e : v)
{
t2.insert({ e, num2++ });
++getInsert2;
}
int end3 = clock();
cout << "my::multimap insert time: " << end3 - begin3 << endl << endl;
auto findIt = st.begin();
for (int i = 0; i < st.size(); ++i)
{
std::multimap<int, int>::iterator it1 = t1.find(*findIt);
my::multimap<int, int>::iterator it2 = t2.find(*findIt);
size_t count1 = t1.count(*findIt);
size_t count2 = t2.count(*findIt);
cout << "std::find -> key == " << (*it1).first << " " << "value == " << (*it1).second;
cout << " 重复个数:" << count1 << endl;
cout << "my::find -> key == " << (*it2).first << " " << "value == " << (*it2).second;
cout << " 重复个数:" << count2 << endl << endl;
++findIt;
}
int begin2 = clock();
for (auto& e : v)
{
auto get = t1.erase(e);
getErase1 += get;
}
int end2 = clock();
cout << "std::multimap erase time: " << end2 - begin3 << endl;
int begin4 = clock();
for (auto& e : v)
{
auto get = t2.erase(e);
getErase2 += get;
}
int end4 = clock();
cout << "my::multimap erase time: " << end4 - begin4 << endl << endl;
cout << "插入与删除个数对比:" << endl;
cout << "std::multimap insert == " << getInsert1 << endl;
cout << "std::multimap erase == " << getErase1 << endl;
cout << "my::multimap insert == " << getInsert2 << endl;
cout << "my::multimap erase == " << getErase2 << endl;
}