题解:Kitamasa 算法板子

第八个黑题!

题意回顾

给定 k 阶齐次线性递推关系:

a_n=\\sum_{i=1}\^{k}f_i\\cdot a_{n-i}

于初始项 a_0, a_1, \\dots, a_{k-1},求 a_n \\bmod 998244353 的值。其中 n 最大可达 10\^9k 最大为 32000

算法概述

我第一眼以为就是一个简单的递推,但是直接算要 O(nk) 的复杂度,TLE 肯定会爬床抓你代码。还有一种做法是使用矩阵快速幂,将递推转化为矩阵乘幂形式,时间复杂度 O(k\^3 \\log n),But k 高达 32000,仍然鬼爬床。

需要用更高效的方法------多项式优化,利用递推多项式,通过多项式快速幂和取模在 O(k \\log k \\log n) 时间内求出答案。这种方法被称为 Kitamasa 算法,核心是计算 x\^n 关于特征多项式 P(x) 的余式,然后与初始项点乘。

理论讲解

递推的特征多项式

对于递推 a_n = \\sum_{i=1}\^k f_i a_{n-i},其特征多项式为:

P(x)=x\^k-f_1x\^{k-1}-f_2x\^{k-2}-\\cdots -f_k

转移矩阵与多项式的关系

定义状态向量 \\mathbf{v}n = (a{n-k+1}, a_{n-k+2}, \\dots, a_n)\^T,则 \\mathbf{v}n = M \\cdot \\mathbf{v}{n-1},其中 Mk \\times k 的伴随矩阵:

\\begin{pmatrix} 0\& 1\& 0\& \\cdots\& 0\\\\ 0\& 0\& 1\& \\cdots\& 0\\\\ \\vdots \& \\vdots \& \\vdots\& \\ddots \& \\vdots\\\\ 0\& 0\& 0\& \\cdots\& 1\\\\ f_k\& f_{k-1}\& f_{k-2}\& \\cdots\&f_{k-3} \\end{pmatrix}

矩阵 M 的特征多项式恰好是 P(x)。根据 Cayley-Hamilton 定理,M 满足 P(M)=0,即 M\^k = f_1 M\^{k-1} + f_2 M\^{k-2} + \\cdots + f_k I。何意味?这意味着任何 M 的幂都可以表示为 I, M, M\^2, \\dots, M\^{k-1} 的线性组合,系数由 P(x) 决定。

将问题转化为多项式幂

由于 a_n\\mathbf{v}_n 的最后一个分量,而 \\mathbf{v}_n = M\^n \\mathbf{v}_0,因此只需要求出 M\^n 的最后一行的第一个元素。但直接计算 M\^n 仍然是矩阵乘法,TLE 还是永生。

M\^n 可以表示为关于 M 的多项式 R(M),其中 R(x) = x\^n \\bmod P(x)。设最后一行为 E

a_n=E\\cdot\\mathbf{v}_n=\\sum_{k-1}\^{i=0}c_ia_i

其中 c_iR(x) = \\sum_{i=0}\^{k-1} c_i x\^i 的系数。所以问题就是这么个东东:**计算多项式 x\^nP(x) 的余式**,然后与初始项点乘就 AKIOI。

复杂度分析

多项式乘法和求逆:O(k \\log k)

  • 快速幂:需要 O(\\log n) 次乘法和取模,每次乘法 O(k \\log k),取模 O(k \\log k)

  • 总时间复杂度:O(k \\log k \\log n)

  • 空间复杂度:O(k)

  • 对于 k=32000, \\log n \\approx 30,可以 AKIOI 了。

AC code

::::success[AC code]

