动手学深度学习(李沐)笔记:Softmax 回归简洁实现(PyTorch 版)

上一节我们已经把 Softmax 回归从零实现 跑通了:

手写了参数 (W,b)、手写了 softmax、手写了交叉熵、手写了 SGD。

这一节要做的事很简单:
把这些"手写部分"全部换成 PyTorch 官方封装。

你会发现,训练逻辑完全没变,只是代码更短、更稳、更适合工程开发。


1. 这节到底"简洁"在哪?

从零实现里你手写的是:

  • W, b

  • softmax

  • cross_entropy

  • sgd

而简洁实现里分别对应:

  • nn.Linear

  • nn.CrossEntropyLoss

  • torch.optim.SGD

也就是说:

从零实现 简洁实现
X @ W + b nn.Linear(784, 10)
softmax + cross_entropy nn.CrossEntropyLoss()
手写 sgd torch.optim.SGD()

2. 准备数据:Fashion-MNIST

和上一节一样,先加载图片分类数据集。

复制代码
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

batch_size = 256
trans = transforms.ToTensor()

train_dataset = datasets.FashionMNIST(
    root="./data", train=True, download=True, transform=trans
)
test_dataset = datasets.FashionMNIST(
    root="./data", train=False, download=True, transform=trans
)

train_iter = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_iter = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

如果你在 Windows 上跑卡住,把 num_workers=2 改成 0


3. 定义模型:nn.Flatten + nn.Linear

Fashion-MNIST 每张图是 (1, 28, 28),而线性层需要二维输入 (batch, features)

所以先拉平,再做线性分类。

复制代码
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 10)
)

这里的含义就是:

\\text{logits} = XW + b

输出形状是 (batch, 10),表示 10 个类别的分数。


4. 参数初始化

李沐常用的初始化方式是:权重小随机数、偏置为 0。

复制代码
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, std=0.01)
        nn.init.zeros_(m.bias)

net.apply(init_weights)

5. 定义损失函数:CrossEntropyLoss

这是这一节最核心的地方。

复制代码
loss = nn.CrossEntropyLoss()

为什么不用手写 softmax?

因为 CrossEntropyLoss 内部已经做了

  • log_softmax

  • 负对数似然 NLL

也就是它等价于:

\\text{CrossEntropy} = -\\log(\\text{softmax}(\\text{logits})_y)

而且内部做了数值稳定处理,比你手写 softmax + log 更安全。

这里一定记住一句话

CrossEntropyLoss 吃的是 logits,不是 softmax 后的概率。

也就是说下面这样是对的:

复制代码
logits = net(X)
l = loss(logits, y)

而下面这样是错的:

复制代码
probs = torch.softmax(net(X), dim=1)
l = loss(probs, y)   # 错

6. 定义优化器:SGD

复制代码
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

这一步就替代了你从零实现里的手写 sgd([W,b], lr, batch_size)


7. 计算准确率函数

分类任务不能只看 loss,还得看 accuracy。

复制代码
def accuracy(y_hat, y):
    pred = y_hat.argmax(dim=1)
    return (pred == y).float().mean().item()

测试集准确率:

复制代码
def evaluate_accuracy(data_iter, net):
    net.eval()
    total_acc, total_num = 0.0, 0
    with torch.no_grad():
        for X, y in data_iter:
            y_hat = net(X)
            total_acc += (y_hat.argmax(dim=1) == y).float().sum().item()
            total_num += y.numel()
    net.train()
    return total_acc / total_num

8. 训练循环:PyTorch 工程标准写法

复制代码
num_epochs = 10

for epoch in range(num_epochs):
    total_loss, total_correct, total_num = 0.0, 0.0, 0

    for X, y in train_iter:
        logits = net(X)
        l = loss(logits, y)

        trainer.zero_grad()
        l.backward()
        trainer.step()

        total_loss += l.item() * y.numel()
        total_correct += (logits.argmax(dim=1) == y).float().sum().item()
        total_num += y.numel()

    train_loss = total_loss / total_num
    train_acc = total_correct / total_num
    test_acc = evaluate_accuracy(test_iter, net)

    print(f"epoch {epoch+1}: loss {train_loss:.4f}, train acc {train_acc:.4f}, test acc {test_acc:.4f}")

正常情况下,Softmax 回归在 Fashion-MNIST 上的测试准确率大概在 0.80 ~ 0.84 左右。


9. 完整代码整理版(可以直接运行)

复制代码
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 1. 数据
batch_size = 256
trans = transforms.ToTensor()

