Luogu P5298 PKUWC2018 Minimax 题解 [ 紫 ] [ 树形 dp ] [ 线段树合并 ] [ 概率 dp ]

Minimax:线段树合并优化 dp 好题。

树形 dp

因为要求出每一个值的出现概率,首先我们可以想到一个很暴力的 dp 式子。

定义 \(dp_{i,j}\) 表示在节点 \(i\) 时,权值 \(j\) 的出现概率,设 \(l\) 表示左儿子,\(r\) 表示右儿子,则有如下转移:

  • 当 \(j\) 在左儿子中时,\(dp_{i,j}\gets dp_{l,j}\times(p_i\times\sum_{k=1}^{j-1}dp_{r,k}+(1-p_i)\times\sum_{k=j+1}^{V}dp_{r,k})\),理解的话就是对父亲节点选大的还是选小的进行分讨。
  • 当 \(j\) 在右儿子中时,\(dp_{i,j}\gets dp_{r,j}\times(p_i\times\sum_{k=1}^{j-1}dp_{l,k}+(1-p_i)\times\sum_{k=j+1}^Vdp_{l,k})\)。

直接转移即可,时间复杂度 \(O(nV)\)。

线段树合并优化

显然原来的时间复杂度会炸掉,但是我们发现每个节点最开始时最多只有一个 dp 位置是有值的,所以我们考虑用这种均摊复杂度的线段树合并来优化这个 dp。

因为 dp 转移的时候需要用到前缀和后缀和,所以我们进行 merge 的时候记录节点 \(x,y\) 的前缀和 \(px,py\) 与后缀和 \(sx,sy\) 以及父亲节点的概率 \(p\)。

梳理一下 merge 的流程:

  • 进入节点 \(x,y\)。
  • 如果 \(x,y\) 其中之一是空树,则说明直接更新 dp 值即可。
    • 若 \(x\) 是空树,对应着上述 \(j\) 在右儿子中的转移方式,则我们对 \(y\) 的整颗子树内的 dp 值全部乘上 \((p\times\sum_{k=1}^{j-1}dp_{l,k}+(1-p)\times\sum_{k=j+1}^Vdp_{l,k})=(p\times px+(1-p)\times sx)\) 即可。这个可以用懒标记实现区间乘。
    • 若 \(y\) 是空树,对应着上述 \(j\) 在左儿子中的转移方式,则我们对 \(x\) 的整颗子树内的 dp 值全部乘上 \((p\times\sum_{k=1}^{j-1}dp_{r,k}+(1-p)\times\sum_{k=j+1}^Vdp_{r,k})=(p\times py+(1-p)\times sy)\) 即可。这个可以用懒标记实现区间乘。
  • 否则就说明要递归合并,递归左右儿子的时候记得更新 \(sx,sy,px,py\) 的值。
  • 最后将左右儿子的 dp 值加起来就是这个区间的 dp 值。

时间复杂度 \(O(n\log n)\)。

代码

