文章目录
在 PyTorch 中,Module
是一个非常核心的概念,它是所有神经网络层和模型的基础类。torch.nn.Module
是构建所有神经网络的基类,在 PyTorch 中非常重要,因为它提供了网络的组织架构,并封装了权重、梯度的管理、模型参数的更新等功能。
PyTorch 中的 Linear 层
、ReLU 激活函数
以及大多数其他神经网络层和函数都返回 torch.Tensor
类型的对象。这些返回的张量包含了经过相应层或函数处理后的数据。在神经网络中,数据通常以张量的形式在各个层之间流动。
一、Module类介绍
所有神经网络层和模型的基础类,自定义神经网络时对其继承。
1、主要功能
-
封装参数:
Module
类在内部自动管理 层的参数 。每当你在Module
中定义一个层对象,如self.conv1 = nn.Conv2d(...)
, PyTorch 自动将这些层的参数加入到模型的参数列表中 。这些参数通过module.parameters()
方法访问。模型参数(定义在模型内部的层的权重和偏置)默认requires_grad=True
。
-
自动梯度计算:
- 每个
Module
可以使用 PyTorch 的自动微分(autograd)系统来自动计算和存储梯度。在forward
方法执行运算时,PyTorch 会跟踪这些运算产生的所有张量,对应的梯度在调用张量的backward()
方法后自动计算。由于模型参数默认requires_grad=True
,因此对这些参数的所有操作都将被进行自动梯度计算。
- 每个
-
前向传播定义:
- 在定义自己的网络时,需要覆盖
Module
的forward()
方法。这是模型接收输入数据并返回输出的地方 。forward()
方法定义了模型的前向传播路径。
- 在定义自己的网络时,需要覆盖
-
模型保存和加载 :
模型的保存和加载是在 PyTorch 中进行模型持久化和迁移学习的常用操作。模型可以保存为
.pt
或.pth
文件,包括其参数、优化器状态和其他任何相关的信息。
-
保存模型:
- 最简单的保存方法是使用
torch.save
来保存模型的state_dict
,这是一个包含模型参数的字典。
pythontorch.save(model.state_dict(), 'model_path.pth')
- 值得注意的是, 这条命令只是用来保存模型参数的,因此在加载参数时,需要使用同样的模型使用
load_state_dict()
才可。
- 最简单的保存方法是使用
-
加载模型:
- 加载模型时,首先需要实例化模型对象,然后使用
load_state_dict()
方法加载参数。
pythonmodel = MyModel() model.load_state_dict(torch.load('model_path.pth'))
- 加载模型时,首先需要实例化模型对象,然后使用
- 将模型移动到指定的设备 :
在 PyTorch 中,可以将模型和数据移动到不同的设备上(如 CPU 或 GPU),以支持不同的计算需求。
-
使用
.to()
方法可以将模型移动到指定的设备:pythondevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model.to(device)
- 切换模型的训练和评估方式
torch.nn.Module
提供了.train()
和.eval()
方法,用于切换模型的训练和评估模式。
-
训练模式 (
train
):- 在训练模式下,所有的层都被通知模型正在训练 ,这对于某些特定层(如
Dropout
和BatchNorm
)非常重要,因为它们在训练和评估时的行为不同。
pythonmodel.train()
- 在训练模式下,所有的层都被通知模型正在训练 ,这对于某些特定层(如
-
评估模式 (
eval
):- 评估模式用于模型测试或验证阶段,确保所有层都处于评估状态。
pythonmodel.eval()
.parameters()
方法
-
parameters()
方法返回一个迭代器,包含模型中所有的参数(通常用于传递给优化器)。pythonfor param in model.parameters(): print(param.size())
.modules()
方法
-
modules()
方法返回一个迭代器,遍历模型中的所有模块(层)。这在分析模型结构或应用特定操作到每一层时非常有用。pythonfor module in model.modules(): print(module)
2、神经网络模型使用理解
白话:
损失函数和优化器都不是module
类中的方法,而是外部的方法 ,但是他们都能够作用于模型的权重:由于自动微分,损失函数接收的是结果张量,因此损失函数带来的梯度会被更新给权重的梯度。而优化器接受的是module
对象参数的迭代器,它能根据参数的梯度对参数进行更新。
自定义神经网络,实际上就是定义一个类 ,该类继承自torch.nn.Module
。在对这个类进行实例化时,是使用__init__
默认构造函数实例化的。实例化后得到一个神经网络对象,对该对象输入数据会被重载为输入forward
函数,而forward
函数就是对输入数据进行一层一层的网络层结构处理。forward
函数的输出一般是对输入进行了前向传播后的结果,为了对模型参数进行训练更新,我们一般还需要定义一个损失函数;这个损失函数是torch.tensor
类型的,可以调用其backward
函数,进行反向传播梯度;最后定义一个优化器,进行参数权重更新(实际上这里反向传播梯度 就 相当于损失函数对权重进行求导了,改变权重的方向就是让损失更小的方向。)
反向传播并不更新参数,优化器才是用来更新参数的,反向传播只是更新梯度。 这也是为什么优化器有一个学习率。教程:张量的梯度计算
非白话 自定义神经网络的流程:
-
定义一个类,继承自
torch.nn.Module
:- 这个类是您自定义神经网络的基础。通过继承
torch.nn.Module
,您的网络能够利用 PyTorch 提供的模块化、参数管理、梯度计算等强大功能。
- 这个类是您自定义神经网络的基础。通过继承
-
在
__init__
方法中初始化网络层:- 这是定义神经网络结构的地方。您可以添加诸如全连接层 (
nn.Linear
), 卷积层 (nn.Conv2d
), 激活函数 (nn.ReLU
) 等。这些层将被自动注册为模块的子项,使其参数也自动成为模型的一部分。
- 这是定义神经网络结构的地方。您可以添加诸如全连接层 (
-
定义
forward
方法:forward
方法描述了输入数据如何通过定义的层传播。这个方法是在模型训练和评估时自动被调用的,用于前向传播计算输出。
-
损失函数和反向传播:
- 在训练阶段,网络输出通过一个损失函数 (
loss function
) 评估其与真实标签的差异。常用的损失函数有nn.CrossEntropyLoss
(用于分类任务)和nn.MSELoss
(用于回归任务)。 - 调用损失张量的
.backward()
方法启动自动梯度计算,即反向传播。在这一过程中,PyTorch 根据损失函数自动计算每个参数的梯度,并存储在参数的.grad
属性中。 - 在每次迭代后,需要手动清空梯度,以便下一次迭代。如果不清空梯度,梯度会累积,导致不正确的参数更新。清空梯度:
optimizer.zero_grad()
。
- 在训练阶段,网络输出通过一个损失函数 (
-
参数更新:
- 使用一个优化器(如
torch.optim.SGD
或torch.optim.Adam
)来调整网络参数,基于计算的梯度进行更新,以减少损失函数的值。这通常在调用.backward()
后进行。
- 使用一个优化器(如
a.前向传播示例代码
损失函数和优化器的例子请看:神经网络训练过程代码详解
下面是一个简单的自定义 Module
的例子,定义了一个包含两个全连接层的简单神经网络。
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
# 定义第一个全连接层
self.fc1 = nn.Linear(16, 12)
# 定义第二个全连接层
self.fc2 = nn.Linear(12, 10)
def forward(self, x):
# 第一个全连接层的激活函数使用ReLU
x = F.relu(self.fc1(x))
# 第二个全连接层的输出
x = self.fc2(x)
return x
# 实例化网络
net = SimpleNet()#__init__()里并不需要参数。默认构造函数不需要参数,net就是一个实例化对象。
# 创建一些随机输入数据
input = torch.randn(1, 16)
# 通过网络进行前向传播
output = net(input)#实际上直接使用对象名(),重载为:调用forward函数。
#input先经过一个nn.Linear(16,12),然后进行一次relu(),然后经过一个nn.Linear(12,10)
b.关键点
- 继承 :自定义的模型需要继承自
nn.Module
。 - 超类初始化 :使用
super()
初始化基类,这是在 Python 类中常见的做法,确保正确初始化父类部分。 - 定义层:在构造函数中定义网络所需的各种层。
- 前向传播 :在
forward
方法中定义数据如何通过网络。