题解:Kitamasa 算法板子

第八个黑题!

题意回顾

给定 k 阶齐次线性递推关系:

a_n=\\sum_{i=1}\^{k}f_i\\cdot a_{n-i}

于初始项 a_0, a_1, \\dots, a_{k-1},求 a_n \\bmod 998244353 的值。其中 n 最大可达 10\^9k 最大为 32000

算法概述

我第一眼以为就是一个简单的递推,但是直接算要 O(nk) 的复杂度,TLE 肯定会爬床抓你代码。还有一种做法是使用矩阵快速幂,将递推转化为矩阵乘幂形式,时间复杂度 O(k\^3 \\log n),But k 高达 32000,仍然鬼爬床。

需要用更高效的方法------多项式优化,利用递推多项式,通过多项式快速幂和取模在 O(k \\log k \\log n) 时间内求出答案。这种方法被称为 Kitamasa 算法,核心是计算 x\^n 关于特征多项式 P(x) 的余式,然后与初始项点乘。

理论讲解

递推的特征多项式

对于递推 a_n = \\sum_{i=1}\^k f_i a_{n-i},其特征多项式为:

P(x)=x\^k-f_1x\^{k-1}-f_2x\^{k-2}-\\cdots -f_k

转移矩阵与多项式的关系

定义状态向量 \\mathbf{v}n = (a{n-k+1}, a_{n-k+2}, \\dots, a_n)\^T,则 \\mathbf{v}n = M \\cdot \\mathbf{v}{n-1},其中 Mk \\times k 的伴随矩阵:

$$\begin{pmatrix}

0& 1& 0& \cdots& 0\\

0& 0& 1& \cdots& 0\\

\vdots & \vdots & \vdots& \ddots & \vdots\\

0& 0& 0& \cdots& 1\\

f_k& f_{k-1}& f_{k-2}& \cdots&f_{k-3}

\end{pmatrix}$$

矩阵 M 的特征多项式恰好是 P(x)。根据 Cayley-Hamilton 定理,M 满足 P(M)=0,即 M\^k = f_1 M\^{k-1} + f_2 M\^{k-2} + \\cdots + f_k I。何意味?这意味着任何 M 的幂都可以表示为 I, M, M\^2, \\dots, M\^{k-1} 的线性组合,系数由 P(x) 决定。

将问题转化为多项式幂

由于 a_n\\mathbf{v}_n 的最后一个分量,而 \\mathbf{v}_n = M\^n \\mathbf{v}_0,因此只需要求出 M\^n 的最后一行的第一个元素。但直接计算 M\^n 仍然是矩阵乘法,TLE 还是永生。

M\^n 可以表示为关于 M 的多项式 R(M),其中 R(x) = x\^n \\bmod P(x)。设最后一行为 E

a_n=E\\cdot\\mathbf{v}_n=\\sum_{k-1}\^{i=0}c_ia_i

其中 c_iR(x) = \\sum_{i=0}\^{k-1} c_i x\^i 的系数。所以问题就是这么个东东:**计算多项式 x\^nP(x) 的余式**,然后与初始项点乘就 AKIOI。

复杂度分析

多项式乘法和求逆:O(k \\log k)

  • 快速幂:需要 O(\\log n) 次乘法和取模,每次乘法 O(k \\log k),取模 O(k \\log k)

  • 总时间复杂度:O(k \\log k \\log n)

  • 空间复杂度:O(k)

  • 对于 k=32000, \\log n \\approx 30,可以 AKIOI 了。

AC code

::::successAC code

