题目链接
简单转化后的题意
给一个长度为 n n n 的数组 a 1... n a1 ... n a1...n,其中如果 a i ≠ 0 ai \neq 0 ai=0 说明第 i i i 位是个确定切不可修改的值,否则 a i = 0 ai = 0 ai=0 需要将该位填数使得最后生成的 a a\[\] a\[\] 数组是一个 n n n 阶排列。
同时,还有一个 s 1... n s1...n s1...n 数组, s i = 1 si = 1 si=1 表示对应位的 a i ai ai 是"有效的",否则 s i = 0 si = 0 si=0 表示对应位的 a i ai ai 是"无效的"。
我们定义合法的序列是:取出所有"有效的" a i ai ai 组成的序列是一个严格上升子序列。
对于每个生成的排列 a a\[\] a\[\] ,记 f ( a ) f(a) f(a) 表示不同 s s\[\] s\[\] 生成的合法序列的方案数。
请求解: ∑ f ( a ) \sum f(a) ∑f(a)。
解题过程
首先,既然是严格上升子序列,我们发现只与之前出现的最大的"有效的" a i ai ai 有关。于是,试图进行 DP 求解。记 dp[i][v] 表示考虑完前 i i i 位目前的有效最大值为 v v v 的合法序列的方案数。
考虑从 i i i 推给 i + 1 i+1 i+1,则我们发现, i + 1 i+1 i+1 的选择有:有效、更大、无效、更大、无效、更小
这三种类型。我们似乎就可以写成 d p i v → d p i + 1 v ′ dpiv \rightarrow dpi + 1v' dpiv→dpi+1v′、 d p i v → d p i + 1 v dpiv \rightarrow dpi + 1v dpiv→dpi+1v、 d p i v → d p i + 1 v dpiv \rightarrow dpi + 1v dpiv→dpi+1v,按顺序对应上方的三种类型。
稍加分析,我们发现 无效、更大 这里,选了更大的,但是却没有记录,那么在选择 有效、更大 的时候岂不是会发生我们选了一个 v ′ v' v′ 但是,这个 v ′ v' v′ 是已经被使用过了的情况吗?
朴素的,要存储是否使用,我们需要使用到 2 n 2^n 2n 这样的状态量进行存储。
考虑一个不朴素的方法------"贡献延迟计算",我们更改 DP 的含义,dp[i][v][k] 表示前 i i i 位目前的有效最大值为 v v v 并且使用">v"的未出现过的数的个数为 k k k 个的合法序列的方案数。
使用">v"的未出现过的数的个数 指的是,在 a 1... n a1...n a1...n 中不存在的数。
再来看一下这个状态是否可以转移起来。(钦定 dp[i][v][k] 是已知)
(一) a i + 1 ≠ 0 ai + 1 \neq 0 ai+1=0 情况下
(1.1) a i + 1 > v ai+1 > v ai+1>v 情况
d p i v k → d p i + 1 v k dpivk \rightarrow dpi + 1vk dpivk→dpi+1vk,指的是 s i + 1 = 0 si+1 = 0 si+1=0;
d p i v k → d p i + 1 a \[ i + 1 ] k − w dpivk \rightarrow dpi + 1 a\[i+1 ]k - w dpivk→dpi+1a\[i+1]k−w ,指的是 s i + 1 = 1 si+1 = 1 si+1=1,那么发现 k 也要随着变化,因为 k 和 v 是含义绑定的。我们需要枚举用了 w w w 个大小在 [v+1, a[i+1] - 1] 之间的数,同时乘以其对应组合数进行转移。对应组合数为 前 k 个位置用了哪 w 个 以及 有 bet 个数是大小满足范围的但是只取了其中 w 个。
(1.2) a i + 1 < v ai+1 < v ai+1<v 情况
d p i v k → d p i + 1 v k dpivk \rightarrow dpi + 1vk dpivk→dpi+1vk,必然得让其无效了。
(二) a i + 1 = 0 ai + 1 = 0 ai+1=0 情况下
(2.1)选择一个 < v < v <v 的数 v ′ v' v′
d p i v k → d p i + 1 v k dpivk \rightarrow dpi + 1vk dpivk→dpi+1vk,此时必然只能选择让他无效掉了。
(2.2)选择一个 > v > v >v 的数 v ′ v' v′
d p i v k → d p i + 1 v k + 1 dpivk \rightarrow dpi + 1vk+1 dpivk→dpi+1vk+1,选择了,但是 s i + 1 = 0 si+1 = 0 si+1=0,无效掉。
d p i v k → d p i + 1 v ′ k − w dpivk \rightarrow dpi + 1v'k - w dpivk→dpi+1v′k−w, s i + 1 = 1 si+1 = 1 si+1=1,我们需要枚举用了 w w w 个大小在 [v+1, v' - 1] 之间的数,同时乘以其对应组合数进行转移。
代码
cpp
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 505, mod = 998244353;
void Mod(int &x) {
if(x >= mod) x -= mod;
}
int Sum(int x, int y) { x += y; Mod(x); return x; }
int n, a[maxn], pre0[maxn], suf[maxn], c[maxn][maxn], A[maxn][maxn], jc[maxn];
void init() {
jc[0] = 1;
for(int i = 1; i < maxn; i ++) jc[i] = 1LL * jc[i - 1] * i % mod;
c[0][0] = 1;
for(int i = 1; i < maxn; i ++) {
c[i][0] = 1;
for(int j = 1; j <= i; j ++) c[i][j] = Sum(c[i - 1][j - 1], c[i - 1][j]);
}
for(int i = 0; i < maxn; i ++) {
for(int j = 0; j <= i; j ++) {
A[i][j] = 1LL * c[i][j] * jc[j] % mod;
}
}
}
bool vis[maxn];
struct node {
int p, q;
node(int _p = 0, int _q = 0):p(_p), q(_q) {}
friend bool operator < (node e1, node e2) {
return e1.p < e2.p;
}
} t[maxn];
int dp[maxn][maxn][maxn];
void Solve() {
dp[0][0][0] = 1;
int bet, fnt, lit, big;
for(int i = 0; i < n; i ++) {
for(int v = 0; v <= n; v ++) {
for(int k = 0; k <= suf[v + 1]; k ++) {
if(!dp[i][v][k]) continue;
if(a[i + 1]) {
if(a[i + 1] > v) {
bet = suf[v + 1] - suf[a[i + 1]];
for(int w = 0; w <= min(k, bet); w ++) {
dp[i + 1][a[i + 1]][k - w] += 1LL * dp[i][v][k] * A[bet][w] % mod * c[k][w] % mod;
Mod(dp[i + 1][a[i + 1]][k - w]);
}
dp[i + 1][v][k] = Sum(dp[i + 1][v][k], dp[i][v][k]);
}
else {
dp[i + 1][v][k] = Sum(dp[i + 1][v][k], dp[i][v][k]);
}
continue;
}
fnt = pre0[i];
lit = suf[1] - suf[v];
big = k;
if(v && !vis[v]) big ++;
lit = lit - (fnt - big);
if(lit < 0) continue;
if(lit > 0) {
dp[i + 1][v][k] += 1LL * dp[i][v][k] * lit % mod;
Mod(dp[i + 1][v][k]);
}
if(suf[v + 1] > k) {
dp[i + 1][v][k + 1] += dp[i][v][k];
Mod(dp[i + 1][v][k + 1]);
}
for(int nv = v + 1; nv <= n; nv ++) {
if(vis[nv]) continue;
bet = suf[v + 1] - suf[nv];
for(int w = 0; w <= min(k, bet); w ++) {
dp[i + 1][nv][k - w] += 1LL * dp[i][v][k] * A[bet][w] % mod * c[k][w] % mod;
Mod(dp[i + 1][nv][k - w]);
}
}
}
}
}
ll ans = 0;
for(int v = 0; v <= n; v ++) {
for(int k = 0; k <= suf[v + 1]; k ++) {
if(!dp[n][v][k]) continue;
ans = (ans + 1LL * dp[n][v][k] * jc[k]) % mod;
}
}
printf("%lld\n", ans);
}
int main() {
// freopen("permutation.in", "r", stdin);
// freopen("permutation.out", "w", stdout);
init();
scanf("%d", &n);
for(int i = 1; i <= n; i ++) scanf("%d", &t[i].p);
for(int i = 1; i <= n; i ++) scanf("%d", &t[i].q);
sort(t + 1,t + n + 1);
for(int i = 1; i <= n; i ++) a[i] = t[i].q;
for(int i = 1; i <= n; i ++) vis[a[i]] = true;
vis[0] = true;
for(int i = n; i >= 0; i --) {
suf[i] = suf[i + 1];
if(!vis[i]) suf[i] ++;
}
for(int i = 1; i <= n; i ++) {
pre0[i] = pre0[i - 1];
if(!a[i]) pre0[i] ++;
}
Solve();
return 0;
}
/*
3
1 2 3
0 3 0
ans:11
2
1 2
2 0
ans:3
2
1 2
0 0
ans:7
*/