维护一个长度为 \(n\) 的数组,支持以下操作:
- 区间赋值:将 \(l,r\) 全部赋值为 \(x\)。
- 区间乘法:将 \(l,r\) 全部乘以 \(c\)。
- 区间加法:将 \(l,r\) 全部加上 \(d\)。
- 区间查询:查询 \(\sum_{i=l}^{r} a_i^k\)(\(k\) 次方和)。
记 \(S_t = \sum a_i^t\) 表示区间内所有数的 \(t\) 次方和,我们需要维护 \(t = 0,1,\cdots,K\) 的所有 \(S_t\)。
区间加
对于每个数 \(a_i\),加上 \(d\) 后:
\(a_i + d)\^t = \\sum_{j=0}\^{t} \\binom{t}{j} a_i\^j d\^{t-j} \\
对区间求和,记 \(S_t = \sum a_i^t\):
\S_t' = \\sum_{j=0}\^{t} \\binom{t}{j} d\^{t-j} S_j \\
区间乘
对于每个数 \(a_i\),乘以 \(c\) 后:
\(c \\cdot a_i)\^t = c\^t \\cdot a_i\^t \\
对区间求和:
\S_t' = c\^t \\cdot S_t \\
区间赋值
区间全部变为 \(x\),长度为 \(len\):
\S_t = len \\cdot x\^t \\
朴素实现
时间复杂度为 \(O(n\log n+qk^2\log n)\)。
代码
cpp
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e5+5,MAXK=100,P=998244353;
int n,m,a[N],C[MAXK+1][MAXK+1];
struct Node
{
int lz1,lz2,lz3,len,s[MAXK+1];
// lz1:赋值 -1表示无, lz2:乘法 1表示无, lz3:加法 0表示无
}tr[N<<2];
void init()
{
for(int i=0;i<=MAXK;i++)
{
C[i][0]=C[i][i]=1;
for(int j=1;j<i;j++)C[i][j]=(C[i-1][j-1]+C[i-1][j])%P;
}
}
void push_assign(int u,int x)
{
tr[u].lz1=x;
tr[u].lz2=1;
tr[u].lz3=0;
int pw=1;
for(int t=0;t<=MAXK;t++)
{
tr[u].s[t]=1LL*tr[u].len*pw%P;
pw=1LL*pw*x%P;
}
}
void push_mul(int u,int c)
{
tr[u].lz2=1LL*tr[u].lz2*c%P;
tr[u].lz3=1LL*tr[u].lz3*c%P;
int pw=1;
for(int t=0;t<=MAXK;t++)
{
tr[u].s[t]=1LL*tr[u].s[t]*pw%P;
pw=1LL*pw*c%P;
}
}
void push_add(int u,int d)
{
int pd[MAXK+1]={1};
for(int i=1;i<=MAXK;i++)pd[i]=1LL*pd[i-1]*d%P;
int ns[MAXK+1]={0};
for(int t=0;t<=MAXK;t++)
for(int j=0;j<=t;j++)
ns[t]=(ns[t]+1LL*C[t][j]*pd[t-j]%P*tr[u].s[j])%P;
for(int t=0;t<=MAXK;t++)tr[u].s[t]=ns[t];
tr[u].lz3=(tr[u].lz3+d)%P;
}
void pd(int u)
{
if(tr[u].lz1!=-1)
{
push_assign(u<<1,tr[u].lz1);
push_assign(u<<1|1,tr[u].lz1);
tr[u].lz1=-1;
}
if(tr[u].lz2!=1)
{
push_mul(u<<1,tr[u].lz2);
push_mul(u<<1|1,tr[u].lz2);
tr[u].lz2=1;
}
if(tr[u].lz3)
{
push_add(u<<1,tr[u].lz3);
push_add(u<<1|1,tr[u].lz3);
tr[u].lz3=0;
}
}
void up(int u)
{
for(int t=0;t<=MAXK;t++)
tr[u].s[t]=(tr[u<<1].s[t]+tr[u<<1|1].s[t])%P;
}
void build(int u,int l,int r)
{
tr[u].len=r-l+1;
tr[u].lz1=-1;
tr[u].lz2=1;
tr[u].lz3=0;
if(l==r)
{
int pw=1;
for(int t=0;t<=MAXK;t++)
{
tr[u].s[t]=pw;
pw=1LL*pw*a[l]%P;
}
return;
}
int m=(l+r)>>1;
build(u<<1,l,m);
build(u<<1|1,m+1,r);
up(u);
}
void upd_assign(int u,int l,int r,int ql,int qr,int x)
{
if(ql<=l&&r<=qr) { push_assign(u,x);return; }
pd(u);
int m=(l+r)>>1;
if(ql<=m)upd_assign(u<<1,l,m,ql,qr,x);
if(qr>m)upd_assign(u<<1|1,m+1,r,ql,qr,x);
up(u);
}
void upd_mul(int u,int l,int r,int ql,int qr,int c)
{
if(ql<=l&&r<=qr) { push_mul(u,c);return; }
pd(u);
int m=(l+r)>>1;
if(ql<=m)upd_mul(u<<1,l,m,ql,qr,c);
if(qr>m)upd_mul(u<<1|1,m+1,r,ql,qr,c);
up(u);
}
void upd_add(int u,int l,int r,int ql,int qr,int d)
{
if(ql<=l&&r<=qr) { push_add(u,d);return; }
pd(u);
int m=(l+r)>>1;
if(ql<=m)upd_add(u<<1,l,m,ql,qr,d);
if(qr>m)upd_add(u<<1|1,m+1,r,ql,qr,d);
up(u);
}
int qry(int u,int l,int r,int ql,int qr,int k)
{
if(ql<=l&&r<=qr)return tr[u].s[k];
pd(u);
int m=(l+r)>>1,res=0;
if(ql<=m)res=(res+qry(u<<1,l,m,ql,qr,k))%P;
if(qr>m)res=(res+qry(u<<1|1,m+1,r,ql,qr,k))%P;
return res;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);
init();
cin>>n>>m;
for(int i=1;i<=n;i++)cin>>a[i];
build(1,1,n);
while(m--)
{
int op,l,r,x;
cin>>op>>l>>r>>x;
if(op==1)upd_assign(1,1,n,l,r,x);
else if(op==2)upd_mul(1,1,n,l,r,x);
else if(op==3)upd_add(1,1,n,l,r,x);
else cout<<qry(1,1,n,l,r,x)<<'\n';
}
return 0;
}
NTT 优化
观察区间加的公式:
\S_t' = \\sum_{j=0}\^{t} \\binom{t}{j} d\^{t-j} S_j \\
将其改写为指数生成函数形式:
\\\frac{S_t'}{t!} = \\sum_{j=0}\^{t} \\frac{S_j}{j!} \\cdot \\frac{d\^{t-j}}{(t-j)!} \\
令 \(A_j=\frac{S_j}{j!}\),\(B_i=\frac{d^i}{i!}\),则上式即为卷积形式:
\S_t' = t! \\cdot (A \* B)_t \\
因此区间加法等价于计算序列 \(A\) 和 \(B\) 的卷积。使用 NTT 可在 \(O(K \log K)\) 时间内完成卷积计算。
模数 \(P = 998244353\),原根为 \(G = 3\)。
时间复杂度
时间复杂度为 \(O(n\log n+qk\log k \log n)\)。
NTT 优化代码
cpp
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e5+5,MAXK=100,P=998244353,G=3;
int n,m,a[N],C[MAXK+1][MAXK+1];
int rev[MAXK*4];
int fact[MAXK+1],invf[MAXK+1];
int qpow(int a,int b)
{
int res=1;
while(b)
{
if(b&1)res=1LL*res*a%P;
a=1LL*a*a%P;
b>>=1;
}
return res;
}
void init_ntt(int n)
{
for(int i=0;i<n;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)?(n>>1):0);
}
void ntt(int *a,int n,int op)
{
for(int i=0;i<n;i++)
if(i<rev[i])swap(a[i],a[rev[i]]);
for(int len=2;len<=n;len<<=1)
{
int wn=qpow(G,(P-1)/len);
if(op==-1)wn=qpow(wn,P-2);
for(int i=0;i<n;i+=len)
{
int w=1;
for(int j=i;j<i+len/2;j++)
{
int u=a[j],v=1LL*w*a[j+len/2]%P;
a[j]=(u+v)%P;
a[j+len/2]=(u-v+P)%P;
w=1LL*w*wn%P;
}
}
}
if(op==-1)
{
int inv=qpow(n,P-2);
for(int i=0;i<n;i++)a[i]=1LL*a[i]*inv%P;
}
}
struct Node
{
int lz1,lz2,lz3,len,s[MAXK+1];
}tr[N<<2];
void init()
{
for(int i=0;i<=MAXK;i++)
{
C[i][0]=C[i][i]=1;
for(int j=1;j<i;j++)C[i][j]=(C[i-1][j-1]+C[i-1][j])%P;
}
fact[0]=1;
for(int i=1;i<=MAXK;i++)fact[i]=1LL*fact[i-1]*i%P;
invf[MAXK]=qpow(fact[MAXK],P-2);
for(int i=MAXK-1;i>=0;i--)invf[i]=1LL*invf[i+1]*(i+1)%P;
}
void push_assign(int u,int x)
{
tr[u].lz1=x;
tr[u].lz2=1;
tr[u].lz3=0;
int pw=1;
for(int t=0;t<=MAXK;t++)
{
tr[u].s[t]=1LL*tr[u].len*pw%P;
pw=1LL*pw*x%P;
}
}
void push_mul(int u,int c)
{
tr[u].lz2=1LL*tr[u].lz2*c%P;
tr[u].lz3=1LL*tr[u].lz3*c%P;
int pw=1;
for(int t=0;t<=MAXK;t++)
{
tr[u].s[t]=1LL*tr[u].s[t]*pw%P;
pw=1LL*pw*c%P;
}
}
void push_add(int u,int d)
{
if(d==0)return;
static int A[MAXK*4],B[MAXK*4];
int k=MAXK;
int n=1;
while(n<=2*k)n<<=1;
memset(A,0,n*sizeof(int));
memset(B,0,n*sizeof(int));
for(int j=0;j<=k;j++)
A[j]=1LL*tr[u].s[j]*invf[j]%P;
int pw=1;
for(int i=0;i<=k;i++)
{
B[i]=1LL*pw*invf[i]%P;
pw=1LL*pw*d%P;
}
init_ntt(n);
ntt(A,n,1);
ntt(B,n,1);
for(int i=0;i<n;i++)A[i]=1LL*A[i]*B[i]%P;
ntt(A,n,-1);
for(int t=0;t<=k;t++)
tr[u].s[t]=1LL*fact[t]*A[t]%P;
tr[u].lz3=(tr[u].lz3+d)%P;
}
void pd(int u)
{
if(tr[u].lz1!=-1)
{
push_assign(u<<1,tr[u].lz1);
push_assign(u<<1|1,tr[u].lz1);
tr[u].lz1=-1;
}
if(tr[u].lz2!=1)
{
push_mul(u<<1,tr[u].lz2);
push_mul(u<<1|1,tr[u].lz2);
tr[u].lz2=1;
}
if(tr[u].lz3)
{
push_add(u<<1,tr[u].lz3);
push_add(u<<1|1,tr[u].lz3);
tr[u].lz3=0;
}
}
void up(int u)
{
for(int t=0;t<=MAXK;t++)
tr[u].s[t]=(tr[u<<1].s[t]+tr[u<<1|1].s[t])%P;
}
void build(int u,int l,int r)
{
tr[u].len=r-l+1;
tr[u].lz1=-1;
tr[u].lz2=1;
tr[u].lz3=0;
if(l==r)
{
int pw=1;
for(int t=0;t<=MAXK;t++)
{
tr[u].s[t]=pw;
pw=1LL*pw*a[l]%P;
}
return;
}
int m=(l+r)>>1;
build(u<<1,l,m);
build(u<<1|1,m+1,r);
up(u);
}
void upd_assign(int u,int l,int r,int ql,int qr,int x)
{
if(ql<=l&&r<=qr) { push_assign(u,x);return; }
pd(u);
int m=(l+r)>>1;
if(ql<=m)upd_assign(u<<1,l,m,ql,qr,x);
if(qr>m)upd_assign(u<<1|1,m+1,r,ql,qr,x);
up(u);
}
void upd_mul(int u,int l,int r,int ql,int qr,int c)
{
if(ql<=l&&r<=qr) { push_mul(u,c);return; }
pd(u);
int m=(l+r)>>1;
if(ql<=m)upd_mul(u<<1,l,m,ql,qr,c);
if(qr>m)upd_mul(u<<1|1,m+1,r,ql,qr,c);
up(u);
}
void upd_add(int u,int l,int r,int ql,int qr,int d)
{
if(ql<=l&&r<=qr) { push_add(u,d);return; }
pd(u);
int m=(l+r)>>1;
if(ql<=m)upd_add(u<<1,l,m,ql,qr,d);
if(qr>m)upd_add(u<<1|1,m+1,r,ql,qr,d);
up(u);
}
int qry(int u,int l,int r,int ql,int qr,int k)
{
if(ql<=l&&r<=qr)return tr[u].s[k];
pd(u);
int m=(l+r)>>1,res=0;
if(ql<=m)res=(res+qry(u<<1,l,m,ql,qr,k))%P;
if(qr>m)res=(res+qry(u<<1|1,m+1,r,ql,qr,k))%P;
return res;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);
init();
cin>>n>>m;
for(int i=1;i<=n;i++)cin>>a[i];
build(1,1,n);
while(m--)
{
int op,l,r,x;
cin>>op>>l>>r>>x;
if(op==1)upd_assign(1,1,n,l,r,x);
else if(op==2)upd_mul(1,1,n,l,r,x);
else if(op==3)upd_add(1,1,n,l,r,x);
else cout<<qry(1,1,n,l,r,x)<<'\n';
}
return 0;
}