《动手学深度学习(PyTorch版)》笔记7.5

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter7 Modern Convolutional Neural Networks

7.5 Batch Normalization

批量规范化应用于单个可选层(也可以应用到所有层),其原理如下:

在每次训练迭代中,首先规范化输入,即通过减去其均值并除以其标准差(两者均基于当前小批量处理)。接下来应用比例系数和比例偏移。请注意,如果使用大小为1的小批量应用批量规范化,将无法学到任何东西,这是因为在减去均值之后,每个隐藏单元将为0,所以只有使用足够大的小批量,批量规范化才有效且稳定。用 x ∈ B \mathbf{x} \in \mathcal{B} x∈B表示一个来自小批量 B \mathcal{B} B的输入,批量规范化 B N \mathrm{BN} BN根据下式转换 x \mathbf{x} x:

B N ( x ) = γ ⊙ x − μ ^ B σ ^ B + β . \mathrm{BN}(\mathbf{x}) = \boldsymbol{\gamma} \odot \frac{\mathbf{x} - \hat{\boldsymbol{\mu}}\mathcal{B}}{\hat{\boldsymbol{\sigma}}\mathcal{B}} + \boldsymbol{\beta}. BN(x)=γ⊙σ^Bx−μ^B+β.

在上式中, μ ^ B \hat{\boldsymbol{\mu}}\mathcal{B} μ^B是小批量 B \mathcal{B} B的样本均值, σ ^ B \hat{\boldsymbol{\sigma}}\mathcal{B} σ^B是小批量 B \mathcal{B} B的样本标准差。应用标准化后,生成的小批量的平均值为0和方差为1。由于单位方差是一个主观选择的结果,因此通常包含拉伸参数 (scale) γ \boldsymbol{\gamma} γ和偏移参数 (shift) β \boldsymbol{\beta} β,它们的形状与 x \mathbf{x} x相同,且是需要与其他模型参数一起学习的参数。

μ ^ B = 1 ∣ B ∣ ∑ x ∈ B x , σ ^ B 2 = 1 ∣ B ∣ ∑ x ∈ B ( x − μ ^ B ) 2 + ϵ . \begin{aligned} \hat{\boldsymbol{\mu}}\mathcal{B} &= \frac{1}{|\mathcal{B}|} \sum{\mathbf{x} \in \mathcal{B}} \mathbf{x},\\ \hat{\boldsymbol{\sigma}}\mathcal{B}^2 &= \frac{1}{|\mathcal{B}|} \sum{\mathbf{x} \in \mathcal{B}} (\mathbf{x} - \hat{\boldsymbol{\mu}}_{\mathcal{B}})^2 + \epsilon.\end{aligned} μ^Bσ^B2=∣B∣1x∈B∑x,=∣B∣1x∈B∑(x−μ^B)2+ϵ.

我们在方差估计值中添加一个小的常量 ϵ > 0 \epsilon > 0 ϵ>0,以确保永远不会除以零。此外,估计值 μ ^ B \hat{\boldsymbol{\mu}}\mathcal{B} μ^B和 σ ^ B {\hat{\boldsymbol{\sigma}}\mathcal{B}} σ^B还会通过使用关于平均值和方差的噪声(noise)来抵消缩放问题。乍看起来,这种噪声是一个问题,而事实上它是有益的。

7.5.1 Batch Normalize Layer

批量规范化层和其他层之间的一个关键区别是:由于批量规范化在完整的小批量上运行,因此我们不能像以前在引入其他层时那样忽略批量大小。我们在下面讨论全连接层和卷积层的批量规范化实现。

7.5.1.1 Fully Connected Layer

我们通常将批量规范化层置于全连接层中的仿射变换和激活函数之间。设全连接层的输入为x,权重参数和偏置参数分别为 W \mathbf{W} W和 b \mathbf{b} b,激活函数为 ϕ \phi ϕ,批量规范化的运算符为 B N \mathrm{BN} BN,那么使用批量规范化的全连接层的输出的计算公式如下:
h = ϕ ( B N ( W x + b ) ) . \mathbf{h} = \phi(\mathrm{BN}(\mathbf{W}\mathbf{x} + \mathbf{b}) ). h=ϕ(BN(Wx+b)).

7.5.1.2 Convolutional Layer

对于卷积层,可以在卷积操作之后和非线性激活函数之前应用批量规范化。当卷积有多个输出通道时,我们需要对这些通道的每个输出执行批量规范化。每个通道都有自己的拉伸(scale)和偏移(shift)参数,两者都是标量。假设小批量包含 m m m个样本,并且对于每个通道,卷积的输出具有高度 p p p和宽度 q q q,那么对于卷积层,在每个输出通道的 m ⋅ p ⋅ q m \cdot p \cdot q m⋅p⋅q个元素上同时执行每个批量规范化。因此在计算平均值和方差时,我们会收集所有空间位置的值,然后在给定通道内应用相同的均值和方差,以便在每个空间位置对值进行规范化。

7.5.2 Batch Normalization During Prediction

批量规范化在训练模式和预测模式下的行为和结果通常不同。在训练过程中,我们无法得知使用整个数据集来估计平均值和方差,所以只能根据每个小批次的平均值和方差不断训练模型,一种常用的方法是通过移动平均(具体方法在batch_norm()函数中定义)估算整个训练数据集的样本均值和方差,并在预测时使用它们得到确定的输出。而在将训练好的模型用于预测时,我们可以根据整个数据集精确计算批量规范化所需的平均值和方差,不再需要样本均值中的噪声以及在微批次上估计每个小批次产生的样本方差了。

复制代码
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下,用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data

class BatchNorm(nn.Module):
    # num_features:完全连接层的输出数量或卷积层的输出通道数。
    # num_dims:2表示完全连接层,4表示卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        #参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        #这些参数是nn.Parameter类型,因此会被PyTorch的优化器所管理,梯度更新是自动进行的
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # 如果X不在内存上,就将moving_mean和moving_var复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y
    
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
    nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
    nn.Linear(84, 10))

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

print(net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,)))

#简洁实现
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
    nn.Linear(84, 10))

d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

训练结果:

相关推荐
amhjdx3 小时前
星巽短剧以科技赋能影视创新,构建全球短剧新生态!
人工智能·科技
听风南巷3 小时前
机器人全身控制WBC理论及零空间原理解析(数学原理解析版)
人工智能·数学建模·机器人
美林数据Tempodata4 小时前
“双新”指引,AI驱动:工业数智应用生产性实践创新
大数据·人工智能·物联网·实践中心建设·金基地建设
电科_银尘4 小时前
【大语言模型】-- 私有化部署
人工智能·语言模型·自然语言处理
惊讶的猫5 小时前
LSTM论文解读
开发语言·python
潇冉沐晴5 小时前
div2 1052 个人补题笔记
笔记
测试老哥5 小时前
软件测试之单元测试知识总结
自动化测试·软件测试·python·测试工具·职场和发展·单元测试·测试用例
翔云 OCR API5 小时前
人工智能驱动下的OCR API技术演进与实践应用
人工智能·ocr
buvsvdp50059ac5 小时前
如何在VSCode中设置Python解释器?
ide·vscode·python
南方者6 小时前
重磅升级!文心 ERNIE-5.0 新一代原生全模态大模型,这你都不认可它吗?!
人工智能·aigc