贡献延迟计算DP

题目链接

简单转化后的题意

给一个长度为 n n n 的数组 a [ 1... n ] a[1 ... n] a[1...n],其中如果 a [ i ] ≠ 0 a[i] \neq 0 a[i]=0 说明第 i i i 位是个确定切不可修改的值,否则 a [ i ] = 0 a[i] = 0 a[i]=0 需要将该位填数使得最后生成的 a [ ] a[] a[] 数组是一个 n n n 阶排列。

同时,还有一个 s [ 1... n ] s[1...n] s[1...n] 数组, s [ i ] = 1 s[i] = 1 s[i]=1 表示对应位的 a [ i ] a[i] a[i] 是"有效的",否则 s [ i ] = 0 s[i] = 0 s[i]=0 表示对应位的 a [ i ] a[i] a[i] 是"无效的"。

我们定义合法的序列是:取出所有"有效的" a [ i ] a[i] a[i] 组成的序列是一个严格上升子序列。

对于每个生成的排列 a [ ] a[] a[] ,记 f ( a ) f(a) f(a) 表示不同 s [ ] s[] s[] 生成的合法序列的方案数。

请求解: ∑ f ( a ) \sum f(a) ∑f(a)。


解题过程

首先,既然是严格上升子序列,我们发现只与之前出现的最大的"有效的" a [ i ] a[i] a[i] 有关。于是,试图进行 DP 求解。记 dp[i][v] 表示考虑完前 i i i 位目前的有效最大值为 v v v 的合法序列的方案数。

考虑从 i i i 推给 i + 1 i+1 i+1,则我们发现, i + 1 i+1 i+1 的选择有:有效、更大无效、更大无效、更小

