PyTorch -- nn.Module 快速实践

  • 网络层父类 nn.Module:Evrey Layer is nn.Module (nn.Linear, nn,Conv2d ...)
  • 具体的,我们在定义自已的网络时:需要继承 nn.Module,并重新实现
    • __init__ 方法: 一般放置网络中具有可学习参数的层(如全连接层、卷积层等)
      • 也可放置不具有可学习参数的层(如ReLU、dropout 等);or 直接在 forward 方法 直接用 nn.functional 来代替
      • 除了基础模块,还可以用 nn.Sequential 来定义复合层
    • forward 方法:实现各个层之间的连接关系

经典代码示例:

python3 复制代码
import torch
 
class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()  
        self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu1=torch.nn.ReLU()
        self.max_pooling1=torch.nn.MaxPool2d(2,1)
 
        self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu2=torch.nn.ReLU()
        self.max_pooling2=torch.nn.MaxPool2d(2,1)
 
        self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
        self.dense2 = torch.nn.Linear(128, 10)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)  					## 写法一
        # x = torch.relu(x)  			    ## 写法二 (__init__ 中不需要定义 relf.relu1) 
        # x = torch.nn.functional.relu(x)  	## 写法三 (__init__ 中不需要定义 relf.relu1) 
        x = self.max_pooling1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.max_pooling2(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x
 
model = MyNet()  

print(model) 的结果如下所示:可参照结果对照理解

复制代码
MyNet(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (max_pooling1): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (max_pooling2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  (dense1): Linear(in_features=288, out_features=128, bias=True)
  (dense2): Linear(in_features=128, out_features=10, bias=True)
)

Module 类的常见方法:
  • .children(), .named_children(): 返回模型的直接子模块,不包括嵌套的子模块

    • 适合快速查看模型的主要组成部分
    python3 复制代码
    >>> list(model.children())[0]
    Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))  # 对应 `conv1`
  • .modules(), .named_modules(): 递归遍历模型本身及其所有嵌套的子模块

    • 适合查看模型的完整结构。
    python3 复制代码
    >>> list(model.modules())[0]  
    MyNet(  # 对应模型本身
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu1): ReLU()
      (max_pooling1): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
      (conv2): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu2): ReLU()
      (max_pooling2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
      (dense1): Linear(in_features=288, out_features=128, bias=True)
      (dense2): Linear(in_features=128, out_features=10, bias=True)
    )

此外,使用 nn.Module 定义 model 可以使下列操作方便实现:

  • 可转移到指定 device = torch.device('cuda),使用 model.to(device)
  • 可加载和保存 model 参数:
    • 加载: model.load_state_dict(torch.load('xxx.mdl'))
    • 保存: torch.save(model.state_dict(), 'xxx.del')
  • 可进行状态转化 train/test:model.train()model.eval() 一键切换

相关推荐
吴佳浩1 小时前
什么?有人手写 Skill?Agent Skill?Skill?
人工智能·llm·agent
俊哥V5 小时前
每日 AI 研究简报 · 2026-05-21
人工智能·ai
biter down5 小时前
14:pytest-order 插件 顺序控制案例
开发语言·python·pytest
测试开发-学习笔记6 小时前
从0开始搭建自动化(一)-appium+python
python·自动化
2601_957884846 小时前
深度拆解:大模型RAG架构下,GEO优化的技术实现路径
人工智能·架构
这个DBA有点耶6 小时前
DBA的AI助手:向量检索与NL2SQL入门
数据库·人工智能·postgresql·学习方法·dba
㳺三才人子6 小时前
初探 Flask
后端·python·flask·html
YOLO数据集集合6 小时前
无人机航拍林业树种分割|单木树冠检测|三维点云|遥感影像数据集10059期
人工智能·yolo·目标检测·无人机
Pocker_Spades_A6 小时前
工业智能化的时序选型指南:当数据底座遇见机器学习
人工智能·机器学习
2601_955781986 小时前
飞书远程控机:OpenClaw配置全攻略
人工智能·开源·github·飞书·open claw安装·open claw部署