逻辑回归学习笔记-数学直接解回归方程

逻辑回归(Linear Regression)学习笔记

一、模型公式

y^=w0+∑i=1nwixi\hat{y} = w_0 + \sum_{i=1}^{n} w_i x_iy^=w0+i=1∑nwixi

  • 时间复杂度: O(n) --- 只需遍历一次所有特征
  • w0w_0w0: 偏置项(截距)
  • wiw_iwi: 各特征的权重参数

核心问题:给定一堆 (x,y)(x, y)(x,y) 样本,如何求解 w0w_0w0 和 wiw_iwi?


二、损失函数选择

损失函数衡量模型预测值与真实值的差距,选择合适的损失函数至关重要。

2.1 四大类损失函数

任务类型 常用损失函数 适用场景
回归 (连续值预测) MSE、MAE、Huber Loss 房价、销量、温度预测
分类 (离散标签) Binary CrossEntropy、Categorical CrossEntropy、Focal Loss 二分类、多分类、样本不均衡
排序 (推荐/搜索) BPR Loss、RankNet Loss 推荐系统、个性化排序
自监督/对比学习 InfoNCE、Triplet Loss 特征提取、相似度学习

2.2 快速选择指南

三问法则

  1. 回归还是分类?

    • 回归 → MSE(默认)/ MAE(有异常值)/ Huber(两者结合)
    • 分类 → CrossEntropy(标准)/ Focal(样本不均衡)
  2. 数据是否均衡?

    • 均衡 → 标准损失
    • 不均衡 → 加权损失 / Focal Loss / Dice Loss
  3. 是否需要概率输出?

    • 需要概率 → CrossEntropy
    • 只需分数 → MSE / MAE

2.3 损失函数对比

损失函数 优点 缺点 典型场景
MSE 梯度平滑、易优化 对异常点敏感 房价预测
MAE 抗异常值强 零点梯度不平滑 噪声数据
BCE 输出概率、稳定 样本不均衡效果差 CTR预估
Focal 专注难样本 调参略麻烦 检测、推荐

注意: MSE(均方误差) ≠ 方差。方差描述数据波动,MSE 描述模型误差。


三、参数求解方法

方法一:正规方程(数学直接解)

适用场景:数据量小、特征少。

推导过程

将 y=w0+w1xy = w_0 + w_1 xy=w0+w1x 改写为矩阵形式:

y=XWy = XWy=XW

其中:
X=[1x11x2⋮⋮],W=[w0w1]X = \begin{bmatrix} 1 & x_1 \\ 1 & x_2 \\ \vdots & \vdots \end{bmatrix}, \quad W = \begin{bmatrix} w_0 \\ w_1 \end{bmatrix}X= 11⋮x1x2⋮ ,W=[w0w1]

求解公式:

W=(XTX)−1XTyW = (X^T X)^{-1} X^T yW=(XTX)−1XTy

前提条件 : XXX 必须是列满秩矩阵(即 rank(X)=d\text{rank}(X) = drank(X)=d,特征之间线性无关)。

手工计算示例

样本数据:

i x y
1 1 2
2 2 4
3 3 6

X=[111213],y=[246]X = \begin{bmatrix} 1 & 1 \\ 1 & 2 \\ 1 & 3 \end{bmatrix}, \quad y = \begin{bmatrix} 2 \\ 4 \\ 6 \end{bmatrix}X= 111123 ,y= 246

计算步骤:

  1. XTX=[36614]X^T X = \begin{bmatrix} 3 & 6 \\ 6 & 14 \end{bmatrix}XTX=[36614]
  2. (XTX)−1=[14/6−1−10.5](X^T X)^{-1} = \begin{bmatrix} 14/6 & -1 \\ -1 & 0.5 \end{bmatrix}(XTX)−1=[14/6−1−10.5]
  3. XTy=[1228]X^T y = \begin{bmatrix} 12 \\ 28 \end{bmatrix}XTy=[1228]
  4. W=[02]W = \begin{bmatrix} 0 \\ 2 \end{bmatrix}W=[02]

结果:w0=0w_0 = 0w0=0,w1=2w_1 = 2w1=2,即 y=2xy = 2xy=2x(完美拟合)

Python 实现

python 复制代码
import numpy as np

X = np.array([[1, 1], [1, 2], [1, 3]])  # 第一列全1(偏置项)
y = np.array([2, 4, 6])

# 正规方程求解
W = np.linalg.inv(X.T @ X) @ X.T @ y
print(f"最优参数:w0={W[0]}, w1={W[1]}")
print(f"MSE:{np.mean((y - X @ W) ** 2)}")

四、两种方法对比

方法 正规方程 梯度下降
计算方式 一次性求解 多次迭代
适用规模 小规模数据 大规模数据
矩阵要求 必须可逆 无要求
时间复杂度 O(n的3次方)O(n的3次方)O(n的3次方) O(kn)O(kn)O(kn),k为迭代次数
特征数限制 不适合高维 适合高维

五、关键知识点补充

矩阵乘法口诀

前列后行要相等,前行后列定形状。

(m×n)×(n×p)=m×p(m \times n) \times (n \times p) = m \times p(m×n)×(n×p)=m×p

列满秩判断

rank(X)=d\text{rank}(X) = drank(X)=d(特征数)→ 列满秩 → XTXX^T XXTX 可逆


相关推荐
weixin_qq_163951362 小时前
hypermill五轴后处理制作需要学习哪些知识点
学习·ug
wsjsf3 小时前
智能代码审查助手的搭建
java·学习·ai编程
xuhaoyu_cpp_java3 小时前
MyBatis学习(二)
java·经验分享·笔记·学习·mybatis
我是发哥哈3 小时前
主流AI视频生成方案商用化能力横向评测
大数据·人工智能·学习·机器学习·chatgpt·音视频
楼田莉子3 小时前
CMake学习:CMake语法
c++·后端·学习·软件构建
nashane3 小时前
HarmonyOS 6学习:加密一致性与安全存储——AES GCM排查与SaveButton实践
学习·安全·harmonyos·harmony app
周末也要写八哥3 小时前
编程初学者学习:句柄(二)
学习
我登哥MVP4 小时前
【SpringMVC笔记】 - 11 - SpringMVC 执行流程
java·spring boot·笔记·spring·tomcat·intellij-idea
道长爱睡懒觉4 小时前
蓝牙,导航,仪表,TBOX,OTA
笔记