你是不是遇到过这种这种题,没遇到过算了,给你一颗n个节点的树,需要你维护一个数据结构,支持链上修改,链上求和,子树修改,子树求和,共m次操作
离线的话我会树上差分,骄傲.jpg,但我要在线你不就没招了吗,今天我们学一个新算法树剖,具体说是我们之前学过dfs序,可以做到把一个树上问题转换为序列问题,我们今天要学的重链剖分亦是如此。
树链剖分,顾名思义就是把一个好好的树剖开成很多条链,再往下细分还会有重链剖分和长链剖分,但是我只会重链剖分,所以我们只讲重链剖分。
首先我们先给出一些定义
重儿子:子树大小最大的那个儿子
轻儿子:其他的儿子
重边:连接两个重儿子的边
轻边:其他边
重链:重边首尾相连组成重链
Tips:重链开始于轻儿子
这些定义其实很显然,没必要举例子。
然后这颗本来优美的树就被我们无情的剖开成了一条一条的重链,每一个点都只会属于一个重链。

举个例子,这棵树重链剖分出来长这样。
好了,我们已经知道算法是干什么的啦,那该如何用代码实现它呢?
树剖的核心思想就是两遍dfs预处理
首先,第一遍dfs,我们叫做dfs1,他要预处理四个信息
fa[x]:节点x的父亲是谁
dep[x]:节点x的深度
siz[x]:节点x的子树大小
son[x]:节点x的重儿子是谁
所以我们可写出以下代码
cpp
void dfs1(int x,int f){
fa[x]=f; dep[x]=dep[f]+1; siz[x]=1;
int maxx=-1;
for(int i=head[x];i;i=nxt[i]){
int y=to[i];
if(y==f) continue;
dfs1(y,x); siz[x]+=siz[y];
if(siz[y]>maxx) son[x]=y,maxx=siz[y];
}
}
第二遍dfs,我们叫做dfs2,我们维护三个信息
tpf[x]:点x所在的重链的起始节点
dfn[x]:点x的dfs序
wr[x]:dfn序为x的点的权值
所以我们可写出以下代码
cpp
void dfs2(int x,int topf){
tpf[x]=topf; dfn[x]=++dfnidx; wr[dfnidx]=w[x]%mod;
if(!son[x]) return ;
dfs2(son[x],topf);
for(int i=head[x];i;i=nxt[i]){
int y=to[i];
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
哎,这样我们就成功的把树上问题转换成了序列问题,由于我们重链剖分的独特性质,所以每个子树的dfs序是连续的,所以子树的修改查询都很好做。
那我们考虑链的问题,我们给出x,y我们考虑让他们跳到重链的最顶端,每次都加重链的和,因为重链的dfs序也是连续的,然后就是一个区间修改区间查询的事,上线段树即可。
cpp
#define lc p<<1
#define rc p<<1|1
int n,rt,mod,w[N];
int fa[N],dep[N],siz[N],son[N];
int tpf[N],dfn[N],wr[N],dfnidx;
ll a[N<<2],lazy[N<<2];
void dfs1(int x,int f){
fa[x]=f; dep[x]=dep[f]+1; siz[x]=1;
int maxx=-1;
for(int i=head[x];i;i=nxt[i]){
int y=to[i];
if(y==f) continue;
dfs1(y,x); siz[x]+=siz[y];
if(siz[y]>maxx) son[x]=y,maxx=siz[y];
}
}
void dfs2(int x,int topf){
tpf[x]=topf; dfn[x]=++dfnidx; wr[dfnidx]=w[x]%mod;
if(!son[x]) return ;
dfs2(son[x],topf);
for(int i=head[x];i;i=nxt[i]){
int y=to[i];
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
void pushdown(int p,int l,int r){
if(lazy[p]){
int mid=(l+r)>>1;
lazy[lc]+=lazy[p];
lazy[rc]+=lazy[p];
a[lc]+=(mid-l+1)*lazy[p];
a[rc]+=(r-mid)*lazy[p];
a[lc]%=mod; a[rc]%=mod;
lazy[p]=0;
}
}
void build(int p,int l,int r){
if(l==r){ a[p]=wr[l]; return ; }
int mid=(l+r)>>1;
build(lc,l,mid);
build(rc,mid+1,r);
a[p]=(a[lc]+a[rc])%mod;
}
void update(int p,int l,int r,int ql,int qr,int val){
if(ql<=l&&r<=qr){ a[p]+=(r-l+1)*val; a[p]%=mod; lazy[p]+=val; lazy[p]%=mod; return ; }
int mid=(l+r)>>1;
pushdown(p,l,r);
if(ql<=mid) update(lc,l,mid,ql,qr,val);
if(qr>mid) update(rc,mid+1,r,ql,qr,val);
a[p]=(a[lc]+a[rc])%mod;
}
ll ask(int p,int l,int r,int ql,int qr){
if(ql<=l&&r<=qr) return a[p];
int mid=(l+r)>>1;
pushdown(p,l,r);
ll res=0;
if(ql<=mid) res+=ask(lc,l,mid,ql,qr); res%=mod;
if(qr>mid) res+=ask(rc,mid+1,r,ql,qr); res%=mod;
return res;
}
void updatechain(int x,int y,int val){
while(tpf[x]!=tpf[y]){
if(dep[tpf[x]]<dep[tpf[y]]) swap(x,y);
update(1,1,n,dfn[tpf[x]],dfn[x],val);
x=fa[tpf[x]];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,1,n,dfn[x],dfn[y],val);
}
ll askchain(int x,int y){
ll ans=0;
while(tpf[x]!=tpf[y]){
if(dep[tpf[x]]<dep[tpf[y]]) swap(x,y);
ans+=ask(1,1,n,dfn[tpf[x]],dfn[x]); ans%=mod;
x=fa[tpf[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans+=ask(1,1,n,dfn[x],dfn[y]);
return ans%mod;
}
dfs1(rt,0);
dfs2(rt,rt);
build(1,1,n);