NTT / Schönhage-Strassen 大整数乘法

上一篇看了 Toom-Cook:把大整数拆成若干段,把乘法转成低次数多项式的求值、点乘和插值。它仍然属于"选有限个点,把乘法次数降下来"的思路。当整数继续变大,真正的瓶颈会变成另一个问题:能不能把所有交叉项一次性算出来,而不是一项一项枚举?

这条路会把大整数乘法带到卷积,再带到 FFT 和 NTT。Schönhage--Strassen 也在这条线上,只是它把普通 NTT 的算术环境换成了更适合二进制机器的环。

大整数乘法先变成卷积

把一个大整数按某个基数 \(B\) 分块,可以写成 \(A=\sum_{i=0}^{n-1}a_iB^i\)。另一个数写成 \(C=\sum_{j=0}^{m-1}c_jB^j\)。这里的 \(B\) 可以是 \(10\)、\(10^4\)、\(2^{16}\) 或其他方便的基数。程序里通常低位在前,例如十进制数 12345 可以表示成:

cpp 复制代码
vector<int> a = {5, 4, 3, 2, 1};

两个整数相乘时,展开得到

\[AC=\left(\sum_i a_iB^i\right)\left(\sum_j c_jB^j\right) =\sum_k\left(\sum_{i+j=k}a_ic_j\right)B^k \]

第 \(k\) 个位置上的临时系数是 \(d_k=\sum_{i+j=k}a_ic_j\),这正是两个数组的卷积。也就是说,大整数乘法可以拆成两步:先算分块数组的卷积,再做进位。

朴素卷积的代码很直接:

cpp 复制代码
for (int i = 0; i < n; ++i) {
    for (int j = 0; j < m; ++j) {
        d[i + j] += a[i] * c[j];
    }
}

时间复杂度是 \(O(nm)\)。如果两个数长度接近,就是 \(O(n^2)\)。整数很长时,真正慢的就是这段双重循环。

系数表示、点值表示和变换长度

多项式有两种常见表示方式。第一种是系数表示:\(A(x)=a_0+a_1x+a_2x^2+\cdots\)。这正好对应程序里的数组,存储方便,但直接相乘需要卷积。

另一种是点值表示。一个次数小于 \(n\) 的多项式,可以由 \(n\) 个互异点上的值唯一确定。若知道 \(A(x)\) 和 \(C(x)\) 在同一批点 \(x_0,x_1,\dots,x_{N-1}\) 上的值,那么乘积多项式 \(D(x)=A(x)C(x)\) 在这些点上的值就是 \(D(x_i)=A(x_i)C(x_i)\)。点值表示下,多项式乘法变成逐点相乘。

不过做乘法时有一个容易漏掉的长度问题。若 \(A\) 有 \(n\) 项,\(C\) 有 \(m\) 项,那么乘积最多有 \(n+m-1\) 项。实际做 FFT 或 NTT 时,需要选择变换长度 \(N\),满足 \(N\ge n+m-1\),并把两个系数数组补零到长度 \(N\)。为了方便使用二分蝶形结构,\(N\) 通常取不小于 \(n+m-1\) 的最小 \(2\) 的幂。

于是多项式乘法变成了这样一条流程:

text 复制代码
系数表示 -> 点值表示 -> 逐点相乘 -> 系数表示

普通地对很多点求值和插值并不便宜。如果点随便选,整体仍然可能是 \(O(N^2)\)。FFT 和 NTT 的关键是:点不能随便选,要选一组有结构的单位根。

FFT 为什么能把求值变快

设 \(\omega\) 是一个 \(N\) 次原始单位根,满足 \(\omega^N=1\),并且对 \(0<k<N\) 有 \(\omega^k\ne1\)。复数 FFT 里通常取 \(\omega=e^{2\pi i/N}\)。假设 \(N\) 是 \(2\) 的幂,把多项式按偶数项和奇数项拆开:

\[A(x)=A_0(x^2)+xA_1(x^2) \]

