PyTorch -- nn.Module 快速实践

  • 网络层父类 nn.Module:Evrey Layer is nn.Module (nn.Linear, nn,Conv2d ...)
  • 具体的,我们在定义自已的网络时:需要继承 nn.Module,并重新实现
    • __init__ 方法: 一般放置网络中具有可学习参数的层(如全连接层、卷积层等)
      • 也可放置不具有可学习参数的层(如ReLU、dropout 等);or 直接在 forward 方法 直接用 nn.functional 来代替
      • 除了基础模块,还可以用 nn.Sequential 来定义复合层
    • forward 方法:实现各个层之间的连接关系

经典代码示例:

python3 复制代码
import torch
 
class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()  
        self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu1=torch.nn.ReLU()
        self.max_pooling1=torch.nn.MaxPool2d(2,1)
 
        self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu2=torch.nn.ReLU()
        self.max_pooling2=torch.nn.MaxPool2d(2,1)
 
        self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
        self.dense2 = torch.nn.Linear(128, 10)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)  					## 写法一
        # x = torch.relu(x)  			    ## 写法二 (__init__ 中不需要定义 relf.relu1) 
        # x = torch.nn.functional.relu(x)  	## 写法三 (__init__ 中不需要定义 relf.relu1) 
        x = self.max_pooling1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.max_pooling2(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x
 
model = MyNet()  

print(model) 的结果如下所示:可参照结果对照理解

MyNet(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (max_pooling1): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (max_pooling2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  (dense1): Linear(in_features=288, out_features=128, bias=True)
  (dense2): Linear(in_features=128, out_features=10, bias=True)
)

Module 类的常见方法:
  • .children(), .named_children(): 返回模型的直接子模块,不包括嵌套的子模块

    • 适合快速查看模型的主要组成部分
    python3 复制代码
    >>> list(model.children())[0]
    Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))  # 对应 `conv1`
  • .modules(), .named_modules(): 递归遍历模型本身及其所有嵌套的子模块

    • 适合查看模型的完整结构。
    python3 复制代码
    >>> list(model.modules())[0]  
    MyNet(  # 对应模型本身
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu1): ReLU()
      (max_pooling1): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
      (conv2): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu2): ReLU()
      (max_pooling2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
      (dense1): Linear(in_features=288, out_features=128, bias=True)
      (dense2): Linear(in_features=128, out_features=10, bias=True)
    )

此外,使用 nn.Module 定义 model 可以使下列操作方便实现:

  • 可转移到指定 device = torch.device('cuda),使用 model.to(device)
  • 可加载和保存 model 参数:
    • 加载: model.load_state_dict(torch.load('xxx.mdl'))
    • 保存: torch.save(model.state_dict(), 'xxx.del')
  • 可进行状态转化 train/test:model.train()model.eval() 一键切换

相关推荐
GOTXX6 分钟前
基于Opencv的图像处理软件
图像处理·人工智能·深度学习·opencv·卷积神经网络
IT古董11 分钟前
【人工智能】Python在机器学习与人工智能中的应用
开发语言·人工智能·python·机器学习
CV学术叫叫兽26 分钟前
快速图像识别:落叶植物叶片分类
人工智能·分类·数据挖掘
湫ccc34 分钟前
《Python基础》之pip换国内镜像源
开发语言·python·pip
hakesashou36 分钟前
Python中常用的函数介绍
java·网络·python
菜鸟的人工智能之路1 小时前
极坐标气泡图:医学数据分析的可视化新视角
python·数据分析·健康医疗
菜鸟学Python1 小时前
Python 数据分析核心库大全!
开发语言·python·数据挖掘·数据分析
小白不太白9501 小时前
设计模式之 责任链模式
python·设计模式·责任链模式
WeeJot嵌入式1 小时前
卷积神经网络:深度学习中的图像识别利器
人工智能
喜欢猪猪1 小时前
Django:从入门到精通
后端·python·django