最小二乘求解器lstsq,处理带权重和L2正则的线性回归

目录

代码注释版:

关键功能说明:

[torch.linalg.cholesky 的原理](#torch.linalg.cholesky 的原理)

代码示例

[Cholesky 分解的应用](#Cholesky 分解的应用)

[与 torch.cholesky 的区别](#与 torch.cholesky 的区别)

总结


代码注释版:

python 复制代码
from typing import Optional

import torch


def lstsq(
    matrix: torch.Tensor, 
    rhs: torch.Tensor, 
    weights: torch.Tensor, 
    l2_regularizer: Optional[torch.Tensor] = None,
    l2_regularizer_rhs: Optional[torch.Tensor] = None,
    shared: bool = False
) -> torch.Tensor:
    """带权重和L2正则化的最小二乘求解器,使用Cholesky分解
    
    解决形如 (A^T W A + λI) x = A^T W b 的线性系统
    支持多任务共享参数(通过shared参数合并Gram矩阵和右侧项)
    
    Args:
        matrix: 设计矩阵A,形状为 [batch_size, n_obs, n_params]
        rhs: 右侧项b,形状为 [batch_size, n_obs, n_outputs]
        weights: 权重矩阵W的对角元素,形状为 [batch_size, n_obs]
        l2_regularizer: L2正则化项λ的对角矩阵,形状为 [batch_size, n_params, n_params]
        l2_regularizer_rhs: 正则化项对右侧的修正,形状为 [batch_size, n_params, n_outputs]
        shared: 是否共享参数(将多个系统的Gram矩阵和右侧项求和)
    
    Returns:
        最小二乘解,形状为 [batch_size, n_params, n_outputs]
    """
    # 加权设计矩阵: W^(1/2) * A
    weighted_matrix = weights.unsqueeze(-1) * matrix
    
    # 计算正则化的Gram矩阵: A^T W A + λI
    regularized_gramian = weighted_matrix.mT @ matrix
    if l2_regularizer is not None:
        regularized_gramian += l2_regularizer  # 添加L2正则项
    
    # 计算右侧项: A^T W b + λ_rhs
    ATb = weighted_matrix.mT @ rhs
    if l2_regularizer_rhs is not None:
        ATb += l2_regularizer_rhs
    
    # 如果共享参数,合并所有batch的贡献
    if shared:
        regularized_gramian = regularized_gramian.sum(dim=0, keepdim=True)
        ATb = ATb.sum(dim=0, keepdim=True)
    
    # Cholesky分解求解
    chol = torch.linalg.cholesky(regularized_gramian)
    return torch.cholesky_solve(ATb, chol)


def lstsq_partial_share(
    matrix: torch.Tensor,
    rhs: torch.Tensor,
    weights: torch.Tensor,
    l2_regularizer: torch.Tensor,
    n_shared: int = 0
) -> torch.Tensor:
    """部分参数共享的最小二乘求解器
    
    将参数分为共享部分和独立部分:
    - 共享参数在所有样本间共享
    - 独立参数每个样本单独估计
    通过分块回归实现高效求解
    
    Args:
        matrix: 设计矩阵A,形状为 [batch_size, n_obs, n_params]
        rhs: 右侧项b,形状为 [batch_size, n_obs, n_outputs]
        weights: 权重矩阵的对角元素,形状为 [batch_size, n_obs]
        l2_regularizer: 正则化强度,形状为 [batch_size, n_params]
        n_shared: 共享参数的数量
    
    Returns:
        参数矩阵,前n_shared列为共享参数,其余为独立参数
        形状为 [batch_size, n_params, n_outputs]
    """
    n_params = matrix.shape[-1]
    n_rhs_outputs = rhs.shape[-1]
    n_indep = n_params - n_shared

    # 全共享情况直接返回广播结果
    if n_indep == 0:
        result = lstsq(matrix, rhs, weights, l2_regularizer, shared=True)
        return result.expand(matrix.shape[0], -1, -1)

    # 将正则化项转换为设计矩阵的扩展部分
    # 相当于添加 λI 的正则化项
    matrix = torch.cat([matrix, batch_eye(n_params, matrix.shape[0])], dim=1)
    rhs = torch.nn.functional.pad(rhs, (0, 0, 0, n_params))  # 右侧添加0
    weights = torch.cat([weights, l2_regularizer.unsqueeze(0).expand(matrix.shape[0], -1)], dim=1)

    # 分割共享和独立参数对应的设计矩阵
    matrix_shared, matrix_indep = torch.split(matrix, [n_shared, n_indep], dim=-1)

    # 步骤1:求解独立参数对共享参数和输出的影响
    indep_coeffs = lstsq(matrix_indep, torch.cat([matrix_shared, rhs], dim=-1), weights)
    coeff_indep2shared, coeff_indep2rhs = torch.split(indep_coeffs, [n_shared, n_rhs_outputs], dim=-1)

    # 步骤2:用残差求解共享参数
    shared_residual = matrix_shared - matrix_indep @ coeff_indep2shared
    rhs_residual = rhs - matrix_indep @ coeff_indep2rhs
    coeff_shared2rhs = lstsq(shared_residual, rhs_residual, weights, shared=True)

    # 步骤3:更新独立参数系数
    coeff_indep2rhs = coeff_indep2rhs - coeff_indep2shared @ coeff_shared2rhs

    # 合并结果:共享参数广播,独立参数保持独立
    coeff_shared2rhs = coeff_shared2rhs.expand(matrix.shape[0], -1, -1)
    return torch.cat([coeff_shared2rhs, coeff_indep2rhs], dim=1)


def batch_eye(n_params: int, batch_size: int) -> torch.Tensor:
    """生成批次对角矩阵
    
    Args:
        n_params: 矩阵维度
        batch_size: 批次大小
    
    Returns:
        形状为 [batch_size, n_params, n_params] 的单位矩阵批次
    """
    return torch.eye(n_params).reshape(1, n_params, n_params).expand(batch_size, -1, -1)

关键功能说明:

  1. lstsq:

    • 核心最小二乘求解器,处理带权重和L2正则的线性回归

    • 使用Cholesky分解提高数值稳定性

    • 支持多任务参数共享模式(shared=True时合并所有任务的贡献)

  2. lstsq_partial_share:

    • 处理部分参数共享的回归问题

    • 通过三步分块回归实现:

      1. 估计独立参数对共享参数和输出的影响

      2. 用残差估计共享参数

      3. 修正独立参数估计值

    • 通过矩阵拼接技巧将正则化转换为设计矩阵扩展

  3. batch_eye:

    • 生成批次单位矩阵,用于构建正则化项

    • 典型应用:将L2正则转换为扩展设计矩阵的伪观测

torch.linalg.cholesky 的原理

torch.linalg.cholesky(A) 用于对对称正定矩阵 AAA 进行 Cholesky 分解,即将其分解为:

A=LLTA = L L^TA=LLT

其中:

  • AAA 是 对称正定矩阵(必须满足 A=ATA = A^TA=AT 且所有特征值大于 0)。

  • LLL 是 下三角矩阵

计算 Cholesky 分解 的方式基于逐行计算 LLL:

  1. 计算对角元素:

    Lii=Aii−∑k=1i−1Lik2L_{ii} = \sqrt{ A_{ii} - \sum_{k=1}^{i-1} L_{ik}^2 }Lii=Aii−k=1∑i−1Lik2

  2. 计算非对角元素:

    Lji=1Lii(Aji−∑k=1i−1LjkLik),j>iL_{ji} = \frac{1}{L_{ii}} \left( A_{ji} - \sum_{k=1}^{i-1} L_{jk} L_{ik} \right), \quad j > iLji=Lii1(Aji−k=1∑i−1LjkLik),j>i

这个算法 只需要计算下三角部分 ,所以比 LU 分解 计算量更少,适用于 正定矩阵的快速求解


代码示例

python 复制代码
import torch

# 生成一个对称正定矩阵
A = torch.tensor([[4.0, 12.0, -16.0], 
                  [12.0, 37.0, -43.0], 
                  [-16.0, -43.0, 98.0]])

# Cholesky 分解
L = torch.linalg.cholesky(A)
print(L)

输出

tensor([[ 2.0000, 0.0000, 0.0000],

[ 6.0000, 1.0000, 0.0000],

[-8.0000, 5.0000, 3.0000]])

可以验证:

python 复制代码
print(torch.mm(L, L.T))

# 结果应当等于 A

Cholesky 分解的应用

  1. 解线性方程组 Ax=bAx = bAx=b:

    • 先求 L = torch.linalg.cholesky(A)

    • Ly = b(前代法)

    • L^T x = y(后代法)

  2. 生成多元正态分布

    • 如果协方差矩阵 Σ\SigmaΣ 进行 Cholesky 分解 Σ=LLT\Sigma = L L^TΣ=LLT,

    • 则可以用 L @ torch.randn(n, d) 生成符合协方差 Σ\SigmaΣ 的多元正态分布数据。


torch.cholesky 的区别

  • torch.cholesky(A) 旧版 API,不推荐使用。

  • torch.linalg.cholesky(A) 现代 API,支持 batch 计算,推荐使用。


总结

  • torch.linalg.cholesky(A) 计算 对称正定矩阵Cholesky 分解 ,分解成下三角矩阵 L,使得 A=LLTA = L L^TA=LLT。

  • 计算方式比 LU 分解更快,主要用于 正定矩阵的求解、统计学、多元正态分布 等。

  • 使用 Cholesky 分解求解线性方程组比直接求逆更稳定高效。

相关推荐
牧歌悠悠28 分钟前
【Python 算法】动态规划
python·算法·动态规划
Doris Liu.3 小时前
如何检测代码注入(Part 2)
windows·python·安全·网络安全·网络攻击模型
逢生博客3 小时前
阿里 FunASR 开源中文语音识别大模型应用示例(准确率比faster-whisper高)
人工智能·python·语音识别·funasr
噔噔噔噔@3 小时前
软件测试对于整个行业的重要性及必要性
python·单元测试·压力测试
赵谨言3 小时前
基于Python的Django框架的个人博客管理系统
经验分享·python·毕业设计
Guarding and trust3 小时前
python系统之综合案例:用python打造智能诗词生成助手
服务器·数据库·python
淮北4943 小时前
ros调试工具foxglove使用指南三:在3d空间写写画画(Panel->3D ->Scene entity)
python·学习·3d·机器人
mosquito_lover13 小时前
Python实现音频数字水印方法
python·音视频
苹果.Python.八宝粥3 小时前
Python第七章02:文件读取的练习
开发语言·python
Python之栈4 小时前
Python 3.13 正式支持 iOS:移动开发的新篇章
python·macos·objective-c·cocoa