其中 \(A_0(x)=a_0+a_2x+a_4x^2+\cdots\),\(A_1(x)=a_1+a_3x+a_5x^2+\cdots\)。当我们在 \(\omega^k\) 处求值时,有

\[A(\omega^k)=A_0(\omega^{2k})+\omega^kA_1(\omega^{2k}) \]

而在相隔半圈的点 \(\omega^{k+N/2}\) 处,因为 \(\omega^{N/2}=-1\),所以 \(\omega^{k+N/2}=-\omega^k\),从而

\[A(\omega^{k+N/2})=A_0(\omega^{2k})-\omega^kA_1(\omega^{2k}) \]

这就是蝶形合并。它说明,要算 \(A(1),A(\omega),A(\omega^2),\dots,A(\omega^{N-1})\),可以先分别算偶数部分和奇数部分在 \(1,\omega^2,\omega^4,\dots,\omega^{N-2}\) 上的值,再用加减和乘以 \(\omega^k\) 合并。一次拆分把长度为 \(N\) 的问题变成两个长度为 \(N/2\) 的问题,每层合并花 \(O(N)\),一共有 \(\log N\) 层,所以整体是 \(O(N\log N)\)。

这段推导本身没有依赖"复数长什么样"。它只用到了加、减、乘、单位根,以及逆变换里需要的除以 \(N\)。这也是 NTT 能成立的原因:只要把单位根搬到合适的模整数环境里,蝶形结构仍然可以照用。

逆变换为什么几乎一样

FFT 做的是把系数数组 \(a_0,a_1,\dots,a_{N-1}\) 变成多项式在单位根上的取值:

\[\hat a_k=A(\omega^k)=\sum_{j=0}^{N-1}a_j\omega^{jk} \]

其中 \(k=0,1,\dots,N-1\)。从点值 \(\hat a_k\) 回到系数 \(a_j\) 时,公式只需要把单位根换成逆单位根,再除以 \(N\):

\[a_j=\frac{1}{N}\sum_{k=0}^{N-1}\hat a_k\omega^{-jk} \]

所以逆 FFT 和正 FFT 几乎是同一个算法:蝶形结构不变,正变换使用 \(\omega\),逆变换使用 \(\omega^{-1}\),最后每一项除以 \(N\)。这一步补上以后,FFT 乘法的流程才是完整的:先对两个系数数组做 FFT,在点值处逐点相乘,再做逆 FFT 回到卷积系数。

复数 FFT 的问题是浮点误差。工程上可以通过四舍五入、拆系数、长双精度等办法缓解,但如果希望整个过程完全精确,就会自然转向 NTT。

NTT:把单位根放进模整数

NTT 可以看成是把 FFT 中的单位根从复数域搬到模质数域。我们希望在模 \(p\) 下找到一个元素 \(\omega\),满足 \(\omega^N\equiv1\pmod p\),并且对 \(0<k<N\) 有 \(\omega^k\not\equiv1\pmod p\)。这样的 \(\omega\) 就是模 \(p\) 意义下的 \(N\) 次原始单位根。

在模 \(p\) 下,前面的偶奇拆分仍然成立:

\[A(\omega^k)\equiv A_0(\omega^{2k})+\omega^kA_1(\omega^{2k})\pmod p \]

又因为 \(\omega^{N/2}\equiv -1\pmod p\),所以

\[A(\omega^{k+N/2})\equiv A_0(\omega^{2k})-\omega^kA_1(\omega^{2k})\pmod p \]

算法结构没有变化,变化的是数字所在的环境。复数 FFT 里的加法、减法、乘法,在 NTT 中变成了模 \(p\) 的加法、减法、乘法。

单位根从哪里来?如果 \(p\) 是质数,那么模 \(p\) 的非零元素构成大小为 \(p-1\) 的乘法群。为了支持长度为 \(2^r\) 的 NTT,通常选择 \(p=k\cdot2^r+1\) 这样的质数,使 \(p-1\) 含有足够大的 \(2\) 的幂因子。若 \(g\) 是模 \(p\) 的原根,且 \(N\mid p-1\),那么可以取 \(\omega=g^{(p-1)/N}\bmod p\)。

