最近准备学习 LCT,因此先学习了 Splay。
前置知识
核心操作
基础操作
#define fa(x) t[x].fa
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]
int k,rt;//节点数,根
struct tree
{
int ch[2],fa,val,sz;//左右儿子,父亲,值,子树大小
}t[N];
bool dir(int x)//判断x是它父亲的左儿子还是右儿子
{
return x==rs(fa(x));
}
int newnode(int v)//新建节点
{
t[++k].val=v;
t[k].sz=1;
return k;
}
void pushup(int x)//合并儿子信息
{
t[x].sz=t[ls(x)].sz+t[rs(x)].sz+1;
}
旋转操作
旋转操作的本质是把指定节点上移一个位置,并保证树的中序遍历(即二叉搜索树的性质)不变。
旋转分为右旋(Zig) 和左旋(Zag),分别用于处理指定节点是左儿子和右儿子的情况。如下图,由上到下为右旋,由下到上为左旋。


代码按照旋转的定义模拟即可,需要注意的是必须保证 \(0\) 号节点的所有属性都为 \(0\)。
void rotate(int x)
{
int y=fa(x),z=fa(y);
bool f=dir(x);
t[y].ch[f]=t[x].ch[!f];
t[x].ch[!f]=y;
if(z)//判断0号节点
t[z].ch[dir(y)]=x;
if(t[y].ch[f])
fa(t[y].ch[f])=y;
fa(y)=x;
fa(x)=z;
pushup(y);//先更新儿子再更新父亲
pushup(x);
}
Splay 操作
\(Splay(x)\) 的作用是把点 \(x\) 一路旋到根 \(rt\) 上,其由三种类型组成:
Zig / Zag
这种操作仅发生在 \(fa(x)=rt\) 时,将 \(x\) 旋转一次即可。
Zig-Zig / Zag-Zag
当 \(x\) 和 \(fa(x)\) 同为它们父亲的左儿子或右儿子时,先将 \(fa(x)\) 旋转一次,再将 \(x\) 旋转一次。下图为对 \(3\) 号节点进行的一次 Zig-Zig 操作。

Zig-Zag / Zag-Zig
当 \(x\) 和 \(fa(x)\) 相对于父亲是不同方向的儿子时,连续将 \(x\) 旋转两次。下图为对 \(3\) 号节点进行的一次 Zig-Zag 操作。

