一、算法核心思想
线性回归(Linear Regression) 是统计学和机器学习中最基础的预测模型,核心思想是通过线性方程描述自变量(X)与因变量(y)之间的关系:
y = β₀ + β₁X₁ + β₂X₂ + ... + βₚXₚ + ε
β₀:截距项(y轴交点)
β₁...βₚ:特征系数(斜率)
ε:误差项(随机噪声)
二、数学原理与损失函数
目标:找到最优系数β,使预测值ŷ与实际值y的误差最小化
损失函数(均方误差 MSE):
J(\beta) = \frac{1}{2m}\sum_{i=1}^{m}(y^{(i)} - \hat{y}^{(i)})^2
求解方法:
-
最小二乘法(OLS):
\hat{\beta} = (X^TX)^{-1}X^Ty
-
梯度下降(大数据集首选):
\beta_j := \beta_j - \alpha \frac{\partial J(\beta)}{\partial \beta_j}
三、Python实现(NumPy手写版)
import numpy as np
class LinearRegression:
def __init__(self, learning_rate=0.01, n_iters=1000):
self.lr = learning_rate
self.n_iters = n_iters
self.weights = None
self.bias = None
def fit(self, X, y):
# 初始化参数
n_samples, n_features = X.shape
self.weights = np.zeros(n_features)
self.bias = 0
# 梯度下降
for _ in range(self.n_iters):
y_pred = np.dot(X, self.weights) + self.bias
# 计算梯度
dw = (1/n_samples) * np.dot(X.T, (y_pred - y))
db = (1/n_samples) * np.sum(y_pred - y)
# 更新参数
self.weights -= self.lr * dw
self.bias -= self.lr * db
def predict(self, X):
return np.dot(X, self.weights) + self.bias
四、Scikit-Learn实战应用
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
# 1. 数据预处理
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 2. 多项式特征扩展(解决非线性问题)
poly = PolynomialFeatures(degree=2, include_bias=False)
X_poly = poly.fit_transform(X_scaled)
# 3. 创建模型
model = LinearRegression()
# 4. 训练模型
model.fit(X_poly, y)
# 5. 预测与评估
y_pred = model.predict(X_poly)
print(f"R² Score: {r2_score(y, y_pred):.3f}")
print(f"RMSE: {np.sqrt(mean_squared_error(y, y_pred)):.2f}")
五、关键诊断指标
| 指标 | 公式 | 说明 |
|----------|-----------------------------|---------------------|------|-------------|
| R² | 1 - SS_res/SS_tot
| 模型解释的方差比例,0~1越近1越好 |
| 调整R² | 1 - [(1-R²)(n-1)/(n-p-1)]
| 考虑特征数量的修正R² |
| MSE | Σ(y-ŷ)²/n
| 均方误差,越小越好 |
| MAE | `Σ | y-ŷ | /n` | 平均绝对误差,鲁棒性强 |
| 系数P值 | t检验统计量 | <0.05表示特征显著 |
六、模型诊断可视化(Matplotlib)
import matplotlib.pyplot as plt
from statsmodels.graphics.gofplots import qqplot
# 1. 残差分析
plt.figure(figsize=(12, 4))
plt.subplot(131)
plt.scatter(y_pred, y - y_pred)
plt.axhline(y=0, color='r', linestyle='--')
plt.title('Residuals vs Fitted')
# 2. Q-Q图(正态性检验)
plt.subplot(132)
qqplot(y - y_pred, line='s', ax=plt.gca())
plt.title('Normal Q-Q')
# 3. 系数重要性
plt.subplot(133)
coefs = pd.Series(model.coef_, index=poly.get_feature_names_out())
coefs.sort_values().plot.barh()
plt.title('Feature Coefficients')
plt.tight_layout()
七、常见问题与解决方案
问题 | 症状 | 解决方案 |
---|---|---|
多重共线性 | 系数符号反常/VIF>10 | 1. 特征选择 2. 岭回归(L2正则) |
异方差性 | 残差呈漏斗状 | 1. 目标变量变换 2. 加权最小二乘法 |
非线性关系 | 残差呈曲线模式 | 1. 多项式特征 2. 样条回归 |
异常值影响 | 个别点远离趋势线 | 1. 鲁棒回归 2. 删除/转换异常值 |
八、正则化进阶:岭回归 vs Lasso
# 岭回归 (L2正则)
from sklearn.linear_model import Ridge
ridge = Ridge(alpha=0.5).fit(X_poly, y)
# Lasso回归 (L1正则)
from sklearn.linear_model import Lasso
lasso = Lasso(alpha=0.1).fit(X_poly, y) # 可产生稀疏解
# 弹性网络 (L1+L2)
from sklearn.linear_model import ElasticNet
en = ElasticNet(alpha=0.1, l1_ratio=0.5).fit(X_poly, y)
九、统计学视角 vs 机器学习视角
维度 | 统计学视角 | 机器学习视角 |
---|---|---|
目标 | 解释变量关系 | 预测未来结果 |
重点 | 系数显著性检验 | 预测准确性 |
数据要求 | 严格假设检验 | 更多关注特征工程 |
模型选择 | 基于p值/信息准则 | 交叉验证 |
正则化 | 较少使用 | 标准配置 |
十、典型应用场景
-
房价预测:基于面积/位置/房龄等特征
-
销售预测:历史销量与市场指标的关系
-
金融风控:用户特征与违约概率的关联
-
医疗分析:临床指标与疾病风险的量化
-
工业控制:工艺参数与产品质量的映射
黄金法则:
始终检查线性假设(散点图+残差图)
分类变量必须独热编码(One-Hot Encoding)
数值特征需标准化(StandardScaler)
使用交叉验证防止过拟合
正则化是处理高维数据的必备工具
线性回归虽简单,却是理解更复杂模型的基石。掌握其数学原理、实现方法、诊断技巧和优化策略,将为你构建稳健的预测模型打下坚实基础。