常用模数 \(998244353=119\cdot2^{23}+1\),它的一个原根是 \(3\),因此最多支持长度为 \(2^{23}\) 的 NTT。若当前变换长度是 \(N\),可以使用 \(3^{(998244353-1)/N}\bmod998244353\) 作为这次变换的单位根。

逆 NTT 对应逆 FFT 的公式。正变换使用 \(\omega\),逆变换使用 \(\omega^{-1}\),最后除以 \(N\)。在模质数下,除以 \(N\) 等价于乘 \(N\) 的模逆元;只要 \(p\) 是质数且 \(N\) 不被 \(p\) 整除,这个逆元就存在,可以写成 \(N^{p-2}\bmod p\)。

用 NTT 做十进制大整数乘法

下面的实现输入和输出都是十进制低位在前数组。每个元素是一位十进制数字,也就是基数 \(B=10\)。这个基数看起来浪费,但它让单模数 NTT 的正确性更容易保证,后面再讨论原因。

cpp 复制代码
#include <algorithm>
#include <iostream>
#include <stdexcept>
#include <vector>
using namespace std;

namespace {
    // 998244353 = 119 * 2^23 + 1
    constexpr int MOD = 998244353;

    // 998244353 的一个原根是 3。
    constexpr int G = 3;

    constexpr int MAX_NTT_LEN = 1 << 23;

    long long modPow(long long a, long long e) {
        long long r = 1;
        while (e > 0) {
            if (e & 1) r = r * a % MOD;
            a = a * a % MOD;
            e >>= 1;
        }
        return r;
    }

    void ntt(vector<int>& a, bool invert) {
        int n = (int)a.size();
        if (n == 0 || (n & (n - 1))) {
            throw invalid_argument("NTT length must be a power of two");
        }
        if (n > MAX_NTT_LEN) {
            throw invalid_argument("NTT length exceeds 2^23 for MOD=998244353");
        }

        // bit-reversal permutation
        for (int i = 1, j = 0; i < n; i++) {
            int bit = n >> 1;
            while (j & bit) {
                j ^= bit;
                bit >>= 1;
            }
            j ^= bit;

            if (i < j) {
                swap(a[i], a[j]);
            }
        }

        // len 是当前蝶形合并的区间长度:2, 4, 8, ...
        for (int len = 2; len <= n; len <<= 1) {
            int wlen = (int)modPow(G, (MOD - 1) / len);

            // 逆变换使用单位根的逆元。
            if (invert) {
                wlen = (int)modPow(wlen, MOD - 2);
            }

            for (int i = 0; i < n; i += len) {
                long long w = 1;

                for (int j = 0; j < len / 2; j++) {
                    int u = a[i + j];
                    int v = (int)(a[i + j + len / 2] * w % MOD);

                    int x = u + v;
                    if (x >= MOD) x -= MOD;

                    int y = u - v;
                    if (y < 0) y += MOD;

                    a[i + j] = x;
                    a[i + j + len / 2] = y;

                    w = w * wlen % MOD;
                }
            }
        }

        // 逆变换最后要除以 n。
        if (invert) {
            int invN = (int)modPow(n, MOD - 2);
            for (int& x : a) {
                x = (int)(1LL * x * invN % MOD);
            }
        }
    }

    void trim(vector<int>& a) {
        while (a.size() > 1 && a.back() == 0) {
            a.pop_back();
        }
    }
}

