204_从回归到分类:Softmax 回归、损失函数与多分类实战

在处理分类任务时,我们不仅需要模型告诉我们"它是哪一类",还需要模型给出"它是每一类的概率有多大"。这就是 Softmax 回归大显身手的地方。

1. Softmax 回归:虽然叫回归,其实是分类

Softmax 回归通过将输出层映射到 (0, 1) 区间内,且保证所有类别的概率之和等于 1

核心公式:

对于输出向量
,Softmax 的计算方式为:

这意味着每个类别的得分都被转化成了概率,得分越高,概率越大。


2. 交叉熵损失函数 (Cross-Entropy Loss)

在分类问题中,我们不再使用均方误差(MSE),而是使用交叉熵。它专门用来衡量两个概率分布(预测分布与真实分布)之间的差异。

  • 逻辑:如果预测正确类的概率趋近于 1,损失就趋近于 0;如果预测概率趋近于 0,损失就会变得无穷大。

3. 代码实战:简洁实现 Fashion-MNIST 分类

文件展示了如何利用 PyTorch 的 nn.Sequential 快速搭建一个单层神经网络,并处理 28x28 像素的图像分类。

Python

复制代码
import torch
from torch import nn
from d2l import torch as d2l

# 1. 搭建网络
# nn.Flatten() 将 28*28 的矩阵展平为 784 维向量
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

# 2. 初始化参数
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01) # 以正态分布初始化权重

net.apply(init_weights)

# 3. 定义损失函数与优化器
# 注意:PyTorch 的 CrossEntropyLoss 内部集成了 Softmax 计算,
# 因此网络最后一层不需要再手动加 Softmax 层,直接输出 logits 即可。
loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

# 4. 训练模型 (使用 d2l 封装好的训练函数)
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

4. 关键细节解析

为什么使用 nn.Flatten()

输入的 Fashion-MNIST 图片是 的矩阵,但全连接层(Linear)只接收一维向量。Flatten 能够把二维矩阵"拉直",而不破坏数据。

数值稳定性问题

在手动实现 Softmax 时,指数运算

容易导致数值溢出(太大或太小)。PyTorch 的 nn.CrossEntropyLoss 在底层通过数学变形(Log-Sum-Exp 技巧)解决了这个问题,因此直接使用官方损失函数更加稳健


5. 总结:多分类任务的"标准姿势"

  1. 输入处理 :展平图像 (Flatten)。
  2. 输出设计:输出节点数等于类别总数。
  3. 损失选择 :必选 CrossEntropyLoss
  4. 评价标准:使用准确率(Accuracy)而非 Loss 值作为最终衡量模型好坏的指标。

💡 学习小结

Softmax 回归是理解更复杂的深度学习模型(如多层感知机 MLP、卷积神经网络 CNN)的必经之路。掌握了它,你就掌握了计算机"识物"的最基础逻辑。

相关推荐
人工智能AI技术2 小时前
字节开源 DeerFlow 2.0——登顶 GitHub Trending 1,让 AI 可做任何事情
人工智能
spider'2 小时前
系统的架构
人工智能
莱歌数字2 小时前
强化学习如何重构芯片热管理?
人工智能·重构·制造·cae·散热
光仔December2 小时前
【从0学习Spring AI Alibaba】2、Spring AI Alibaba版本选型及环境搭建
人工智能·大模型·saa·spring ai·ai alibaba
源码之家2 小时前
计算机毕业设计:基于Python的汽车数据可视化分析系统 Django框架 Scrapy爬虫 可视化 车辆 懂车帝大数据 数据分析 机器学习(建议收藏)✅
python·信息可视化·django·flask·汽车·课程设计·美食
我的xiaodoujiao2 小时前
API 接口自动化测试详细图文教程学习系列8--测试接口
python·学习·测试工具·pytest
凸头2 小时前
从“搜了就答”到“智能决策”:拥抱 RAG 2.0 时代的架构演进 ——Java 后端工程师视角下的 AI 应用工程化落地
java·人工智能·架构·rag
蓝色的杯子2 小时前
免费体验GPT5.4效果
python·chatgpt
逐渐会飞2 小时前
如何用python在word插入复选框
python·word