Softmax 回归

softmax 回归是机器学习另外一个非常经典且重要的模型,是一个分类问题。
下面先解释一下分类和回归的区别:
简单来说,分类问题从回归的单输出变成了多输出,输出的个数等于类别的个数。
实际上,对于分类来说,我们不关心它们之间实际的值,我们关心的是:模型是否对正确类别的置信度特别的大
虽然上述没有要求 O i O_i Oi 是一个什么样的值,但是如果我们将值放在合适的区间,也会让后续的处理变得更加的简单,比如下面我们希望模型的输出是一个概率:
上述要是你使用了 o n e − h o t one-hot one−hot 编码的话,只有当 i = y i=y i=y时, y i = 1 y_i = 1 yi=1,否则就是0。
损失函数
损失函数是用来衡量预测值与真实值之间的区别,是机器学习里面一个非常重要的概念。
1. L2 Loss(均方损失)

蓝色的线表示 y = 0 y=0 y=0 时变换我的 预测值 y ′ y' y′ 所生成的函数,可以看出来是一个二次函数。绿色是一个似然函数,似然函数取得最大值表明取该参数模型最合理。橙色的表示的是损失函数的梯度,由于是一次函数,穿过原点。
由上述可以发现,当预测值与真实值距离比较远的时候,梯度比较的大,则对参数的更新是比较的多的,当越靠近原点的时候,梯度的绝对值就会越小,对参数的更新就会越来越小。但这可能并不是一件好事,因为在离原点越远的地方,我可能并不希望需要那么大的梯度来更新我的参数。因此也可以考虑下面的 L1 Loss
L1 Loss

当然也是可以提出新的损失函数来结合上述两种损失函数的好处。
上述损失函数定义的好处就是:当预测值与真实值差别比较大的时候,我可以以均匀的力度
往回拉。当两者越来越接近时,我可以使得拉的力度越来越小,从而不会出现数值上的问题。
图片分类数据集
下面使用 Fashion-MNIST 数据集,展示对数据集的一般操作:
首先导入所需的库:
python
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
d2l.use_svg_display()
# 使用svg来显示图片
接着我们可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。
python
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor() # 预处理,将图片转换成tensor
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
# transform=trans希望得到的是一个tensor而不是一张图片
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
Fashion-MNIST由10个类别的图像组成,每个类别由训练数据集 (train dataset)中的6000张图像和测试数据集 (test dataset)中的1000张图像组成。因此,训练集和测试集分别包含60000和10000张图像。测试数据集不会用于训练,只用于评估模型性能。
每个输入图像的高度和宽度均为28像素。
数据集由灰度图像组成,其通道数为1。
为了简洁起见,将高度 h h h像素、宽度 w w w像素图像的形状记为 h × w h \times w h×w或( h h h, w w w)。
接着定义两个可视化数据集的函数
Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。
以下函数用于在数字标签索引及其文本名称之间进行转换。
python
def get_fashion_mnist_labels(labels): #@save
"""返回Fashion-MNIST数据集的文本标签"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""绘制图像列表"""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# 图片张量
ax.imshow(img.numpy())
else:
# PIL图片
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
以下展示训练数据集中前几个样本的图像及其相应的标签。
为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。
在每次迭代中,数据加载器每次都会读取一小批量数据,大小为batch_size
。
通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。
python
batch_size = 256
def get_dataloader_workers(): #@save
"""使用4个进程来读取数据"""
return 4
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers())
timer = d2l.Timer() # 用来测试速度
for X, y in train_iter:
continue
f'{timer.stop():.2f} sec'

