线性回归算法详解
🧠 算法思想
线性回归 是统计学和机器学习中最基础的预测建模技术之一,其核心思想是通过建立自变量(特征)与因变量(目标)之间的线性关系,来预测或解释因变量的变化。线性回归模型假设因变量是自变量的线性组合,再加上一个误差项。
数学表达式
线性回归模型的一般形式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Y = β 0 + β 1 X 1 + β 2 X 2 + ⋯ + β n X n + ϵ Y = \beta_0 + \beta_1 X_1 + \beta_2 X_2 + \dots + \beta_n X_n + \epsilon </math>Y=β0+β1X1+β2X2+⋯+βnXn+ϵ
其中:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> Y Y </math>Y 是因变量(目标值)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> X 1 , X 2 , ... , X n X_1, X_2, \dots, X_n </math>X1,X2,...,Xn 是自变量(特征)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> β 0 , β 1 , ... , β n \beta_0, \beta_1, \dots, \beta_n </math>β0,β1,...,βn 是模型参数(系数)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ 是误差项(无法通过自变量解释的部分)
目标
线性回归的目标是通过数据估计参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β,使得模型能够最小化预测值与实际值之间的误差。最常用的方法是 最小二乘法(Ordinary Least Squares, OLS),即最小化残差平方和:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Loss = ∑ i = 1 m ( y ( i ) − ( β 0 + β 1 x 1 ( i ) + ⋯ + β n x n ( i ) ) ) 2 \text{Loss} = \sum_{i=1}^{m} (y^{(i)} - (\beta_0 + \beta_1 x_1^{(i)} + \dots + \beta_n x_n^{(i)}))^2 </math>Loss=i=1∑m(y(i)−(β0+β1x1(i)+⋯+βnxn(i)))2
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> m m </math>m 是样本数量, <math xmlns="http://www.w3.org/1998/Math/MathML"> x ( i ) x^{(i)} </math>x(i) 是第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 个样本的特征向量, <math xmlns="http://www.w3.org/1998/Math/MathML"> y ( i ) y^{(i)} </math>y(i) 是实际输出值。
🧮 数学原理:正规方程
核心公式
线性回归的闭式解(闭合解)通过 正规方程 直接求得最优参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> β ^ = ( X T X ) − 1 X T y \hat{\beta} = (X^T X)^{-1} X^T y </math>β^=(XTX)−1XTy
其中:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 是特征矩阵(形状为 <math xmlns="http://www.w3.org/1998/Math/MathML"> n × f n \times f </math>n×f, <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 为样本数, <math xmlns="http://www.w3.org/1998/Math/MathML"> f f </math>f 为特征数)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y 是目标向量(形状为 <math xmlns="http://www.w3.org/1998/Math/MathML"> n × 1 n \times 1 </math>n×1)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> β ^ \hat{\beta} </math>β^ 是最优参数向量(形状为 <math xmlns="http://www.w3.org/1998/Math/MathML"> f × 1 f \times 1 </math>f×1)
该公式仅在 XᵀX
是满秩矩阵(即特征之间不存在完美的多重共线性)时才有效。如果 XᵀX
不可逆(奇异),通常意味着存在线性相关的特征或特征数量大于样本数量,此时需要使用岭回归等正则化方法或伪逆。
🛠️ 参数详解
在 scikit-learn
的 LinearRegression
中,核心参数如下:
参数名 | 说明 | 默认值/示例值 | 值的含义 |
---|---|---|---|
fit_intercept |
是否计算截距项 <math xmlns="http://www.w3.org/1998/Math/MathML"> β 0 \beta_0 </math>β0。 | True |
- True :模型包含截距项(推荐) - False :模型不包含截距项 |
n_jobs |
并行计算使用的处理器数量。 | None |
- 1 :单线程 - -1 :使用所有处理器(推荐) |
⏱️ 时间复杂度分析
线性回归的计算复杂度主要取决于求解参数的方法(如最小二乘法或梯度下降)。以下是不同方法的复杂度分析:
1. 最小二乘法(Normal Equation)
- 训练时间复杂度 : <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( f 2 n + f 3 ) O(f^2 n + f^3) </math>O(f2n+f3)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> f f </math>f 是特征数, <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 是样本数。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> f 2 n f^2 n </math>f2n:矩阵乘法 <math xmlns="http://www.w3.org/1998/Math/MathML"> X T X X^T X </math>XTX 的复杂度。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> f 3 f^3 </math>f3:矩阵求逆 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( X T X ) − 1 (X^T X)^{-1} </math>(XTX)−1 的复杂度。
- 预测时间复杂度 : <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( f ) O(f) </math>O(f)
- 每次预测只需计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> w T x + b w^T x + b </math>wTx+b,复杂度与特征数成正比。
✅ 示例代码
python
from sklearn.linear_model import LinearRegression
# 训练线性回归模型
model = LinearRegression( n_jobs=-1)
model.fit(X_train, y_train)
# 预测与评估
score = model.score(X_test, y_test)
print(f"模型 R² 分数: {score:.4f}")