【模板】动态 dp 学习笔记(树剖版)

动态 dp 学习笔记(树剖版)

本文同步发表于 cnblogs

本文同步发表于 luogu

前置知识:

  • 简单 dp
  • 树链剖分
  • 矩阵乘法和广义矩阵乘法

P4719 【模板】动态 DP

本文着重讲下修改的具体过程以及代码实现,蒟蒻花了好长时间才明白。

鏖战一天终于通过了板子题啊啊啊!!!

不带修:简单树上 dp。

考虑不带修改,就是一个平凡的树上最大权独立集问题,简单树上 dp 即可求解。

设 \(dp[i][0/1]\) 表示以 \(i\) 为根子树中选\(/\)不选 \(i\) 所得到的树上最大权独立集的大小。

转移是容易的:

\(dp_{i,0}=\sum \max(dp_{son,1},dp_{son,0})\)

\(dp_{i,1}=\sum dp_{to,0}\)


带修了!

但是丧心病狂的出题人加上了修改点权!

考虑动态维护这个问题。

发现树剖有一个性质:跳重链至多 \(O(\log)\) 次。

设:

\(son_p\) 为 \(p\) 的重儿子。

\(top_p\) 为 \(p\) 点所在重链的链头节点。

\(fa_p\) 为 \(p\) 的父亲节点。

然后修正我们的 dp 为:

\(f[p][0/1]\) 表示以 \(p\) 为根子树选\(/\)不选 \(p\) 的答案。

\(g[p][0/1]\) 表示以 \(p\) 为根子树不选重儿子 ,选\(/\)不选 \(p\) 的答案。

设 \(to\) 为点 \(p\) 的儿子节点,那么有转移:

\[\begin{equation*} \begin{aligned} &g_{p,0}=\sum \max(g_{to,0},g_{to,1}) [ to \ne son_p ] \\ &g_{p,1}=val_p+ \sum g_{to,0} \\ &f_{p,0}=g_{p,0}+\max(f_{son_p,0},f_{son_p,1})\\ &f_{p,1}=g_{p,1}+f_{son_p,0}\\ \end{aligned} \end{equation*}\]


转移改为使用广义矩阵乘法:

直接定义广义矩阵乘法 \(C_{i,j}=\max_k(A_{i,k}+B_{k,j})\)。

本文并不在广义矩阵乘这里深入展开,具体可参考其他博客。作者还没完全搞懂。

具体可参考 https://www.cnblogs.com/qkhm/p/19055513/ddp

当然如果你不会证明你也可以直接设三个矩阵然后暴力手算验证该新定义的矩阵是否满足结合律,也是可行的。

发现这个新的矩乘还是满足结合律(不满足交换律),单位矩阵是主对角线上是 \(0\),其他全都是 \(−\infty\)。

本题单位矩阵为

\[ I= \begin{bmatrix} 0 & -\infty \\ -\infty & 0 \end{bmatrix} \]

原树上 点 \(p\) 的答案矩阵 \(B_p\) 为

\[ B_p= \begin{bmatrix} f_{p,0}\\ f_{p,1}\\ \end{bmatrix} \]

那么我们需要求出每个点的转移矩阵,设为 \(A_p\)。

定义完所有所需矩阵之后,转移可以写为:

\[A_p \times B_{son_p} = B_p \]

\[ A_p=\begin{bmatrix} a & b \\ c & d \end{bmatrix} \]

写出完整矩阵转移的柿子为

\[\begin{bmatrix} a & b \\ c & d \end{bmatrix} \times \begin{bmatrix} f_{son_p,0} \\ f_{son_p,1}\\ \end{bmatrix} = \begin{bmatrix} f_{p,0}\\ f_{p,1}\\ \end{bmatrix} \]

由待定系数法

\[\begin{equation*} \begin{aligned} f_{p,0}&=\max(a + f_{son_p,0}, b+f_{son_p,1} )\\ &=g_{p,0}+\max(f_{son_p,0},f_{son_p,1})\\ &=\max({\color{red}g_{p,0}}+f_{son_p,0},{\color{red}g_{p,0}}+f_{son_p,1})\\ \end{aligned} \end{equation*}\]

\[\begin{equation*} \begin{aligned} f_{p,1}&=\max(c+f_{son_p,0}, d+f_{son_p,1})\\ &=g_{p,1}+f_{son_p,0}\\ &=\max({\color{red}g_{p,1}}+f_{son_p,0},{\color{red}-\infty} +f_{son_p,1} )\\ \end{aligned} \end{equation*}\]

得出 \(p\) 点的转移矩阵为