这三种类型。我们似乎就可以写成 d p [ i ] [ v ] → d p [ i + 1 ] [ v ′ ] dp[i][v] \rightarrow dp[i + 1][v'] dp[i][v]→dp[i+1][v′]、 d p [ i ] [ v ] → d p [ i + 1 ] [ v ] dp[i][v] \rightarrow dp[i + 1][v] dp[i][v]→dp[i+1][v]、 d p [ i ] [ v ] → d p [ i + 1 ] [ v ] dp[i][v] \rightarrow dp[i + 1][v] dp[i][v]→dp[i+1][v],按顺序对应上方的三种类型。

稍加分析,我们发现 无效、更大 这里,选了更大的,但是却没有记录,那么在选择 有效、更大 的时候岂不是会发生我们选了一个 v ′ v' v′ 但是,这个 v ′ v' v′ 是已经被使用过了的情况吗?

朴素的,要存储是否使用,我们需要使用到 2 n 2^n 2n 这样的状态量进行存储。

考虑一个不朴素的方法------"贡献延迟计算",我们更改 DP 的含义,dp[i][v][k] 表示前 i i i 位目前的有效最大值为 v v v 并且使用">v"的未出现过的数的个数为 k k k 个的合法序列的方案数。

使用">v"的未出现过的数的个数 指的是,在 a [ 1... n ] a[1...n] a[1...n] 中不存在的数。

再来看一下这个状态是否可以转移起来。(钦定 dp[i][v][k] 是已知)

(一) a [ i + 1 ] ≠ 0 a[i + 1] \neq 0 a[i+1]=0 情况下

(1.1) a [ i + 1 ] > v a[i+1] > v a[i+1]>v 情况

d p [ i ] [ v ] [ k ] → d p [ i + 1 ] [ v ] [ k ] dp[i][v][k] \rightarrow dp[i + 1][v][k] dp[i][v][k]→dp[i+1][v][k],指的是 s [ i + 1 ] = 0 s[i+1] = 0 s[i+1]=0;

d p [ i ] [ v ] [ k ] → d p [ i + 1 ] [ a [ i + 1 ] ] [ k − w ] dp[i][v][k] \rightarrow dp[i + 1][ a[i+1] ][k - w] dp[i][v][k]→dp[i+1][a[i+1]][k−w] ,指的是 s [ i + 1 ] = 1 s[i+1] = 1 s[i+1]=1,那么发现 k 也要随着变化,因为 kv 是含义绑定的。我们需要枚举用了 w w w 个大小在 [v+1, a[i+1] - 1] 之间的数,同时乘以其对应组合数进行转移。对应组合数为 前 k 个位置用了哪 w 个 以及 有 bet 个数是大小满足范围的但是只取了其中 w 个

(1.2) a [ i + 1 ] < v a[i+1] < v a[i+1]<v 情况

d p [ i ] [ v ] [ k ] → d p [ i + 1 ] [ v ] [ k ] dp[i][v][k] \rightarrow dp[i + 1][v][k] dp[i][v][k]→dp[i+1][v][k],必然得让其无效了。

(二) a [ i + 1 ] = 0 a[i + 1] = 0 a[i+1]=0 情况下

(2.1)选择一个 < v < v <v 的数 v ′ v' v′

d p [ i ] [ v ] [ k ] → d p [ i + 1 ] [ v ] [ k ] dp[i][v][k] \rightarrow dp[i + 1][v][k] dp[i][v][k]→dp[i+1][v][k],此时必然只能选择让他无效掉了。

(2.2)选择一个 > v > v >v 的数 v ′ v' v′

d p [ i ] [ v ] [ k ] → d p [ i + 1 ] [ v ] [ k + 1 ] dp[i][v][k] \rightarrow dp[i + 1][v][k+1] dp[i][v][k]→dp[i+1][v][k+1],选择了,但是 s [ i + 1 ] = 0 s[i+1] = 0 s[i+1]=0,无效掉。

d p [ i ] [ v ] [ k ] → d p [ i + 1 ] [ v ′ ] [ k − w ] dp[i][v][k] \rightarrow dp[i + 1][v'][k - w] dp[i][v][k]→dp[i+1][v′][k−w], s [ i + 1 ] = 1 s[i+1] = 1 s[i+1]=1,我们需要枚举用了 w w w 个大小在 [v+1, v' - 1] 之间的数,同时乘以其对应组合数进行转移。


代码

cpp 复制代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 505, mod = 998244353;
void Mod(int &x) {
    if(x >= mod) x -= mod;
}
int Sum(int x, int y) { x += y; Mod(x); return x; }
int n, a[maxn], pre0[maxn], suf[maxn], c[maxn][maxn], A[maxn][maxn], jc[maxn];
void init() {
    jc[0] = 1;
    for(int i = 1; i < maxn; i ++) jc[i] = 1LL * jc[i - 1] * i % mod;
    c[0][0] = 1;
    for(int i = 1; i < maxn; i ++) {
        c[i][0] = 1;
        for(int j = 1; j <= i; j ++) c[i][j] = Sum(c[i - 1][j - 1], c[i - 1][j]);
    }
    for(int i = 0; i < maxn; i ++) {
        for(int j = 0; j <= i; j ++) {
            A[i][j] = 1LL * c[i][j] * jc[j] % mod;
        }
    }
}
bool vis[maxn];
struct node {
    int p, q;
    node(int _p = 0, int _q = 0):p(_p), q(_q) {}
    friend bool operator < (node e1, node e2) {
        return e1.p < e2.p;
    }
} t[maxn];
int dp[maxn][maxn][maxn];
void Solve() {
    dp[0][0][0] = 1;
    int bet, fnt, lit, big;
    for(int i = 0; i < n; i ++) {
        for(int v = 0; v <= n; v ++) {
            for(int k = 0; k <= suf[v + 1]; k ++) {
                if(!dp[i][v][k]) continue;
                if(a[i + 1]) {
                    if(a[i + 1] > v) {
                        bet = suf[v + 1] - suf[a[i + 1]];
                        for(int w = 0; w <= min(k, bet); w ++) {
                            dp[i + 1][a[i + 1]][k - w] += 1LL * dp[i][v][k] * A[bet][w] % mod * c[k][w] % mod;
                            Mod(dp[i + 1][a[i + 1]][k - w]);
                        }
                        dp[i + 1][v][k] = Sum(dp[i + 1][v][k], dp[i][v][k]);
                    }
                    else {
                        dp[i + 1][v][k] = Sum(dp[i + 1][v][k], dp[i][v][k]);
                    }
                    continue;
                }
                fnt = pre0[i];
                lit = suf[1] - suf[v];
                big = k;
                if(v && !vis[v]) big ++;
                lit = lit - (fnt - big);
                if(lit < 0) continue;
                if(lit > 0) {
                    dp[i + 1][v][k] += 1LL * dp[i][v][k] * lit % mod;
                    Mod(dp[i + 1][v][k]);
                }
                if(suf[v + 1] > k) {
                    dp[i + 1][v][k + 1] += dp[i][v][k];
                    Mod(dp[i + 1][v][k + 1]);
                }
                for(int nv = v + 1; nv <= n; nv ++) {
                    if(vis[nv]) continue;
                    bet = suf[v + 1] - suf[nv];
                    for(int w = 0; w <= min(k, bet); w ++) {
                        dp[i + 1][nv][k - w] += 1LL * dp[i][v][k] * A[bet][w] % mod * c[k][w] % mod;
                        Mod(dp[i + 1][nv][k - w]);
                    }
                }
            }
        }
    }
    ll ans = 0;
    for(int v = 0; v <= n; v ++) {
        for(int k = 0; k <= suf[v + 1]; k ++) {
            if(!dp[n][v][k]) continue;
            ans = (ans + 1LL * dp[n][v][k] * jc[k]) % mod;
        }
    }
    printf("%lld\n", ans);
}
int main() {
//    freopen("permutation.in", "r", stdin);
//    freopen("permutation.out", "w", stdout);
    init();
    scanf("%d", &n);
    for(int i = 1; i <= n; i ++) scanf("%d", &t[i].p);
    for(int i = 1; i <= n; i ++) scanf("%d", &t[i].q);
    sort(t + 1,t + n + 1);
    for(int i = 1; i <= n; i ++) a[i] = t[i].q;
    for(int i = 1; i <= n; i ++) vis[a[i]] = true;
    vis[0] = true;
    for(int i = n; i >= 0; i --) {
        suf[i] = suf[i + 1];
        if(!vis[i]) suf[i] ++;
    }
    for(int i = 1; i <= n; i ++) {
        pre0[i] = pre0[i - 1];
        if(!a[i]) pre0[i] ++;
    }
    Solve();
    return 0;
}
/*
3
1 2 3
0 3 0
ans:11
2
1 2
2 0
ans:3
2
1 2
0 0
ans:7
*/
相关推荐
季明洵14 小时前
Java中哈希
java·算法·哈希
jaysee-sjc14 小时前
【练习十】Java 面向对象实战:智能家居控制系统
java·开发语言·算法·智能家居
cici1587414 小时前
基于MATLAB实现eFAST全局敏感性分析
算法·matlab
gihigo199814 小时前
MATLAB实现K-SVD算法
数据结构·算法·matlab
vegetablesssss14 小时前
=和{}赋值区别
c++
dyyx11114 小时前
C++编译期数据结构
开发语言·c++·算法
曼巴UE514 小时前
UE C++ 组件 非构造函数创建的技巧
开发语言·c++
Swift社区14 小时前
LeetCode 384 打乱数组
算法·leetcode·职场和发展
SJLoveIT14 小时前
架构师视角:深度解构 Redis 底层数据结构的设计哲学
数据结构·数据库·redis
running up that hill14 小时前
日常刷题记录
java·数据结构·算法