用 FFT 和 NTT 解决多项式乘法

多项式乘法是一个很常见重要的问题。给两个多项式

\[A(x)=\sum_{i=0}^{n-1}a_i x^i,\qquad B(x)=\sum_{j=0}^{m-1}b_j x^j \]

那么它们的乘积 \(C(x)=A(x)B(x)\) 的第 \(k\) 项系数是

\[c_k=\sum_{i+j=k}a_i b_j \]

这就是序列 \(a\) 和 \(b\) 的卷积。直接按定义算,需要枚举所有 \(i,j\),复杂度是 \(O(nm)\)。假定 \(n,m\) 规模相当,这个复杂度就是平方的,很快就不够用了。

大整数乘法也是同一个问题。把一个整数按进制 \(\beta\) 拆成若干位:

\[X=\sum_i a_i \beta^i,\qquad Y=\sum_j b_j \beta^j \]

那么 \(XY\) 在处理进位之前,第 \(k\) 位的原始值就是 \(\sum_{i+j=k}a_i b_j\)。也就是说,多项式乘法、卷积、大整数乘法在核心计算上是同一个结构,只是最后解释结果的方式不同:多项式保留系数,卷积保留序列,大整数还要做一遍进位。

系数表示和点值表示

一个次数小于 \(n\) 的多项式,可以用它的 \(n\) 个系数表示:

\[A(x)=a_0+a_1x+\cdots+a_{n-1}x^{n-1} \]

这叫系数表示。它很适合做加法,因为对应系数相加即可;但相乘两个多项式会变成卷积,比较麻烦。

另一种表示方式是点值表示。选 \(n\) 个不同的点 \(x_0,x_1,\ldots,x_{n-1}\),记录

\[(x_i,A(x_i)) \]

这些点值也可以唯一确定多项式。用点值表示做乘法非常轻松:如果 \(C(x)=A(x)B(x)\),那么

\[C(x_i)=A(x_i)B(x_i) \]

所以只要逐点相乘就行。因此要避开平方的卷积,就要把系数表示转化成点值表示,进行点乘运算,然后把点值表示还原成系数表示。那么如何在系数表示和点值表示之间快速转换?

FFT 快速傅里叶变换

朴素求值要对每个点代入一次多项式,求一次是 \(O(n)\),一共 \(n\) 个点就是 \(O(n^2)\)。而 FFT 的关键是选择一组非常特殊的点,使求值在 \(O(n\log n)\) 内完成。

FFT 使用复数域里的单位根。长度为 \(n\) 时,取

\[\omega=e^{2\pi i/n} \]

取用它是因为他有一个性质: \(\omega^n=1\),并且 \(\omega^0,\omega^1,\ldots,\omega^{n-1}\) 两两不同。对一个系数序列 \(a_0,\ldots,a_{n-1}\),它的离散傅里叶变换可以写成

\[\hat a_k=\sum_{j=0}^{n-1}a_j\omega^{jk} \]

从信号与系统里傅里叶变换的意义来说,把 \(a_0,a_1...a_{n-1}\) 转化成 \(\hat a_0,\hat a_1...\hat a_{n-1}\) 是从时域到频域的转换,而从多项式的角度来说,这其实就是把多项式 \(A(x)=a_0+a_1x+\cdots+a_{n-1}x^{n-1}\) 分别求解 \(A(1),A(\omega),...A(\omega^{n-1})\)。相比随便选 \(n\) 个点求解来说,求解这些特殊点的值,可以利用其数学性质加速。

蝶形变换

单位根有很强的对称性。假设 \(n\) 是 2 的整次幂,把多项式按偶数次和奇数次拆开:

\[A(x)=A_0(x^2)+xA_1(x^2) \]

其中

\[A_0(x)=a_0+a_2x+a_4x^2+\cdots \]

\[A_1(x)=a_1+a_3x+a_5x^2+\cdots \]

因此

\[A(\omega^k)=A_0(\omega^{2k})+\omega^k A_1(\omega^{2k}) \]

\[A(\omega^{k+n/2})=A_0(\omega^{2k})-\omega^k A_1(\omega^{2k}) \]

而 \(\omega^{2}\) 也是一个单位根,所以可以递归求解。

设 \( u=A_0(\omega^{2k}), v=A_1(\omega^{2k}) \),那么原多项式在两个对应点上的值是

\[A(\omega^k)=u+\omega^k v \]

\[A(\omega^{k+n/2})=u-\omega^k v \]