\[ A_p= \begin{bmatrix} g_{p,0} & g_{p,0} \\ g_{p,1} & -\infty \end{bmatrix} \]

那么转移可以写成

\[A_p \times \begin{bmatrix} f_{son_p,0} \\ f_{son_p,1}\\ \end{bmatrix} = \begin{bmatrix} f_{p,0}\\ f_{p,1}\\ \end{bmatrix} \]

那么一条重链上某一点 \(p\) 的答案矩阵(\(f\) 值)可以用矩阵连乘计算:

\[B_p=A_p \times A_{son_p} \times A_{son_{son_p}}\times \dots \times \begin{bmatrix} 0\\ 0\\ \end{bmatrix} \]

链尾节点没有 \(son\),所以 \(B_{son_{tail}} = \left[ \begin{smallmatrix} 0\\ 0\\ \end{smallmatrix} \right] \)。

手算发现乘\( \left[ \begin{smallmatrix} 0\\ 0\\ \end{smallmatrix} \right] \)之前的那个 \(2\times 2\) 矩阵提出第一列构成一个矩阵后,恰好是乘之后的答案矩阵。所以我们不用真正乘
\( \left[ \begin{smallmatrix} 0\\ 0\\ \end{smallmatrix} \right] \)

矩阵,直接提取第一列即可。


初始化:

那么我们就可以在线段树上维护区间矩阵连乘积了。注意矩阵乘法不满足交换律,我的做法在合并区间时需要左 \(\times\) 右。

具体实现时比较复杂。

先两次 dfs 进行树链剖分。我们需要额外记录一个 \(bot_i\) 表示一条以 \(i\) 为链首的那条重链的链底。

再来一次 dfs 跑一遍树上 dp 初始化 \(f\) 和 \(g\) 数组。

在 dfn 序上建线段树,每个叶子节点记录它代表的树上点 \(p\) 的转移矩阵 \(A_p\)。

对于非叶子节点,记录的矩阵为左儿子矩阵乘右儿子矩阵。

查询答案时我们需要查询以根为链首的 \(B_{top}\) 矩阵值,可以通过线段树区间查询该重链的矩阵乘积得到。


解决修改问题:

着重讲下修改的具体过程以及代码实现,蒟蒻花了好长时间才明白。

注意,下文对 \(g\) 的修改改的是矩阵里的值 ,而不是 \(g\) 数组的值。

设我们要将树上点 \(p\) 的权值改为 \(k\)。设原先点权为 \(val_p\)。

首先开一个全局临时矩阵 \(X\),令 \(X = A_p\)。

然后修改矩阵 \(X\) 中 \(g_{p,1}\) 的值,也就是矩阵的第二行第一列那个位置的值。容易发现点 \(p\) 对 \(g_{p,1}\) 的贡献由原先的 \(val_p\) 变为 \(k\),变化量为 \(k-val_p\),所以我们在矩阵中更改 \(g_{p,1}=g_{p,1}+k-val_p\) 即可。

然后我们考察 \(g\) 的实际意义,为一个点轻儿子的贡献。考虑有几个点的 \(g\) 值需要更新。发现是 \(p\) 点和跳链过程中每条链链顶的父亲,其他点均不需要修改。

由于每次查询根链最多跳 \(\log n\) 条重链,所以对应的转移矩阵 \(A_p\) 只会有 \(\log n\) 个得到修改,加上线段树的复杂度就是 \(\log^2\) 的。复杂度得到证明。

然后现在节点为 \(p\),开始跳链操作。

先线段树区间查询 \(p\) 所在重链的乘积为矩阵 \(Pre\),\(Pre\) 的第一列即为修改之前 链顶的答案矩阵的值(\(f\) 值)。

然后线段树单点修改 \(p\) 点的转移矩阵,将其变为 \(X\)。

再线段树区间查询 \(p\) 所在重链的乘积为矩阵 \(Nxt\),\(Nxt\) 的第一列即为修改之后 链顶的答案矩阵的值(\(f\) 值)。

设 \(to = top_p\),然后令 \(p = fa_{top_p}\),意为令 \(p\) 跳到 \(p\) 所在重链链首的父亲。

再次线段树单点查询出 \(p\) 点所对应的转移矩阵 \(A_p\),令 \(X=A_p\)。

考察 \(to\) 对 \(g_{p,0}\) 的贡献,为 \(\max(f_{to,0},f_{to,1})\)。

那么矩阵 \(X\) 里的所有 \(g_{p,0}\) 变化量即为 \(\max(Nxt_{1,1},Nxt_{2,1}) - \max(Pre_{1,1},Pre_{2,1})\)。将 \(X_{1,1}\) 和 \(X_{1,2}\) 加上变化量即可。

