上一节我们已经把 Softmax 回归从零实现 跑通了:
手写了参数 (W,b)、手写了 softmax、手写了交叉熵、手写了 SGD。
这一节要做的事很简单:
把这些"手写部分"全部换成 PyTorch 官方封装。
你会发现,训练逻辑完全没变,只是代码更短、更稳、更适合工程开发。
1. 这节到底"简洁"在哪?
从零实现里你手写的是:
-
W, b -
softmax -
cross_entropy -
sgd
而简洁实现里分别对应:
-
nn.Linear -
nn.CrossEntropyLoss -
torch.optim.SGD
也就是说:
| 从零实现 | 简洁实现 |
|---|---|
X @ W + b |
nn.Linear(784, 10) |
softmax + cross_entropy |
nn.CrossEntropyLoss() |
手写 sgd |
torch.optim.SGD() |
2. 准备数据:Fashion-MNIST
和上一节一样,先加载图片分类数据集。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
batch_size = 256
trans = transforms.ToTensor()
train_dataset = datasets.FashionMNIST(
root="./data", train=True, download=True, transform=trans
)
test_dataset = datasets.FashionMNIST(
root="./data", train=False, download=True, transform=trans
)
train_iter = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_iter = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
如果你在 Windows 上跑卡住,把 num_workers=2 改成 0。
3. 定义模型:nn.Flatten + nn.Linear
Fashion-MNIST 每张图是 (1, 28, 28),而线性层需要二维输入 (batch, features)。
所以先拉平,再做线性分类。
net = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 10)
)
这里的含义就是:
\\text{logits} = XW + b
输出形状是 (batch, 10),表示 10 个类别的分数。
4. 参数初始化
李沐常用的初始化方式是:权重小随机数、偏置为 0。
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.01)
nn.init.zeros_(m.bias)
net.apply(init_weights)
5. 定义损失函数:CrossEntropyLoss
这是这一节最核心的地方。
loss = nn.CrossEntropyLoss()
为什么不用手写 softmax?
因为 CrossEntropyLoss 内部已经做了:
-
log_softmax -
负对数似然 NLL
也就是它等价于:
\\text{CrossEntropy} = -\\log(\\text{softmax}(\\text{logits})_y)
而且内部做了数值稳定处理,比你手写 softmax + log 更安全。
这里一定记住一句话
CrossEntropyLoss 吃的是 logits,不是 softmax 后的概率。
也就是说下面这样是对的:
logits = net(X)
l = loss(logits, y)
而下面这样是错的:
probs = torch.softmax(net(X), dim=1)
l = loss(probs, y) # 错
6. 定义优化器:SGD
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
这一步就替代了你从零实现里的手写 sgd([W,b], lr, batch_size)。
7. 计算准确率函数
分类任务不能只看 loss,还得看 accuracy。
def accuracy(y_hat, y):
pred = y_hat.argmax(dim=1)
return (pred == y).float().mean().item()
测试集准确率:
def evaluate_accuracy(data_iter, net):
net.eval()
total_acc, total_num = 0.0, 0
with torch.no_grad():
for X, y in data_iter:
y_hat = net(X)
total_acc += (y_hat.argmax(dim=1) == y).float().sum().item()
total_num += y.numel()
net.train()
return total_acc / total_num
8. 训练循环:PyTorch 工程标准写法
num_epochs = 10
for epoch in range(num_epochs):
total_loss, total_correct, total_num = 0.0, 0.0, 0
for X, y in train_iter:
logits = net(X)
l = loss(logits, y)
trainer.zero_grad()
l.backward()
trainer.step()
total_loss += l.item() * y.numel()
total_correct += (logits.argmax(dim=1) == y).float().sum().item()
total_num += y.numel()
train_loss = total_loss / total_num
train_acc = total_correct / total_num
test_acc = evaluate_accuracy(test_iter, net)
print(f"epoch {epoch+1}: loss {train_loss:.4f}, train acc {train_acc:.4f}, test acc {test_acc:.4f}")
正常情况下,Softmax 回归在 Fashion-MNIST 上的测试准确率大概在 0.80 ~ 0.84 左右。
9. 完整代码整理版(可以直接运行)
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 1. 数据
batch_size = 256
trans = transforms.ToTensor()
train_dataset = datasets.FashionMNIST(
root="./data", train=True, download=True, transform=trans
)
test_dataset = datasets.FashionMNIST(
root="./data", train=False, download=True, transform=trans
)
train_iter = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_iter = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
# 2. 模型
net = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 10)
)
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.01)
nn.init.zeros_(m.bias)
net.apply(init_weights)
# 3. 损失函数与优化器
loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
# 4. 评估函数
def evaluate_accuracy(data_iter, net):
net.eval()
total_correct, total_num = 0.0, 0
with torch.no_grad():
for X, y in data_iter:
y_hat = net(X)
total_correct += (y_hat.argmax(dim=1) == y).float().sum().item()
total_num += y.numel()
net.train()
return total_correct / total_num
# 5. 训练
num_epochs = 10
for epoch in range(num_epochs):
total_loss, total_correct, total_num = 0.0, 0.0, 0
for X, y in train_iter:
logits = net(X)
l = loss(logits, y)
trainer.zero_grad()
l.backward()
trainer.step()
total_loss += l.item() * y.numel()
total_correct += (logits.argmax(dim=1) == y).float().sum().item()
total_num += y.numel()
train_loss = total_loss / total_num
train_acc = total_correct / total_num
test_acc = evaluate_accuracy(test_iter, net)
print(f"epoch {epoch+1}: loss {train_loss:.4f}, train acc {train_acc:.4f}, test acc {test_acc:.4f}")
10. 从零实现 vs 简洁实现:你现在应该能看穿什么?
你不能只会"调用 API",还要知道 API 背后做了什么。
从零实现时你做的是:
-
X.reshape(-1, 784) -
X @ W + b -
softmax -
cross_entropy -
sgd
简洁实现时 PyTorch 帮你做的是:
-
nn.Flatten():拉平图片 -
nn.Linear(784,10):参数矩阵和偏置 -
nn.CrossEntropyLoss():稳定版 softmax + 交叉熵 -
optim.SGD():自动更新参数
所以你现在看简洁代码,不该觉得它是"黑盒",而应该知道它只是把你手写过的那些步骤封装起来了。
11. 这一节最常踩的坑
1)把 softmax 后的结果喂给 CrossEntropyLoss
错。
✅ 正确:
logits = net(X)
l = loss(logits, y)
2)标签 y 的形状和类型不对
CrossEntropyLoss 需要:
-
y.shape = (batch,) -
y.dtype = torch.long
不能是 one-hot,也不能是 float。
3)忘记 zero_grad()
梯度默认累加,如果你不清零,训练会越来越离谱。
trainer.zero_grad()
4)训练时只看 loss,不看 acc
分类任务必须看准确率,不然 loss 降了你也不知道模型到底分对了多少。
12. 小结
这一节你真正掌握的不是"会写几行 PyTorch",而是:
-
知道 Softmax 回归就是一个线性分类器
-
知道
CrossEntropyLoss为什么直接吃 logits -
知道简洁实现和从零实现的逐一对应关系
-
能用 PyTorch 标准写法完成一个完整的图像分类训练循环