```cpp

#include <bits/stdc++.h>

using namespace std;

const int MOD = 998244353;

const int g = 3;

int qpow(int a, int b) {

int res = 1;

while (b) {

if (b & 1) res = 1LL * res * a % MOD;

a = 1LL * a * a % MOD;

b >>= 1;

}

return res;

}

void ntt(vector<int>& a, bool invert) {

int n = a.size();

for (int i = 1, j = 0; i < n; i++) {

int bit = n >> 1;

for (; j & bit; bit >>= 1) j ^= bit;

j ^= bit;

if (i < j) swap(a[i], a[j]);

}

for (int len = 2; len <= n; len <<= 1) {

int wlen = qpow(g, (MOD - 1) / len);

if (invert) wlen = qpow(wlen, MOD - 2);

for (int i = 0; i < n; i += len) {

int w = 1;

for (int j = 0; j < len / 2; j++) {

int u = a[i + j];

int v = 1LL * a[i + j + len / 2] * w % MOD;

a[i + j] = (u + v) % MOD;

a[i + j + len / 2] = (u - v + MOD) % MOD;

w = 1LL * w * wlen % MOD;

}

}

}

if (invert) {

int inn = qpow(n, MOD - 2);

for (int& x : a) x = 1LL * x * inn % MOD;

}

}

vector<int> multiply(const vector<int>& a, const vector<int>& b) {

int sz = a.size() + b.size() - 1;

int n = 1;

while (n < sz) n <<= 1;

vector<int> fa = a, fb = b;

fa.resize(n);

fb.resize(n);

ntt(fa, false);

ntt(fb, false);

for (int i = 0; i < n; i++) fa[i] = 1LL * fa[i] * fb[i] % MOD;

ntt(fa, true);

fa.resize(sz);

return fa;

}

vector<int> pyi(const vector<int>& a, int m) {

assert(a[0] != 0);

vector<int> inv = { qpow(a[0], MOD - 2) };

int cur = 1;

while (cur < m) {

int nxt = min(cur * 2, m);

vector<int> apx(a.begin(), a.begin() + min((int)a.size(), nxt));

vector<int> prod = multiply(apx, inv);

prod.resize(nxt);

for (int i = 0; i < nxt; i++) {

if (i == 0) prod[i] = (2 - prod[i] + MOD) % MOD;

else prod[i] = (-prod[i] + MOD) % MOD;

}

inv = multiply(inv, prod);

inv.resize(nxt);

cur = nxt;

}

return inv;

}

vector<int> pmd(const vector<int>& c, const vector<int>& p, const vector<int>& ire, int k) {

if (c.size() <= k) return c;

int tyu = c.size() - 1;

int efb = tyu - k + 1;

vector<int> iot = c;

reverse(iot.begin(), iot.end());

vector<int> crn(iot.begin(), iot.begin() + efb);

vector<int> sac(ire.begin(), ire.begin() + efb);

vector<int> jhn = multiply(crn, sac);

if (jhn.size() > efb) jhn.resize(efb);

vector<int> q = jhn;

reverse(q.begin(), q.end());

vector<int> qp = multiply(q, p);

if (qp.size() < c.size()) qp.resize(c.size(), 0);

vector<int> r(k, 0);

for (int i = 0; i < k; i++) {

int val = c[i] - qp[i];

val %= MOD;

if (val < 0) val += MOD;

r[i] = val;

}

return r;

}

int main() {

ios::sync_with_stdio(false);

cin.tie(nullptr);

long long n;

int k;

cin >> n >> k;

vector<int> f(k), a(k);

for (int i = 0; i < k; i++) {

cin >> f[i];

f[i] %= MOD;

if (f[i] < 0) f[i] += MOD;

}

for (int i = 0; i < k; i++) {

cin >> a[i];

a[i] %= MOD;

if (a[i] < 0) a[i] += MOD;

}

vector<int> p(k + 1, 0);

p[k] = 1;

for (int i = 1; i <= k; i++) {

p[k - i] = (MOD - f[i - 1]) % MOD;

}

vector<int> tnz(k, 0);

tnz[0] = 1;

for (int i = 1; i < k; i++) {

tnz[i] = (MOD - f[i - 1]) % MOD;

}

vector<int> ire = pyi(tnz, k);

vector<int> res = {1};

vector<int> base = {0, 1};

long long exp = n;

while (exp) {

if (exp & 1) {

res = pmd(multiply(res, base), p, ire, k);

}

base = pmd(multiply(base, base), p, ire, k);

exp >>= 1;

}

int ans = 0;

for (int i = 0; i < k; i++) {

ans = (ans + 1LL * res[i] * a[i]) % MOD;

}

cout << ans << endl;

return 0;

}

```

::::

相关推荐
筱昕~呀1 小时前
冲刺蓝桥杯-DFS板块(第一天)
算法·蓝桥杯·深度优先
We་ct2 小时前
LeetCode 637. 二叉树的层平均值:BFS层序遍历实战解析
前端·数据结构·算法·leetcode·typescript·宽度优先
I_LPL2 小时前
day36 代码随想录算法训练营 动态规划专题4
java·算法·leetcode·动态规划·hot100
ab1515172 小时前
2.24完成129、134、135
数据结构·算法
2301_816997882 小时前
虚拟DOM与Diff算法
前端·vue.js·算法
闻缺陷则喜何志丹2 小时前
P8153 「PMOI-5」送分题/Yet Another Easy Strings Merging|普及+
c++·数学·算法·洛谷
tankeven2 小时前
HJ102 字符统计
c++·算法
升讯威在线客服系统2 小时前
从 GC 抖动到稳定低延迟:在升讯威客服系统中实践 Span 与 Memory 的高性能优化
java·javascript·python·算法·性能优化·php·swift
We་ct2 小时前
LeetCode 199. 二叉树的右视图:层序遍历解题详解
前端·算法·leetcode·typescript·广度优先