考察 \(to\) 对 \(g_{p,1}\) 的贡献,为 \(f_{to,0}\)。

那么矩阵 \(X\) 里的所有 \(g_{p,1}\) 变化量即为 \(Nxt_{1,1} - Pre_{1,1}\)。将 \(X_{2,1}\) 加上变化量即可。

矩阵 \(X\) 中 \(X_{2,2}\) 仍为 \(-\infty\)。

重复执行上述过程直至 \(p = 0\) 时结束。

即可完成修改。


实现细节:

  • 矩阵可以用 \(2 \times 2\) 的二维数组存储。

  • 矩阵乘可以循环展开。

  • 线段树上非叶子节点存储的矩阵为左儿子矩阵 \(\times\) 右儿子矩阵。


Code:

cpp 复制代码
#include<bits/stdc++.h>
#define int long long
#define lp (p<<1)
#define rp ((p<<1)|1)
using namespace std;

inline int read()
{
	int x=0,c=getchar(),f=0;
	for(;c>'9'||c<'0';f=c=='-',c=getchar());
	for(;c>='0'&&c<='9';c=getchar())
		x=(x<<1)+(x<<3)+(c^48);
	return f?-x:x;
}
inline void write(int x)
{
	if(x<0) x=-x,putchar('-');
	if(x>9)  write(x/10);
	putchar(x%10+'0');
}

#ifndef ONLINE_JUDGE
#define ONLINE_JUDGE
#endif

const int N=1e6+5;
int n,m;
int a[N];
int fa[N];
int tail[N];
int siz[N];
int son[N];
int dfn[N];
int tod[N];
int top[N];
int id[N];
int tot;
vector<int> E[N];
const int inf=1e12;
struct Martix
{
    int a[2][2]={};
}I,t[N<<2];
Martix nw;
int f[N][2],g[N][2];
int tid[N<<2];

inline int max(int x,int y) { return x>y?x:y; }

Martix operator*(const Martix &x,const Martix &y)
{
    Martix ans;
    ans.a[0][0]=max(x.a[0][0]+y.a[0][0],x.a[0][1]+y.a[1][0]);
    ans.a[0][1]=max(x.a[0][0]+y.a[0][1],x.a[0][1]+y.a[1][1]);
    ans.a[1][0]=max(x.a[1][0]+y.a[0][0],x.a[1][1]+y.a[1][0]);
    ans.a[1][1]=max(x.a[1][0]+y.a[0][1],x.a[1][1]+y.a[1][1]);
    return ans;
}

Martix ksm(Martix x,int p)
{
    Martix ans=I;
    while(p)
    {
        if(p&1) ans=ans*x;
        x=x*x;
        p>>=1;
    }
    return ans;
}

void dfs1(int p,int f)
{
    fa[p]=f;
    siz[p]=1;
    for(int to:E[p])
    {
        if(to==f) continue;
        dfs1(to,p);
        siz[p]+=siz[to];
        if(siz[to]>siz[son[p]]) son[p]=to;
    }
}

void dfs2(int p,int tp)
{
    dfn[p]=++tot;
    id[tot]=p;
    top[p]=tp;
    tail[tp]=p;

    if(son[p]) dfs2(son[p],tp);

    for(int to:E[p])
    if(!dfn[to]) dfs2(to,to);
}

void dfs3(int p,int fa)
{
    g[p][1]=a[p];
    for(int to:E[p])
    {
        if(to==fa) continue;
        
        dfs3(to,p);
        if(to==son[p]) continue;

        g[p][1]+=f[to][0];
        g[p][0]+=max(f[to][0],f[to][1]);
    }
    f[p][0]=g[p][0]+max(f[son[p]][0],f[son[p]][1]);
    f[p][1]=g[p][1]+f[son[p]][0];
}

void pushup(int p)
{
    t[p]=t[lp]*t[rp];
}

void build(int l,int r,int p)
{
    if(l==r)
    {
        tid[id[l]]=p;
        t[p].a[0][0]=t[p].a[0][1]=g[id[l]][0];
        t[p].a[1][0]=g[id[l]][1];
        t[p].a[1][1]=-inf;   

        return;
    }
    int mid=(l+r)>>1;
    build(l,mid,lp);
    build(mid+1,r,rp);
    pushup(p);  
}

