Treap树的基本概念
Treap(Tree + Heap)是一种结合二叉搜索树(BST)和堆(Heap)特性的数据结构。每个节点包含两个值:
- 键值(Key):遵循二叉搜索树的性质(左子树键值 ≤ 当前节点 ≤ 右子树键值)。
- 优先级(Priority):遵循堆的性质(通常为最大堆,即父节点优先级 ≥ 子节点优先级)。
通过随机分配优先级,Treap在期望上能保持平衡,避免普通BST退化为链表的极端情况。
Treap的核心操作
插入操作
- 按BST规则插入新节点。
- 若新节点的优先级违反堆性质,通过旋转调整(左旋或右旋)恢复堆序。
删除操作
- 找到目标节点后,通过旋转将其移至叶子节点或仅有一个子节点的位置。
- 直接删除目标节点。
查找操作
与普通BST相同,时间复杂度为 (O(\log n))(期望)。
无旋Treap与有旋Treap对比
- 无旋Treap:基于分裂(Split)和合并(Merge)操作,无需旋转,代码更简洁但常数较大。
- 有旋Treap:直接通过旋转调整,常数更优,适合频繁插入/删除的场景。
应用场景
- 动态数据维护:如排行榜、实时统计。
- 区间操作:通过扩展子树大小字段支持区间查询。
- 替代平衡树:相比AVL或红黑树更易实现,且期望性能相近。
复杂度分析
- 时间:插入、删除、查找的期望时间复杂度均为 (O(\log n))。
- 空间:(O(n)),每个节点需存储额外优先级字段。
旋转Treap树的相关模板:
旋转Treap基础概念
- Treap结构:结合二叉搜索树(BST)和堆(Heap)性质,每个节点包含键值(key)和随机优先级(priority)。
- 旋转操作:通过左旋和右旋维护堆性质,同时保持BST的有序性。
节点定义
cpp
struct Node {
int key, priority, size;
Node *left, *right;
Node(int val) : key(val), priority(rand()), size(1), left(nullptr), right(nullptr) {}
};
插入操作
核心逻辑:按照BST规则插入后,通过旋转恢复堆性质。
cpp
Node* insert(Node* root, int key) {
if (!root) return new Node(key);
if (key <= root->key) {
root->left = insert(root->left, key);
if (root->left->priority > root->priority)
root = rightRotate(root);
} else {
root->right = insert(root->right, key);
if (root->right->priority > root->priority)
root = leftRotate(root);
}
updateSize(root);
return root;
}
删除操作
核心逻辑:通过旋转将待删除节点移至叶子节点后直接删除。
cpp
Node* remove(Node* root, int key) {
if (!root) return nullptr;
if (key < root->key) {
root->left = remove(root->left, key);
} else if (key > root->key) {
root->right = remove(root->right, key);
} else {
if (!root->left) return root->right;
if (!root->right) return root->left;
if (root->left->priority > root->right->priority) {
root = rightRotate(root);
root->right = remove(root->right, key);
} else {
root = leftRotate(root);
root->left = remove(root->left, key);
}
}
updateSize(root);
return root;
}
根据值查询排名
排名定义:小于该值的节点数 + 1。
cpp
int getRank(Node* root, int key) {
if (!root) return 0;
if (key < root->key) return getRank(root->left, key);
if (key > root->key) return 1 + size(root->left) + getRank(root->right, key);
return size(root->left) + 1;
}
根据排名查询值
cpp
int getKth(Node* root, int k) {
int leftSize = size(root->left);
if (k <= leftSize) return getKth(root->left, k);
if (k == leftSize + 1) return root->key;
return getKth(root->right, k - leftSize - 1);
}
查询前驱(第一个比val小的节点)
cpp
int getPrev(Node* root, int key) {
int prev = -1; // 初始值设为无效
while (root) {
if (root->key < key) {
prev = root->key;
root = root->right;
} else {
root = root->left;
}
}
return prev;
}
查询后继(第一个比val大的节点)
cpp
int getNext(Node* root, int key) {
int next = -1; // 初始值设为无效
while (root) {
if (root->key > key) {
next = root->key;
root = root->left;
} else {
root = root->right;
}
}
return next;
}
其他:
cpp
// 更新子树大小
void updateSize(Node* node) {
if (node) node->size = 1 + size(node->left) + size(node->right);
}
// 左旋
Node* leftRotate(Node* x) {
Node* y = x->right;
x->right = y->left;
y->left = x;
updateSize(x);
updateSize(y);
return y;
}
// 右旋
Node* rightRotate(Node* y) {
Node* x = y->left;
y->left = x->right;
x->right = y;
updateSize(y);
updateSize(x);
return x;
}
例题:
P3369 【模板】普通平衡树