而 Splay 操作则就是这三种操作的组合。代码如下,为了便于理解(其实是我不会用三目运算符),这里使用较为复杂的 \(if/else\) 实现。
//在常规的平衡树操作中只需要旋转到树根,但是部分操作有旋转到其他祖先的要求,所以这里有一个z表示要旋转到的位置
void splay(int x,int &z=rt)
{
int w=fa(z);//x和z的父亲相等,则表示到位置了
while(fa(x)!=w && fa(fa(x))!=w)
{
if(dir(fa(x))==dir(x))
rotate(fa(x));//Zig-Zig / Zag-Zag
else
rotate(x);//Zig-Zag / Zag-Zig
rotate(x);
}
if(fa(x)!=w)
rotate(x);//最后可能有一次Zig / Zag
z=x;
}
时间复杂度
单次均摊复杂度是 \(O(\log n)\) 的,我不会证,想看证明可以去 oi-wiki。
维护集合操作
需要注意的是,所有操作结束后都应进行 Splay 操作以保证时间复杂度。
插入
从根一直找到应该插入的位置。
void insert(int v)
{
int x=rt,y=0;//y是x的父亲
while(x)
{
y=x;
x=t[x].ch[t[x].val<v];
}
x=newnode(v);
fa(x)=y;
t[y].ch[t[y].val<v]=x;
splay(x);
}
删除
最复杂的操作,有不同的实现方法,这里采用的方法是找到要删除的节点后将其转到根上操作。
void erase(int v)
{
int x=rt,y=0;
while(t[x].val!=v && x)
{
y=x;
x=t[x].ch[t[x].val<v];
}
if(!x)//找不到节点,直接退出
{
splay(y);
return;
}
splay(x);
if(!ls(x) || !rs(x))//如果要删除节点只有一个儿子,将儿子设为根即可
{
rt=ls(x)+rs(x);
fa(ls(x)+rs(x))=0;
return;
}
int p=rt=ls(x);
fa(p)=0;
while(rs(p))
p=rs(p);
rs(p)=rs(x);//将右儿子接在左子树中最大的节点下面
fa(rs(x))=p;
pushup(p);//改变了结构,要额外pushup一次
splay(p);
}
查询排名
在树上搜索的时候统计比 \(v\) 小的节点数量。
int getrnk(int v)
{
int x=rt,y=0,ans=1;
while(x)
{
y=x;
if(t[x].val<v)
{
ans+=t[ls(x)].sz+1;
x=rs(x);
}
else
x=ls(x);
}
splay(y);
return ans;
}
查询第 k 大值
类似线段树二分。
int getkth(int v)
{
int x=rt;
while(1)
{
int now=t[ls(x)].sz+1;
if(now==v)
break;
if(now<v)
{
v-=now;
x=rs(x);
}
else
x=ls(x);
}
splay(x);
return t[x].val;
}
查询前驱后继
查询前驱类似于查询排名,只是改为纪录比 \(v\) 小的节点数值;查询后继就是前驱的做法反过来。
int getpre(int v)//前驱
{
int x=rt,y=0,ans=0;
while(x)
{
y=x;
if(t[x].val<v)
{
ans=t[x].val;
x=rs(x);
}
else
x=ls(x);
}
splay(y);
return ans;
}
int getnxt(int v)//后继
{
int x=rt,y=0,ans=0;
while(x)
{
y=x;
if(t[x].val>v)
{
ans=t[x].val;
x=ls(x);
}
else
x=rs(x);
}
splay(y);
return ans;
}
完整代码
#include <bits/stdc++.h>
using namespace std;
const int N=1e5+5;
#define fa(x) t[x].fa
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]
int k,rt,n;
struct tree
{
int ch[2],fa,val,sz;
}t[N];
bool dir(int x)
{
return x==rs(fa(x));
}
int newnode(int v)
{
t[++k].val=v;
t[k].sz=1;
return k;
}
void pushup(int x)
{
t[x].sz=t[ls(x)].sz+t[rs(x)].sz+1;
}
void rotate(int x)
{
int y=fa(x),z=fa(y);
bool f=dir(x);
t[y].ch[f]=t[x].ch[!f];
t[x].ch[!f]=y;
if(z)
t[z].ch[dir(y)]=x;
if(t[y].ch[f])
fa(t[y].ch[f])=y;
fa(y)=x;
fa(x)=z;
pushup(y);
pushup(x);
}
void splay(int x,int &z=rt)
{
int w=fa(z);
while(fa(x)!=w && fa(fa(x))!=w)
{
if(dir(fa(x))==dir(x))
rotate(fa(x));
else
rotate(x);
rotate(x);
}
if(fa(x)!=w)
rotate(x);
z=x;
}
void insert(int v)
{
int x=rt,y=0;
while(x)
{
y=x;
x=t[x].ch[t[x].val<v];
}
x=newnode(v);
fa(x)=y;
t[y].ch[t[y].val<v]=x;
splay(x);
}
void erase(int v)
{
int x=rt,y=0;
while(t[x].val!=v && x)
{
y=x;
x=t[x].ch[t[x].val<v];
}
if(!x)
{
splay(y);
return;
}
splay(x);
if(!ls(x) || !rs(x))
{
rt=ls(x)+rs(x);
fa(ls(x)+rs(x))=0;
return;
}
int p=rt=ls(x);
fa(p)=0;
while(rs(p))
p=rs(p);
rs(p)=rs(x);
fa(rs(x))=p;
pushup(p);
splay(p);
}
int getrnk(int v)
{
int x=rt,y=0,ans=1;
while(x)
{
y=x;
if(t[x].val<v)
{
ans+=t[ls(x)].sz+1;
x=rs(x);
}
else
x=ls(x);
}
splay(y);
return ans;
}
int getkth(int v)
{
int x=rt;
while(1)
{
int now=t[ls(x)].sz+1;
if(now==v)
break;
if(now<v)
{
v-=now;
x=rs(x);
}
else
x=ls(x);
}
splay(x);
return t[x].val;
}
int getpre(int v)
{
int x=rt,y=0,ans=0;
while(x)
{
y=x;
if(t[x].val<v)
{
ans=t[x].val;
x=rs(x);
}
else
x=ls(x);
}
splay(y);
return ans;
}
int getnxt(int v)
{
int x=rt,y=0,ans=0;
while(x)
{
y=x;
if(t[x].val>v)
{
ans=t[x].val;
x=ls(x);
}
else
x=rs(x);
}
splay(y);
return ans;
}
int main()
{
scanf("%d",&n);
while(n--)
{
int op,x;
scanf("%d%d",&op,&x);
if(op==1)
insert(x);
else if(op==2)
erase(x);
else if(op==3)
printf("%d\n",getrnk(x));
else if(op==4)
printf("%d\n",getkth(x));
else if(op==5)
printf("%d\n",getpre(x));
else
printf("%d\n",getnxt(x));
}
return 0;
}