这就是蝶形合并。它说明,要算 \(A(1),A(\omega),A(\omega^2),\dots,A(\omega^{n-1})\),可以先分别算偶数部分和奇数部分在 \(1,\omega^2,\omega^4,\dots,\omega^{n-2}\) 上的值,再用加减和乘以 \(\omega^k\) 合并。一次拆分把长度为 \(n\) 的问题变成两个长度为 \(n/2\) 的问题,每层合并花 \(O(n)\),一共有 \(\log n\) 层,所以整体是 \(O(n\log n)\)。在迭代版 FFT 中,通常先做 bit-reversal 重排,然后从长度 \(2\) 的小块开始合并,块长依次翻倍。每一层会遍历所有块,每个块里做若干个蝶形变换。

逆变换

正变换把系数变成点值,逆变换把点值插值回系数。DFT 的逆变换形式是

\[a_j=\frac{1}{n}\sum_{k=0}^{n-1}\hat a_k\omega^{-jk} \]

也就是说,逆变换和正变换几乎一样,只是把 \(\omega\) 换成 \(\omega^{-1}\),最后所有结果再除以 \(n\)。这个形式在 NTT 里也会原样保留,只是"除以 \(n\)"会变成乘上 \(n\) 在模意义下的逆元。

用 FFT 做卷积的流程很短:先把两个序列补零到长度 \(N\),其中 \(N\) 至少是 \(n+m-1\),通常取不小于它的最小二次幂。然后分别 FFT,逐点相乘,再逆 FFT。最终前 \(n+m-1\) 项就是卷积结果。

从 FFT 到 NTT 数论变换

FFT 用复数,速度快,但有浮点误差。对于整数卷积,尤其是竞赛、密码学或需要完全精确的场景,更常用 NTT。NTT 可以看作"把 FFT 搬到模质数意义下"。

现在,在模 \(p\) 的有限域里,如果存在一个元素 \(g\),它的幂能生成所有非零元素,那么 \(g\) 是模 \(p\) 的原根。换句话说,\(1, g,g^2...\ g^{p-2}\),刚好遍历了 \([1, p)\),而 \(g^{p-1}\) 再次回到 \(1\)。比如 \(3\) 是模 \(7\) 的一个原根(\(1→3→2→6→4→5→1\)),但 \(2\) 不是。

由于模质数 \(p\) 下的非零元素有 \(p-1\) 个,我们想做长度为 \(n\) 的 NTT,就需要 \(n\) 整除 \(p-1\)。

换一个稍大的例子。取 \(p=17\),可以验证 \(g=3\) 是一个原根。现在想做长度为 \(8\) 的 NTT,可以构造

\[\omega=3^{(17-1)/8}\equiv 9\pmod {17} \]

验证一下 \( 9^8\equiv 1\pmod {17} \),并且在 \(1,2,\ldots,7\) 次幂时都不会提前变成 \(1\)。

这时我们就发现,\(1,9,9^2,\ldots,9^7\) 就可以扮演 FFT 里 \(1,\omega,\omega^2,\ldots,\omega^7\) 的角色。同理,\(\omega=3\) 是一个长度 16 的 NTT 单位根,\(\omega=13\) 是一个长度 4 的 NTT 单位根。

更普遍的来说,若 \(g\) 是模 \(p\) 的原根,则

\[\omega=g^{(p-1)/n}\bmod p \]

就是一个 \(n\) 阶单位根。这样,FFT 里关于单位根的推导仍然成立,只是所有加法、减法、乘法都放在模 \(p\) 下。复数单位根 \(\omega=e^{2\pi i/n}\) 变成模意义下的单位根 \(\omega=g^{(p-1)/n}\);

在逆变换里,原理也完全相同,除以 \(n\) 变成乘 \(n^{-1}\bmod p\)(乘法逆元);

常用模数 998244353

\(\text{MOD}=998244353\) 很常用,首先他是质数,并且

\[998244353=119\cdot 2^{23}+1 \]

它足够大,最多可以生成长度 \(2^{23}\) 的 NTT 单位根。对多数算法题来说,这个长度已经足够大。同时这个数的大小刚好满足 \(2\text{MOD}\) 不超过 int32 的上限, \(\text{MOD}^2\) 没有超过 int64 的上限,利于实际实现。\(3\) 是他的一个原根。

通用情况

那么更大的情况呢?和 FFT 的区别是,NTT 下,卷积系数的运算也在模 \(p\) 意义下进行,若系数最终没有超过 \(p\) 就好说,但如果超过,就会在取模意义下绕回来,我们这时无法得知算出的 \(c_i\) 实际上是 \(c_i\) 还是 \(c_i + p\)。

假设输入系数非负,长度为 \(L\),每个系数不超过 \(M\),那么卷积中单个系数的上界大约是 \(LM^2\)。如果这个值可能超过模数,就要么选择更大的可用模数,要么使用多个 NTT 质数分别计算一次,再通过中国剩余定理合并。