```cpp

#include <bits/stdc++.h>

using namespace std;

const int MOD = 998244353;

const int g = 3;

int qpow(int a, int b) {

int res = 1;

while (b) {

if (b & 1) res = 1LL * res * a % MOD;

a = 1LL * a * a % MOD;

b >>= 1;

}

return res;

}

void ntt(vector<int>& a, bool invert) {

int n = a.size();

for (int i = 1, j = 0; i < n; i++) {

int bit = n >> 1;

for (; j & bit; bit >>= 1) j ^= bit;

j ^= bit;

if (i < j) swap(ai, aj);

}

for (int len = 2; len <= n; len <<= 1) {

int wlen = qpow(g, (MOD - 1) / len);

if (invert) wlen = qpow(wlen, MOD - 2);

for (int i = 0; i < n; i += len) {

int w = 1;

for (int j = 0; j < len / 2; j++) {

int u = ai + j;

int v = 1LL * ai + j + len / 2 * w % MOD;

ai + j = (u + v) % MOD;

ai + j + len / 2 = (u - v + MOD) % MOD;

w = 1LL * w * wlen % MOD;

}

}

}

if (invert) {

int inn = qpow(n, MOD - 2);

for (int& x : a) x = 1LL * x * inn % MOD;

}

}

vector<int> multiply(const vector<int>& a, const vector<int>& b) {

int sz = a.size() + b.size() - 1;

int n = 1;

while (n < sz) n <<= 1;

vector<int> fa = a, fb = b;

fa.resize(n);

fb.resize(n);

ntt(fa, false);

ntt(fb, false);

for (int i = 0; i < n; i++) fai = 1LL * fai * fbi % MOD;

ntt(fa, true);

fa.resize(sz);

return fa;

}

vector<int> pyi(const vector<int>& a, int m) {

assert(a0 != 0);

vector<int> inv = { qpow(a0, MOD - 2) };

int cur = 1;

while (cur < m) {

int nxt = min(cur * 2, m);

vector<int> apx(a.begin(), a.begin() + min((int)a.size(), nxt));

vector<int> prod = multiply(apx, inv);

prod.resize(nxt);

for (int i = 0; i < nxt; i++) {

if (i == 0) prodi = (2 - prodi + MOD) % MOD;

else prodi = (-prodi + MOD) % MOD;

}

inv = multiply(inv, prod);

inv.resize(nxt);

cur = nxt;

}

return inv;

}

vector<int> pmd(const vector<int>& c, const vector<int>& p, const vector<int>& ire, int k) {

if (c.size() <= k) return c;

int tyu = c.size() - 1;

int efb = tyu - k + 1;

vector<int> iot = c;

reverse(iot.begin(), iot.end());

vector<int> crn(iot.begin(), iot.begin() + efb);

vector<int> sac(ire.begin(), ire.begin() + efb);

vector<int> jhn = multiply(crn, sac);

if (jhn.size() > efb) jhn.resize(efb);

vector<int> q = jhn;

reverse(q.begin(), q.end());

vector<int> qp = multiply(q, p);

if (qp.size() < c.size()) qp.resize(c.size(), 0);

vector<int> r(k, 0);

for (int i = 0; i < k; i++) {

int val = ci - qpi;

val %= MOD;

if (val < 0) val += MOD;

ri = val;

}

return r;

}

int main() {

ios::sync_with_stdio(false);

cin.tie(nullptr);

long long n;

int k;

cin >> n >> k;

vector<int> f(k), a(k);

for (int i = 0; i < k; i++) {

cin >> fi;

fi %= MOD;

if (fi < 0) fi += MOD;

}

for (int i = 0; i < k; i++) {

cin >> ai;

ai %= MOD;

if (ai < 0) ai += MOD;

}

vector<int> p(k + 1, 0);

pk = 1;

for (int i = 1; i <= k; i++) {

pk - i = (MOD - fi - 1) % MOD;

}

vector<int> tnz(k, 0);

tnz0 = 1;

for (int i = 1; i < k; i++) {

tnzi = (MOD - fi - 1) % MOD;

}

vector<int> ire = pyi(tnz, k);

vector<int> res = {1};

vector<int> base = {0, 1};

long long exp = n;

while (exp) {

if (exp & 1) {

res = pmd(multiply(res, base), p, ire, k);

}

base = pmd(multiply(base, base), p, ire, k);

exp >>= 1;

}

int ans = 0;

for (int i = 0; i < k; i++) {

ans = (ans + 1LL * resi * ai) % MOD;

}

cout << ans << endl;

return 0;

}

```

::::

相关推荐
智者知已应修善业35 分钟前
【51单片机8位数码管同时倒计时从9999】2024-1-25
c++·经验分享·笔记·算法·51单片机
洛水水38 分钟前
【力扣100题】86.柱状图中最大的矩形
算法·leetcode·职场和发展
渡之1 小时前
GRiM-Net 深度解析 | 无人机 GNSS 拒止场景下两阶段跨视角视觉定位框架
深度学习·算法·动态规划·无人机
测试仪器廖生135902563851 小时前
罗德与施瓦茨 FSP13频谱分析仪FSP30
网络·人工智能·算法
happymaker06261 小时前
LeetCodeHot100——560.和为K的子数组
算法
dtq04241 小时前
C语言刷题数组5,6(求平均值,求最大值)
c语言·数据结构·算法
郭梧悠2 小时前
Hash算法入门Hash冲突解决方案
算法·哈希算法
洛水水2 小时前
【力扣100题】81.寻找两个正序数组的中位数
数据结构·算法·leetcode
happymaker06263 小时前
LeetCodeHot100——155.最小栈
算法
洛水水3 小时前
【力扣100题】85.每日温度
算法·leetcode·职场和发展