Martix query(int l,int r,int sl,int sr,int p)
{
    if(p==0) exit(0);
    if(sl<=l&&r<=sr) return t[p];
    int mid=(l+r)>>1;
    Martix ql=I,qr=I;
    if(sl<=mid) ql=query(l,mid,sl,sr,lp);
    if(sr>mid) qr=query(mid+1,r,sl,sr,rp);
    return ql*qr;
}

void change(int l,int r,int x,int p,const Martix &nw)
{
    if(l==r)
    {
        t[p]=nw;
        return;
    }
    int mid=(l+r)>>1;

    if(x<=mid) change(l,mid,x,lp,nw);
    else change(mid+1,r,x,rp,nw);

    pushup(p);
}

void change(int x,int y)
{
    nw=t[tid[x]];
    nw.a[1][0]+=y-a[x];
    a[x]=y;
    while(x)
    {
        Martix last=query(1,n,dfn[top[x]],dfn[tail[top[x]]],1);
        change(1,n,dfn[x],1,nw);
        Martix cur=query(1,n,dfn[top[x]],dfn[tail[top[x]]],1);
        
        x=fa[top[x]];
        nw=t[tid[x]];

        nw.a[0][0]+=max(cur.a[0][0],cur.a[1][0])-max(last.a[0][0],last.a[1][0]);
        nw.a[0][1]+=max(cur.a[0][0],cur.a[1][0])-max(last.a[0][0],last.a[1][0]);
        nw.a[1][0]+=cur.a[0][0]-last.a[0][0];
    }
}   

signed main()
{   
    I.a[0][0]=I.a[1][1]=0;
    I.a[0][1]=I.a[1][0]=-inf;   
    t[0]=I;

    n=read();
    m=read();
    for(int i=1;i<=n;i++) a[i]=read();
    for(int i=1;i<n;i++)
    {
        int u=read(),v=read();
        E[u].push_back(v);
        E[v].push_back(u);
    }
    dfs1(1,0);
    memset(siz,0,sizeof(siz));
    dfs2(1,1);
    dfs3(1,0);
    build(1,n,1);

    while(m--)
    {
        int x=read(),y=read();
        change(x,y);
        Martix ans=query(1,n,1,dfn[tail[1]],1);
        write(max({ans.a[0][0],ans.a[1][0],ans.a[1][1],ans.a[0][1]}));
        putchar('\n');
    }
	return 0;
}

P4751 【模板】动态 DP(加强版)

我册出题人竟然卡树剖!

干了哥们儿!

卡常技巧:

  • 矩阵可以用 \(2 \times 2\) 的二维数组存储。

  • 矩阵乘可以循环展开。

  • pushup#define

  • 不开 long long

  • 对每条重链单开线段树维护,优化掉线段树区间查求上文的 \(Pre\) 和 \(Nxt\) 矩阵的过程,改为直接访问该重链根节点记录的矩阵(重要优化)。

  • 使用 _unlocked 以面对大量读入,手写 buf 使用 fread 也可(重要优化)。

  • 你的线段树需要动态开点,精细实现,每遇到一个非叶子节点直接开它的左右儿子,以令内存访问极为连续(重要优化)。

  • 手写 \(\max\)。

  • 再夜深人静时或大早上交,卡评测机波动。

  • 使用 \(\texttt{C++17/C++98/C++23}\) 提交。

Code:

cpp 复制代码
#include<bits/stdc++.h>
// #define int long long

using namespace std;

inline int read()
{
	int x=0,c=getchar_unlocked(),f=0;
	for(;c>'9'||c<'0';f=c=='-',c=getchar_unlocked());
	for(;c>='0'&&c<='9';c=getchar_unlocked())
		x=(x<<1)+(x<<3)+(c^48);
	return f?-x:x;
}
inline void write(int x)
{
	if(x<0) x=-x,putchar_unlocked('-');
	if(x>9)  write(x/10);
	putchar_unlocked(x%10+'0');
}

const int N=1e6+5;
int n,m;
int a[N];
int fa[N];
int tail[N];
int siz[N];
int son[N];
int dfn[N];
int tod[N];
int top[N];
int id[N];
int tot;
vector<int> E[N];
const int inf=1e9;
struct Martix{
    int a[2][2]={};
}I;
struct Tree
{
    int lp,rp;
    Martix a;
}t[N<<2];
int cntroot=0;
int root[N];
int f[N][2],g[N][2];
int tid[N];
Martix nw;

inline int max(int x,int y) { return x>y?x:y; }