具体来说,选几个互质模数 \(p_1,p_2,\ldots,p_r\),分别算出

\[c_k\bmod p_1,\quad c_k\bmod p_2,\quad \ldots,\quad c_k\bmod p_r \]

只要真实的 \(c_k\) 小于 \(P=p_1p_2\cdots p_r\),那么用 CRT 就可以唯一恢复 \(c_k\)。实际工程里常见做法是使用多个 NTT-friendly prime,例如 \(998244353\)、\(1004535809\)、\(469762049\) 。

对于大整数乘法,还可以通过控制进制来降低单个卷积系数的上界。例如把十进制字符串拆成 \(10^3\) 或 \(10^4\) 进制,而不是直接用 \(10^9\) 进制。这个情况下,进制越大,卷积长度越短,但中间系数越容易溢出模数;进制越小,长度更长,但结果更安全。这是一个很实际的 trade-off。

cpp 复制代码
#include <bits/stdc++.h>
using namespace std;

const int MOD = 998244353;
const int G = 3;

long long mod_pow(long long a, long long e) {
    long long r = 1;
    while (e) {
        if (e & 1) r = r * a % MOD;
        a = a * a % MOD;
        e >>= 1;
    }
    return r;
}

void ntt(vector<int>& a, bool invert) {
    int n = (int)a.size();

    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1) j ^= bit;
        j ^= bit;
        if (i < j) swap(a[i], a[j]);
    }

    for (int len = 2; len <= n; len <<= 1) {
        int wlen = (int)mod_pow(G, (MOD - 1) / len);
        if (invert) wlen = (int)mod_pow(wlen, MOD - 2);

        for (int i = 0; i < n; i += len) {
            long long w = 1;
            int half = len >> 1;

            for (int j = 0; j < half; j++) {
                int u = a[i + j];
                int v = (int)(a[i + j + half] * w % MOD);

                int x = u + v;
                if (x >= MOD) x -= MOD;

                int y = u - v;
                if (y < 0) y += MOD;

                a[i + j] = x;
                a[i + j + half] = y;
                w = w * wlen % MOD;
            }
        }
    }

    if (invert) {
        int inv_n = (int)mod_pow(n, MOD - 2);
        for (int& x : a) {
            x = (int)(1LL * x * inv_n % MOD);
        }
    }
}

vector<int> multiply(vector<int> a, vector<int> b) {
    if (a.empty() || b.empty()) return {};

    int need = (int)a.size() + (int)b.size() - 1;
    int n = 1;
    while (n < need) n <<= 1;

    if ((MOD - 1) % n != 0) {
        throw runtime_error("NTT length is not supported by this modulus");
    }

    a.resize(n);
    b.resize(n);

    ntt(a, false);
    ntt(b, false);

    for (int i = 0; i < n; i++) {
        a[i] = (int)(1LL * a[i] * b[i] % MOD);
    }

    ntt(a, true);
    a.resize(need);
    return a;
}

这份实现使用迭代版 NTT(更快)。先用 bit-reversal 把数组重排成递归到底后的顺序,然后令块长 len 从 \(2\) 开始不断翻倍。每一层把两个长度为 len / 2 的相邻块合并成一个长度为 len 的块。

在一个块内,第 j 个蝶形变换取

\[u=a_{i+j},\qquad v=\omega^j a_{i+j+\text{len}/2} \]

然后写回

\[a_{i+j}=u+v,\qquad a_{i+j+\text{len}/2}=u-v \]

这里的 \(\omega\) 是当前块长对应的单位根。正变换使用 \(\omega\),逆变换使用 \(\omega^{-1}\)。因此代码虽然没有显式递归,但它执行的合并顺序和递归 FFT/NTT 是一样的,只是把递归树按层展开了。

复杂度

为了方便,一般都会将两个初始序列长度相加,再补齐到最近的 2 的整数幂(这最多是原来的两倍长度)。设补零后的长度为 \(N\)。直接卷积需要 \(O(N^2)\),FFT 或 NTT 的流程是两次正变换、一次逐点相乘、一次逆变换。正变换和逆变换的复杂度都是

\[O(N\log N) \]

空间复杂度通常是 \(O(N)\)。迭代实现会原地完成变换,除了输入数组和少量临时变量,不需要递归栈和额外的大数组。

在实际使用中,FFT 和 NTT 的选择并不是单纯的速度问题。FFT 可以处理更自由的长度和数值范围,但要处理浮点误差;NTT 结果精确,适合整数和模意义下的卷积,但受模数和可用变换长度限制。对于程序员来说,最常见的路径是:需要模卷积时优先 NTT;需要精确整数卷积且结果可能很大时,用多模 NTT 加 CRT;只需要近似或能容忍舍入误差时,FFT 也很自然。