Pytorch:模块(Module类)

文章目录


在 PyTorch 中,Module 是一个非常核心的概念,它是所有神经网络层和模型的基础类。torch.nn.Module 是构建所有神经网络的基类,在 PyTorch 中非常重要,因为它提供了网络的组织架构,并封装了权重、梯度的管理、模型参数的更新等功能。

PyTorch 中的 Linear 层ReLU 激活函数以及大多数其他神经网络层和函数都返回 torch.Tensor 类型的对象。这些返回的张量包含了经过相应层或函数处理后的数据。在神经网络中,数据通常以张量的形式在各个层之间流动。

一、Module类介绍

所有神经网络层和模型的基础类,自定义神经网络时对其继承。

1、主要功能

  1. 封装参数

    • Module 类在内部自动管理 层的参数 。每当你在 Module 中定义一个层对象,如 self.conv1 = nn.Conv2d(...), PyTorch 自动将这些层的参数加入到模型的参数列表中 。这些参数通过 module.parameters() 方法访问。模型参数(定义在模型内部的层的权重和偏置)默认 requires_grad=True
  2. 自动梯度计算

    • 每个 Module 可以使用 PyTorch 的自动微分(autograd)系统来自动计算和存储梯度。在 forward 方法执行运算时,PyTorch 会跟踪这些运算产生的所有张量,对应的梯度在调用张量的 backward()方法后自动计算。由于模型参数默认requires_grad=True,因此对这些参数的所有操作都将被进行自动梯度计算。
  3. 前向传播定义

    • 在定义自己的网络时,需要覆盖 Moduleforward() 方法。这是模型接收输入数据并返回输出的地方forward() 方法定义了模型的前向传播路径
  4. 模型保存和加载

    模型的保存和加载是在 PyTorch 中进行模型持久化和迁移学习的常用操作。模型可以保存为 .pt.pth 文件,包括其参数、优化器状态和其他任何相关的信息。

  • 保存模型:

    • 最简单的保存方法是使用 torch.save 来保存模型的 state_dict,这是一个包含模型参数的字典。
    python 复制代码
    torch.save(model.state_dict(), 'model_path.pth')
    • 值得注意的是, 这条命令只是用来保存模型参数的,因此在加载参数时,需要使用同样的模型使用load_state_dict()才可。
  • 加载模型:

    • 加载模型时,首先需要实例化模型对象,然后使用 load_state_dict() 方法加载参数。
    python 复制代码
    model = MyModel()
    model.load_state_dict(torch.load('model_path.pth'))
  1. 将模型移动到指定的设备 :
    在 PyTorch 中,可以将模型和数据移动到不同的设备上(如 CPU 或 GPU),以支持不同的计算需求。
  • 使用 .to() 方法可以将模型移动到指定的设备:

    python 复制代码
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
  1. 切换模型的训练和评估方式
    torch.nn.Module 提供了 .train().eval() 方法,用于切换模型的训练和评估模式。
  • 训练模式 (train):

    • 在训练模式下,所有的层都被通知模型正在训练 ,这对于某些特定层(如 DropoutBatchNorm)非常重要,因为它们在训练和评估时的行为不同。
    python 复制代码
    model.train()
  • 评估模式 (eval):

    • 评估模式用于模型测试或验证阶段,确保所有层都处于评估状态。
    python 复制代码
    model.eval()
  1. .parameters()方法
  • parameters() 方法返回一个迭代器,包含模型中所有的参数(通常用于传递给优化器)。

    python 复制代码
    for param in model.parameters():
        print(param.size())
  1. .modules()方法
  • modules() 方法返回一个迭代器,遍历模型中的所有模块(层)。这在分析模型结构或应用特定操作到每一层时非常有用。

    python 复制代码
    for module in model.modules():
        print(module)

2、神经网络模型使用理解

白话:

损失函数和优化器都不是module类中的方法,而是外部的方法 ,但是他们都能够作用于模型的权重:由于自动微分,损失函数接收的是结果张量,因此损失函数带来的梯度会被更新给权重的梯度。而优化器接受的是module对象参数的迭代器,它能根据参数的梯度对参数进行更新。

自定义神经网络,实际上就是定义一个 ,该类继承自torch.nn.Module。在对这个类进行实例化时,是使用__init__默认构造函数实例化的。实例化后得到一个神经网络对象,对该对象输入数据会被重载为输入forward函数,而forward函数就是对输入数据进行一层一层的网络层结构处理。forward函数的输出一般是对输入进行了前向传播后的结果,为了对模型参数进行训练更新,我们一般还需要定义一个损失函数;这个损失函数是torch.tensor类型的,可以调用其backward函数,进行反向传播梯度;最后定义一个优化器,进行参数权重更新(实际上这里反向传播梯度 就 相当于损失函数对权重进行求导了,改变权重的方向就是让损失更小的方向。)

