【PyTorch】(三)模型的创建、参数初始化、保存和加载

文章目录

  • [1. 模型的创建](#1. 模型的创建)
    • [1.1. 模型组件](#1.1. 模型组件)
      • [1.1.1. 网络层](#1.1.1. 网络层)
      • [1.1.2. 激活函数](#1.1.2. 激活函数)
      • [1.1.3. 函数包](#1.1.3. 函数包)
      • [1.1.4. 容器](#1.1.4. 容器)
    • [1.2. 创建方法](#1.2. 创建方法)
      • [1.1.1. 通过使用模型组件](#1.1.1. 通过使用模型组件)
      • [1.1.2. 通过继承nn.Module类](#1.1.2. 通过继承nn.Module类)
    • [1.3. 将模型转移到GPU](#1.3. 将模型转移到GPU)
  • [2. 模型参数初始化](#2. 模型参数初始化)
  • [3. 模型的保存与加载](#3. 模型的保存与加载)
    • [3.1. 只保存参数](#3.1. 只保存参数)
    • [3.2. 保存模型和参数](#3.2. 保存模型和参数)

1. 模型的创建

1.1. 模型组件

1.1.1. 网络层

1.1.2. 激活函数

1.1.3. 函数包

1.1.4. 容器

1.2. 创建方法

1.1.1. 通过使用模型组件

可以直接使用模型组件快速创建模型。

python 复制代码
import torch.nn as nn

model =	nn.Linear(10, 10)
print(model)

输出结果:

bash 复制代码
Linear(in_features=10, out_features=10, bias=True)

1.1.2. 通过继承nn.Module类

在__init__方法中使用模型组件定义模型各层。必须重写forward方法实现前向传播。

python 复制代码
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 10)
        self.layer2 = nn.Linear(10, 10)
        self.layer3 = nn.Sequential(
            nn.Linear(10, 10),
            nn.ReLU(),
            nn.Linear(10, 10)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

model = Model()
print(model)

输出结果:

bash 复制代码
Model(
  (layer1): Linear(in_features=10, out_features=10, bias=True)
  (layer2): Linear(in_features=10, out_features=10, bias=True)
  (layer3): Sequential(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): ReLU()
    (2): Linear(in_features=10, out_features=10, bias=True)
  )
)

1.3. 将模型转移到GPU

方法与将数据转移到GPU类似,都有两种方法:

  1. model.to(device)
  2. mode.cuda()
python 复制代码
import torch
import torch.nn as nn

# 创建模型实例
model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10)
)

# 将模型移动到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 也可以
model = model.cuda()

2. 模型参数初始化

torch.nn.init提供了许多初始化参数的函数:

函数名 作用 参数
uniform_ 从均匀分布 U ( a , b ) U(a,b) U(a,b)中生成值,填充输入的张量 tensor, a = 0, b = 1
normal_ 从正态分布 N ( m e a n , s t d 2 ) N(mean, std^2) N(mean,std2)中生成值,填充输入的张量 tensor, mean = 0, std = 1
constant_ 用常数 v a l val val,填充输入的张量 tensor, val
eye_ 用单位矩阵,填充二维输入张量 tensor(二维)
dirac_ 用狄拉克函数,填充{3, 4, 5}维输入张量 tensor({3, 4, 5}维), groups = 1
xavier_uniform_ 从xavier均匀分布中生成值,填充输入张量 tensor, gain = 1
xavier_normal_ 从xavier正态分布中生成值,填充输入张量 tensor, gain = 1
kaiming_uniform_ 从kaiming均匀分布中生成值,填充输入张量 tensor, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu'
kaiming_normal_ 从kaiming正态分布中生成值,填充输入张量 tensor, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu'
orthogonal_ 用一个(半)正交矩阵,填充输入张量 tensor, gain = 1
sparse_ 用非零元素服从 N ( 0 , s t d 2 ) N(0, std^2) N(0,std2)的稀疏矩阵,填充二维输入张量 tensor, sparsity, std = 0.01

3. 模型的保存与加载

模型保存和加载使用的python内置的pickle模块。

3.1. 只保存参数

python 复制代码
import torch
import torch.nn as nn

# 创建模型实例
model1 = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10)
)

# 保存和加载参数
torch.save(model1.state_dict(), '../model/model_params.pkl')
model1.load_state_dict(torch.load('../model/model_params.pkl'))

3.2. 保存模型和参数

python 复制代码
import torch
import torch.nn as nn

# 创建模型实例
model1 = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10)
)

# 保存和加载模型和参数
torch.save(model1, '../model/model.pt')
model2 = torch.load('../model/model.pt')
print(model2)
相关推荐
工程师老罗5 小时前
基于Pytorch的YOLOv1 的网络结构代码
人工智能·pytorch·yolo
JarryStudy6 小时前
HCCL与PyTorch集成 hccl_comm.cpp DDP后端注册全流程
人工智能·pytorch·python·cann
Eloudy7 小时前
用 Python 直写 CUDA Kernel的技术,CuTile、TileLang、Triton 与 PyTorch 的深度融合实践
人工智能·pytorch
Rorsion9 小时前
PyTorch实现线性回归
人工智能·pytorch·线性回归
骇城迷影10 小时前
Makemore 核心面试题大汇总
人工智能·pytorch·python·深度学习·线性回归
mailangduoduo11 小时前
零基础教学连接远程服务器部署项目——VScode版本
服务器·pytorch·vscode·深度学习·ssh·gpu算力
多恩Stone12 小时前
【3D AICG 系列-6】OmniPart 训练流程梳理
人工智能·pytorch·算法·3d·aigc
前端摸鱼匠1 天前
YOLOv8 环境配置全攻略:Python、PyTorch 与 CUDA 的和谐共生
人工智能·pytorch·python·yolo·目标检测
纤纡.1 天前
PyTorch 入门精讲:从框架选择到 MNIST 手写数字识别实战
人工智能·pytorch·python
子榆.1 天前
CANN 与主流 AI 框架集成:从 PyTorch/TensorFlow 到高效推理的无缝迁移指南
人工智能·pytorch·tensorflow