HDU6087 Rikka with Sequence 题解

分析

本题中的操作3提到了版本回退,于是可以想到这是用可持久化数据结构做。又看到操作2复杂的区间操作,说明可以用可持久化平衡树解决。

模拟一下操作2,可以发现操作2其实是在将 \([l-k,l-1]\) 这个区间(这个区间长度为 \(k\) )复制出来 \(\lceil \frac {r-l+1} {k} \rceil\)个,然后把它们顺序拼接起来作为一个大区间,最后取前 \(r-l+1\) 个数整体替换掉 \([l,r]\) 区间。这一步可以用类似快速幂的办法去做(因为 \(merge\) 函数满足结合律),每次将这棵代表 \([l-k,l-1]\) 区间的 \(Treap\) 树复制并合并。操作1和操作3直接截取目标区间进行统计或者操作即可。

还有一点就是本题空间限制紧,为了可持久化会产生许多废弃节点,当节点总数超过一个值(我选的 \(1.5e6\) ,可以多调一调这个数的值)之后就应该重建两棵树(一棵是初始的,一棵是经过多次操作的),保证不会爆空间。

代码

RE代码

cpp 复制代码
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e5+5;
int n,m;
int a[N];
struct node {
    int l,r,key,sz,val;
    ll sum;
}tr[N*8];
int rt[5],idx;
vector <int> v[5];
int new_node(int v) {
    tr[++idx].sum=v;
    tr[idx].val=v;
    tr[idx].sz=1;
    tr[idx].key=rand();
    tr[idx].l=tr[idx].r=0;
    return idx;
}
int copy_node(int x) {
    tr[++idx]=tr[x];
    return idx;
}
void push_up(int rt) {
    tr[rt].sum=tr[tr[rt].l].sum+tr[tr[rt].r].sum+tr[rt].val;
    tr[rt].sz=tr[tr[rt].l].sz+tr[tr[rt].r].sz+1;
}
void split(int p,int k,int &x,int &y) {
    if(!p) {
        x=y=0;
        return ;
    }
    if(tr[tr[p].l].sz<k) {
        x=copy_node(p);
        split(tr[x].r,k-tr[tr[x].l].sz-1,tr[x].r,y);
        push_up(x);
    }
    else {
        y=copy_node(p);
        split(tr[y].l,k,x,tr[y].l);
        push_up(y);
    }
}
int merge(int x,int y) {
    if(!x||!y) return x|y;
    if(tr[x].key<tr[y].key) {
        int p=copy_node(x);
        tr[p].r=merge(tr[p].r,y);
        push_up(p);
        return p;
    }
    else {
        int p=copy_node(y);
        tr[p].l=merge(x,tr[p].l);
        push_up(p);
        return p;
    }
}
int qpow(int p,int b) {
    int a=copy_node(p),ret=0;
    for(;b;b>>=1,a=merge(a,a)) if(b&1) ret=merge(ret,a);
    return ret;
}
void split_seq(int rt,int l,int r,int &x,int &y,int &z) {
    split(rt,r,x,z);
    split(x,l-1,x,y);
}
ll find_sum(int &rt,int l,int r) {
    int x,y,z;
    split_seq(rt,l,r,x,y,z);
    ll ret=tr[y].sum;
    rt=merge(merge(x,y),z);
    return ret;
}
void dfs(int u,int i) {
    if(!u) return ;
    dfs(tr[u].l,i);
    v[i].push_back(tr[u].val);
    dfs(tr[u].r,i);
}
int build(int l,int r,int i) {
    if(l>r) return 0;
    int mid=(l+r)>>1;
    int x=new_node(v[i][mid]);
    tr[x].l=build(l,mid-1,i);
    tr[x].r=build(mid+1,r,i);
    push_up(x);
    return x;
}
int build_a(int l,int r) {
    if(l>r) return 0;
    int mid=(l+r)>>1;
    int x=new_node(a[mid]);
    tr[x].l=build_a(l,mid-1);
    tr[x].r=build_a(mid+1,r);
    push_up(x);
    return x;
}
void rebuild() {
    idx=0;
    v[0].clear();
    v[1].clear();
    dfs(rt[0],0);
    dfs(rt[1],1);
    rt[0]=build(0,n-1,0);
    rt[1]=build(0,n-1,1);
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    cin>>n>>m;
    for(int i=1;i<=n;i++) cin>>a[i];
    rt[0]=build_a(1,n);
    rt[1]=rt[0];
    for(int i=1;i<=m;i++) {
        int op;
        cin>>op;
        if(idx>=1500000) rebuild();
        if(op==1) {
            int l,r;
            cin>>l>>r;
            cout<<find_sum(rt[1],l,r)<<'\n';
        }
        else if(op==2) {
            int l,r,k;
            cin>>l>>r>>k;
            int x,y,z;
            split_seq(rt[1],l-k,l-1,x,y,z);
            int cnt=(r-l+1+k-1)/k;
            int Y=qpow(y,cnt);
            int u,v;
            split(Y,r-l+1,u,v);
            int b,c;
            split(z,r-l+1,b,c);
            rt[1]=merge(merge(merge(x,y),u),c);
        }
        else {
            int l,r;
            cin>>l>>r;
            int x,y,z,u,v,w;
            split_seq(rt[0],l,r,x,y,z);
            split_seq(rt[1],l,r,u,v,w);
            rt[1]=merge(merge(u,y),w);
        }
    }
    return 0;
}

