204_从回归到分类:Softmax 回归、损失函数与多分类实战

在处理分类任务时,我们不仅需要模型告诉我们"它是哪一类",还需要模型给出"它是每一类的概率有多大"。这就是 Softmax 回归大显身手的地方。

1. Softmax 回归:虽然叫回归,其实是分类

Softmax 回归通过将输出层映射到 (0, 1) 区间内,且保证所有类别的概率之和等于 1

核心公式:

对于输出向量
,Softmax 的计算方式为:

这意味着每个类别的得分都被转化成了概率,得分越高,概率越大。


2. 交叉熵损失函数 (Cross-Entropy Loss)

在分类问题中,我们不再使用均方误差(MSE),而是使用交叉熵。它专门用来衡量两个概率分布(预测分布与真实分布)之间的差异。

  • 逻辑:如果预测正确类的概率趋近于 1,损失就趋近于 0;如果预测概率趋近于 0,损失就会变得无穷大。

3. 代码实战:简洁实现 Fashion-MNIST 分类

文件展示了如何利用 PyTorch 的 nn.Sequential 快速搭建一个单层神经网络,并处理 28x28 像素的图像分类。

Python

复制代码
import torch
from torch import nn
from d2l import torch as d2l

# 1. 搭建网络
# nn.Flatten() 将 28*28 的矩阵展平为 784 维向量
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

# 2. 初始化参数
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01) # 以正态分布初始化权重

net.apply(init_weights)

# 3. 定义损失函数与优化器
# 注意:PyTorch 的 CrossEntropyLoss 内部集成了 Softmax 计算,
# 因此网络最后一层不需要再手动加 Softmax 层,直接输出 logits 即可。
loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

# 4. 训练模型 (使用 d2l 封装好的训练函数)
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

4. 关键细节解析

为什么使用 nn.Flatten()

输入的 Fashion-MNIST 图片是 的矩阵,但全连接层(Linear)只接收一维向量。Flatten 能够把二维矩阵"拉直",而不破坏数据。

数值稳定性问题

在手动实现 Softmax 时,指数运算

容易导致数值溢出(太大或太小)。PyTorch 的 nn.CrossEntropyLoss 在底层通过数学变形(Log-Sum-Exp 技巧)解决了这个问题,因此直接使用官方损失函数更加稳健


5. 总结:多分类任务的"标准姿势"

  1. 输入处理 :展平图像 (Flatten)。
  2. 输出设计:输出节点数等于类别总数。
  3. 损失选择 :必选 CrossEntropyLoss
  4. 评价标准:使用准确率(Accuracy)而非 Loss 值作为最终衡量模型好坏的指标。

💡 学习小结

Softmax 回归是理解更复杂的深度学习模型(如多层感知机 MLP、卷积神经网络 CNN)的必经之路。掌握了它,你就掌握了计算机"识物"的最基础逻辑。

相关推荐
一只专注api接口开发的技术猿9 小时前
OpenClaw 对接淘宝商品 API,低成本实现全天候选品监控|附可运行 Python 实操代码
大数据·开发语言·数据库·python
weixin_446260859 小时前
ACTS:代理链式思考 Steering 用于高效且可控的 LLM 推理
人工智能
xingpanvip9 小时前
星盘接口开发文档:马盘次限盘接口指南
android·开发语言·python·php·lua
FBI HackerHarry浩9 小时前
第二阶段Day07【Python生成器、yield关键字、property、正则表达式】
开发语言·python·正则表达式
梦想不只是梦与想9 小时前
Python 中的 4 种作用域
python·作用域
动物园猫9 小时前
外墙裂缝目标检测数据集分享(适用于YOLO系列深度学习分类检测任务)
深度学习·yolo·目标检测
阿里云大数据AI技术10 小时前
MaxFrame 智驾数据处理 Pipeline Skill 正式发布:一句话生成智驾视频处理作业
人工智能
神奇小汤圆10 小时前
Hermes Agent 响应速度优化实战:从 15 秒到 2.6 秒
人工智能
TheRouter10 小时前
LLM 流式输出工程实践:SSE、背压、断流重连与JSON 流解析的 6 个生产陷阱
人工智能·json
AI浩10 小时前
OpenCV 检测流程中损坏 JPEG 图片的定位与清理
人工智能·opencv·计算机视觉