分析
本题中的操作3提到了版本回退,于是可以想到这是用可持久化数据结构做。又看到操作2复杂的区间操作,说明可以用可持久化平衡树解决。
模拟一下操作2,可以发现操作2其实是在将 \([l-k,l-1]\) 这个区间(这个区间长度为 \(k\) )复制出来 \(\lceil \frac {r-l+1} {k} \rceil\)个,然后把它们顺序拼接起来作为一个大区间,最后取前 \(r-l+1\) 个数整体替换掉 \([l,r]\) 区间。这一步可以用类似快速幂的办法去做(因为 \(merge\) 函数满足结合律),每次将这棵代表 \([l-k,l-1]\) 区间的 \(Treap\) 树复制并合并。操作1和操作3直接截取目标区间进行统计或者操作即可。
还有一点就是本题空间限制紧,为了可持久化会产生许多废弃节点,当节点总数超过一个值(我选的 \(1.5e6\) ,可以多调一调这个数的值)之后就应该重建两棵树(一棵是初始的,一棵是经过多次操作的),保证不会爆空间。
代码
RE代码
cpp
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e5+5;
int n,m;
int a[N];
struct node {
int l,r,key,sz,val;
ll sum;
}tr[N*8];
int rt[5],idx;
vector <int> v[5];
int new_node(int v) {
tr[++idx].sum=v;
tr[idx].val=v;
tr[idx].sz=1;
tr[idx].key=rand();
tr[idx].l=tr[idx].r=0;
return idx;
}
int copy_node(int x) {
tr[++idx]=tr[x];
return idx;
}
void push_up(int rt) {
tr[rt].sum=tr[tr[rt].l].sum+tr[tr[rt].r].sum+tr[rt].val;
tr[rt].sz=tr[tr[rt].l].sz+tr[tr[rt].r].sz+1;
}
void split(int p,int k,int &x,int &y) {
if(!p) {
x=y=0;
return ;
}
if(tr[tr[p].l].sz<k) {
x=copy_node(p);
split(tr[x].r,k-tr[tr[x].l].sz-1,tr[x].r,y);
push_up(x);
}
else {
y=copy_node(p);
split(tr[y].l,k,x,tr[y].l);
push_up(y);
}
}
int merge(int x,int y) {
if(!x||!y) return x|y;
if(tr[x].key<tr[y].key) {
int p=copy_node(x);
tr[p].r=merge(tr[p].r,y);
push_up(p);
return p;
}
else {
int p=copy_node(y);
tr[p].l=merge(x,tr[p].l);
push_up(p);
return p;
}
}
int qpow(int p,int b) {
int a=copy_node(p),ret=0;
for(;b;b>>=1,a=merge(a,a)) if(b&1) ret=merge(ret,a);
return ret;
}
void split_seq(int rt,int l,int r,int &x,int &y,int &z) {
split(rt,r,x,z);
split(x,l-1,x,y);
}
ll find_sum(int &rt,int l,int r) {
int x,y,z;
split_seq(rt,l,r,x,y,z);
ll ret=tr[y].sum;
rt=merge(merge(x,y),z);
return ret;
}
void dfs(int u,int i) {
if(!u) return ;
dfs(tr[u].l,i);
v[i].push_back(tr[u].val);
dfs(tr[u].r,i);
}
int build(int l,int r,int i) {
if(l>r) return 0;
int mid=(l+r)>>1;
int x=new_node(v[i][mid]);
tr[x].l=build(l,mid-1,i);
tr[x].r=build(mid+1,r,i);
push_up(x);
return x;
}
int build_a(int l,int r) {
if(l>r) return 0;
int mid=(l+r)>>1;
int x=new_node(a[mid]);
tr[x].l=build_a(l,mid-1);
tr[x].r=build_a(mid+1,r);
push_up(x);
return x;
}
void rebuild() {
idx=0;
v[0].clear();
v[1].clear();
dfs(rt[0],0);
dfs(rt[1],1);
rt[0]=build(0,n-1,0);
rt[1]=build(0,n-1,1);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>a[i];
rt[0]=build_a(1,n);
rt[1]=rt[0];
for(int i=1;i<=m;i++) {
int op;
cin>>op;
if(idx>=1500000) rebuild();
if(op==1) {
int l,r;
cin>>l>>r;
cout<<find_sum(rt[1],l,r)<<'\n';
}
else if(op==2) {
int l,r,k;
cin>>l>>r>>k;
int x,y,z;
split_seq(rt[1],l-k,l-1,x,y,z);
int cnt=(r-l+1+k-1)/k;
int Y=qpow(y,cnt);
int u,v;
split(Y,r-l+1,u,v);
int b,c;
split(z,r-l+1,b,c);
rt[1]=merge(merge(merge(x,y),u),c);
}
else {
int l,r;
cin>>l>>r;
int x,y,z,u,v,w;
split_seq(rt[0],l,r,x,y,z);
split_seq(rt[1],l,r,u,v,w);
rt[1]=merge(merge(u,y),w);
}
}
return 0;
}
(我基于上面这个第一版代码不断微调,一直交,结果当然是RE,甚至还微调出了TLE)
我认为RE原因主要是 \(qpow\) 函数不断对现有的 \(Treap\) 树进行复制,导致复制后的两棵 \(Treap\) 树存在大量 \(key\) 相等的节点。而 \(merge\) 函数在 \(key\) 相等时又总是走 \(else\) 分支,导致递归深度过大造成栈溢出(欢迎对这种写法的RE原因进行探讨)。那么解决这个问题的办法可以是当 \(key\) 相等时随机走一个分支,这样即可避免爆栈。
只需在 \(merge\) 函数的 \(if\) 判断中加一个条件即可。
AC代码
cpp
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e5+5;
int n,m;
int a[N];
struct node {
int l,r,key,sz,val;
ll sum;
}tr[N*8];
int rt[5],idx;
vector <int> v[5];
int new_node(int v) {
tr[++idx].sum=v;
tr[idx].val=v;
tr[idx].sz=1;
tr[idx].key=rand();
tr[idx].l=tr[idx].r=0;
return idx;
}
int copy_node(int x) {
tr[++idx]=tr[x];
return idx;
}
void push_up(int rt) {
tr[rt].sum=tr[tr[rt].l].sum+tr[tr[rt].r].sum+tr[rt].val;
tr[rt].sz=tr[tr[rt].l].sz+tr[tr[rt].r].sz+1;
}
void split(int p,int k,int &x,int &y) {
if(!p) {
x=y=0;
return ;
}
if(tr[tr[p].l].sz<k) {
x=copy_node(p);
split(tr[x].r,k-tr[tr[x].l].sz-1,tr[x].r,y);
push_up(x);
}
else {
y=copy_node(p);
split(tr[y].l,k,x,tr[y].l);
push_up(y);
}
}
int merge(int x,int y) {
if(!x||!y) return x|y;
if(tr[x].key<tr[y].key||(tr[x].key==tr[y].key&&rand()%2==0)) {
int p=copy_node(x);
tr[p].r=merge(tr[p].r,y);
push_up(p);
return p;
}
else {
int p=copy_node(y);
tr[p].l=merge(x,tr[p].l);
push_up(p);
return p;
}
}
int qpow(int p,int b) {
int a=copy_node(p),ret=0;
for(;b;b>>=1,a=merge(a,a)) if(b&1) ret=merge(ret,a);
return ret;
}
void split_seq(int rt,int l,int r,int &x,int &y,int &z) {
split(rt,r,x,z);
split(x,l-1,x,y);
}
ll find_sum(int &rt,int l,int r) {
int x,y,z;
split_seq(rt,l,r,x,y,z);
ll ret=tr[y].sum;
rt=merge(merge(x,y),z);
return ret;
}
void dfs(int u,int i) {
if(!u) return ;
dfs(tr[u].l,i);
v[i].push_back(tr[u].val);
dfs(tr[u].r,i);
}
int build(int l,int r,int i) {
if(l>r) return 0;
int mid=(l+r)>>1;
int x=new_node(v[i][mid]);
tr[x].l=build(l,mid-1,i);
tr[x].r=build(mid+1,r,i);
push_up(x);
return x;
}
int build_a(int l,int r) {
if(l>r) return 0;
int mid=(l+r)>>1;
int x=new_node(a[mid]);
tr[x].l=build_a(l,mid-1);
tr[x].r=build_a(mid+1,r);
push_up(x);
return x;
}
void rebuild() {
idx=0;
v[0].clear();
v[1].clear();
dfs(rt[0],0);
dfs(rt[1],1);
rt[0]=build(0,n-1,0);
rt[1]=build(0,n-1,1);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>a[i];
rt[0]=build_a(1,n);
rt[1]=rt[0];
for(int i=1;i<=m;i++) {
int op;
cin>>op;
if(idx>=1500000) rebuild();
if(op==1) {
int l,r;
cin>>l>>r;
cout<<find_sum(rt[1],l,r)<<'\n';
}
else if(op==2) {
int l,r,k;
cin>>l>>r>>k;
int x,y,z;
split_seq(rt[1],l-k,l-1,x,y,z);
int cnt=(r-l+1+k-1)/k;
int Y=qpow(y,cnt);
int u,v;
split(Y,r-l+1,u,v);
int b,c;
split(z,r-l+1,b,c);
rt[1]=merge(merge(merge(x,y),u),c);
}
else {
int l,r;
cin>>l>>r;
int x,y,z,u,v,w;
split_seq(rt[0],l,r,x,y,z);
split_seq(rt[1],l,r,u,v,w);
rt[1]=merge(merge(u,y),w);
}
}
return 0;
}
这种方法跑得比较慢,空间也要大一些。有几篇题解采用的是基于子树大小随机合并的方法,耗时比这种方法少 \(40\%\)(同样使用 \(rand\) 生成随机数,并且都是c++14开O2优化),空间上每个节点少一个 \(key\)值,解法很优秀。
把基于子树大小随机合并的代码也放在这里(简要解释一下这种随机方法:以概率 \(\frac {tr[x].sz} {tr[x].sz+tr[y].sz}\) 将 \(x\) 作为根,否则将 \(y\) 作为根。通过取模可以生成一个 \([0,tr[x].sz+tr[y].sz-1]\) 区间内的随机数,然后依据与 \(tr[x].sz\) 大小关系判断选择哪一个(小于 \(tr[x].sz\) 的情况有 \((tr[x].sz-1)-0+1=tr[x].sz\) 种)):
cpp
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e5+5;
int n,m;
int a[N];
struct node {
int l,r,sz,val;
ll sum;
}tr[N*8];
int rt[5],idx;
vector <int> v[5];
int new_node(int v) {
tr[++idx].sum=v;
tr[idx].val=v;
tr[idx].sz=1;
tr[idx].l=tr[idx].r=0;
return idx;
}
int copy_node(int x) {
tr[++idx]=tr[x];
return idx;
}
void push_up(int rt) {
tr[rt].sum=tr[tr[rt].l].sum+tr[tr[rt].r].sum+tr[rt].val;
tr[rt].sz=tr[tr[rt].l].sz+tr[tr[rt].r].sz+1;
}
void split(int p,int k,int &x,int &y) {
if(!p) {
x=y=0;
return ;
}
if(tr[tr[p].l].sz<k) {
x=copy_node(p);
split(tr[x].r,k-tr[tr[x].l].sz-1,tr[x].r,y);
push_up(x);
}
else {
y=copy_node(p);
split(tr[y].l,k,x,tr[y].l);
push_up(y);
}
}
int merge(int x,int y) {
if(!x||!y) return x|y;
if(rand()%(tr[x].sz+tr[y].sz)<tr[x].sz) {
int p=copy_node(x);
tr[p].r=merge(tr[p].r,y);
push_up(p);
return p;
}
else {
int p=copy_node(y);
tr[p].l=merge(x,tr[p].l);
push_up(p);
return p;
}
}
int qpow(int p,int b) {
int a=copy_node(p),ret=0;
for(;b;b>>=1,a=merge(a,a)) if(b&1) ret=merge(ret,a);
return ret;
}
void split_seq(int rt,int l,int r,int &x,int &y,int &z) {
split(rt,r,x,z);
split(x,l-1,x,y);
}
ll find_sum(int &rt,int l,int r) {
int x,y,z;
split_seq(rt,l,r,x,y,z);
ll ret=tr[y].sum;
rt=merge(merge(x,y),z);
return ret;
}
void dfs(int u,int i) {
if(!u) return ;
dfs(tr[u].l,i);
v[i].push_back(tr[u].val);
dfs(tr[u].r,i);
}
int build(int l,int r,int i) {
if(l>r) return 0;
int mid=(l+r)>>1;
int x=new_node(v[i][mid]);
tr[x].l=build(l,mid-1,i);
tr[x].r=build(mid+1,r,i);
push_up(x);
return x;
}
int build_a(int l,int r) {
if(l>r) return 0;
int mid=(l+r)>>1;
int x=new_node(a[mid]);
tr[x].l=build_a(l,mid-1);
tr[x].r=build_a(mid+1,r);
push_up(x);
return x;
}
void rebuild() {
idx=0;
v[0].clear();
v[1].clear();
dfs(rt[0],0);
dfs(rt[1],1);
rt[0]=build(0,n-1,0);
rt[1]=build(0,n-1,1);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>a[i];
rt[0]=build_a(1,n);
rt[1]=rt[0];
for(int i=1;i<=m;i++) {
int op;
cin>>op;
if(idx>=1500000) rebuild();
if(op==1) {
int l,r;
cin>>l>>r;
cout<<find_sum(rt[1],l,r)<<'\n';
}
else if(op==2) {
int l,r,k;
cin>>l>>r>>k;
int x,y,z;
split_seq(rt[1],l-k,l-1,x,y,z);
int cnt=(r-l+1+k-1)/k;
int Y=qpow(y,cnt);
int u,v;
split(Y,r-l+1,u,v);
int b,c;
split(z,r-l+1,b,c);
rt[1]=merge(merge(merge(x,y),u),c);
}
else {
int l,r;
cin>>l>>r;
int x,y,z,u,v,w;
split_seq(rt[0],l,r,x,y,z);
split_seq(rt[1],l,r,u,v,w);
rt[1]=merge(merge(u,y),w);
}
}
return 0;
}
欢迎大家指出问题!