在模型训练之前,一般都是需要测试数据读取的速度,数据读取的速度需要比模型的训练速度更快才好。
基于上述内容,现在我们定义load_data_fashion_mnist
函数 ,用于获取和读取Fashion-MNIST数据集。这个函数返回训练集 和验证集的数据迭代器 。此外,这个函数还接受一个可选参数resize
,用来将图像大小调整为另一种形状。
python
def load_data_fashion_mnist(batch_size, resize=None): #@save
"""下载Fashion-MNIST数据集,然后将其加载到内存中"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=get_dataloader_workers()))

Softmax 回归从0开始实现
python
import torch
from IPython import display
from d2l import torch as d2l
batch_size = 256 # 每次随机读取256张图片
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size) # 前面实现过
由于图像是12828的,但是对于softmax来说,输入的需要是一个向量。(但是这种操作会损失很多空间信息,卷积部分解决。)因此我们将展平每个图像,把它们看作长度为784的向量。数据集有十个类别,因此网络输出维度就是10。
python
num_inputs = 784 # 将空间拉长,28*28拉成784的一个向量
num_outputs = 10
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
# 行数为输入的个数,列数等于输出的个数
b = torch.zeros(num_outputs, requires_grad=True)
# 对每一个输出,都需要有一个偏移
下面定义 softmax 操作:
实现softmax由三个步骤组成:
- 对每个项求幂(使用
exp
); - 对每一行求和(小批量中每个样本是一行),得到每个样本的规范化常数;
- 将每一行除以其规范化常数,确保结果的和为1。
表达式如下:
s o f t m a x ( X ) i j = exp ( X i j ) ∑ k exp ( X i k ) . \mathrm{softmax}(\mathbf{X}){ij} = \frac{\exp(\mathbf{X}{ij})}{\sum_k \exp(\mathbf{X}_{ik})}. softmax(X)ij=∑kexp(Xik)exp(Xij).分母或规范化常数,有时也称为配分函数 (其对数称为对数-配分函数)。该名称来自统计物理学中一个模拟粒子群分布的方程。
python
def softmax(X):
X_exp = torch.exp(X) # 对X中的每个元素作指数运算
partition = X_exp.sum(1, keepdim=True) # 按照每一行进行求和
return X_exp / partition # 这里应用了广播机制

定义softmax操作后,可以实现softmax回归模型 。
下面的代码定义了输入如何通过网络映射到输出。
注意,将数据传递到模型之前,我们使用reshape
函数将每张原始图像展平为向量。
python
def net(X):
return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)
首先回顾一下交叉熵:
交叉熵采用真实标签的预测概率的负对数似然。这里我们不使用Python的for循环迭代预测(这往往是低效的),而是通过一个运算符选择所有元素。
下面,**创建一个数据样本y_hat
,其中包含2个样本在3个类别的预测概率,以及它们对应的标签y
。**有了y
,我们知道在第一个样本中,第一类是正确的预测;而在第二个样本中,第三类是正确的预测。然后(使用y
作为y_hat
中概率的索引),我们选择第一个样本中第一个类的概率和第二个样本中第三个类的概率。
python
y = torch.tensor([0, 2]) # 表示两个样本的真实标签分别为0、2
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]
# 对第0个样本,拿出y[0]对应的那个元素,对第一个样本,拿出y[1]对应的那个元素
# [0, 1] 是一个索引列表,表示要选取 y_hat 中的第一行和第二行。

基于上述,我们下面来实现交叉熵损失函数:
python
# 了解交叉熵公式和代码上述原理,一行代码即可完成。
def cross_entropy(y_hat, y):
return - torch.log(y_hat[range(len(y_hat)), y])
cross_entropy(y_hat, y)

由于上述是分类问题,因此需要将预测类别与真实 y y y 元素进行比较:
python
def accuracy(y_hat, y): #@save
"""计算预测正确的数量"""
# 要是 y_hat 是一个二维矩阵且列数也大于1
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
y_hat = y_hat.argmax(axis=1) # 按每一行来存最大值的下标
cmp = y_hat.type(y.dtype) == y # 将 y_hat 转换为 y 的数据类型,然后作比较
return float(cmp.type(y.dtype).sum()) # 返回预测正确的样本数
我们将继续使用之前定义的变量y_hat
和y
分别作为预测的概率分布和标签。可以看到,第一个样本的预测类别是2(该行的最大元素为0.6,索引为2),这与实际标签0不一致。第二个样本的预测类别是2(该行的最大元素为0.5,索引为2),这与实际标签2一致。因此,这两个样本的分类精度率为0.5。
同样,对于任意数据迭代器data_iter
可访问的数据集,可以评估在任意模型net
的精度。
python
def evaluate_accuracy(net, data_iter): #@save
"""计算在指定数据集上模型的精度"""
if isinstance(net, torch.nn.Module):
net.eval() # 将模型设置为评估模式
metric = Accumulator(2) # 正确预测数、预测总数
with torch.no_grad():
for X, y in data_iter:
metric.add(accuracy(net(X), y), y.numel())
return metric[0] / metric[1]
这里定义一个实用程序类Accumulator
,用于对多个变量进行累加。在上面的evaluate_accuracy
函数中,我们在(Accumulator
实例中创建了2个变量,分别用于存储正确预测的数量和预测的总数量)。当我们遍历数据集时,两者都将随着时间的推移而累加。
python
class Accumulator: #@save
"""在n个变量上累加"""
def __init__(self, n):
self.data = [0.0] * n
def add(self, *args):
self.data = [a + float(b) for a, b in zip(self.data, args)]
def reset(self):
self.data = [0.0] * len(self.data)
def __getitem__(self, idx):
return self.data[idx]

下面就可以进行 softmax 的回归训练了:
python
def train_epoch_ch3(net, train_iter, loss, updater): #@save
"""训练模型一个迭代周期(定义见第3章)"""
# 将模型设置为训练模式
if isinstance(net, torch.nn.Module):
net.train()
# 训练损失总和、训练准确度总和、样本数
metric = Accumulator(3)
for X, y in train_iter:
# 计算梯度并更新参数
y_hat = net(X)
l = loss(y_hat, y)
if isinstance(updater, torch.optim.Optimizer):
# 使用PyTorch内置的优化器和损失函数
updater.zero_grad()
l.mean().backward()
updater.step()
else:
# 使用定制的优化器和损失函数
l.sum().backward()
updater(X.shape[0])
metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
# 返回训练损失和训练精度
return metric[0] / metric[2], metric[1] / metric[2]
在展示训练函数的实现之前,我们[定义一个在动画中绘制数据的实用程序类 ]Animator
python
class Animator: #@save
"""在动画中绘制数据"""
def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
ylim=None, xscale='linear', yscale='linear',
fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
figsize=(3.5, 2.5)):
# 增量地绘制多条线
if legend is None:
legend = []
d2l.use_svg_display()
self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
if nrows * ncols == 1:
self.axes = [self.axes, ]
# 使用lambda函数捕获参数
self.config_axes = lambda: d2l.set_axes(
self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
self.X, self.Y, self.fmts = None, None, fmts
def add(self, x, y):
# 向图表中添加多个数据点
if not hasattr(y, "__len__"):
y = [y]
n = len(y)
if not hasattr(x, "__len__"):
x = [x] * n
if not self.X:
self.X = [[] for _ in range(n)]
if not self.Y:
self.Y = [[] for _ in range(n)]
for i, (a, b) in enumerate(zip(x, y)):
if a is not None and b is not None:
self.X[i].append(a)
self.Y[i].append(b)
self.axes[0].cla()
for x, y, fmt in zip(self.X, self.Y, self.fmts):
self.axes[0].plot(x, y, fmt)
self.config_axes()
display.display(self.fig)
display.clear_output(wait=True)
下面开始训练:
python
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): #@save
"""训练模型(定义见第3章)"""
animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
legend=['train loss', 'train acc', 'test acc'])
for epoch in range(num_epochs):
train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
test_acc = evaluate_accuracy(net, test_iter)
animator.add(epoch + 1, train_metrics + (test_acc,))
train_loss, train_acc = train_metrics
assert train_loss < 0.5, train_loss
assert train_acc <= 1 and train_acc > 0.7, train_acc
assert test_acc <= 1 and test_acc > 0.7, test_acc
**小批量随机梯度下降来优化模型的损失函数**\],设置学习率为0.1 ```python lr = 0.1 def updater(batch_size): return d2l.sgd([W, b], lr, batch_size) ```  对图像进行预测: ```python def predict_ch3(net, test_iter, n=6): #@save """预测标签""" for X, y in test_iter: break trues = d2l.get_fashion_mnist_labels(y) preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1)) titles = [true +'\n' + pred for true, pred in zip(trues, preds)] d2l.show_images( X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n]) predict_ch3(net, test_iter) ```  ## Softmax 回归的简洁实现 通过深度学习框架的高级API也能更方便地实现softmax回归模型: ```python import torch from torch import nn from d2l import torch as d2l batch_size = 256 train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size) ``` Softmax 回归的输出层是一个全连接层 ```python # PyTorch不会隐式地调整输入的形状。因此, # 我们在线性层前定义了展平层(flatten),来调整网络输入的形状 net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)) def init_weights(m): if type(m) == nn.Linear: nn.init.normal_(m.weight, std=0.01) net.apply(init_weights); ``` 在交叉熵损失函数中传递未规范化的预测,并同时计算softmax及其对数 ```python loss = nn.CrossEntropyLoss(reduction='none')# 不进行任何减少操作,返回每个样本的损失值。 ``` 使用学习率为0.1的小批量随机梯度下降作为优化算法 ```python trainer = torch.optim.SGD(net.parameters(), lr=0.1) ``` 训练,重用之前编写的函数:  ## QA思考 Q1:softlabel训练策略。 上述被称为软标签,旨在通过使用非硬性(即不是0或1的绝对分类结果)的目标标签来提高模型的泛化能力和鲁棒性。 传统的分类任务中,目标标签通常是**one-hot编码** 的形式,即对于每个样本,正确的类别标记为1,其他类别标记为0。但是实际上对于边界值是很难达到的,比如对于softmax函数而言: softmax ( z i ) = e z i ∑ j = 1 n e z j \\text{softmax}(z_i) = \\frac{e\^{z_i}}{\\sum_{j=1}\^{n} e\^{z_j}} softmax(zi)=∑j=1nezjezi 要想使其输出为 1 ,则需要某一个 z i z_i zi ,趋近于无穷才行。 而softlabel则允许这些标签值位于(0, 1)之间,并且所有类别的概率之和通常为1。这意味着即使是错误的类别也可能被赋予一定的概率,从而向模型传达"某种程度上的正确"。比如我可以认为0.9 就是正确,0.1 就是不正确。 Q2 : softmax 回归和 logistic 回归的联系。 可以认为logistic是softmax的特例,也就是logistic是一个两分类的问题,只需要输出一个类别的概率 P P P 即可,剩下的直接 1 − P 1-P 1−P 即可。但是在实际的分类问题中,两分类的问题很少。 Q3 : 在 Accuracy函数中为啥不把除以 len(y) 做完呢? 在 Accuracy 函数中,不能直接除以 len(y),因为最后一个 batch 的样本数量可能会少于设定的 batch size。为了确保准确率计算的正确性,应该根据当前 batch 实际包含的样本数量进行归一化,而不是固定地使用完整的 batch size。 补充: 考虑到李沐老师的视线中使用到了d2l,且是在jupyter上面进行实现的,但是我现在不想用d2l,以及需要再Pycharm上面编写,于是我根据上述代码编写了下面的代码,结果也能很好的复现李沐老师代码的结果。 ```python import torch import torchvision from torchvision import transforms from torch.utils import data import matplotlib.pyplot as plt class Animator: def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None, xscale='linear', yscale='linear', fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5)): if legend is None: legend = [] self.xlabel = xlabel self.ylabel = ylabel self.legend = legend self.xlim = xlim self.ylim = ylim self.xscale = xscale self.yscale = yscale self.fmts = fmts self.figsize = figsize self.X, self.Y = [], [] def add(self, x, y): if not hasattr(y, "__len__"): y = [y] n = len(y) if not hasattr(x, "__len__"): x = [x] * n if not self.X: self.X = [[] for _ in range(n)] if not self.Y: self.Y = [[] for _ in range(n)] for i, (a, b) in enumerate(zip(x, y)): if a is not None and b is not None: self.X[i].append(a) self.Y[i].append(b) def show(self): plt.figure(figsize=self.figsize) for x_data, y_data, fmt in zip(self.X, self.Y, self.fmts): plt.plot(x_data, y_data, fmt) plt.xlabel(self.xlabel) plt.ylabel(self.ylabel) if self.legend: plt.legend(self.legend) if self.xlim: plt.xlim(self.xlim) if self.ylim: plt.ylim(self.ylim) plt.xscale(self.xscale) plt.yscale(self.yscale) plt.grid() plt.show() def get_dataloader_workers(): return 0 # 禁用多进程加载 def load_data_fashion_mnist(batch_size, resize=None): trans = [transforms.ToTensor()] if resize: trans.insert(0, transforms.Resize(resize)) trans = transforms.Compose(trans) mnist_train = torchvision.datasets.FashionMNIST("./data", train=True, transform=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST("./data", train=False, transform=trans, download=True) return ( data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()), data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()) ) # softmax 实现 def softmax(X): X_exp = torch.exp(X) partition = X_exp.sum(1, keepdim=True) return X_exp / partition # 回归模型 def net(X): return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b) # 交叉熵损失函数 def cross_entropy(y_hat, y): return -torch.log(y_hat[range(len(y_hat)), y]) # 预测正确的数量 def accuracy(y_hat, y): if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: y_hat = y_hat.argmax(axis=1) cmp = y_hat.type(y.dtype) == y return float(cmp.type(y.dtype).sum()) class Accumulator: def __init__(self, n): self.data = [0.0] * n def add(self, *args): self.data = [a + float(b) for a, b in zip(self.data, args)] def reset(self): self.data = [0.0] * len(self.data) def __getitem__(self, idx): return self.data[idx] def evaluate_accuracy(net, data_iter): if isinstance(net, torch.nn.Module): net.eval() metric = Accumulator(2) with torch.no_grad(): for X, y in data_iter: metric.add(accuracy(net(X), y), y.numel()) return metric[0] / metric[1] def train_epoch_ch3(net, train_iter, loss, updater): if isinstance(net, torch.nn.Module): net.train() metric = Accumulator(3) for X, y in train_iter: y_hat = net(X) l = loss(y_hat, y) if isinstance(updater, torch.optim.Optimizer): updater.zero_grad() l.mean().backward() updater.step() else: l.sum().backward() updater(X.shape[0]) metric.add(float(l.sum()), accuracy(y_hat, y), y.numel()) return metric[0] / metric[2], metric[1] / metric[2] def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9], legend=['train loss', 'train acc', 'test acc']) for epoch in range(num_epochs): train_metrics = train_epoch_ch3(net, train_iter, loss, updater) test_acc = evaluate_accuracy(net, test_iter) animator.add(epoch + 1, train_metrics + (test_acc,)) train_loss, train_acc = train_metrics assert train_loss < 0.5, train_loss assert train_acc <= 1 and train_acc > 0.7, train_acc assert test_acc <= 1 and test_acc > 0.7, test_acc animator.show() # 展示最终结果图 def sgd(params, lr, batch_size): with torch.no_grad(): for param in params: param -= lr * param.grad / batch_size param.grad.zero_() def updater(batch_size): return sgd([W, b], lr, batch_size) def get_fashion_mnist_labels(labels): text_labels = [ 't-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot' ] return [text_labels[int(i)] for i in labels] def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): figsize = (num_cols * scale, num_rows * scale) _, axes = plt.subplots(num_rows, num_cols, figsize=figsize) axes = axes.flatten() for i, (ax, img) in enumerate(zip(axes, imgs)): if torch.is_tensor(img): ax.imshow(img.numpy(), cmap='gray') else: ax.imshow(img, cmap='gray') ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) if titles: ax.set_title(titles[i]) plt.show() def predict_ch3(net, test_iter, n=6): for X, y in test_iter: break trues = get_fashion_mnist_labels(y) preds = get_fashion_mnist_labels(net(X).argmax(axis=1)) titles = [true + '\n' + pred for true, pred in zip(trues, preds)] show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n]) if __name__ == "__main__": # 定义超参数 batch_size = 256 num_epochs = 10 lr = 0.1 # 加载数据 train_iter, test_iter = load_data_fashion_mnist(batch_size) # 初始化模型参数 num_inputs = 784 num_outputs = 10 W = torch.normal(0, 0.1, size=(num_inputs, num_outputs), requires_grad=True) b = torch.zeros(num_outputs, requires_grad=True) # 训练模型 train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater) # 测试模型并显示预测结果 predict_ch3(net, test_iter) ```