题目链接
简单转化后的题意
给一个长度为 n n n 的数组 a [ 1... n ] a[1 ... n] a[1...n],其中如果 a [ i ] ≠ 0 a[i] \neq 0 a[i]=0 说明第 i i i 位是个确定切不可修改的值,否则 a [ i ] = 0 a[i] = 0 a[i]=0 需要将该位填数使得最后生成的 a [ ] a[] a[] 数组是一个 n n n 阶排列。
同时,还有一个 s [ 1... n ] s[1...n] s[1...n] 数组, s [ i ] = 1 s[i] = 1 s[i]=1 表示对应位的 a [ i ] a[i] a[i] 是"有效的",否则 s [ i ] = 0 s[i] = 0 s[i]=0 表示对应位的 a [ i ] a[i] a[i] 是"无效的"。
我们定义合法的序列是:取出所有"有效的" a [ i ] a[i] a[i] 组成的序列是一个严格上升子序列。
对于每个生成的排列 a [ ] a[] a[] ,记 f ( a ) f(a) f(a) 表示不同 s [ ] s[] s[] 生成的合法序列的方案数。
请求解: ∑ f ( a ) \sum f(a) ∑f(a)。
解题过程
首先,既然是严格上升子序列,我们发现只与之前出现的最大的"有效的" a [ i ] a[i] a[i] 有关。于是,试图进行 DP 求解。记 dp[i][v] 表示考虑完前 i i i 位目前的有效最大值为 v v v 的合法序列的方案数。
考虑从 i i i 推给 i + 1 i+1 i+1,则我们发现, i + 1 i+1 i+1 的选择有:有效、更大、无效、更大、无效、更小
这三种类型。我们似乎就可以写成 d p [ i ] [ v ] → d p [ i + 1 ] [ v ′ ] dp[i][v] \rightarrow dp[i + 1][v'] dp[i][v]→dp[i+1][v′]、 d p [ i ] [ v ] → d p [ i + 1 ] [ v ] dp[i][v] \rightarrow dp[i + 1][v] dp[i][v]→dp[i+1][v]、 d p [ i ] [ v ] → d p [ i + 1 ] [ v ] dp[i][v] \rightarrow dp[i + 1][v] dp[i][v]→dp[i+1][v],按顺序对应上方的三种类型。
稍加分析,我们发现 无效、更大 这里,选了更大的,但是却没有记录,那么在选择 有效、更大 的时候岂不是会发生我们选了一个 v ′ v' v′ 但是,这个 v ′ v' v′ 是已经被使用过了的情况吗?
朴素的,要存储是否使用,我们需要使用到 2 n 2^n 2n 这样的状态量进行存储。
考虑一个不朴素的方法------"贡献延迟计算",我们更改 DP 的含义,dp[i][v][k] 表示前 i i i 位目前的有效最大值为 v v v 并且使用">v"的未出现过的数的个数为 k k k 个的合法序列的方案数。
使用">v"的未出现过的数的个数 指的是,在 a [ 1... n ] a[1...n] a[1...n] 中不存在的数。
再来看一下这个状态是否可以转移起来。(钦定 dp[i][v][k] 是已知)
(一) a [ i + 1 ] ≠ 0 a[i + 1] \neq 0 a[i+1]=0 情况下
(1.1) a [ i + 1 ] > v a[i+1] > v a[i+1]>v 情况
d p [ i ] [ v ] [ k ] → d p [ i + 1 ] [ v ] [ k ] dp[i][v][k] \rightarrow dp[i + 1][v][k] dp[i][v][k]→dp[i+1][v][k],指的是 s [ i + 1 ] = 0 s[i+1] = 0 s[i+1]=0;
d p [ i ] [ v ] [ k ] → d p [ i + 1 ] [ a [ i + 1 ] ] [ k − w ] dp[i][v][k] \rightarrow dp[i + 1][ a[i+1] ][k - w] dp[i][v][k]→dp[i+1][a[i+1]][k−w] ,指的是 s [ i + 1 ] = 1 s[i+1] = 1 s[i+1]=1,那么发现 k 也要随着变化,因为 k 和 v 是含义绑定的。我们需要枚举用了 w w w 个大小在 [v+1, a[i+1] - 1] 之间的数,同时乘以其对应组合数进行转移。对应组合数为 前 k 个位置用了哪 w 个 以及 有 bet 个数是大小满足范围的但是只取了其中 w 个。
(1.2) a [ i + 1 ] < v a[i+1] < v a[i+1]<v 情况
d p [ i ] [ v ] [ k ] → d p [ i + 1 ] [ v ] [ k ] dp[i][v][k] \rightarrow dp[i + 1][v][k] dp[i][v][k]→dp[i+1][v][k],必然得让其无效了。
(二) a [ i + 1 ] = 0 a[i + 1] = 0 a[i+1]=0 情况下
(2.1)选择一个 < v < v <v 的数 v ′ v' v′
d p [ i ] [ v ] [ k ] → d p [ i + 1 ] [ v ] [ k ] dp[i][v][k] \rightarrow dp[i + 1][v][k] dp[i][v][k]→dp[i+1][v][k],此时必然只能选择让他无效掉了。
(2.2)选择一个 > v > v >v 的数 v ′ v' v′
d p [ i ] [ v ] [ k ] → d p [ i + 1 ] [ v ] [ k + 1 ] dp[i][v][k] \rightarrow dp[i + 1][v][k+1] dp[i][v][k]→dp[i+1][v][k+1],选择了,但是 s [ i + 1 ] = 0 s[i+1] = 0 s[i+1]=0,无效掉。
d p [ i ] [ v ] [ k ] → d p [ i + 1 ] [ v ′ ] [ k − w ] dp[i][v][k] \rightarrow dp[i + 1][v'][k - w] dp[i][v][k]→dp[i+1][v′][k−w], s [ i + 1 ] = 1 s[i+1] = 1 s[i+1]=1,我们需要枚举用了 w w w 个大小在 [v+1, v' - 1] 之间的数,同时乘以其对应组合数进行转移。
代码
cpp
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 505, mod = 998244353;
void Mod(int &x) {
if(x >= mod) x -= mod;
}
int Sum(int x, int y) { x += y; Mod(x); return x; }
int n, a[maxn], pre0[maxn], suf[maxn], c[maxn][maxn], A[maxn][maxn], jc[maxn];
void init() {
jc[0] = 1;
for(int i = 1; i < maxn; i ++) jc[i] = 1LL * jc[i - 1] * i % mod;
c[0][0] = 1;
for(int i = 1; i < maxn; i ++) {
c[i][0] = 1;
for(int j = 1; j <= i; j ++) c[i][j] = Sum(c[i - 1][j - 1], c[i - 1][j]);
}
for(int i = 0; i < maxn; i ++) {
for(int j = 0; j <= i; j ++) {
A[i][j] = 1LL * c[i][j] * jc[j] % mod;
}
}
}
bool vis[maxn];
struct node {
int p, q;
node(int _p = 0, int _q = 0):p(_p), q(_q) {}
friend bool operator < (node e1, node e2) {
return e1.p < e2.p;
}
} t[maxn];
int dp[maxn][maxn][maxn];
void Solve() {
dp[0][0][0] = 1;
int bet, fnt, lit, big;
for(int i = 0; i < n; i ++) {
for(int v = 0; v <= n; v ++) {
for(int k = 0; k <= suf[v + 1]; k ++) {
if(!dp[i][v][k]) continue;
if(a[i + 1]) {
if(a[i + 1] > v) {
bet = suf[v + 1] - suf[a[i + 1]];
for(int w = 0; w <= min(k, bet); w ++) {
dp[i + 1][a[i + 1]][k - w] += 1LL * dp[i][v][k] * A[bet][w] % mod * c[k][w] % mod;
Mod(dp[i + 1][a[i + 1]][k - w]);
}
dp[i + 1][v][k] = Sum(dp[i + 1][v][k], dp[i][v][k]);
}
else {
dp[i + 1][v][k] = Sum(dp[i + 1][v][k], dp[i][v][k]);
}
continue;
}
fnt = pre0[i];
lit = suf[1] - suf[v];
big = k;
if(v && !vis[v]) big ++;
lit = lit - (fnt - big);
if(lit < 0) continue;
if(lit > 0) {
dp[i + 1][v][k] += 1LL * dp[i][v][k] * lit % mod;
Mod(dp[i + 1][v][k]);
}
if(suf[v + 1] > k) {
dp[i + 1][v][k + 1] += dp[i][v][k];
Mod(dp[i + 1][v][k + 1]);
}
for(int nv = v + 1; nv <= n; nv ++) {
if(vis[nv]) continue;
bet = suf[v + 1] - suf[nv];
for(int w = 0; w <= min(k, bet); w ++) {
dp[i + 1][nv][k - w] += 1LL * dp[i][v][k] * A[bet][w] % mod * c[k][w] % mod;
Mod(dp[i + 1][nv][k - w]);
}
}
}
}
}
ll ans = 0;
for(int v = 0; v <= n; v ++) {
for(int k = 0; k <= suf[v + 1]; k ++) {
if(!dp[n][v][k]) continue;
ans = (ans + 1LL * dp[n][v][k] * jc[k]) % mod;
}
}
printf("%lld\n", ans);
}
int main() {
// freopen("permutation.in", "r", stdin);
// freopen("permutation.out", "w", stdout);
init();
scanf("%d", &n);
for(int i = 1; i <= n; i ++) scanf("%d", &t[i].p);
for(int i = 1; i <= n; i ++) scanf("%d", &t[i].q);
sort(t + 1,t + n + 1);
for(int i = 1; i <= n; i ++) a[i] = t[i].q;
for(int i = 1; i <= n; i ++) vis[a[i]] = true;
vis[0] = true;
for(int i = n; i >= 0; i --) {
suf[i] = suf[i + 1];
if(!vis[i]) suf[i] ++;
}
for(int i = 1; i <= n; i ++) {
pre0[i] = pre0[i - 1];
if(!a[i]) pre0[i] ++;
}
Solve();
return 0;
}
/*
3
1 2 3
0 3 0
ans:11
2
1 2
2 0
ans:3
2
1 2
0 0
ans:7
*/