// 输入:十进制低位在前数组,例如 12345 表示为 {5,4,3,2,1}
// 输出:乘积的十进制低位在前数组
vector<int> multiplyBigInt(const vector<int>& A, const vector<int>& C) {
    if (A.empty() || C.empty()) {
        throw invalid_argument("input must not be empty");
    }

    vector<int> a = A;
    vector<int> c = C;

    trim(a);
    trim(c);

    for (int x : a) {
        if (x < 0 || x >= 10) {
            throw invalid_argument("A 中每个元素必须是 0~9");
        }
    }
    for (int x : c) {
        if (x < 0 || x >= 10) {
            throw invalid_argument("C 中每个元素必须是 0~9");
        }
    }

    if (a.size() == 1 && a[0] == 0) return {0};
    if (c.size() == 1 && c[0] == 0) return {0};

    int need = (int)a.size() + (int)c.size() - 1;

    int n = 1;
    while (n < need) n <<= 1;
    if (n > MAX_NTT_LEN) {
        throw invalid_argument("input is too large for one 998244353 NTT");
    }

    vector<int> fa(a.begin(), a.end());
    vector<int> fc(c.begin(), c.end());

    fa.resize(n);
    fc.resize(n);

    ntt(fa, false);
    ntt(fc, false);

    for (int i = 0; i < n; i++) {
        fa[i] = (int)(1LL * fa[i] * fc[i] % MOD);
    }

    ntt(fa, true);

    // fa 此时是卷积结果,但还不是合法的十进制表示。
    vector<int> res;
    res.reserve(need + 10);

    long long carry = 0;
    for (int i = 0; i < need; i++) {
        long long cur = fa[i] + carry;
        res.push_back((int)(cur % 10));
        carry = cur / 10;
    }

    while (carry > 0) {
        res.push_back((int)(carry % 10));
        carry /= 10;
    }

    trim(res);
    return res;
}

int main() {
    vector<int> a = {5, 4, 3, 2, 1}; // 12345
    vector<int> c = {9, 8, 7, 6};    // 6789

    vector<int> product = multiplyBigInt(a, c);

    for (int i = (int)product.size() - 1; i >= 0; --i) {
        cout << product[i];
    }
    cout << '\n'; // 83810205
}

这份代码刻意保持了比较朴素的接口。真实工程里通常不会让一个 int 只存一位十进制数字,但教学版这么写有一个好处:卷积系数不容易超过单个模数的可还原范围。

基数、模数、CRT 与 limb

基数越大,数组长度越短,NTT 的规模越小;但基数越大,单个卷积系数也越大。假设两个数组长度相近,某个卷积位置最多大约累加 \(L\) 项,每一项最大接近 \((B-1)^2\),于是卷积系数的上界大约是 \(L(B-1)^2\)。如果只使用一个模数 \(\text{MOD}\),想在逆 NTT 后得到真实整数系数,而不是模 \(\text{MOD}\) 后的值,就需要 \(L(B-1)^2<\text{MOD}\)。

这条不等式解释了为什么上面的代码使用 BASE = 10。在 \(998244353\) 下,最大 NTT 长度是 \(2^{23}\)。如果两个输入长度接近,单边长度最多约为 \(2^{22}\),每个卷积项最多累加约 \(2^{22}\) 项;十进制单项乘积最大为 \(9\cdot9=81\),所以上界约为 \(81\cdot2^{22}=339738624\),仍小于 \(998244353\)。在这份代码的长度限制内,单模数结果可以直接解释为真实卷积系数。

如果把基数改成 10000,情况马上变得不同。即使只累加 \(10\) 项,也可能达到 \(10\cdot9999^2=999800010\),已经超过 \(998244353\)。一旦卷积系数在模意义下绕回,后面的进位处理再正确也无济于事,因为拿到的已经不是原来的整数系数。

外部 limb 和 NTT 内部块不一定要相同。工程实现里,外部可以用 uint32_tuint64_t 存大整数,进入 NTT 前再拆成更小的块,例如 \(15\) 位、\(16\) 位或 \(10^4\)。直接把 uint32_t 的上限当作单模数 NTT 的基数通常不合适,因为卷积项会接近甚至超过 \(2^{64}\),普通 30 位模数根本承接不了。

解决这个问题有两条常见路线。一条是使用更大的 NTT 友好模数,例如接近 \(2^{60}\) 的质数,并用 __int128 承接模乘中间结果。另一条是多模数 NTT 加 CRT:用几个 30 位左右的 NTT 友好质数分别计算卷积,再通过中国剩余定理还原真实系数。多模数会增加常数,但实现仍然相对稳健,也更容易写出跨平台代码。