cpp 复制代码
#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi=pair<int,int>;
const int N=300005;
const ll mod=998244353;
int n,fa[N],m=0,b[N],son[N][2],cd[N],p[N],ans[N];
ll qpow(ll a,ll b)
{
    ll res=1;
    while(b)
    {
        if(b&1)res=(res*a)%mod;
        b>>=1;
        a=(a*a)%mod;
    }
    return res;
}
int getrk(int x)
{
    return (lower_bound(b+1,b+m+1,x)-b);
}
struct Node{
    int ls,rs;
    ll dp,tag=1;
};
struct Segtree{
    Node tr[20*N];
    int root[N],tot=0;
    void pushup(int p)
    {
        tr[p].dp=(tr[lc(p)].dp+tr[rc(p)].dp)%mod;
    }
    void pushdown(int p)
    {
        if(tr[p].tag!=1)
        {
            tr[lc(p)].tag=(tr[lc(p)].tag*tr[p].tag)%mod;
            tr[rc(p)].tag=(tr[rc(p)].tag*tr[p].tag)%mod;
            tr[lc(p)].dp=(tr[lc(p)].dp*tr[p].tag)%mod;
            tr[rc(p)].dp=(tr[rc(p)].dp*tr[p].tag)%mod;
        }
        tr[p].tag=1;
    }
    void modify(int p,int v)
    {
        tr[p].dp=(tr[p].dp*1ll*v)%mod;
        tr[p].tag=(tr[p].tag*1ll*v)%mod;
    }
    void update(int &u,int ln,int rn,int x,ll k)
    {
        if(u==0)u=++tot;
        if(ln==rn){tr[u].dp+=k;return;}
        int mid=(ln+rn)>>1;
        if(x<=mid)update(lc(u),ln,mid,x,k);
        else update(rc(u),mid+1,rn,x,k);
        pushup(u);
    }
    int merge(int x,int y,int px,int py,int sx,int sy,int p)
    {
        if(x==0&&y==0)return 0;
        if(x==0)
        {
            modify(y,(1ll*p*px%mod+1ll*((1-p)%mod+mod)%mod*sx)%mod);
            return y;
        }
        if(y==0)
        {
            modify(x,(1ll*p*py%mod+1ll*((1-p)%mod+mod)%mod*sy)%mod);
            return x;
        }
        pushdown(x);pushdown(y);
        int lx=tr[lc(x)].dp,rx=tr[rc(x)].dp,ly=tr[lc(y)].dp,ry=tr[rc(y)].dp;
        tr[x].ls=merge(lc(x),lc(y),px,py,(sx+rx)%mod,(sy+ry)%mod,p);
        tr[x].rs=merge(rc(x),rc(y),(px+lx)%mod,(py+ly)%mod,sx,sy,p);
        pushup(x);
        return x;
    }
    void query(int u,int ln,int rn)
    {
        if(ln==rn){ans[ln]=tr[u].dp;return;}
        int mid=(ln+rn)>>1;
        pushdown(u);
        query(lc(u),ln,mid);
        query(rc(u),mid+1,rn);
    }
}tr1;
void dfs1(int u)
{
    if(son[u][0]==0)
    {
        tr1.update(tr1.root[u],1,m,getrk(p[u]),1);
        return;
    }
    if(son[u][1]==0)
    {
        dfs1(son[u][0]);
        tr1.root[u]=tr1.root[son[u][0]];
        return;
    }
    dfs1(son[u][0]);
    dfs1(son[u][1]);
    tr1.root[u]=tr1.merge(tr1.root[son[u][0]],tr1.root[son[u][1]],0,0,0,0,p[u]);
}
int main()
{
    //freopen("sample.in","r",stdin);
    //freopen("sample.out","w",stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin>>n;
    for(int i=1;i<=n;i++)cin>>fa[i];
    for(int i=1;i<=n;i++)
    {
        son[fa[i]][cd[fa[i]]]=i;
        cd[fa[i]]++;
    }
    for(int i=1;i<=n;i++)
    {
        cin>>p[i];
        if(cd[i])p[i]=p[i]*1ll*qpow(10000,mod-2)%mod;
        else b[++m]=p[i];
    }
    sort(b+1,b+m+1);
    m=unique(b+1,b+m+1)-b-1;
    dfs1(1);
    tr1.query(tr1.root[1],1,m);
    ll res=0;
    for(int i=1;i<=m;i++)res=(res+1ll*i*b[i]%mod*ans[i]%mod*ans[i]%mod)%mod;
    cout<<res;
    return 0;
}
相关推荐
_不会dp不改名_1 小时前
leetcode_80删除有序数组中的重复项 II
数据结构·算法·leetcode
muxue1782 小时前
数据结构:栈
java·开发语言·数据结构
牛奶咖啡.8549 小时前
经典排序算法复习----C语言
c语言·开发语言·数据结构·算法·排序算法
带多刺的玫瑰10 小时前
Leecode刷题C语言之全排列②
java·数据结构·算法
tt55555555555510 小时前
每日一题--数组中只出现一次的两个数字
c语言·数据结构·算法·leetcode
_周游10 小时前
【数据结构】_队列经典算法OJ:循环队列
数据结构·算法
Luo_LA13 小时前
【LeetCode Hot100 堆】第 K 大的元素、前 K 个高频元素
数据结构·算法·leetcode
奇变偶不变072715 小时前
【C/C++】每日温度 [ 栈的应用 ] 蓝桥杯/ACM备赛
c语言·开发语言·数据结构·c++·算法·蓝桥杯
Excuse_lighttime16 小时前
树与二叉树的概念
java·开发语言·数据结构