第八个黑题!
题意回顾
给定 k 阶齐次线性递推关系:
a_n=\\sum_{i=1}\^{k}f_i\\cdot a_{n-i}
于初始项 a_0, a_1, \\dots, a_{k-1},求 a_n \\bmod 998244353 的值。其中 n 最大可达 10\^9,k 最大为 32000。
算法概述
我第一眼以为就是一个简单的递推,但是直接算要 O(nk) 的复杂度,TLE 肯定会爬床抓你代码。还有一种做法是使用矩阵快速幂,将递推转化为矩阵乘幂形式,时间复杂度 O(k\^3 \\log n),But k 高达 32000,仍然鬼爬床。
需要用更高效的方法------多项式优化,利用递推多项式,通过多项式快速幂和取模在 O(k \\log k \\log n) 时间内求出答案。这种方法被称为 Kitamasa 算法,核心是计算 x\^n 关于特征多项式 P(x) 的余式,然后与初始项点乘。
理论讲解
递推的特征多项式
对于递推 a_n = \\sum_{i=1}\^k f_i a_{n-i},其特征多项式为:
P(x)=x\^k-f_1x\^{k-1}-f_2x\^{k-2}-\\cdots -f_k
转移矩阵与多项式的关系
定义状态向量 \\mathbf{v}n = (a{n-k+1}, a_{n-k+2}, \\dots, a_n)\^T,则 \\mathbf{v}n = M \\cdot \\mathbf{v}{n-1},其中 M 是 k \\times k 的伴随矩阵:
\\begin{pmatrix} 0\& 1\& 0\& \\cdots\& 0\\\\ 0\& 0\& 1\& \\cdots\& 0\\\\ \\vdots \& \\vdots \& \\vdots\& \\ddots \& \\vdots\\\\ 0\& 0\& 0\& \\cdots\& 1\\\\ f_k\& f_{k-1}\& f_{k-2}\& \\cdots\&f_{k-3} \\end{pmatrix}
矩阵 M 的特征多项式恰好是 P(x)。根据 Cayley-Hamilton 定理,M 满足 P(M)=0,即 M\^k = f_1 M\^{k-1} + f_2 M\^{k-2} + \\cdots + f_k I。何意味?这意味着任何 M 的幂都可以表示为 I, M, M\^2, \\dots, M\^{k-1} 的线性组合,系数由 P(x) 决定。
将问题转化为多项式幂
由于 a_n 是 \\mathbf{v}_n 的最后一个分量,而 \\mathbf{v}_n = M\^n \\mathbf{v}_0,因此只需要求出 M\^n 的最后一行的第一个元素。但直接计算 M\^n 仍然是矩阵乘法,TLE 还是永生。
M\^n 可以表示为关于 M 的多项式 R(M),其中 R(x) = x\^n \\bmod P(x)。设最后一行为 E:
a_n=E\\cdot\\mathbf{v}_n=\\sum_{k-1}\^{i=0}c_ia_i
其中 c_i 是 R(x) = \\sum_{i=0}\^{k-1} c_i x\^i 的系数。所以问题就是这么个东东:**计算多项式 x\^n 模 P(x) 的余式**,然后与初始项点乘就 AKIOI。
复杂度分析
多项式乘法和求逆:O(k \\log k)
-
快速幂:需要 O(\\log n) 次乘法和取模,每次乘法 O(k \\log k),取模 O(k \\log k)。
-
总时间复杂度:O(k \\log k \\log n)。
-
空间复杂度:O(k)。
-
对于 k=32000, \\log n \\approx 30,可以 AKIOI 了。
AC code
::::success[AC code]
```cpp
#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
const int g = 3;
int qpow(int a, int b) {
int res = 1;
while (b) {
if (b & 1) res = 1LL * res * a % MOD;
a = 1LL * a * a % MOD;
b >>= 1;
}
return res;
}
void ntt(vector<int>& a, bool invert) {
int n = a.size();
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
for (; j & bit; bit >>= 1) j ^= bit;
j ^= bit;
if (i < j) swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
int wlen = qpow(g, (MOD - 1) / len);
if (invert) wlen = qpow(wlen, MOD - 2);
for (int i = 0; i < n; i += len) {
int w = 1;
for (int j = 0; j < len / 2; j++) {
int u = a[i + j];
int v = 1LL * a[i + j + len / 2] * w % MOD;
a[i + j] = (u + v) % MOD;
a[i + j + len / 2] = (u - v + MOD) % MOD;
w = 1LL * w * wlen % MOD;
}
}
}
if (invert) {
int inn = qpow(n, MOD - 2);
for (int& x : a) x = 1LL * x * inn % MOD;
}
}
vector<int> multiply(const vector<int>& a, const vector<int>& b) {
int sz = a.size() + b.size() - 1;
int n = 1;
while (n < sz) n <<= 1;
vector<int> fa = a, fb = b;
fa.resize(n);
fb.resize(n);
ntt(fa, false);
ntt(fb, false);
for (int i = 0; i < n; i++) fa[i] = 1LL * fa[i] * fb[i] % MOD;
ntt(fa, true);
fa.resize(sz);
return fa;
}
vector<int> pyi(const vector<int>& a, int m) {
assert(a[0] != 0);
vector<int> inv = { qpow(a[0], MOD - 2) };
int cur = 1;
while (cur < m) {
int nxt = min(cur * 2, m);
vector<int> apx(a.begin(), a.begin() + min((int)a.size(), nxt));
vector<int> prod = multiply(apx, inv);
prod.resize(nxt);
for (int i = 0; i < nxt; i++) {
if (i == 0) prod[i] = (2 - prod[i] + MOD) % MOD;
else prod[i] = (-prod[i] + MOD) % MOD;
}
inv = multiply(inv, prod);
inv.resize(nxt);
cur = nxt;
}
return inv;
}
vector<int> pmd(const vector<int>& c, const vector<int>& p, const vector<int>& ire, int k) {
if (c.size() <= k) return c;
int tyu = c.size() - 1;
int efb = tyu - k + 1;
vector<int> iot = c;
reverse(iot.begin(), iot.end());
vector<int> crn(iot.begin(), iot.begin() + efb);
vector<int> sac(ire.begin(), ire.begin() + efb);
vector<int> jhn = multiply(crn, sac);
if (jhn.size() > efb) jhn.resize(efb);
vector<int> q = jhn;
reverse(q.begin(), q.end());
vector<int> qp = multiply(q, p);
if (qp.size() < c.size()) qp.resize(c.size(), 0);
vector<int> r(k, 0);
for (int i = 0; i < k; i++) {
int val = c[i] - qp[i];
val %= MOD;
if (val < 0) val += MOD;
r[i] = val;
}
return r;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
long long n;
int k;
cin >> n >> k;
vector<int> f(k), a(k);
for (int i = 0; i < k; i++) {
cin >> f[i];
f[i] %= MOD;
if (f[i] < 0) f[i] += MOD;
}
for (int i = 0; i < k; i++) {
cin >> a[i];
a[i] %= MOD;
if (a[i] < 0) a[i] += MOD;
}
vector<int> p(k + 1, 0);
p[k] = 1;
for (int i = 1; i <= k; i++) {
p[k - i] = (MOD - f[i - 1]) % MOD;
}
vector<int> tnz(k, 0);
tnz[0] = 1;
for (int i = 1; i < k; i++) {
tnz[i] = (MOD - f[i - 1]) % MOD;
}
vector<int> ire = pyi(tnz, k);
vector<int> res = {1};
vector<int> base = {0, 1};
long long exp = n;
while (exp) {
if (exp & 1) {
res = pmd(multiply(res, base), p, ire, k);
}
base = pmd(multiply(base, base), p, ire, k);
exp >>= 1;
}
int ans = 0;
for (int i = 0; i < k; i++) {
ans = (ans + 1LL * res[i] * a[i]) % MOD;
}
cout << ans << endl;
return 0;
}
```
::::