逻辑回归 (Logistic Regression)
核心思想
逻辑回归用于二分类问题 ,输出一个概率值 p∈(0,1)p \in (0, 1)p∈(0,1),表示样本属于正类的概率。
核心公式
1. Sigmoid 函数
将线性组合映射到 (0,1)(0, 1)(0,1):
σ(z)=11+e−z\sigma(z) = \frac{1}{1 + e^{-z}}σ(z)=1+e−z1
2. 预测概率
y^=σ(wTx+b)=11+e−(wTx+b)\hat{y} = \sigma(w^T x + b) = \frac{1}{1 + e^{-(w^T x + b)}}y^=σ(wTx+b)=1+e−(wTx+b)1
- xxx:输入特征向量
- www:权重向量
- bbb:偏置项
- y^\hat{y}y^:预测为正类的概率
3. 决策规则
预测类别={1y^≥0.50y^<0.5 \text{预测类别} = \begin{cases} 1 & \hat{y} \geq 0.5 \\ 0 & \hat{y} < 0.5 \end{cases} 预测类别={10y^≥0.5y^<0.5
简单例子
问题:根据学习时长(小时)预测是否通过考试(1=通过,0=不通过)
| 学习时长 xxx | 是否通过 yyy |
|---|---|
| 1 | 0 |
| 2 | 0 |
| 4 | 1 |
| 5 | 1 |
假设训练后得到参数:w=2w = 2w=2,b=−6b = -6b=−6
预测学习 3 小时的结果:
z=2×3+(−6)=0z = 2 \times 3 + (-6) = 0z=2×3+(−6)=0
y^=11+e0=12=0.5\hat{y} = \frac{1}{1 + e^{0}} = \frac{1}{2} = 0.5y^=1+e01=21=0.5
边界情况,刚好 0.5,预测为通过。
预测学习 1 小时的结果:
z=2×1−6=−4z = 2 \times 1 - 6 = -4z=2×1−6=−4
y^=11+e4≈0.018\hat{y} = \frac{1}{1 + e^{4}} \approx 0.018y^=1+e41≈0.018
概率极低,预测为不通过。
训练过程
第一步:定义损失函数(对数损失)
对单个样本:
L(y^,y)=−[ylog(y^)+(1−y)log(1−y^)]L(\hat{y}, y) = -\left[ y \log(\hat{y}) + (1 - y) \log(1 - \hat{y}) \right]L(y^,y)=−[ylog(y^)+(1−y)log(1−y^)]
对整个数据集(nnn 个样本)取平均:
J(w,b)=−1n∑i=1n[y(i)log(y^(i))+(1−y(i))log(1−y^(i))]J(w, b) = -\frac{1}{n} \sum_{i=1}^{n} \left[ y^{(i)} \log(\hat{y}^{(i)}) + (1 - y^{(i)}) \log(1 - \hat{y}^{(i)}) \right]J(w,b)=−n1i=1∑n[y(i)log(y^(i))+(1−y(i))log(1−y^(i))]
为什么用对数损失而不用均方误差?
均方误差在 sigmoid 输出上会导致梯度消失,对数损失是凸函数,保证梯度下降收敛。
第二步:梯度下降更新参数
计算梯度:
∂J∂w=1n∑i=1n(y^(i)−y(i))x(i)\frac{\partial J}{\partial w} = \frac{1}{n} \sum_{i=1}^{n} (\hat{y}^{(i)} - y^{(i)}) x^{(i)}∂w∂J=n1i=1∑n(y^(i)−y(i))x(i)
∂J∂b=1n∑i=1n(y^(i)−y(i))\frac{\partial J}{\partial b} = \frac{1}{n} \sum_{i=1}^{n} (\hat{y}^{(i)} - y^{(i)})∂b∂J=n1i=1∑n(y^(i)−y(i))
更新规则(α\alphaα 为学习率):
w←w−α⋅∂J∂ww \leftarrow w - \alpha \cdot \frac{\partial J}{\partial w}w←w−α⋅∂w∂J
b←b−α⋅∂J∂bb \leftarrow b - \alpha \cdot \frac{\partial J}{\partial b}b←b−α⋅∂b∂J
第三步:迭代直到收敛
初始化 w = 0, b = 0
repeat:
计算所有样本的 ŷ
计算损失 J
计算梯度 ∂J/∂w, ∂J/∂b
更新 w, b
until 损失不再明显下降
整体流程图
输入特征 x
↓
线性计算: z = wᵀx + b
↓
Sigmoid: ŷ = 1/(1+e⁻ᶻ)
↓
对比标签 y → 计算损失 J
↓
反向传播 → 计算梯度
↓
梯度下降 → 更新 w, b
↓
重复迭代直到收敛
关键要点总结
| 要素 | 内容 |
|---|---|
| 输出 | 概率值 (0,1)(0, 1)(0,1) |
| 激活函数 | Sigmoid |
| 损失函数 | 对数损失(Binary Cross-Entropy) |
| 优化方法 | 梯度下降 |
| 适用场景 | 二分类(垃圾邮件、疾病诊断等) |