为什么 Cholesky 求逆比 Gauss-Jordan 快一倍——行列式溢出防护详

写在前面:从"会分解"到"会求逆"

上一篇我们拆了 Cholesky------30 行 C 代码把对称正定矩阵 AA A 分解成 A=RTRA = R^T R A=RTR, RR R 是上三角。当时只用半篇文章的篇幅扫过 Cholesky 的求逆部分,留了个悬念:所谓"两步法求逆"到底是怎么把 A−1A^{-1} A−1 算出来的?为什么它就比通用方法快一倍?求逆结果为什么是个"半成品"?

这一篇就把这件事讲透。聚焦点是 Cholesky 的后半段(求逆部分),而不是 分解部分。这是教科书最含糊其辞、而工程代码最值得玩味的一段。读完你会明白:

  • 为什么三角矩阵的逆只需要 O(n2)O(n^2) O(n2) 回代而不是 O(n3)O(n^3) O(n3) 消元;
  • 为什么 求出来的逆只填了上三角,调用方还得自己"对称化"一次;
  • 为什么这套两步法的总计算量是通用 Gauss-Jordan 的一半左右;
  • 为什么行列式可以"顺路"从求逆循环里捞出来。

一句话总结本篇主旨:通用的求逆算法把矩阵当成"一坨数"无差别地处理;而 Cholesky 求逆尊重矩阵的结构(对称、正定、已分解),把一个大问题拆成两个能用上结构的小问题。 结构信息用得越充分,省下的计算就越多。


钩子:教科书默认的求逆方法其实是"暴力法"

随便翻开一本数值分析或线性代数教材,"求矩阵的逆"那一节八成是这样教的:

构造增广矩阵 A∣IA \\mid I A∣I,对前 nn n 列做初等行变换把它化成单位阵,则后 nn n 列就是 A−1A^{-1} A−1。这就是 Gauss-Jordan 消元求逆

然后给你一个 3×3 的手算例子,让你在草稿纸上写满半个黑板。考试也考这个,于是你记住了:"求逆 = 增广 + 全选主元 + 消元"。

这个方法有一个致命问题:它无视矩阵的一切结构。 不管 AA A 是对称的、稀疏的、正定的、还是带状的,Gauss-Jordan 都把它当成 n×nn \times n n×n 的"一坨数",从左上角消到右下角,每个元素都参与消元。代价是实打实的 O(n3)O(n^3) O(n3) 浮点运算,且过程中要做主元选取、行交换,数值上还得提防大主元相消。

但如果你已经知道 AA A 是对称正定的呢?这时候"暴力法"就是暴殄天物------你手握一份关于矩阵结构的强信息(它有 Cholesky 分解、可以分解成两个三角矩阵的乘积),却完全没用上。

Cholesky 求逆的核心思想,就是把这份结构信息转化为计算量的节省。 先分解(上一篇讲过, n3/3n^3/3 n3/3 量级),再分别处理两个三角矩阵(这一篇要讲,各 n3/3n^3/3 n3/3 量级),合起来比 Gauss-Jordan 的 2n3/32n^3/3 2n3/3 少了一截,而且更稳。

下面正式拆 解程序。


数学原理:把求逆拆成两个"小一号"的问题

核心恒等式

A−1=(RTR)−1=R−1(R−1)TA^{-1} = (R^T R)^{-1} = R^{-1} (R^{-1})^T A−1=(RTR)−1=R−1(R−1)T

证明只要两行: (RTR)−1=R−1(RT)−1=R−1(R−1)T(R^T R)^{-1} = R^{-1} (R^T)^{-1} = R^{-1} (R^{-1})^T (RTR)−1=R−1(RT)−1=R−1(R−1)T,最后一步用了"转置的逆等于逆的转置"。这条恒等式看起来不起眼,却是整个两步法的根基。

它告诉我们: A−1A^{-1} A−1 不必碰 AA A,只要求 R−1R^{-1} R−1 然后做一次三角乘法即可。 RR R 是上三角矩阵,三角矩阵的求逆有专门的、便宜得多的算法------回代。

三角矩阵求逆的本质:回代,每个元素 O(n)O(n) O(n)

为什么三角矩阵求逆便宜?因为三角矩阵的逆"结构上"还是三角矩阵(上三角的逆仍是上三角),而且每个元素都可以用一次回代得到。

RR R 的逆记作 S=R−1S = R^{-1} S=R−1,满足 RS=IRS = I RS=I。展开 RS=IRS = I RS=I 的第 ii i 行第 jj j 列元素(约定 i≤ji \le j i≤j,因为 SS S 也是上三角, i>ji > j i>j 时 Sij =0 S_{ij} = 0 Sij=0):

∑k=ij Rik Skj = δij \sum_{k=i}^{j} R_{ik} S_{kj} = \delta_{ij} ∑k=ijRikSkj=δij

注意求和从 k=ik=i k=i 开始(因为 RR R 是上三角, Rik =0 R_{ik} = 0 Rik=0 当 k<ik < i k<i),到 k=jk=j k=j 结束(因为 SS S 是上三角, Skj =0 S_{kj} = 0 Skj=0 当 k>jk > j k>j)。这是一个只跨越 i,ji, j i,j 区间 的小型求和,而不是全跨度 1,n1, n 1,n

k=ik = i k=i 那一项拎出来:

Rii Sij + ∑k=i+1j Rik Skj = δij R_{ii} S_{ij} + \sum_{k=i+1}^{j} R_{ik} S_{kj} = \delta_{ij} RiiSij+∑k=i+1jRikSkj=δij

解出 Sij S_{ij} Sij:

