《动手学深度学习(PyTorch版)》笔记7.5

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter7 Modern Convolutional Neural Networks

7.5 Batch Normalization

批量规范化应用于单个可选层(也可以应用到所有层),其原理如下:

在每次训练迭代中,首先规范化输入,即通过减去其均值并除以其标准差(两者均基于当前小批量处理)。接下来应用比例系数和比例偏移。请注意,如果使用大小为1的小批量应用批量规范化,将无法学到任何东西,这是因为在减去均值之后,每个隐藏单元将为0,所以只有使用足够大的小批量,批量规范化才有效且稳定。用 x ∈ B \mathbf{x} \in \mathcal{B} x∈B表示一个来自小批量 B \mathcal{B} B的输入,批量规范化 B N \mathrm{BN} BN根据下式转换 x \mathbf{x} x:

B N ( x ) = γ ⊙ x − μ ^ B σ ^ B + β . \mathrm{BN}(\mathbf{x}) = \boldsymbol{\gamma} \odot \frac{\mathbf{x} - \hat{\boldsymbol{\mu}}\mathcal{B}}{\hat{\boldsymbol{\sigma}}\mathcal{B}} + \boldsymbol{\beta}. BN(x)=γ⊙σ^Bx−μ^B+β.

在上式中, μ ^ B \hat{\boldsymbol{\mu}}\mathcal{B} μ^B是小批量 B \mathcal{B} B的样本均值, σ ^ B \hat{\boldsymbol{\sigma}}\mathcal{B} σ^B是小批量 B \mathcal{B} B的样本标准差。应用标准化后,生成的小批量的平均值为0和方差为1。由于单位方差是一个主观选择的结果,因此通常包含拉伸参数 (scale) γ \boldsymbol{\gamma} γ和偏移参数 (shift) β \boldsymbol{\beta} β,它们的形状与 x \mathbf{x} x相同,且是需要与其他模型参数一起学习的参数。

μ ^ B = 1 ∣ B ∣ ∑ x ∈ B x , σ ^ B 2 = 1 ∣ B ∣ ∑ x ∈ B ( x − μ ^ B ) 2 + ϵ . \begin{aligned} \hat{\boldsymbol{\mu}}\mathcal{B} &= \frac{1}{|\mathcal{B}|} \sum{\mathbf{x} \in \mathcal{B}} \mathbf{x},\\ \hat{\boldsymbol{\sigma}}\mathcal{B}^2 &= \frac{1}{|\mathcal{B}|} \sum{\mathbf{x} \in \mathcal{B}} (\mathbf{x} - \hat{\boldsymbol{\mu}}_{\mathcal{B}})^2 + \epsilon.\end{aligned} μ^Bσ^B2=∣B∣1x∈B∑x,=∣B∣1x∈B∑(x−μ^B)2+ϵ.

我们在方差估计值中添加一个小的常量 ϵ > 0 \epsilon > 0 ϵ>0,以确保永远不会除以零。此外,估计值 μ ^ B \hat{\boldsymbol{\mu}}\mathcal{B} μ^B和 σ ^ B {\hat{\boldsymbol{\sigma}}\mathcal{B}} σ^B还会通过使用关于平均值和方差的噪声(noise)来抵消缩放问题。乍看起来,这种噪声是一个问题,而事实上它是有益的。

7.5.1 Batch Normalize Layer

批量规范化层和其他层之间的一个关键区别是:由于批量规范化在完整的小批量上运行,因此我们不能像以前在引入其他层时那样忽略批量大小。我们在下面讨论全连接层和卷积层的批量规范化实现。

7.5.1.1 Fully Connected Layer

我们通常将批量规范化层置于全连接层中的仿射变换和激活函数之间。设全连接层的输入为x,权重参数和偏置参数分别为 W \mathbf{W} W和 b \mathbf{b} b,激活函数为 ϕ \phi ϕ,批量规范化的运算符为 B N \mathrm{BN} BN,那么使用批量规范化的全连接层的输出的计算公式如下:
h = ϕ ( B N ( W x + b ) ) . \mathbf{h} = \phi(\mathrm{BN}(\mathbf{W}\mathbf{x} + \mathbf{b}) ). h=ϕ(BN(Wx+b)).

7.5.1.2 Convolutional Layer

对于卷积层,可以在卷积操作之后和非线性激活函数之前应用批量规范化。当卷积有多个输出通道时,我们需要对这些通道的每个输出执行批量规范化。每个通道都有自己的拉伸(scale)和偏移(shift)参数,两者都是标量。假设小批量包含 m m m个样本,并且对于每个通道,卷积的输出具有高度 p p p和宽度 q q q,那么对于卷积层,在每个输出通道的 m ⋅ p ⋅ q m \cdot p \cdot q m⋅p⋅q个元素上同时执行每个批量规范化。因此在计算平均值和方差时,我们会收集所有空间位置的值,然后在给定通道内应用相同的均值和方差,以便在每个空间位置对值进行规范化。

7.5.2 Batch Normalization During Prediction

批量规范化在训练模式和预测模式下的行为和结果通常不同。在训练过程中,我们无法得知使用整个数据集来估计平均值和方差,所以只能根据每个小批次的平均值和方差不断训练模型,一种常用的方法是通过移动平均(具体方法在batch_norm()函数中定义)估算整个训练数据集的样本均值和方差,并在预测时使用它们得到确定的输出。而在将训练好的模型用于预测时,我们可以根据整个数据集精确计算批量规范化所需的平均值和方差,不再需要样本均值中的噪声以及在微批次上估计每个小批次产生的样本方差了。

复制代码
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下,用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data

class BatchNorm(nn.Module):
    # num_features:完全连接层的输出数量或卷积层的输出通道数。
    # num_dims:2表示完全连接层,4表示卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        #参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        #这些参数是nn.Parameter类型,因此会被PyTorch的优化器所管理,梯度更新是自动进行的
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # 如果X不在内存上,就将moving_mean和moving_var复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y
    
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
    nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
    nn.Linear(84, 10))

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

print(net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,)))

#简洁实现
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
    nn.Linear(84, 10))

d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

训练结果:

相关推荐
芯智工坊1 分钟前
第6章 Mosquitto用户认证与访问控制
网络·人工智能·mqtt·开源
人工智能AI技术1 小时前
Gemma 4 C# 部署教程:手机/边缘设备离线运行
人工智能
青瓷程序设计2 小时前
基于YOLO的火灾烟雾检测系统~Python+目标检测+算法模型+2026原创
python·yolo·目标检测
立莹Sir2 小时前
Spring Bean 生命周期详解
java·python·spring
JHC0000008 小时前
基于Ollama,Milvus构建的建议知识检索系统
人工智能·python·milvus
mOok ONSC8 小时前
SpringBoot项目中读取resource目录下的文件(六种方法)
spring boot·python·pycharm
Flittly8 小时前
【SpringAIAlibaba新手村系列】(11)Embedding 向量化与向量数据库
java·笔记·spring·ai·springboot
ZPC82108 小时前
如何创建一个单例类 (Singleton)
开发语言·前端·人工智能
AppOS8 小时前
手把手教你 Openclaw 在 Mac 上本地化部署,保姆级教程!接入飞书打造私人 AI 助手
人工智能·macos·飞书
workflower9 小时前
AI制造-推荐初始步骤
java·开发语言·人工智能·软件工程·制造·需求分析·软件需求