F2 - Long Colorful Strip
很牛的题!
首先,我们可以将颜色相同的一段区间缩成一个点,那么每次加入一个新的颜色时,最多只能将其所覆盖的那个颜色所属的区间分成三部分(原本:00000000
,加入1
后\(\rightarrow\)0001111000
),也就是增加了两个点,那么也就意味着最终所成的点的个数最多只有\(2*n-1\)个,这样的话,\(m\)看着很大,其实也就是\(O(N)\)级别的
对于这种区间覆盖的题,优先考虑区间\(dp\)
设\(dp[l][r]\)表示让区间\([l,r]\)合法的方案数,且满足\([l,r]\)中的颜色不会在\([l,r]\)外出现
设\(p\)为区间\([l,r]\)中最小的颜色,其出现的最早的位置为\(mn_p\),最晚的位置为\(mx_p\),那么显然有我们可以一开始就让\(p\)覆盖区间\([a,b]\)(\([mn_p,mx_p]\in[a,b]\)),那么就有转移式
\[dp[l][r]=\sum_{a\in[l,mn_p]}\sum_{b\in[mx_p,r]}dp[l][a-1]\times dp[a][b]\times dp[b+1][r] \]
其中,对于\(dp[a][b]\),其实就是相当于对于\([a,b]\)的所有方案,直接在每种方案的第一个操作之前给\([a,b]\)覆盖上颜色\(p\)即可(总之就是让第一步变成用\(p\)覆盖\([a,b]\),具体的,若原本的方案没有这一步,加上即可;否则,若其原本只用\(p\)覆盖\([s,t]\)(\([s,t]\in[a,b]\),因为\(p\)也是\([a,b]\)中最小的值,所以原本的用\(p\)覆盖\([s,t]\)的这一步也一定是该方案中的第一步)将其补全成覆盖\([a,b]\)即可
但是可以发现,当\(a=l,b=r\)时,就转不了了,所以考虑将\(dp[a][b]\)拆开
因为\(p\)的最早出现位置为\(mn_p\),最晚出现位置为\(mx_p\),所以显然有所有处于\([mn_p,mx_p]\)中的颜色的出现区间范围也不会超过\([mn_p,mx_p]\),这也意味这\([mn_p,mx_p]\)将\([a,b]\)分隔成了三个区间\([a,mn_p-1],[mn_p,mx_p],[mx_p+1,b]\),这三个区间是"独立的",即它们的颜色集合没有交集
证明的话可以从\(mn_p\)和\(mx_p\)这两个点下手,因为它们的颜色\(p\)是区间\([a,b]\)中最小的颜色,这就意味着区间\([a,b]\)中剩下的所有颜色都不会跨越它们,也就是没有颜色会覆盖上述三个区间中的任意两个,所以所有颜色只会待在对应的一个区间内
\[dp[a][b]=dp[a][mn_p-1]\times dp[mx_p+1][b]\times \sum_{[i,j]}dp[i][j] \]
其中\([i,j]\)表示所有不包含颜色\(p\)的、属于区间\([mn_p,mx_p]\)的极长区间,它实际上如何转移的和上文中\(对dp[a][b]\)的解释一致
所以我们就得到了最终的式子
\[dp[l][r]=\sum_{a\in[l,mn_p]}\sum_{b\in[mx_p,r]}dp[l][a-1]\times dp[a][mn_p-1]\times(\sum_{[i,j]}dp[i][j])\times dp[mx_p+1][b]\times dp[b+1][r] \]
复杂度\(O(N^3)\)
代码中,只要保证了当前区间\([l,r]\)的最小颜色\(p\)的范围属于当前区间,在转移时因为其他转移给\(dp[l][r]\)的\(dp[x][y]\)会保证属于\([x,y]\)的所有颜色只属于\([x,y]\)中,所以最终就一定能保证\([l,r]\)中的所有颜色只属于\([l,r]\)中
也就是归纳证明即可
c++
#include<bits/stdc++.h>
using namespace std;
const int N=505,M=1e6+5,MOD=998244353;
int n,m,col[M],dp[N<<1][N<<1],mn[N],mx[N];
void add(int &x,int y){ x+=y; if(x>=MOD) x-=MOD; }
int ad(int x,int y){ x+=y; if(x>=MOD) x-=MOD; return x; }
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=m;++i){
scanf("%d",&col[i]);
if(col[i]==col[i-1]) --i,--m;
else (!mn[col[i]])&&(mn[col[i]]=i),mx[col[i]]=i;
}
if(m>=(n<<1)) return puts("0"),0;
for(int i=1;i<=m+1;++i) for(int j=0;j<i;++j) dp[i][j]=1;
for(int len=1;len<=m;++len)
for(int l=1,r,p,a,b;l+len-1<=m;++l){
r=l+len-1,p=N,a=0,b=0;
for(int k=l;k<=r;++k) p=min(p,col[k]);
if(mn[p]<l||mx[p]>r) continue;
for(int k=l-1;k<=mn[p]-1;++k) add(a,1ll*dp[l][k]*dp[k+1][mn[p]-1]%MOD);
for(int k=mx[p];k<=r;++k) add(b,1ll*dp[mx[p]+1][k]*dp[k+1][r]%MOD);
dp[l][r]=1ll*a*b%MOD;
for(int i=mn[p]+1,j=mn[p];i<mx[p];i=(j+=2)){
for(;col[j+1]!=p;++j) ;
dp[l][r]=1ll*dp[l][r]*dp[i][j]%MOD;
}
}
printf("%d",dp[1][m]);
return 0;
}