Schönhage--Strassen 升级了什么

Schönhage--Strassen 也从同一条主线出发:把大整数分块,转成卷积,用快速变换计算卷积,再进位还原结果。它不是推翻 NTT,而是把"快速卷积"推进到更适合超大整数的算术环境里。

普通 NTT 通常工作在有限域 \(\mathbb{F}_p\) 中,模数选成 \(p=k\cdot2^r+1\)。这样做的好处是数学结构干净,原根、单位根和逆元都好处理。Schönhage--Strassen 选择的环境是 \(\mathbb{Z}/(2^m+1)\mathbb{Z}\)。这个模数有一个对二进制机器非常友好的性质:\(2^m\equiv-1\pmod{2^m+1}\),于是 \(2^{2m}\equiv1\pmod{2^m+1}\)。这意味着可以用 \(2\) 的幂构造单位根。

这个选择改变了蝶形运算里的成本。普通 NTT 的旋转因子乘法通常是一般模乘,例如代码里的 a[i + j + len / 2] * w % MOD。而在 \(\mathbb{Z}/(2^m+1)\mathbb{Z}\) 中,很多旋转因子可以取成 \(2\) 的幂。乘以 \(2^k\) 在机器上接近移位操作;如果超过 \(2^m\),还能利用 \(2^m\equiv-1\) 折回并改变符号。对超大整数来说,这比频繁做一般大整数模乘更有吸引力。

一个很小的例子可以看出这种结构。取模数 \(17=2^4+1\),有 \(2^4=16\equiv-1\pmod{17}\),所以 \(2^8\equiv1\pmod{17}\)。计算 \(7\cdot2^6\bmod17\) 时,可以写成

\[7\cdot2^6=7\cdot2^4\cdot2^2\equiv7\cdot(-1)\cdot4=-28\equiv6\pmod{17} \]

直接算 \(7\cdot64=448\),余数也是 \(6\)。这个例子不等于完整的 Schönhage--Strassen,但它展示了关键味道:旋转因子如果是 \(2\) 的幂,乘法就可以和移位、折回、加减联系起来。

不过,Schönhage--Strassen 不是把前面的 MOD 改成 2^m+1 就能运行。普通 NTT 代码依赖的是模质数域,G^((MOD - 1) / len)、费马小定理求逆元、原根生成单位根,这些都建立在质数模数上。\(2^m+1\) 一般不是质数,算法需要重新处理单位根、逆变换、参数选择,以及循环卷积或负循环卷积的细节。它的复杂度通常写作 \(O(n\log n\log\log n)\),这里的 \(n\) 是输入整数的 bit 数。

如果把普通 NTT 大整数乘法和 Schönhage--Strassen 放在一起看,最清楚的差异不在"是否使用变换",而在"变换放在哪里做"。普通 NTT 放在模质数域里,优点是实现清晰、适合中等规模;Schönhage--Strassen 放在 \(\mathbb{Z}/(2^m+1)\mathbb{Z}\) 这样的环里,牺牲了实现简单性,换来更适合超大规模二进制整数的旋转因子和模约简结构。

乘法里的问题被搬到了哪里

从朴素乘法到 Toom-Cook,再到 FFT/NTT,变化的不是"乘法展开式"本身,而是处理交叉项的方式。朴素算法直接枚举所有 \(a_ic_j\);Toom-Cook 用少量点值和插值减少乘法次数;FFT/NTT 选择一整组结构化单位根,把大规模求值和插值降到 \(O(N\log N)\),从而一次性得到所有卷积系数。

NTT 的工程边界也很清楚:模数决定最大变换长度,基数决定卷积系数大小,单模数不够时就要拆块、换模数或上 CRT。Schönhage--Strassen 继续沿着这条路走,把快速卷积放进更贴近二进制表示的环里。它难在参数和边界,不难在主线;主线仍然是分块、变换、点乘、逆变换和进位。