【PyTorch】(三)模型的创建、参数初始化、保存和加载

文章目录

  • [1. 模型的创建](#1. 模型的创建)
    • [1.1. 模型组件](#1.1. 模型组件)
      • [1.1.1. 网络层](#1.1.1. 网络层)
      • [1.1.2. 激活函数](#1.1.2. 激活函数)
      • [1.1.3. 函数包](#1.1.3. 函数包)
      • [1.1.4. 容器](#1.1.4. 容器)
    • [1.2. 创建方法](#1.2. 创建方法)
      • [1.1.1. 通过使用模型组件](#1.1.1. 通过使用模型组件)
      • [1.1.2. 通过继承nn.Module类](#1.1.2. 通过继承nn.Module类)
    • [1.3. 将模型转移到GPU](#1.3. 将模型转移到GPU)
  • [2. 模型参数初始化](#2. 模型参数初始化)
  • [3. 模型的保存与加载](#3. 模型的保存与加载)
    • [3.1. 只保存参数](#3.1. 只保存参数)
    • [3.2. 保存模型和参数](#3.2. 保存模型和参数)

1. 模型的创建

1.1. 模型组件

1.1.1. 网络层

1.1.2. 激活函数

1.1.3. 函数包

1.1.4. 容器

1.2. 创建方法

1.1.1. 通过使用模型组件

可以直接使用模型组件快速创建模型。

python 复制代码
import torch.nn as nn

model =	nn.Linear(10, 10)
print(model)

输出结果:

bash 复制代码
Linear(in_features=10, out_features=10, bias=True)

1.1.2. 通过继承nn.Module类

在__init__方法中使用模型组件定义模型各层。必须重写forward方法实现前向传播。

python 复制代码
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 10)
        self.layer2 = nn.Linear(10, 10)
        self.layer3 = nn.Sequential(
            nn.Linear(10, 10),
            nn.ReLU(),
            nn.Linear(10, 10)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

model = Model()
print(model)

输出结果:

bash 复制代码
Model(
  (layer1): Linear(in_features=10, out_features=10, bias=True)
  (layer2): Linear(in_features=10, out_features=10, bias=True)
  (layer3): Sequential(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): ReLU()
    (2): Linear(in_features=10, out_features=10, bias=True)
  )
)

1.3. 将模型转移到GPU

方法与将数据转移到GPU类似,都有两种方法:

  1. model.to(device)
  2. mode.cuda()
python 复制代码
import torch
import torch.nn as nn

# 创建模型实例
model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10)
)

# 将模型移动到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 也可以
model = model.cuda()

2. 模型参数初始化

torch.nn.init提供了许多初始化参数的函数:

函数名 作用 参数
uniform_ 从均匀分布 U ( a , b ) U(a,b) U(a,b)中生成值,填充输入的张量 tensor, a = 0, b = 1
normal_ 从正态分布 N ( m e a n , s t d 2 ) N(mean, std^2) N(mean,std2)中生成值,填充输入的张量 tensor, mean = 0, std = 1
constant_ 用常数 v a l val val,填充输入的张量 tensor, val
eye_ 用单位矩阵,填充二维输入张量 tensor(二维)
dirac_ 用狄拉克函数,填充{3, 4, 5}维输入张量 tensor({3, 4, 5}维), groups = 1
xavier_uniform_ 从xavier均匀分布中生成值,填充输入张量 tensor, gain = 1
xavier_normal_ 从xavier正态分布中生成值,填充输入张量 tensor, gain = 1
kaiming_uniform_ 从kaiming均匀分布中生成值,填充输入张量 tensor, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu'
kaiming_normal_ 从kaiming正态分布中生成值,填充输入张量 tensor, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu'
orthogonal_ 用一个(半)正交矩阵,填充输入张量 tensor, gain = 1
sparse_ 用非零元素服从 N ( 0 , s t d 2 ) N(0, std^2) N(0,std2)的稀疏矩阵,填充二维输入张量 tensor, sparsity, std = 0.01

3. 模型的保存与加载

模型保存和加载使用的python内置的pickle模块。

3.1. 只保存参数

python 复制代码
import torch
import torch.nn as nn

# 创建模型实例
model1 = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10)
)

# 保存和加载参数
torch.save(model1.state_dict(), '../model/model_params.pkl')
model1.load_state_dict(torch.load('../model/model_params.pkl'))

3.2. 保存模型和参数

python 复制代码
import torch
import torch.nn as nn

# 创建模型实例
model1 = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10)
)

# 保存和加载模型和参数
torch.save(model1, '../model/model.pt')
model2 = torch.load('../model/model.pt')
print(model2)
相关推荐
winner888110 小时前
PyTorch 与 TensorFlow 中基于自定义层的 DNN 实现对比
pytorch·tensorflow·dnn
试着11 小时前
【AI面试准备】TensorFlow与PyTorch构建缺陷预测模型
人工智能·pytorch·面试·tensorflow·测试
郜太素13 小时前
PyTorch 张量与自动微分操作
人工智能·pytorch·python·深度学习·学习方法·张量·自动微分
sheng_er_sheng13 小时前
【笔记】【B站课程 pytorch】梯度下降模型
人工智能·pytorch·笔记
DevangLic13 小时前
【CUDA pytorch】
人工智能·pytorch·python
QQ6765800813 小时前
PyTorch和torchvision为例,如何使用预训练的ResNet模型来训练水稻虫害分类数据集 14类 从数据准备到模型训练、评估全流程
人工智能·pytorch·分类
令狐少侠201115 小时前
PaddlePaddle 和PyTorch选择与对比互斥
人工智能·pytorch·paddlepaddle
AlexandrMisko18 小时前
从零实现基于Transformer的英译汉任务
人工智能·pytorch·python·深度学习·transformer
Tech Synapse19 小时前
基于Jetson Nano与PyTorch的无人机实时目标跟踪系统搭建指南
pytorch·目标跟踪·无人机
字节旅行20 小时前
PyTorch常用命令详解:助力深度学习开发
人工智能·pytorch·深度学习