这题就是典型的模板题,用树写是模板,但我下面的代码是用邻接链表来写的。
cpp
#define _CRT_SECURE_NO_WARNINGS
#include<stdio.h>
#include<iostream>
#include<bits/stdc++.h>
using namespace std;
int n, opt, x;
int root = 0, sum_size = 0;
struct Node {
int ch[2];
int rank, val;
int rep_cnt;
int size;
Node(int val) : val(val), rep_cnt(1), size(1) {
ch[0] = 0;
ch[1] = 0;
rank = rand();
//rank 是随机给出的
}
Node() : val(0), rep_cnt(0), size(0), rank(0) {
ch[0] = ch[1] = 0;
}
} a[100005];
void upd_siz(int i) {
a[i].size=a[i].rep_cnt + a[a[i].ch[0]].size + a[a[i].ch[1]].size;
}
int _rotate(int i,int k) {//0为右旋,1为左旋
int q = a[i].ch[k];
a[i].ch[k] = a[q].ch[(k + 1) % 2];
a[q].ch[(k + 1) % 2] = i;
upd_siz(i);
upd_siz(q);
return q;
}
int insert(int val,int q) {
if (q == 0) {
a[++sum_size]= Node(val);//c++特有的
return sum_size;
}
if (val == a[q].val) {
a[q].rep_cnt++;
upd_siz(q);
return q;
}
else if (a[q].val > val) {
a[q].ch[0] = insert(val, a[q].ch[0]);
upd_siz(q);
if (a[q].rank > a[a[q].ch[0]].rank) {
return _rotate(q, 0);
}
return q;
}
else {
a[q].ch[1] = insert(val, a[q].ch[1]);
upd_siz(q);
if (a[q].rank > a[a[q].ch[1]].rank) {
return _rotate(q, 1);
}
return q;
}
}
int Delete(int val, int q) {
if (q == 0) {
return q;
}
if (val == a[q].val) {
if (a[q].rep_cnt > 1) {
a[q].rep_cnt--;
upd_siz(q);
return q;
}
else {
if (a[q].ch[0] == 0&&a[q].ch[1]==0) {
return 0;
}
else if(a[q].ch[0]==0){
return a[q].ch[1];
}
else if (a[q].ch[1] == 0) {
return a[q].ch[0];
}
else {
int o = a[a[q].ch[0]].rank > a[a[q].ch[1]].rank ? 1 : 0;
q = _rotate(q, o);
a[q].ch[(1 + o)%2] = Delete(val, a[q].ch[(1 + o) % 2]);
upd_siz(q);
return q;
}
}
}
else if(val>a[q].val) {
a[q].ch[1]=Delete(val, a[q].ch[1]);
upd_siz(q);
}
else {
a[q].ch[0]=Delete(val, a[q].ch[0]);
upd_siz(q);
}
upd_siz(q);
return q;
}
int _query_rank(int val, int q) {
if (q == 0) {
return 0;
}
if (val < a[q].val) {
return _query_rank(val, a[q].ch[0]);
}
else if (val > a[q].val) {
return a[a[q].ch[0]].size+a[q].rep_cnt+ _query_rank(val, a[q].ch[1]);
}
else {
return a[a[q].ch[0]].size;
}
}
int _query_val(int size, int q) {
if (size > a[a[q].ch[0]].size+a[q].rep_cnt) {
return _query_val(size - a[a[q].ch[0]].size - a[q].rep_cnt, a[q].ch[1]);
}
else if (size <= a[a[q].ch[0]].size) {
return _query_val(size, a[q].ch[0]);
}
else {
return q;
}
}
int _query_prev(int val, int q,int val_max) {
if (q == 0) {
return val_max;
}
if (val > a[q].val) {
return _query_prev(val, a[q].ch[1], val_max > a[q].val ? val_max : a[q].val);
}
else {
return _query_prev(val, a[q].ch[0], val_max);
}
}
int _query_nex(int val, int q, int val_min) {
if (q == 0) {
return val_min;
}
if (val < a[q].val) {
return _query_nex(val, a[q].ch[0], val_min < a[q].val ? val_min : a[q].val);
}
else {
return _query_nex(val, a[q].ch[1], val_min);
}
}
int main(){
ios::sync_with_stdio(false); // 禁用同步
cin.tie(nullptr); // 解除cin与cout绑定
a[0].size = 0;
a[0].rank = INT_MAX;
cin >> n;
for (int i = 0; i < n; i++) {
cin >> opt >> x;
switch (opt) {
case 1:
root=insert(x,root);
break;
case 2:
root = Delete(x, root);
break;
case 3:
cout << 1 + _query_rank(x, root) << endl;
break;
case 4:
cout << a[_query_val(x, root)].val << endl;
break;
case 5:
cout << _query_prev(x, root, INT_MIN) << endl;
break;
case 6:
cout << _query_nex(x, root, INT_MAX) << endl;
break;
}
}
return 0;
}