反向传播并不更新参数,优化器才是用来更新参数的,反向传播只是更新梯度。 这也是为什么优化器有一个学习率。教程:张量的梯度计算


非白话 自定义神经网络的流程:

  1. 定义一个类,继承自 torch.nn.Module

    • 这个类是您自定义神经网络的基础。通过继承 torch.nn.Module,您的网络能够利用 PyTorch 提供的模块化、参数管理、梯度计算等强大功能。
  2. __init__ 方法中初始化网络层

    • 这是定义神经网络结构的地方。您可以添加诸如全连接层 (nn.Linear), 卷积层 (nn.Conv2d), 激活函数 (nn.ReLU) 等。这些层将被自动注册为模块的子项,使其参数也自动成为模型的一部分。
  3. 定义 forward 方法

    • forward 方法描述了输入数据如何通过定义的层传播。这个方法是在模型训练和评估时自动被调用的,用于前向传播计算输出。
  4. 损失函数和反向传播

    • 在训练阶段,网络输出通过一个损失函数 (loss function) 评估其与真实标签的差异。常用的损失函数有 nn.CrossEntropyLoss(用于分类任务)和 nn.MSELoss(用于回归任务)。
    • 调用损失张量的 .backward() 方法启动自动梯度计算,即反向传播。在这一过程中,PyTorch 根据损失函数自动计算每个参数的梯度,并存储在参数的 .grad 属性中。
    • 在每次迭代后,需要手动清空梯度,以便下一次迭代。如果不清空梯度,梯度会累积,导致不正确的参数更新。清空梯度:optimizer.zero_grad()
  5. 参数更新

    • 使用一个优化器(如 torch.optim.SGDtorch.optim.Adam)来调整网络参数,基于计算的梯度进行更新,以减少损失函数的值。这通常在调用 .backward() 后进行。

a.前向传播示例代码

损失函数和优化器的例子请看:神经网络训练过程代码详解

下面是一个简单的自定义 Module 的例子,定义了一个包含两个全连接层的简单神经网络。

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        # 定义第一个全连接层
        self.fc1 = nn.Linear(16, 12)
        # 定义第二个全连接层
        self.fc2 = nn.Linear(12, 10)
	
    def forward(self, x):
        # 第一个全连接层的激活函数使用ReLU
        x = F.relu(self.fc1(x))
        # 第二个全连接层的输出
        x = self.fc2(x)
        return x

# 实例化网络
net = SimpleNet()#__init__()里并不需要参数。默认构造函数不需要参数,net就是一个实例化对象。
# 创建一些随机输入数据
input = torch.randn(1, 16)

# 通过网络进行前向传播
output = net(input)#实际上直接使用对象名(),重载为:调用forward函数。
#input先经过一个nn.Linear(16,12),然后进行一次relu(),然后经过一个nn.Linear(12,10)

b.关键点

  • 继承 :自定义的模型需要继承自 nn.Module
  • 超类初始化 :使用 super() 初始化基类,这是在 Python 类中常见的做法,确保正确初始化父类部分。
  • 定义层:在构造函数中定义网络所需的各种层。
  • 前向传播 :在 forward 方法中定义数据如何通过网络。
相关推荐
昨日之日200633 分钟前
Moonshine - 新型开源ASR(语音识别)模型,体积小,速度快,比OpenAI Whisper快五倍 本地一键整合包下载
人工智能·whisper·语音识别
浮生如梦_35 分钟前
Halcon基于laws纹理特征的SVM分类
图像处理·人工智能·算法·支持向量机·计算机视觉·分类·视觉检测
深度学习lover36 分钟前
<项目代码>YOLOv8 苹果腐烂识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·苹果腐烂识别
热爱跑步的恒川2 小时前
【论文复现】基于图卷积网络的轻量化推荐模型
网络·人工智能·开源·aigc·ai编程
API快乐传递者2 小时前
淘宝反爬虫机制的主要手段有哪些?
爬虫·python
阡之尘埃4 小时前
Python数据分析案例61——信贷风控评分卡模型(A卡)(scorecardpy 全面解析)
人工智能·python·机器学习·数据分析·智能风控·信贷风控
孙同学要努力6 小时前
全连接神经网络案例——手写数字识别
人工智能·深度学习·神经网络
Eric.Lee20216 小时前
yolo v5 开源项目
人工智能·yolo·目标检测·计算机视觉
其实吧37 小时前
基于Matlab的图像融合研究设计
人工智能·计算机视觉·matlab
丕羽7 小时前
【Pytorch】基本语法
人工智能·pytorch·python