(我基于上面这个第一版代码不断微调,一直交,结果当然是RE,甚至还微调出了TLE)

我认为RE原因主要是 \(qpow\) 函数不断对现有的 \(Treap\) 树进行复制,导致复制后的两棵 \(Treap\) 树存在大量 \(key\) 相等的节点。而 \(merge\) 函数在 \(key\) 相等时又总是走 \(else\) 分支,导致递归深度过大造成栈溢出(欢迎对这种写法的RE原因进行探讨)。那么解决这个问题的办法可以是当 \(key\) 相等时随机走一个分支,这样即可避免爆栈。

只需在 \(merge\) 函数的 \(if\) 判断中加一个条件即可。

AC代码

cpp 复制代码
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e5+5;
int n,m;
int a[N];
struct node {
    int l,r,key,sz,val;
    ll sum;
}tr[N*8];
int rt[5],idx;
vector <int> v[5];
int new_node(int v) {
    tr[++idx].sum=v;
    tr[idx].val=v;
    tr[idx].sz=1;
    tr[idx].key=rand();
    tr[idx].l=tr[idx].r=0;
    return idx;
}
int copy_node(int x) {
    tr[++idx]=tr[x];
    return idx;
}
void push_up(int rt) {
    tr[rt].sum=tr[tr[rt].l].sum+tr[tr[rt].r].sum+tr[rt].val;
    tr[rt].sz=tr[tr[rt].l].sz+tr[tr[rt].r].sz+1;
}
void split(int p,int k,int &x,int &y) {
    if(!p) {
        x=y=0;
        return ;
    }
    if(tr[tr[p].l].sz<k) {
        x=copy_node(p);
        split(tr[x].r,k-tr[tr[x].l].sz-1,tr[x].r,y);
        push_up(x);
    }
    else {
        y=copy_node(p);
        split(tr[y].l,k,x,tr[y].l);
        push_up(y);
    }
}
int merge(int x,int y) {
    if(!x||!y) return x|y;
    if(tr[x].key<tr[y].key||(tr[x].key==tr[y].key&&rand()%2==0)) {
        int p=copy_node(x);
        tr[p].r=merge(tr[p].r,y);
        push_up(p);
        return p;
    }
    else {
        int p=copy_node(y);
        tr[p].l=merge(x,tr[p].l);
        push_up(p);
        return p;
    }
}
int qpow(int p,int b) {
    int a=copy_node(p),ret=0;
    for(;b;b>>=1,a=merge(a,a)) if(b&1) ret=merge(ret,a);
    return ret;
}
void split_seq(int rt,int l,int r,int &x,int &y,int &z) {
    split(rt,r,x,z);
    split(x,l-1,x,y);
}
ll find_sum(int &rt,int l,int r) {
    int x,y,z;
    split_seq(rt,l,r,x,y,z);
    ll ret=tr[y].sum;
    rt=merge(merge(x,y),z);
    return ret;
}
void dfs(int u,int i) {
    if(!u) return ;
    dfs(tr[u].l,i);
    v[i].push_back(tr[u].val);
    dfs(tr[u].r,i);
}
int build(int l,int r,int i) {
    if(l>r) return 0;
    int mid=(l+r)>>1;
    int x=new_node(v[i][mid]);
    tr[x].l=build(l,mid-1,i);
    tr[x].r=build(mid+1,r,i);
    push_up(x);
    return x;
}
int build_a(int l,int r) {
    if(l>r) return 0;
    int mid=(l+r)>>1;
    int x=new_node(a[mid]);
    tr[x].l=build_a(l,mid-1);
    tr[x].r=build_a(mid+1,r);
    push_up(x);
    return x;
}
void rebuild() {
    idx=0;
    v[0].clear();
    v[1].clear();
    dfs(rt[0],0);
    dfs(rt[1],1);
    rt[0]=build(0,n-1,0);
    rt[1]=build(0,n-1,1);
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    cin>>n>>m;
    for(int i=1;i<=n;i++) cin>>a[i];
    rt[0]=build_a(1,n);
    rt[1]=rt[0];
    for(int i=1;i<=m;i++) {
        int op;
        cin>>op;
        if(idx>=1500000) rebuild();
        if(op==1) {
            int l,r;
            cin>>l>>r;
            cout<<find_sum(rt[1],l,r)<<'\n';
        }
        else if(op==2) {
            int l,r,k;
            cin>>l>>r>>k;
            int x,y,z;
            split_seq(rt[1],l-k,l-1,x,y,z);
            int cnt=(r-l+1+k-1)/k;
            int Y=qpow(y,cnt);
            int u,v;
            split(Y,r-l+1,u,v);
            int b,c;
            split(z,r-l+1,b,c);
            rt[1]=merge(merge(merge(x,y),u),c);
        }
        else {
            int l,r;
            cin>>l>>r;
            int x,y,z,u,v,w;
            split_seq(rt[0],l,r,x,y,z);
            split_seq(rt[1],l,r,u,v,w);
            rt[1]=merge(merge(u,y),w);
        }
    }
    return 0;
}

