题目大意
给你一个字符串 S S S,求将 S S S拆成若干个 A B AB AB和一个 C C C的方案数,其中 A , B , C A,B,C A,B,C均为非空字符串,且 A A A中出现奇数次的字符数量不超过 C C C中出现奇数次的字符数量。
有 T T T组数据, 1 ≤ T ≤ 5 , 1 ≤ ∣ S ∣ ≤ 2 20 1\leq T\leq 5,1\leq |S|\leq 2^{20} 1≤T≤5,1≤∣S∣≤220
题解
设 s 1 i s1_i s1i表示字符串 S S S前 i i i个字符中出现奇数次的字符数量, s 2 i s2_i s2i表示字符串 S S S后 i i i个字符中出现奇数次的字符数量。
枚举第一组 A B AB AB的末尾位置 i i i,因为 B B B是非空字符串,所以 A A A的末尾位置一定在 i i i之前。用树状数组来维护 s 1 s1 s1的值,因为小写字母只有 26 26 26个,所以树状数组的大小只需要开 26 26 26即可。注意树状数组中 0 0 0的位置要特判处理。
然后,枚举 i + 1 ∼ 2 i , 2 i + 1 ∼ 3 i , ... i+1\sim2i,2i+1\sim 3i,\dots i+1∼2i,2i+1∼3i,...是否与 1 ∼ i 1\sim i 1∼i相同。如果 ( k − 1 ) i + 1 ∼ k i (k-1)i+1\sim ki (k−1)i+1∼ki与 1 ∼ i 1\sim i 1∼i相同,则令 k i + 1 ki+1 ki+1为 C C C的起始节点,在树状数组查询小于等于 s 2 k i + 1 s2_{ki+1} s2ki+1的 s 1 s1 s1值的个数,并将其统计到答案中;否则,退出当前循环。
判断两个子串是否相同,用哈希可以 O ( 1 ) O(1) O(1)解决。
对于每个 i i i,最多需要枚举 n i \dfrac ni in次。那么,最多会枚举 ∑ i = 1 n n i ≈ n ln n \sum\limits_{i=1}^n\dfrac ni\approx n\ln n i=1∑nin≈nlnn次,每次都会查询一次树状数组,所以总时间复杂度为 O ( n ln n log 26 ) O(n\ln n\log 26) O(nlnnlog26),实现得好的话是可以过的。
code
cpp
#include<bits/stdc++.h>
using namespace std;
const int N=1<<20;
const long long mod=998244353;
int T,n,now,t0,v[35],s1[N+5],s2[N+5],tr[35];
long long ans,hs[N+5],pw[N+5];
char s[N+5];
int lb(int i){
return i&(-i);
}
void pt(int i){
if(i==0){
++t0;return;
}
while(i<=26){
++tr[i];
i+=lb(i);
}
}
int find(int i){
int re=t0;
while(i){
re+=tr[i];
i-=lb(i);
}
return re;
}
long long up(long long x){
if(x<0) x+=mod;
return x;
}
int main()
{
pw[0]=1;
for(int i=1;i<=N;i++){
pw[i]=pw[i-1]*26%mod;
}
scanf("%d",&T);
while(T--){
scanf("%s",s+1);
n=strlen(s+1);
for(int i=0;i<26;i++) v[i]=0;
now=0;
for(int i=1,p;i<=n;i++){
p=s[i]-'a';
++v[p];
if(v[p]&1) ++now;
else --now;
s1[i]=now;
}
for(int i=0;i<26;i++) v[i]=0;
now=0;
for(int i=n,p;i>=1;i--){
p=s[i]-'a';
++v[p];
if(v[p]&1) ++now;
else --now;
s2[i]=now;
}
for(int i=1;i<=n;i++){
hs[i]=(hs[i-1]*26+s[i]-'a')%mod;
}
ans=t0=0;
for(int i=1;i<=26;i++) tr[i]=0;
pt(s1[1]);
for(int i=2;i<n;i++){
for(int j=i;j<n;j+=i){
if(hs[i]!=up(hs[j]-hs[j-i]*pw[i]%mod)) break;
ans+=find(s2[j+1]);
}
pt(s1[i]);
}
printf("%lld\n",ans);
}
return 0;
}