一、文档概述
本文档由孙靖钺编写,聚焦PyTorch 神经网络工具箱的核心应用,从神经网络基础组件出发,逐步讲解工具使用、模型构建、模块自定义及训练流程,通过代码示例与运行结果,清晰呈现 PyTorch 搭建神经网络的关键步骤与逻辑。
二、神经网络核心组件
组件 | 定义与作用 |
---|---|
层 | 神经网络的基本结构,负责将输入张量通过数据变换(如权重运算)转换为输出张量 |
模型 | 由多个层按特定逻辑组合而成的完整网络,实现从输入到预测输出的端到端流程 |
损失函数 | 参数学习的目标函数,用于衡量预测值(Y')与真实值(Y)的差异,需通过优化最小化 |
优化器 | 用于调整模型参数(如权重),使损失函数值达到最小的工具 |
三、构建神经网络的主要工具
1. 两大核心工具对比
对比维度 | nn.Module | nn.functional |
---|---|---|
本质特性 | 面向对象(继承自 nn.Module) | 纯函数(无类属性) |
适用场景 | 卷积层、全连接层、dropout 层 | 激活函数、池化层 |
写法格式 | nn.Xxx(如 nn.Linear、nn.Conv2d) | nn.functional.xxx(如 nn.functional.linear) |
参数管理 | 自动定义 / 管理 weight、bias,无需手动传入 | 需手动定义 weight、bias,调用时需传入 |
nn.Sequential 兼容性 | 支持,可直接组合使用 | 不支持,无法嵌入 nn.Sequential 容器 |
dropout 状态转换 | 调用 model.eval () 后自动切换为测试状态 | 无自动转换,需手动控制 |
2. 关键工具功能
- nn.Module :核心基类,继承后可自动追踪网络中的可学习参数,简化参数管理;例如
nn.Linear(in_features=784, out_features=300)
定义全连接层时,权重和偏置会自动初始化。 - nn.functional :提供纯函数式的操作,需依赖外部参数;例如
F.relu(x)
实现 ReLU 激活函数,仅对输入张量 x 进行计算,无额外参数存储。
四、模型构建的三种方式
方式 1:继承 nn.Module 基类构建
-
核心逻辑 :通过自定义类继承
nn.Module
,在__init__
方法中初始化网络层,在forward
方法中定义前向传播路径。 -
代码示例要点 :
python
class Model_Seq(nn.Module): def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim): super().__init__() self.flatten = nn.Flatten() # 展平层,处理28×28输入 self.linear1 = nn.Linear(in_dim, n_hidden_1) # 全连接层1 self.bn1 = nn.BatchNorm1d(n_hidden_1) # 批量归一化层1 self.linear2 = nn.Linear(n_hidden_1, n_hidden_2) # 全连接层2 self.bn2 = nn.BatchNorm1d(n_hidden_2) # 批量归一化层2 self.out = nn.Linear(n_hidden_2, out_dim) # 输出层 def forward(self, x): x = self.flatten(x) # 展平:28×28→784 x = F.relu(self.bn1(self.linear1(x))) # 层1+BN+激活 x = F.relu(self.bn2(self.linear2(x))) # 层2+BN+激活 x = F.softmax(self.out(x), dim=1) # 输出层+softmax return x
-
超参数 :
in_dim=28*28
(输入维度,对应 28×28 图像)、n_hidden_1=300
(第一隐藏层神经元数)、n_hidden_2=100
(第二隐藏层神经元数)、out_dim=10
(输出维度,对应 10 分类任务)。 -
运行结果 :可打印模型结构,显示各层名称、类型及参数(如
linear1: Linear(in_features=784, out_features=300, bias=True)
)。
方式 2:使用 nn.Sequential 按层顺序构建
nn.Sequential
是层的有序容器,按传入顺序执行前向传播,支持三种子方式:
- 可变参数方式 :直接传入层实例,无自定义层名称。
- 示例:
Seq_arg = nn.Sequential(nn.Flatten(), nn.Linear(in_dim, n_hidden_1), ..., F.softmax(dim=1))
- 示例:
- add_module 方法 :通过
add_module("层名称", 层实例)
自定义层名称。- 示例:
Seq_module.add_module("flatten", nn.Flatten())
、Seq_module.add_module("linear1", nn.Linear(in_dim, n_hidden_1))
- 示例:
- OrderedDict 方法 :通过
collections.OrderedDict
传入键值对(层名称:层实例),确保层顺序与名称对应。
- 运行结果 :三种方式均能生成有序层结构,带自定义名称的方式可在打印模型时清晰区分各层(如
(flatten): Flatten(start_dim=1, end_dim=-1)
)。
方式 3:继承 nn.Module + 模型容器构建
在自定义nn.Module
子类中,嵌入nn.Sequential
、nn.ModuleList
、nn.ModuleDict
等容器管理层,平衡灵活性与简洁性:
- nn.Sequential 容器 :将关联层组合为子模块,简化
forward
逻辑。- 示例:
self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1))
- 示例:
- nn.ModuleList 容器 :以列表形式存储层,需通过循环调用各层。
- 示例:
self.layers = nn.ModuleList([nn.Flatten(), nn.Linear(in_dim, n_hidden_1), ...])
,forward
中for layer in self.layers: x = layer(x)
- 示例:
- nn.ModuleDict 容器 :以字典形式存储层(键为层名称),需按预设顺序调用。
- 示例:
self.layers_dict = nn.ModuleDict({"flatten": nn.Flatten(), "linear1": nn.Linear(...)})
,forward
中按列表["flatten", "linear1", ...]
循环调用
- 示例:
五、自定义网络模块(以残差块与 ResNet18 为例)
1. 两种残差块设计
- 类型 1:RestNetBasicBlock(正常残差块)
- 适用场景:输入与输出张量形状(通道数、分辨率)一致时。
- 结构:2 个 3×3 卷积层(
kernel_size=3
、padding=1
)+ 对应批量归一化层(nn.BatchNorm2d
),前向传播中输入x
直接与卷积输出相加后过 ReLU 激活。
- 类型 2:RestNetDownBlock(维度调整残差块)
- 适用场景:输入与输出形状不一致(需调整通道数或分辨率)时。
- 特殊设计:增加
extra
子模块(1×1 卷积层 + 批量归一化层),用于将输入x
转换为与卷积输出一致的形状,再进行相加激活。 - 关键参数:
stride=[2,1]
(第一卷积步长 2 用于降分辨率,第二卷积步长 1 保持尺寸)。
2. ResNet18 网络结构
通过组合上述残差块,构建经典 ResNet18 模型,结构如下:
- 初始层:
conv1
(3→64 通道,kernel_size=7
、stride=2
、padding=3
)→bn1
(批量归一化)→maxpool
(最大池化,kernel_size=3
、stride=2
) - 残差块层:
layer1
:2 个 RestNetBasicBlock(64→64 通道,stride=1
)layer2
:1 个 RestNetDownBlock(64→128 通道,stride=[2,1]
)+ 1 个 RestNetBasicBlock(128→128 通道)layer3
:1 个 RestNetDownBlock(128→256 通道)+ 1 个 RestNetBasicBlock(256→256 通道)layer4
:1 个 RestNetDownBlock(256→512 通道)+ 1 个 RestNetBasicBlock(512→512 通道)
- 输出层:
avgpool
(自适应平均池化,output_size=(1,1)
)→ 展平(reshape
)→fc
(全连接层,512→10 通道,对应 10 分类)
六、模型训练流程(6 个关键步骤)
- 加载预处理数据集:准备训练与测试数据,需提前完成数据标准化、划分等预处理操作。
- 定义损失函数:根据任务类型选择(如分类任务用交叉熵损失,回归任务用 MSE 损失)。
- 定义优化方法:选择合适的优化器(如 SGD、Adam),用于更新模型参数以最小化损失。
- 循环训练模型 :迭代训练数据,执行前向传播(计算预测值)、反向传播(计算梯度)、参数更新(
optimizer.step()
)。 - 循环测试 / 验证模型:在测试集 / 验证集上评估模型性能(如准确率),监控过拟合情况。
- 可视化结果:通过图表(如损失曲线、准确率曲线)展示训练过程与模型性能,辅助分析优化。
4. 关键问题与答案
问题 1:nn.Module 与 nn.functional 作为 PyTorch 构建神经网络的核心工具,二者的核心区别是什么?这对实际模型开发有何影响?
答案:二者核心区别有 3 点:
- 参数管理:nn.Xxx(nn.Module 子类)自动定义和管理 weight、bias 参数,无需手动传入;nn.functional.xxx 需手动定义参数且调用时必须传入,代码复用性低。
- nn.Sequential 兼容性:nn.Xxx 可直接嵌入 nn.Sequential 容器,简化层的有序组合;nn.functional.xxx 不支持 nn.Sequential,需手动串联层操作。
- dropout 状态转换 :nn.Dropout(nn.Module 子类)在调用
model.eval()
后会自动切换为测试状态(关闭 dropout);nn.functional.dropout 需手动通过training
参数控制状态,易遗漏导致测试误差异常。影响:开发中优先用 nn.Module 构建含可学习参数的层(如卷积、全连接),保证参数管理便捷性与容器兼容性;仅在无参数的操作(如 ReLU、池化)中用 nn.functional,平衡灵活性与代码简洁性。
问题 2:文档中提到基于 nn.Module 构建模型有三种方式,分别是什么?请对比三种方式的适用场景与灵活性差异。
答案:三种方式及对比如下:
-
纯继承 nn.Module 基类:
- 实现:
__init__
初始化单个层,forward
定义完整前向传播逻辑。 - 适用场景:需高度自定义前向传播(如多分支结构、动态层调用)的复杂模型(如残差网络、注意力网络)。
- 灵活性:最高,可自由设计层间交互逻辑,但代码冗余度较高。
- 实现:
-
使用 nn.Sequential 按层顺序:
- 实现:通过可变参数、add_module 或 OrderedDict 组合层,无需自定义
forward
(容器自动按顺序执行)。 - 适用场景:前向传播为 "线性串联" 的简单模型(如单分支全连接网络、基础 CNN)。
- 灵活性:最低,仅支持线性层顺序,无法实现分支或动态逻辑,但代码最简洁。
- 实现:通过可变参数、add_module 或 OrderedDict 组合层,无需自定义
-
继承 nn.Module + 模型容器(nn.Sequential/ModuleList/ModuleDict):
- 实现:在自定义 nn.Module 子类中,用容器管理一组关联层,
forward
中调用容器或循环调用层。 - 适用场景:中等复杂度模型(如多子模块的网络),需平衡代码简洁性与逻辑灵活性(如将残差块用 nn.Sequential 封装,再组合成 ResNet)。
- 灵活性:中等,既减少单一层定义的冗余,又支持一定程度的自定义逻辑(如 ModuleList 循环调用、ModuleDict 按条件调用)。
- 实现:在自定义 nn.Module 子类中,用容器管理一组关联层,
问题 3:文档中自定义的两种残差块(RestNetBasicBlock 与 RestNetDownBlock)设计目的有何不同?RestNet18 网络如何通过这两种残差块实现特征提取与维度调整?
答案:
-
残差块设计目的差异:
- RestNetBasicBlock:目的是在不改变特征图形状(通道数、分辨率) 的前提下,增强特征提取能力;通过 "输入直接与卷积输出相加" 的残差连接,缓解梯度消失问题,适用于同一特征维度的持续提取。
- RestNetDownBlock:目的是实现特征图维度调整(通道数增加、分辨率降低) ,同时保持残差连接;通过 1×1 卷积的
extra
模块将输入调整为与卷积输出一致的形状,满足网络深层对更大感受野、更多通道特征的需求。
-
RestNet18 的应用逻辑:
- 初始层(conv1+maxpool):将 3 通道输入(如 RGB 图像)转换为 64 通道、低分辨率特征图,奠定特征提取基础。
layer1
:2 个 RestNetBasicBlock(64→64 通道,stride=1),在相同维度下深化特征提取,不改变分辨率。layer2-layer4
:每一层以 1 个 RestNetDownBlock 开头(分别实现 64→128、128→256、256→512 通道提升,stride=[2,1] 实现分辨率减半),后续接 1 个 RestNetBasicBlock;通过这种组合,逐步增加通道数(提升特征表达能力)、降低分辨率(扩大感受野),最终通过 avgpool 和全连接层输出分类结果。