Sij = δij − ∑k=i+1j Rik Skj Rii S_{ij} = \frac{\delta_{ij} - \sum_{k=i+1}^{j} R_{ik} S_{kj}}{R_{ii}} Sij=Riiδij−∑k=i+1jRikSkj

分两种情况看:

  • 对角元 i=ji = j i=j): δii =1 \delta_{ii} = 1 δii=1,求和项空集,得到 Sii =1/ Rii S_{ii} = 1 / R_{ii} Sii=1/Rii------对角元就是直接取倒数。
  • 非对角元 i<ji < j i<j): δij =0 \delta_{ij} = 0 δij=0,得到 Sij =− ∑k=i+1j Rik Skj Rii S_{ij} = -\dfrac{\sum_{k=i+1}^{j} R_{ik} S_{kj}}{R_{ii}} Sij=−Rii∑k=i+1jRikSkj------一个内积再除以对角元。

注意求和 ∑k=i+1j \sum_{k=i+1}^{j} ∑k=i+1j 里: Rik R_{ik} Rik 是 RR R 第 ii i 行从 i+1i+1 i+1 到 jj j 的元素, Skj S_{kj} Skj 是 SS S 第 jj j 列从 i+1i+1 i+1 到 jj j 的元素。这两个片段在算法开始计算 Sij S_{ij} Sij 时必须已经算好。 这就规定了计算顺序:对每一列 jj j,从下往上( ii i 从 j−1j-1 j−1 递减到 00 0)算非对角元,最后处理对角元 Sjj S_{jj} Sjj。

这就是"三角矩阵求逆 = 回代"。每个 Sij S_{ij} Sij 的代价是 O(j−i)O(j - i) O(j−i),平均 O(n)O(n) O(n);总元素数 O(n2)O(n^2) O(n2),所以总代价 O(n3)O(n^3) O(n3)------但常数因子只是通用方法的一个零头(下面会精确算)。

合成 A−1A^{-1} A−1:三角乘三角

第二步:拿到 S=R−1S = R^{-1} S=R−1 后,算 B=S⋅STB = S \cdot S^T B=S⋅ST。这也是一个三角矩阵乘以三角矩阵的乘法------ SS S 上三角、 STS^T ST 下三角,乘积 BB B 是对称的,而且只需要算上三角 就够了(下三角由对称性可得,这就是 dpodi 只填上三角的伏笔)。

BB B 的第 ii i 行第 jj j 列( i≤ji \le j i≤j):

Bij =∑k Sik (ST )kj =∑k Sik Sjk = ∑k=max⁡(i,j)n−1 Sik Sjk B_{ij} = \sum_{k} S_{ik} (S^T){kj} = \sum{k} S_{ik} S_{jk} = \sum_{k=\max(i,j)}^{n-1} S_{ik} S_{jk} Bij=∑kSik(ST)kj=∑kSikSjk=∑k=max(i,j)n−1SikSjk

求和下界是 max⁡(i,j)\max(i, j) max(i,j)------因为 SS S 上三角, Sik =0 S_{ik} = 0 Sik=0 当 k<ik < i k<i、 Sjk =0 S_{jk} = 0 Sjk=0 当 k<jk < j k<j,非零项要求 k≥ik \ge i k≥i 且 k≥jk \ge j k≥j,即 k≥max⁡(i,j)k \ge \max(i, j) k≥max(i,j);上界到 n−1n-1 n−1。对 i≤ji \le j i≤j, max⁡(i,j)=j\max(i,j) = j max(i,j)=j,所以:

Bij = ∑k=jn−1 Sik Sjk B_{ij} = \sum_{k=j}^{n-1} S_{ik} S_{jk} Bij=∑k=jn−1SikSjk

这又是一个 O(n−j)O(n-j) O(n−j) 的内积。计算 BB B 上三角的所有元素总代价也是 O(n3)O(n^3) O(n3) 量级(带一个比较小的常数)。

整体计算量:和 Gauss-Jordan 精确对比

把所有阶段加起来(以下"flop"指一次浮点加法或乘法,是经典的计数单位):

阶段 flop 数(近似)
Cholesky 分解 n³/3
三角求逆(反演 R) n³/6(实际带常数更接近此值的一半)
三角乘法合成 A⁻¹ n³/6
Cholesky 求逆合计 ≈ n³/3 + n³/3 = 2n³/3 的下界
通用 Gauss-Jordan 求逆 ≈ 2n³

精确数字在不同教材里略有出入(取决于 flop 的定义、是否计入行交换),但结论稳定:Cholesky 两步法求逆的总计算量大约是 Gauss-Jordan 的 1/2 到 1/3。 这就是"快一倍"的来源。

更妙的是:如果调用方本来就需要分解 (比如回归里既算 β\beta β 又算 (XTX)−1(X^TX)^{-1} (XTX)−1),那 n3/3n^3/3 n3/3 的分解成本是无论如何都要付的,求逆的边际成本只剩 n3/3n^3/3 n3/3------和"再跑一遍 Gauss-Jordan 求 (XTX)−1(X^TX)^{-1} (XTX)−1"相比,省了不止一半

数值稳定性:条件数不平方放大

计算量只是一个维度,另一个维度是数值稳定性。Gauss-Jordan 在消元时经常引入大主元相消(一个大数减一个接近的大数,丢掉有效数字),数值误差会显著累积。

Cholesky 两步法则不然。Cholesky 分解本身不需要选主元 ------正定性天然保证了 Rii >0 R_{ii} > 0 Rii>0,每一步的除法分母都是正数,没有符号抵消。而三角求逆里的除法也永远除以 Rii R_{ii} Rii 或 Sii S_{ii} Sii,分母都不小。