这种方法跑得比较慢,空间也要大一些。有几篇题解采用的是基于子树大小随机合并的方法,耗时比这种方法少 \(40\%\)(同样使用 \(rand\) 生成随机数,并且都是c++14开O2优化),空间上每个节点少一个 \(key\)值,解法很优秀。

把基于子树大小随机合并的代码也放在这里(简要解释一下这种随机方法:以概率 \(\frac {tr[x].sz} {tr[x].sz+tr[y].sz}\) 将 \(x\) 作为根,否则将 \(y\) 作为根。通过取模可以生成一个 \([0,tr[x].sz+tr[y].sz-1]\) 区间内的随机数,然后依据与 \(tr[x].sz\) 大小关系判断选择哪一个(小于 \(tr[x].sz\) 的情况有 \((tr[x].sz-1)-0+1=tr[x].sz\) 种)):

cpp 复制代码
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e5+5;
int n,m;
int a[N];
struct node {
    int l,r,sz,val;
    ll sum;
}tr[N*8];
int rt[5],idx;
vector <int> v[5];
int new_node(int v) {
    tr[++idx].sum=v;
    tr[idx].val=v;
    tr[idx].sz=1;
    tr[idx].l=tr[idx].r=0;
    return idx;
}
int copy_node(int x) {
    tr[++idx]=tr[x];
    return idx;
}
void push_up(int rt) {
    tr[rt].sum=tr[tr[rt].l].sum+tr[tr[rt].r].sum+tr[rt].val;
    tr[rt].sz=tr[tr[rt].l].sz+tr[tr[rt].r].sz+1;
}
void split(int p,int k,int &x,int &y) {
    if(!p) {
        x=y=0;
        return ;
    }
    if(tr[tr[p].l].sz<k) {
        x=copy_node(p);
        split(tr[x].r,k-tr[tr[x].l].sz-1,tr[x].r,y);
        push_up(x);
    }
    else {
        y=copy_node(p);
        split(tr[y].l,k,x,tr[y].l);
        push_up(y);
    }
}
int merge(int x,int y) {
    if(!x||!y) return x|y;
    if(rand()%(tr[x].sz+tr[y].sz)<tr[x].sz) {
        int p=copy_node(x);
        tr[p].r=merge(tr[p].r,y);
        push_up(p);
        return p;
    }
    else {
        int p=copy_node(y);
        tr[p].l=merge(x,tr[p].l);
        push_up(p);
        return p;
    }
}
int qpow(int p,int b) {
    int a=copy_node(p),ret=0;
    for(;b;b>>=1,a=merge(a,a)) if(b&1) ret=merge(ret,a);
    return ret;
}
void split_seq(int rt,int l,int r,int &x,int &y,int &z) {
    split(rt,r,x,z);
    split(x,l-1,x,y);
}
ll find_sum(int &rt,int l,int r) {
    int x,y,z;
    split_seq(rt,l,r,x,y,z);
    ll ret=tr[y].sum;
    rt=merge(merge(x,y),z);
    return ret;
}
void dfs(int u,int i) {
    if(!u) return ;
    dfs(tr[u].l,i);
    v[i].push_back(tr[u].val);
    dfs(tr[u].r,i);
}
int build(int l,int r,int i) {
    if(l>r) return 0;
    int mid=(l+r)>>1;
    int x=new_node(v[i][mid]);
    tr[x].l=build(l,mid-1,i);
    tr[x].r=build(mid+1,r,i);
    push_up(x);
    return x;
}
int build_a(int l,int r) {
    if(l>r) return 0;
    int mid=(l+r)>>1;
    int x=new_node(a[mid]);
    tr[x].l=build_a(l,mid-1);
    tr[x].r=build_a(mid+1,r);
    push_up(x);
    return x;
}
void rebuild() {
    idx=0;
    v[0].clear();
    v[1].clear();
    dfs(rt[0],0);
    dfs(rt[1],1);
    rt[0]=build(0,n-1,0);
    rt[1]=build(0,n-1,1);
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    cin>>n>>m;
    for(int i=1;i<=n;i++) cin>>a[i];
    rt[0]=build_a(1,n);
    rt[1]=rt[0];
    for(int i=1;i<=m;i++) {
        int op;
        cin>>op;
        if(idx>=1500000) rebuild();
        if(op==1) {
            int l,r;
            cin>>l>>r;
            cout<<find_sum(rt[1],l,r)<<'\n';
        }
        else if(op==2) {
            int l,r,k;
            cin>>l>>r>>k;
            int x,y,z;
            split_seq(rt[1],l-k,l-1,x,y,z);
            int cnt=(r-l+1+k-1)/k;
            int Y=qpow(y,cnt);
            int u,v;
            split(Y,r-l+1,u,v);
            int b,c;
            split(z,r-l+1,b,c);
            rt[1]=merge(merge(merge(x,y),u),c);
        }
        else {
            int l,r;
            cin>>l>>r;
            int x,y,z,u,v,w;
            split_seq(rt[0],l,r,x,y,z);
            split_seq(rt[1],l,r,u,v,w);
            rt[1]=merge(merge(u,y),w);
        }
    }
    return 0;
}

欢迎大家指出问题!