204_从回归到分类:Softmax 回归、损失函数与多分类实战

在处理分类任务时,我们不仅需要模型告诉我们"它是哪一类",还需要模型给出"它是每一类的概率有多大"。这就是 Softmax 回归大显身手的地方。

1. Softmax 回归:虽然叫回归,其实是分类

Softmax 回归通过将输出层映射到 (0, 1) 区间内,且保证所有类别的概率之和等于 1

核心公式:

对于输出向量
,Softmax 的计算方式为:

这意味着每个类别的得分都被转化成了概率,得分越高,概率越大。


2. 交叉熵损失函数 (Cross-Entropy Loss)

在分类问题中,我们不再使用均方误差(MSE),而是使用交叉熵。它专门用来衡量两个概率分布(预测分布与真实分布)之间的差异。

  • 逻辑:如果预测正确类的概率趋近于 1,损失就趋近于 0;如果预测概率趋近于 0,损失就会变得无穷大。

3. 代码实战:简洁实现 Fashion-MNIST 分类

文件展示了如何利用 PyTorch 的 nn.Sequential 快速搭建一个单层神经网络,并处理 28x28 像素的图像分类。

Python

复制代码
import torch
from torch import nn
from d2l import torch as d2l

# 1. 搭建网络
# nn.Flatten() 将 28*28 的矩阵展平为 784 维向量
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

# 2. 初始化参数
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01) # 以正态分布初始化权重

net.apply(init_weights)

# 3. 定义损失函数与优化器
# 注意:PyTorch 的 CrossEntropyLoss 内部集成了 Softmax 计算,
# 因此网络最后一层不需要再手动加 Softmax 层,直接输出 logits 即可。
loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

# 4. 训练模型 (使用 d2l 封装好的训练函数)
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

4. 关键细节解析

为什么使用 nn.Flatten()

输入的 Fashion-MNIST 图片是 的矩阵,但全连接层(Linear)只接收一维向量。Flatten 能够把二维矩阵"拉直",而不破坏数据。

数值稳定性问题

在手动实现 Softmax 时,指数运算

容易导致数值溢出(太大或太小)。PyTorch 的 nn.CrossEntropyLoss 在底层通过数学变形(Log-Sum-Exp 技巧)解决了这个问题,因此直接使用官方损失函数更加稳健


5. 总结:多分类任务的"标准姿势"

  1. 输入处理 :展平图像 (Flatten)。
  2. 输出设计:输出节点数等于类别总数。
  3. 损失选择 :必选 CrossEntropyLoss
  4. 评价标准:使用准确率(Accuracy)而非 Loss 值作为最终衡量模型好坏的指标。

💡 学习小结

Softmax 回归是理解更复杂的深度学习模型(如多层感知机 MLP、卷积神经网络 CNN)的必经之路。掌握了它,你就掌握了计算机"识物"的最基础逻辑。

相关推荐
是大强1 分钟前
NCNN简介
人工智能
数字游民95273 分钟前
gpt image 2怎么用?3个案例+使用方法
人工智能·ai·数字游民9527
Dxy12393102166 分钟前
Python 如何使用 XPath 定位元素:从入门到实战
python
minhuan9 分钟前
大模型反向优化传统算法:用大模型学习传统算法的缺陷,反向迭代算法逻辑.152
人工智能·大模型算法应用·大模型反向优化传统算法·算法优化方案
新缸中之脑17 分钟前
用Remotion构建AI生成视频
人工智能·音视频
belldeep18 分钟前
Blender + AI 全套工作流
人工智能·ai·blender
何陋轩19 分钟前
【重磅】悟空来了:国产AI编程助手深度测评,能否吊打Copilot?
人工智能·算法·面试
AI医影跨模态组学20 分钟前
如何将深度学习MRI表型与iCCA淋巴结转移的生物学机制(KRAS突变、MUC5AC、免疫抑制微环境、大导管亚型)关联,并解释其对治疗响应的意义
人工智能·深度学习·机器学习·论文·医学·医学影像
weixin_4249993622 分钟前
mysql行级锁失效的原因排查_检查查询条件与执行计划
jvm·数据库·python
GreenTea24 分钟前
DeepSeek-V4 技术报告深度分析:基础研究创新全景
前端·人工智能·后端