train_dataset = datasets.FashionMNIST(
    root="./data", train=True, download=True, transform=trans
)
test_dataset = datasets.FashionMNIST(
    root="./data", train=False, download=True, transform=trans
)

train_iter = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_iter = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# 2. 模型
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 10)
)

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, std=0.01)
        nn.init.zeros_(m.bias)

net.apply(init_weights)

# 3. 损失函数与优化器
loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

# 4. 评估函数
def evaluate_accuracy(data_iter, net):
    net.eval()
    total_correct, total_num = 0.0, 0
    with torch.no_grad():
        for X, y in data_iter:
            y_hat = net(X)
            total_correct += (y_hat.argmax(dim=1) == y).float().sum().item()
            total_num += y.numel()
    net.train()
    return total_correct / total_num

# 5. 训练
num_epochs = 10

for epoch in range(num_epochs):
    total_loss, total_correct, total_num = 0.0, 0.0, 0

    for X, y in train_iter:
        logits = net(X)
        l = loss(logits, y)

        trainer.zero_grad()
        l.backward()
        trainer.step()

        total_loss += l.item() * y.numel()
        total_correct += (logits.argmax(dim=1) == y).float().sum().item()
        total_num += y.numel()

    train_loss = total_loss / total_num
    train_acc = total_correct / total_num
    test_acc = evaluate_accuracy(test_iter, net)

    print(f"epoch {epoch+1}: loss {train_loss:.4f}, train acc {train_acc:.4f}, test acc {test_acc:.4f}")

10. 从零实现 vs 简洁实现:你现在应该能看穿什么?

你不能只会"调用 API",还要知道 API 背后做了什么。

从零实现时你做的是:

  • X.reshape(-1, 784)

  • X @ W + b

  • softmax

  • cross_entropy

  • sgd

简洁实现时 PyTorch 帮你做的是:

  • nn.Flatten():拉平图片

  • nn.Linear(784,10):参数矩阵和偏置

  • nn.CrossEntropyLoss():稳定版 softmax + 交叉熵

  • optim.SGD():自动更新参数

所以你现在看简洁代码,不该觉得它是"黑盒",而应该知道它只是把你手写过的那些步骤封装起来了。


11. 这一节最常踩的坑

1)把 softmax 后的结果喂给 CrossEntropyLoss

错。

✅ 正确:

复制代码
logits = net(X)
l = loss(logits, y)

2)标签 y 的形状和类型不对

CrossEntropyLoss 需要:

  • y.shape = (batch,)

  • y.dtype = torch.long

不能是 one-hot,也不能是 float。


3)忘记 zero_grad()

梯度默认累加,如果你不清零,训练会越来越离谱。

复制代码
trainer.zero_grad()

4)训练时只看 loss,不看 acc

分类任务必须看准确率,不然 loss 降了你也不知道模型到底分对了多少。


12. 小结

这一节你真正掌握的不是"会写几行 PyTorch",而是:

  • 知道 Softmax 回归就是一个线性分类器

  • 知道 CrossEntropyLoss 为什么直接吃 logits

  • 知道简洁实现和从零实现的逐一对应关系

  • 能用 PyTorch 标准写法完成一个完整的图像分类训练循环


相关推荐
低调小一2 小时前
OpenClaw 模型配置与火山 Coding Plan 支持清单(实践笔记)
java·前端·笔记·openclaw
renhongxia12 小时前
面向开放世界的具身智能泛化能力探索
人工智能·深度学习·机器学习·架构·transformer
陈辛chenxin2 小时前
【零基础学Web-Day1】HTML 基础标签 + CSS 样式规范,附实战案例
css·经验分享·笔记·html·课程设计
Naisu Xu2 小时前
数学笔记:最小二乘法(直线拟合)
笔记·算法·最小二乘法
飞Link2 小时前
进阶时序建模:门控递归单元 (GRU) 深度解析与实战
开发语言·人工智能·rnn·深度学习·gru
猹叉叉(学习版)2 小时前
【ASP.NET CORE】 7. Identity标识框架
笔记·后端·c#·asp.net·.netcore
智者知已应修善业2 小时前
【输入矩阵将其按副对角线交换后输出】2024-11-27
c语言·c++·经验分享·笔记·线性代数·算法·矩阵
大江东去浪淘尽千古风流人物2 小时前
【claw】 OpenClaw 的架构设计探索
深度学习·算法·3d·机器人·slam
xx24062 小时前
RN学习笔记
笔记·学习