Martix operator*(const Martix &x,const Martix &y)
{
    Martix ans;
    ans.a[0][0]=max(x.a[0][0]+y.a[0][0],x.a[0][1]+y.a[1][0]);
    ans.a[0][1]=max(x.a[0][0]+y.a[0][1],x.a[0][1]+y.a[1][1]);
    ans.a[1][0]=max(x.a[1][0]+y.a[0][0],x.a[1][1]+y.a[1][0]);
    ans.a[1][1]=max(x.a[1][0]+y.a[0][1],x.a[1][1]+y.a[1][1]);
    return ans;
}

Martix ksm(Martix x,int p)
{
    Martix ans=I;
    while(p)
    {
        if(p&1) ans=ans*x;
        x=x*x;
        p>>=1;
    }
    return ans;
}

void dfs1(int p,int f)
{
    fa[p]=f;
    siz[p]=1;
    for(int to:E[p])
    {
        if(to==f) continue;
        dfs1(to,p);
        siz[p]+=siz[to];
        if(siz[to]>siz[son[p]]) son[p]=to;
    }
}

void dfs2(int p,int tp)
{
    siz[tp]++;
    dfn[p]=++tot;
    id[tot]=p;
    top[p]=tp;
    tail[tp]=p;

    if(son[p]) dfs2(son[p],tp);
    else root[tp]=++cntroot;

    for(int to:E[p])
    if(!dfn[to]) dfs2(to,to);
}

void dfs3(int p,int fa)
{
    g[p][1]=a[p];
    for(int to:E[p])
    {
        if(to==fa) continue;
        
        dfs3(to,p);
        if(to==son[p]) continue;

        g[p][1]+=f[to][0];
        g[p][0]+=max(f[to][0],f[to][1]);
    }
    f[p][0]=g[p][0]+max(f[son[p]][0],f[son[p]][1]);
    f[p][1]=g[p][1]+f[son[p]][0];
}

void build(int rt,int l,int r,int &p)
{
    if(l==r)
    {
        tid[id[rt+l-1]]=p;
        t[p].a.a[0][0]=t[p].a.a[0][1]=g[id[rt+l-1]][0];
        t[p].a.a[1][0]=g[id[rt+l-1]][1];
        t[p].a.a[1][1]=-inf;
        return;
    }
    if(!t[p].lp) t[p].lp=++tot;
    if(!t[p].rp) t[p].rp=++tot;
    int mid=(l+r)>>1;
    build(rt,l,mid,t[p].lp);
    build(rt,mid+1,r,t[p].rp);
    t[p].a=t[t[p].lp].a*t[t[p].rp].a;
}

void change(int l,int r,int x,int p,const Martix &nw)
{
    if(l==r)
    {
        t[p].a=nw;
        return;
    }
    int mid=(l+r)>>1;

    if(x<=mid) change(l,mid,x,t[p].lp,nw);
    else change(mid+1,r,x,t[p].rp,nw);

    t[p].a=t[t[p].lp].a*t[t[p].rp].a;
}

void change(int x,int y)
{
    nw=t[tid[x]].a;
    nw.a[1][0]+=y-a[x];
    a[x]=y;
    while(x)
    {
        Martix last=t[root[top[x]]].a;
        change(1,siz[top[x]],dfn[x]-dfn[top[x]]+1,root[top[x]],nw);
        Martix cur=t[root[top[x]]].a;
        x=fa[top[x]];
        nw=t[tid[x]].a;
        nw.a[0][0]+=max(cur.a[0][0],cur.a[1][0])-max(last.a[0][0],last.a[1][0]);
        nw.a[0][1]+=max(cur.a[0][0],cur.a[1][0])-max(last.a[0][0],last.a[1][0]);
        nw.a[1][0]+=cur.a[0][0]-last.a[0][0];
    }
}   

signed main()
{  
    I.a[0][0]=I.a[1][1]=0;
    I.a[0][1]=I.a[1][0]=-inf;   
    t[0].a=I;

    n=read();
    m=read();
    for(int i=1;i<=n;i++) a[i]=read();
    for(int i=1;i<n;i++)
    {
        int u=read(),v=read();
        E[u].push_back(v);
        E[v].push_back(u);
    }
    dfs1(1,0);
    memset(siz,0,sizeof(siz));
    dfs2(1,1);
    tot=0;
    dfs3(1,0);
    tot=cntroot;
    for(int i=1;i<=n;i++) if(i==top[i]) build(dfn[i],1,siz[i],root[i]);
    int lastans=0;

    Martix ans;
    while(m--)
    {
        int x=read()^lastans,y=read();
        change(x,y);
        ans=t[root[top[1]]].a;
        lastans=max(ans.a[0][0],ans.a[1][0]);
        write(lastans);
        putchar_unlocked('\n');

    }
	return 0;
}