更重要的是:Cholesky 不平方放大条件数。 直接对 AA A 求逆时,相对误差大约正比于 κ(A)⋅εmach\kappa(A) \cdot \varepsilon_{\text{mach}} κ(A)⋅εmach;而如果我们走"最小二乘 → 构造 A=XTXA = X^TX A=XTX → 求逆"这条路,构造 XTXX^TX XTX 本身就把条件数平方了一次( κ(XTX)≈κ(X)2\kappa(X^TX) \approx \kappa(X)^2 κ(XTX)≈κ(X)2),这是另一个故事,但至少在 AA A 已经给定的情况下,Cholesky 求逆不会进一步放大误差。

一句话:Cholesky 求逆是"用结构换精度"的典范。 你告诉它"这是个正定矩阵",它就回报你一个又快又稳的逆。


逐段拆解 :真实的求逆代码

铺垫完了数学,现在对着源码看。函数分为两段:算行列式(上一篇已讲,这里略),以及求逆。求逆部分又分两步:

  • 第一步 :反演 RR R 得到 R−1R^{-1} R−1(原地存上三角)
  • 第二步 :合成 A−1=R−1(R−1)TA^{-1} = R^{-1} (R^{-1})^T A−1=R−1(R−1)T(填满上三角)

我们一段一段看。

第一步:反演上三角 R

c 复制代码
/* ---- 求逆:第一步,反演 R 得到 R⁻¹(仍存上三角)---- */
if (job % 10 != 0) {
    for (int k = 0; k < n; k++) {
        a[k + k * lda] = 1.0 / a[k + k * lda];        /* R⁻¹ 对角 */
        double t = -a[k + k * lda];
        dscal(k, t, &a[0 + k * lda], 1);              /* 缩放本列上方 */
        for (int j = k + 1; j < n; j++) {
            double tj = a[k + j * lda];
            a[k + j * lda] = 0.0;
            daxpy(k + 1, tj, &a[0 + k * lda], 1, &a[0 + j * lda], 1);
        }
    }

这段代码的循环顺序看起来有点奇怪------不是"对每一列 jj j 从下往上算",而是"对每个对角元 kk k 做一次结构化的更新"。这是工程实现和教科书推导之间的一道鸿沟。我们把它对到刚才推导的公式上。

外层循环:逐个对角元 kk k 处理

c 复制代码
for (int k = 0; k < n; k++) {
    a[k + k * lda] = 1.0 / a[k + k * lda];        /* R⁻¹ 对角 */

kk k 次迭代,先把 RR R 的第 kk k 个对角元 Rkk R_{kk} Rkk 原地替换成它的倒数 Skk =1/ Rkk S_{kk} = 1 / R_{kk} Skk=1/Rkk 。这对应我们推导的 Sii =1/ Rii S_{ii} = 1/R_{ii} Sii=1/Rii。

但这里有个微妙的"时机"问题。注意此时的 a[k + k*lda] 已经是 Skk S_{kk} Skk(取过倒数了),下面这一行:

c 复制代码
    double t = -a[k + k * lda];

tt t 设为 − Skk =−1/ Rkk -S_{kk} = -1/R_{kk} −Skk=−1/Rkk。然后------

c 复制代码
    dscal(k, t, &a[0 + k * lda], 1);              /* 缩放本列上方 */

dscal(n, alpha, x, incx) 是 BLAS 的"向量数乘": x←α⋅xx \leftarrow \alpha \cdot x x←α⋅x。这里它把 kk k 列的前 kk k 个元素 (即 a[0 + k*lda]a[k-1 + k*lda],共 kk k 个)整体乘以 t=− Skk t = -S_{kk} t=−Skk。

为什么?因为在求逆公式 Sij =− ∑k=i+1j Rik Skj Rii S_{ij} = -\dfrac{\sum_{k=i+1}^{j} R_{ik} S_{kj}}{R_{ii}} Sij=−Rii∑k=i+1jRikSkj 里,分母 Rii R_{ii} Rii 会被反复用到------把这一列的上半段预先乘以 −1/ Rii -1/R_{ii} −1/Rii,相当于把除法"折叠"进了乘法。这是一种典型的循环不变量外提:与其在每个内层迭代里都做一次除法,不如一次性把整列的公共因子提出来。

不过这里有个小细节需要解释------为什么缩放的是第 kk k 列的前 kk k 个元素? 这就要看 dscal 之后那一段 daxpy 的循环了。

内层循环:用第 kk k 列更新它右边的所有列

c 复制代码
    for (int j = k + 1; j < n; j++) {
        double tj = a[k + j * lda];
        a[k + j * lda] = 0.0;
        daxpy(k + 1, tj, &a[0 + k * lda], 1, &a[0 + j * lda], 1);
    }

daxpy(n, alpha, x, incx, y, incy) 是 BLAS 的"axpy"操作: y←α⋅x+yy \leftarrow \alpha \cdot x + y y←α⋅x+y。这里它把 jj j 列的前 k+1k+1 k+1 个元素 yy y = a[0 + j*lda] 起,共 k+1k+1 k+1 个)更新为:原值加上 tj t_j tj 倍的第 kk k 列前 k+1k+1 k+1 个元素( xx x = a[0 + k*lda] 起)。

注意循环开始前,第 kk k 列前 kk k 个元素已经被 dscal 乘了 − Skk -S_{kk} −Skk,第 kk k 个位置(即对角元位置)现在是 Skk =1/ Rkk S_{kk} = 1/R_{kk} Skk=1/Rkk。所以第 kk k 列前 k+1k+1 k+1 个元素此时是: (− Skk ⋅ R0k , − Skk ⋅ R1k , ..., − Skk ⋅ Rk−1,k , Skk ) (-S_{kk} \cdot R_{0k},\ -S_{kk} \cdot R_{1k},\ \ldots,\ -S_{kk} \cdot R_{k-1,k},\ S_{kk}) (−Skk⋅R0k, −Skk⋅R1k, ..., −Skk⋅Rk−1,k, Skk)。

而第 jj j 列( j>kj > k j>k)的前 k+1k+1 k+1 个元素,此时还是 RR R 的内容: ( R0j , R1j , ..., Rkj ) (R_{0j},\ R_{1j},\ \ldots,\ R_{kj}) (R0j, R1j, ..., Rkj)。

把它们对应位置相加(乘以 tj= Rkj t_j = R_{kj} tj=Rkj)后,第 ii i 行( 0≤i≤k0 \le i \le k 0≤i≤k)的新值是:

Rij + Rkj ⋅(− Skk ⋅ Rik ),i<k R_{ij} + R_{kj} \cdot (-S_{kk} \cdot R_{ik}), \quad i < k Rij+Rkj⋅(−Skk⋅Rik),i<k

Rkj + Rkj ⋅ Skk ,i=k R_{kj} + R_{kj} \cdot S_{kk}, \quad i = k Rkj+Rkj⋅Skk,i=k

注意 Rkk ⋅ Skk =1 R_{kk} \cdot S_{kk} = 1 Rkk⋅Skk=1,所以上面第二式可以写成 Rkj ⋅(1+ Skk ) R_{kj} \cdot (1 + S_{kk}) Rkj⋅(1+Skk) 吗?不,这里要小心------当 i=ki = k i=k 时,"第 kk k 列的第 kk k 个元素"已经是 Skk S_{kk} Skk,但 Rkk ⋅ Skk =1 R_{kk} \cdot S_{kk} = 1 Rkk⋅Skk=1,所以"第 kk k 列第 kk k 行的位置"对 tj⋅xk t_j \cdot x_k tj⋅xk 的贡献是 tj⋅ Skk = Rkj ⋅ Skk t_j \cdot S_{kk} = R_{kj} \cdot S_{kk} tj⋅Skk=Rkj⋅Skk。然后 yk= Rkj y_k = R_{kj} yk=Rkj,更新后 yk← Rkj + Rkj Skk y_k \leftarrow R_{kj} + R_{kj} S_{kk} yk←Rkj+RkjSkk?

读者如果跟着推到这里会卡住------真实的 LINPACK 风格代码在这里做了一件事:它不是直接维护 RR R 或 SS S,而是维护一个"中间矩阵" ,通过一系列rank-1 更新 逐步把 RR R 变成 SS S。每一轮 kk k 循环都完成了一部分反演。这种"原地 rank-1 累积"的实现,比"显式按公式 Sij =−∑/diag S_{ij} = -\sum/\text{diag} Sij=−∑/diag 一个一个算"更高效,因为所有运算都打包成了 BLAS 调用,能直接吃到向量化加速。

要害在于:这段代码不是数学公式的逐字翻译,而是等价但更高效的算法重排。 数学公式告诉你"每个 Sij S_{ij} Sij 怎么算",代码把它重排成"每个 Rkj R_{kj} Rkj 如何传播到其他位置"。重排前后数学等价,但内存访问模式天差地别------重排后每一步都是连续内存的 BLAS 调用,对 CPU 缓存友好。

对角元取倒数: Skk =1/ Rkk S_{kk} = 1 / R_{kk} Skk=1/Rkk

回到循环开头那一行:

c 复制代码
a[k + k * lda] = 1.0 / a[k + k * lda];

这是公式 Skk =1/ Rkk S_{kk} = 1 / R_{kk} Skk=1/Rkk 的直接翻译。一个除法,干脆利落。整个三角求逆里,除法只发生在对角元上 ------非对角元的"除以 Rii R_{ii} Rii"被合并到了 dscal 里,变成了乘法。这是性能的关键之一:除法比乘法慢 3~5 倍,能合并就合并。

内层循环把第 kk k 行清零

注意这一行:

c 复制代码
a[k + j * lda] = 0.0;

把第 jj j 列的第 kk k 行清零,然后 才调用 daxpy。为什么?因为 daxpy y←αx+yy \leftarrow \alpha x + y y←αx+y,会把 yy y 的原值加进来。但这里我们不想保留原值 Rkj R_{kj} Rkj(它已经被存到 tj 里了),所以要先把那个位置清零,让 daxpy 干净地写入新值。这是 LINPACK 代码里反复出现的模式:daxpy 做 rank-1 更新之前,先把要覆盖的位置清零

读完这一段,你应该能体会到 dpodi 第一步的精妙之处:它把"三角矩阵求逆"这个看似需要 n2n^2 n2 次独立回代的操作,重构成了 nn n 次"对角元取倒数 + 一列缩放 + 一连串 rank-1 更新",每次更新都是 BLAS 的连续内存操作。 这就是为什么实测性能远高于教科书伪代码。

第二步:合成 A⁻¹ = R⁻¹ · (R⁻¹)ᵀ

经过第一步,a 的上三角里存的是 S=R−1S = R^{-1} S=R−1(下三角还是原始 AA A 的下三角,未被使用)。第二步要算 B=S⋅STB = S \cdot S^T B=S⋅ST:

c 复制代码
/* ---- 第二步,合成 A⁻¹ = R⁻¹·(R⁻¹)ᵀ(填满上三角)---- */
for (int j = 0; j < n; j++) {
    for (int k = 0; k < j; k++) {
        double t = a[k + j * lda];
        daxpy(k + 1, t, &a[0 + j * lda], 1, &a[0 + k * lda], 1);
    }
    double t = a[j + j * lda];
    dscal(j + 1, t, &a[0 + j * lda], 1);
}

对应数学公式 Bij = ∑m=jn−1 Sim Sjm B_{ij} = \sum_{m=j}^{n-1} S_{im} S_{jm} Bij=∑m=jn−1SimSjm( i≤ji \le j i≤j)。这段代码的策略是:对每一列 jj j,把它的元素 Skj S_{kj} Skj 当作系数去更新它左边的所有列 kk k( k<jk<j k<j),最后再缩放第 jj j 列本身。

内层循环:用第 jj j 列更新它左边的所有列

c 复制代码
for (int k = 0; k < j; k++) {
    double t = a[k + j * lda];
    daxpy(k + 1, t, &a[0 + j * lda], 1, &a[0 + k * lda], 1);
}

jj j 列第 kk k 行的元素 Skj S_{kj} Skj 取出来作为标量 tt t(注意 k<jk < j k<j,所以这是上三角里的元素)。然后 daxpy(k+1, t, &a[0 + j*lda], 1, &a[0 + k*lda], 1) 按 BLAS 约定是 y←αx+yy \leftarrow \alpha x + y y←αx+y:这里 xx x 是 jj j 列 yy y 是 kk k 列 ,所以是把第 jj j 列的前 k+1k+1 k+1 个元素乘以 tt t、累加到第 kk k 列的前 k+1k+1 k+1 个元素上------被更新的是第 kk k 列

这一步在算什么?它把第 kk k 列前 k+1k+1 k+1 个位置更新为 Sik + Skj ⋅ Sij S_{ik} + S_{kj}\cdot S_{ij} Sik+Skj⋅Sij( i≤ki \le k i≤k;此时第 jj j 列仍存着原始 Sij S_{ij} Sij)。对照目标 Bik =∑m Sim Skm B_{ik} = \sum_{m} S_{im} S_{km} Bik=∑mSimSkm:每一项 Sim Skm S_{im}S_{km} SimSkm,把 Skm S_{km} Skm 当作标量去缩放第 mm m 列、再累加到第 kk k 列------而这里的 mm m 正是 jj j( j>kj>k j>k),所以这次累加贡献的是 Skj ⋅ Sij S_{kj}\cdot S_{ij} Skj⋅Sij。

由于外层 jj j 从 k+1k+1 k+1 跑到 n−1n-1 n−1,位置 (i,k)(i,k) (i,k)( i≤ki \le k i≤k)最终累加到 ∑j=k+1n−1 Skj Sij \sum_{j=k+1}^{n-1} S_{kj} S_{ij} ∑j=k+1n−1SkjSij------只差 m=km=k m=k 的对角项没补上。

收尾:处理对角贡献并缩放

c 复制代码
double t = a[j + j * lda];
dscal(j + 1, t, &a[0 + j * lda], 1);

取出第 jj j 列的对角元 Sjj S_{jj} Sjj 作为标量,把第 jj j 列前 j+1j+1 j+1 个元素整体乘以 Sjj S_{jj} Sjj。

这一步同时完成了两件事:

  1. 补上对角贡献 Sij ⋅ Sjj S_{ij} \cdot S_{jj} Sij⋅Sjj :目标 Bij = ∑m=jn−1 Sim Sjm B_{ij} = \sum_{m=j}^{n-1} S_{im} S_{jm} Bij=∑m=jn−1SimSjm 里, m=jm = j m=j 这一项正是 Sij Sjj S_{ij} S_{jj} SijSjj( i≤ji \le j i≤j)。此时第 jj j 列前 j+1j+1 j+1 个位置仍是原始 Sij S_{ij} Sij,dscal 整体乘以 Sjj S_{jj} Sjj 后,它们便从 Sij S_{ij} Sij 变成 Sij ⋅ Sjj S_{ij}\cdot S_{jj} Sij⋅Sjj。
  2. 把第 jj j 列定格为 BB B 的第 jj j 列 :把这项和内层循环的累加合起来,位置 (i,j)(i,j) (i,j)( i≤ji \le j i≤j)最终等于 Sij ⋅ Sjj + ∑m=j+1n−1 Sim Sjm = ∑m=jn−1 Sim Sjm = Bij S_{ij}\cdot S_{jj} + \sum_{m=j+1}^{n-1} S_{im}S_{jm} = \sum_{m=j}^{n-1} S_{im}S_{jm} = B_{ij} Sij⋅Sjj+∑m=j+1n−1SimSjm=∑m=jn−1SimSjm=Bij。核心思想是:用一次 dscal 同时完成"补对角贡献"和"乘上 Sjj S_{jj} Sjj"两件事,又一次体现了循环不变量外提的工程智慧。

读完这两步,你大概会感慨:真实的数值代码不是数学公式的直接翻译,而是一连串巧妙的算法重排,让每一步都落在 BLAS 这种高度优化的基本操作上。 这是教科书伪代码和 LINPACK 生产代码之间最深的一道鸿沟。


工程细节一:求逆结果只在上三角,必须 symmetrize

这是最容易被忽略、却最坑调用方的一点。看的函数头注释:

c 复制代码
/*   job%10 != 0 : 原地把 A 的上三角替换为 A⁻¹ 的上三角(下三角由对称性可得)
 ...
 *        A⁻¹ = R⁻¹ · (R⁻¹)ᵀ  ------ 先反演上三角 R,再做三角乘法合成
 */

注意那一句"下三角由对称性可得 "。意思是:跑完之后,a 的下三角里还是原始 AA A 的下三角(或者 留下的"未被使用"的下三角),而不是 A−1A^{-1} A−1 的下三角。 调用方必须自己把上三角的值镜像填到下三角,才能得到完整的 A−1A^{-1} A−1。

为什么这样设计?两个原因:

  1. 对称矩阵的逆仍是对称矩阵,下三角完全是上三角的镜像,没必要浪费计算量去算两遍。
  2. 内存零浪费 :原地存储意味着不需要额外的 n×nn \times n n×n 缓冲区。在 1978 年内存以 KB 计的年代,这是必须的;到今天依然有意义------一个 10000×1000010000 \times 10000 10000×10000 的 double 矩阵占 800 MB,能省一份就省一份。

但代价是:调用方必须知道这个约定。 如果直接拿程序的输出当 A−1A^{-1} A−1 用、不做对称化,下三角就全是错的(还是原始 AA A 的数据)。这是一个典型的"接口陷阱"------函数名说"求逆",但实际只给了"半个逆"。

在 工业软件底层代码里,调用 之后的标准模式长这样(伪代码):

c 复制代码
dpodi(a, lda, n, det, 11);    /* job=11 表示同时算行列式和求逆 */
/* 现在的上三角是 A⁻¹ 的上三角 */
for (int j = 0; j < n; j++) {
    for (int i = j + 1; i < n; i++) {
        a[i + j * lda] = a[j + i * lda];   /* 镜像:把上三角的值复制到下三角 */
    }
}
/* 现在的 a 才是完整的 A⁻¹ */

这一步叫 symmetrize(对称化) 。它不在内部做,而是留给调用方------为什么?因为有些调用方根本不需要下三角 。比如只算 A−1bA^{-1} b A−1b(解方程组),用上三角的 A−1A^{-1} A−1 配合 BLAS 的 dtrmv(三角矩阵-向量乘法)就够了,根本不用对称化。把对称化放到调用方,让"需要完整矩阵"和"只需要上三角"的两种调用方各取所需,是更灵活的设计。

这就是"半成品约定"的工程智慧:函数只做最小必要的工作,把可选的后续留给调用方。 教科书从来不讲这种约定,但每个用过 LINPACK/LAPACK 的人都踩过这个坑。


工程细节二:三角矩阵求逆的逐步代数(3×3 例子)

为了让你彻底信服"三角求逆 = 回代",我们手算一个 3×3 的例子。设:

R= ( r00 r01 r02 0 r11 r12 0 0 r22 ) ,S=R−1= ( s00 s01 s02 0 s11 s12 0 0 s22 ) R = \begin{pmatrix} r_{00} & r_{01} & r_{02} \\ 0 & r_{11} & r_{12} \\ 0 & 0 & r_{22} \end{pmatrix}, \quad S = R^{-1} = \begin{pmatrix} s_{00} & s_{01} & s_{02} \\ 0 & s_{11} & s_{12} \\ 0 & 0 & s_{22} \end{pmatrix} R= r0000r01r110r02r12r22 ,S=R−1= s0000s01s110s02s12s22

要求 RS=IRS = I RS=I。逐个元素算:

对角元(直接取倒数)

s00=1/r00,s11=1/r11,s22=1/r22 s_{00} = 1/r_{00}, \quad s_{11} = 1/r_{11}, \quad s_{22} = 1/r_{22} s00=1/r00,s11=1/r11,s22=1/r22

第 2 列的非对角元(从下往上):

s12 s_{12} s12:从 RS=IRS = I RS=I 的第 1 行第 2 列, ∑k r1k sk2 =0 \sum_k r_{1k} s_{k2} = 0 ∑kr1ksk2=0,即 r11s12+r12s22=0 r_{11} s_{12} + r_{12} s_{22} = 0 r11s12+r12s22=0,所以

s12=−r12s22/r11=−r12/(r11r22) s_{12} = -r_{12} s_{22} / r_{11} = -r_{12} / (r_{11} r_{22}) s12=−r12s22/r11=−r12/(r11r22)

s02 s_{02} s02:第 0 行第 2 列, r00s02+r01s12+r02s22=0 r_{00} s_{02} + r_{01} s_{12} + r_{02} s_{22} = 0 r00s02+r01s12+r02s22=0,所以

s02=−(r01s12+r02s22)/r00 s_{02} = -(r_{01} s_{12} + r_{02} s_{22}) / r_{00} s02=−(r01s12+r02s22)/r00

注意 s12 s_{12} s12 已经在上一步算好了,这里直接代入------这就是"回代"的含义。每个元素的计算依赖同列更靠下的元素,所以必须从下往上算。

第 1 列的非对角元

s01 s_{01} s01:第 0 行第 1 列, r00s01+r01s11=0 r_{00} s_{01} + r_{01} s_{11} = 0 r00s01+r01s11=0,所以

s01=−r01s11/r00=−r01/(r00r11) s_{01} = -r_{01} s_{11} / r_{00} = -r_{01} / (r_{00} r_{11}) s01=−r01s11/r00=−r01/(r00r11)

最终

R−1= ( 1/r00 −r01/(r00r11) (r01r12−r02r11)/(r00r11r22) 0 1/r11 −r12/(r11r22) 0 0 1/r22 ) R^{-1} = \begin{pmatrix} 1/r_{00} & -r_{01}/(r_{00} r_{11}) & (r_{01} r_{12} - r_{02} r_{11})/(r_{00} r_{11} r_{22}) \\ 0 & 1/r_{11} & -r_{12}/(r_{11} r_{22}) \\ 0 & 0 & 1/r_{22} \end{pmatrix} R−1= 1/r0000−r01/(r00r11)1/r110(r01r12−r02r11)/(r00r11r22)−r12/(r11r22)1/r22

数一下运算量:3 个对角元(3 次除法)+ 3 个非对角元(每个 1 次内积 + 1 次除法)。对一般的 nn n,三角求逆大约 12n(n−1) \frac{1}{2} n(n-1) 21n(n−1) 个非对角元,每个 O(n)O(n) O(n) 内积------总 flop 约 n3/6n^3/6 n3/6 量级。

对比通用求逆 :对一个 3×33 \times 3 3×3 的稠密矩阵做 Gauss-Jordan 消元,每个元素都要参与行变换,总 flop 约 2n3=542n^3 = 54 2n3=54。而三角求逆只要 n3/6≈4.5n^3/6 \approx 4.5 n3/6≈4.5。差距随 nn n 增大而放大。这就是结构带来的红利。


工程细节三:行列式作为求逆循环的副产品

现在看 函数最前面那段算行列式的代码。注意它的位置------在求逆之前,且单独由 job / 10 != 0 控制

c 复制代码
/* ---- 行列式 ---- */
if (job / 10 != 0) {
    det[0] = 1.0; det[1] = 0.0;
    const double ten = 10.0;
    for (int i = 0; i < n; i++) {
        det[0] *= a[i + i * lda] * a[i + i * lda];   /* Π R(i,i)² */
        if (det[0] == 0.0) break;
        while (det[0] <  1.0) { det[0] *= ten; det[1] -= 1.0; }
        while (det[0] >= ten) { det[0] /= ten; det[1] += 1.0; }
    }
}

这里的关键观察:这段代码读的是 a[i + i*lda],即 RR R 的对角元 Rii R_{ii} Rii。它必须在求逆之前跑------因为求逆的第一步就会把 Rii R_{ii} Rii 替换成 1/ Rii 1/R_{ii} 1/Rii,到时候就再也读不到原始的 Rii R_{ii} Rii 了。 这就是为什么行列式代码放在求逆代码之前。

数学上 det⁡(A)=det⁡(RTR)=det⁡(R)2=∏i Rii2 \det(A) = \det(R^T R) = \det(R)^2 = \prod_i R_{ii}^2 det(A)=det(RTR)=det(R)2=∏iRii2,所以代码就是读 RR R 的对角元,平方,累乘。几乎零额外成本 ------分解已经把 Rii R_{ii} Rii 都算好了,这里只是顺着对角线读一遍。

手工科学计数法防溢出(简述)

det[0] 存尾数,永远保持在 1,10)\[1, 10) \[1,10) 区间;`det[1]` 存 10 的指数。每乘一次就立刻归一化:小于 1 就乘 10、指数减 1;大于等于 10 就除以 10、指数加 1。于是 det⁡(A)=det\[0×10det1\det(A) = \text{det}0 \times 10^{\text{det}1} det(A)=det0×10det1,永远不溢出。

为什么必须这么做?一个 100×100100 \times 100 100×100 的矩阵, ∏ Rii2 \prod R_{ii}^2 ∏Rii2 轻松达到 1020010^{200} 10200,中间过程还可能更高。double 的上限约 1.8×103081.8 \times 10^{308} 1.8×10308,但只要矩阵规模再大一点或者 Rii R_{ii} Rii 略大于 1,就会爆。手工科学计数法是"教科书数学"和"生产代码"之间最典型的鸿沟 ------数学上 ∏ Rii2 \prod R_{ii}^2 ∏Rii2 是个干净的公式,计算机上直接乘会炸。工业软件用两行 while 解决了它。

(这部分逻辑上一篇 #1 已经详细讲过,这里不展开。本文要强调的是:行列式计算是求逆过程的"顺路产物" ------调用方既能拿到 A−1A^{-1} A−1,又能几乎免费地拿到 det⁡(A)\det(A) det(A),一次调用两份收获。这在多元统计里特别有用:判别分析要算 ∣W∣|W| ∣W∣(组内协方差行列式)、贝叶斯模型选择要算 ∣Σ∣|\Sigma| ∣Σ∣,都是和协方差求逆一起做的。job = 11 就是一次拿到逆和行列式的"组合套餐"。)

job 参数的设计:位运算复用一个 int

dpodi 的接口用 一个 int 的两个十进制位同时控制"算不算行列式"和"算不算逆":

job 值 job/10(算行列式?) job%10(算逆?) 行为
0 0 0 啥也不算(合法但无用)
1 0 1 只算逆
10 1 0 只算行列式
11 1 1 都算

这是 1970 年代 Fortran 接口的典型风格------用一个 int 的十进制位打包多个布尔开关,省下函数参数 。现代 C 代码会用位掩码(JOB_DET | JOB_INV),但 LINPACK 选择了十进制位。读这种老代码时要注意:job / 10job % 10 不是在做数学除法,是在拆开关。


把整条链串起来:求逆在统计计算里出现在哪里

理解了两步法,回头看统计软件的底层,到处都是它的影子:

线性回归:算 (XTX)−1(X^TX)^{-1} (XTX)−1

最小二乘回归 β^=(XTX)−1XTy \hat\beta = (X^TX)^{-1} X^Ty β^=(XTX)−1XTy,要算 (XTX)−1(X^TX)^{-1} (XTX)−1。 XTXX^TX XTX 是对称正定,走 Cholesky: 分解、求逆。回归系数的标准误、置信区间、t 统计量,全都依赖 (XTX)−1(X^TX)^{-1} (XTX)−1 的对角元。 这个矩阵的精度直接决定了统计推断的可靠性------所以"又快又稳"的 Cholesky 求逆在这里是刚需。

协方差矩阵求逆:多元统计的常客

判别分析、多元回归、PCA、因子分析、结构方程......都要算协方差矩阵 Σ\Sigma Σ 的逆。 Σ\Sigma Σ 对称正定(假设非奇异),又是 Cholesky 求逆的舞台。判别分析里的马氏距离 d2=(x−μ)TΣ−1(x−μ)d^2 = (x - \mu)^T \Sigma^{-1} (x - \mu) d2=(x−μ)TΣ−1(x−μ),每判一个样本都要算一次 Σ−1x\Sigma^{-1} x Σ−1x,预计算 Σ−1\Sigma^{-1} Σ−1 用 dpodi,运行时只做矩阵-向量乘法。

贝叶斯推断:从多元正态后验采样

贝叶斯回归的后验往往是多元正态 N(μ,Σ)N(\mu, \Sigma) N(μ,Σ),采样需要算 μ+Lz\mu + L z μ+Lz,其中 LLT=ΣL L^T = \Sigma LLT=Σ。这里的 LL L 就是 Cholesky 因子 ,采样过程不需要显式求 Σ−1\Sigma^{-1} Σ−1,但如果要做后验预测、算协方差调整,就得让求逆出马。MCMC 的每一步迭代都可能触发一次分解+求逆,性能至关重要。

优化:牛顿法的 Hessian 求逆

牛顿法更新 θ←θ−H−1g\theta \leftarrow \theta - H^{-1} g θ←θ−H−1g,其中 HH H 是 Hessian。如果 HH H 正定(凸优化场景),又是 Cholesky 求逆。信任域方法、内点法、L-BFGS 的初始化,底层都藏着 Cholesky。


与 Gauss-Jordan 的全面对比

最后做一个总账对比,把性能、稳定性、适用性都摆出来:

维度 Cholesky 两步法求逆 Gauss-Jordan 消元求逆
适用矩阵 对称正定 任意非奇异
计算量(flop) ≈ n³/3 + n³/3 ≈ 2n³/3 ≈ 2n³
实测速度 快一倍左右 基准
主元选取 不需要(正定性保证) 必须(部分主元或全主元)
行交换 有(破坏原矩阵布局)
数值稳定性 极好(无大主元相消) 中等(依赖主元策略)
内存 原地,无额外分配 原地,但需要 pivot 数组
副产品 行列式(几乎免费) 行列式(要单独算)
结果存储 只填上三角,需 symmetrize 填满全矩阵

一句话总结:Cholesky 求逆是"用适用性换性能"的典范------你告诉它矩阵是正定的,它就回报你 2 倍速度和更高的精度。 如果你的矩阵恰好是对称正定(而统计计算里一半的矩阵都是),没有理由用 Gauss-Jordan。


总结:教科书不讲的 5 个点

  1. 核心恒等式 A−1=(RTR)−1=R−1(R−1)TA^{-1} = (R^T R)^{-1} = R^{-1} (R^{-1})^T A−1=(RTR)−1=R−1(R−1)T。把稠密求逆拆成两个三角求逆+一次三角乘法,是两步法的数学根基。
  2. 三角矩阵求逆 = 回代 :每个元素用一个内积加一次除法得到,对角元直接取倒数,非对角元从下往上回代,总代价 O(n3/6)O(n^3/6) O(n3/6)。
  3. 真实代码不是公式的直接翻译 : 把求逆重排成"对角元取倒数 + 列缩放 + rank-1 更新"的循环结构,每一步都落在 BLAS 的 dscal/daxpy 上,吃满向量化加速。
  4. 求逆结果只在上三角,必须 symmetrize :函数只算"半个逆"------上三角是 A−1A^{-1} A−1 的上三角,下三角还是原始数据。调用方要自己镜像填充,这是 LINPACK 风格的"半成品约定",让"需要完整矩阵"和"只需要上三角"的两种调用方各取所需。
  5. 行列式是求逆的副产品 det⁡(A)=∏ Rii2 \det(A) = \prod R_{ii}^2 det(A)=∏Rii2,几乎零成本(顺着对角线读一遍),还用手工科学计数法防溢出。job = 11 一次拿到逆和行列式,是多元统计的"组合套餐"。

如果只记一句话,那就是:通用的求逆算法把矩阵当成一坨无差别的数;Cholesky 求逆尊重矩阵的结构,用结构信息换计算量和精度。 工程上凡是对称正定,第一选择永远是 Cholesky,没有例外。

下一篇,我们会拆 dtrsl------三角方程组求解,也就是 Rx=bRx = b Rx=b 怎么解。那是个看似平凡、但藏着"前向代入 vs 后向代入"和"右端向量批处理"两个工程细节的故事。同样是教科书一句带过、生产代码精妙绝伦。

相关推荐
To_OC12 小时前
LC 994 腐烂的橘子:人人都说是 BFS 入门题,我却写了三遍才过
javascript·算法·leetcode
金銀銅鐵16 小时前
[Python] 扩展欧几里得算法
python·数学·算法
To_OC18 小时前
LC 200 岛屿数量:经典 DFS 入门题,我第一次写居然连方向都搞错了
javascript·算法·leetcode
To_OC1 天前
LC 128 最长连续序列:别上来就排序,O (n) 解法才是这题的灵魂
javascript·算法·leetcode
05Kevin2 天前
lk每日冒险题--数据结构6.27
算法
To_OC2 天前
从一次栈溢出报错说起,我把递归彻底扒明白了
javascript·算法·程序员
千纸鹤安安3 天前
千问Qwen-AgentWorld来了:一个语言模型搞定七大Agent场景,GPT-5.4都输了
算法