Pytorch加载模型

一、使用torchvision加载模型

1.使用说明:

torchvision是PyTorch生态系统中的一个包,专门用于计算机视觉任务。它提供了一系列用于加载、处理和预处理图像和视频数据的工具,以及常用的计算机视觉模型。

torchvision.models模块包含许多常用的预训练计算机视觉模型,例如ResNet、AlexNet、VGG等分类、分割等模型。

2.使用方法: Models and pre-trained weights --- Torchvision 0.18 documentation

(1)查看torchvision.models提供的全部模型

python 复制代码
from torchvision.models import list_models
all_models = list_models()
classification_models = list_models(module=torchvision.models)
返回值是torchvisoin.models中包含的全部模型名称,是一个字符串组成的列表

(2)torchvision.models中包含的模型如下:

a.分类(classification):

AlexNet,ConvNeXt,DenseNet,EfficientNet,EfficientNetV2,GoogLeNet,Inception-V3, MaxVit, MNASNet,MobileNet-v2,MobileNet-v3,RegNet,ResNet,ResNeXt,VGG,VisionTransformer

b.语义分割(Semantic Segmentation):

DeepLabV3,FCN,LRASPP

c.目标检测(Object Detection):

Faster R-CNN,FCOS,RetinaNet,SSD,SSDlite

(3)不带有预训练权重的模型加载:

在torchvision选择对应的模型,即可查看实例化该模型的方法,以ResNet为例:

|--------------------------------------|---------------------------------------------------------------|
| 模型实例方法 | 模型说明 |
| resnet18(*[, weights, progress]) | ResNet-18 from Deep Residual Learning for Image Recognition. |
| resnet34(*[, weights, progress]) | ResNet-34 from Deep Residual Learning for Image Recognition. |
| resnet50(*[, weights, progress]) | ResNet-50 from Deep Residual Learning for Image Recognition. |
| resnet101(*[, weights, progress]) | ResNet-101 from Deep Residual Learning for Image Recognition. |
| resnet152(*[, weights, progress]) | ResNet-152 from Deep Residual Learning for Image Recognition. |

python 复制代码
from torchvision.models import resnet50
model = resnet50(weights=None)

(4)带有预训练权重的模型加载:

在torchvision手册中可查看每个模型包含的预训练权重,以ResNet为例:

|---------------------------------|-------|--------|
| 权重全称 | 参数量 | GFLOPS |
| ResNet101_Weights.IMAGENET1K_V1 | 44.5M | 7.8 |
| ResNet101_Weights.IMAGENET1K_V2 | 44.5M | 7.8 |
| ResNet152_Weights.IMAGENET1K_V1 | 60.2M | 11.51 |
| ResNet152_Weights.IMAGENET1K_V2 | 60.2M | 11.51 |
| ResNet18_Weights.IMAGENET1K_V1 | 11.7M | 1.81 |
| ResNet34_Weights.IMAGENET1K_V1 | 21.8M | 3.66 |
| ResNet50_Weights.IMAGENET1K_V1 | 25.6M | 4.09 |
| ResNet50_Weights.IMAGENET1K_V2 | 25.6M | 4.09 |

具体使用:

python 复制代码
from torchvision.models import resnet50, ResNet50_Weights
# Using pretrained weights:
model1 = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model2 = resnet50(weights="IMAGENET1K_V1")

二、使用torch.nn模块

1.torch.nn模块中的Containers

(1)torch.nn.Module

a.介绍:

pytorch里面一切自定义操作基本上都是继承nn.Module类来实现的;在定义自已的网络的时候需要继承nn.Module类,并重新实现构造函数**__init__构造函数** 和forward这两个方法;

一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中;当然也可以把不具有可学习参数的层也放在里面,但是通常情况下在forward方法里面可以使用nn.functional来代替;

forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心;所有放在构造函数__init__里面的层的都是这个模型的固有属性

b.主要方法和成员变量:

python 复制代码
一、成员变量:
1.training(bool):判断当前的模型处于训练(true)还是推断(false)模式
二、成员方法:
1.add_module(name, module):
(1)用法:
将一个子module添加到当前的module中,并赋予相应的名字name,该子module便可以像调用属性一样被调用
(2)参数:
a.name(str):子模型被赋予的名称
b.module(Module):被添加的子模型
2.apply(fn)
(1)用法:
将函数fn作用在该模型的每一个子模型和该模型自己上,常用于参数的初始化
(2)参数:
a.fn(Module -> None):作用于每一个模块的函数名;
用于初始化参数;接受参数类型是一个Module类型的参数;需要提前定义
3.children()
(1)用法:
返回当前Module的子module组成的列表
4.cpu()
(1)用法:
将当前Module和其子Module的参数迁移到CPU上
5.cuda(device=None)
(1)用法:
将当前Module和其子Module的参数迁移到指定的GPU设备上
(2)参数:
a.device(int, optional):指定需要迁移的GPU设备编号
6.eval():
(1)用法:
将模型设置为eval状态;只对某些特定的模块有实际效果;效果等同于self.train(False)
7.forward(*input)
(1)用法:
定义了每次调用时指定的计算方法;应该被每个子类所重写,以代表不同的操作;
调用时直接通过实例名称调用即可,例如:output = model1(input);而不需要通过.forward的方法调用
(2)参数:
a.input(tensor):需要进行操作的张量
8.get_parameter(target)
(1)用法:
获取指定参数target所确定的参数,返回类型为torch.nn.Parameter
(2)参数:
a.target(str):所指定参数的名称,字符串类型
9.get_submodule(target)
(1)用法:
获取指定参数target所确定的子模块,返回类型为torch.nn.Module
(2)参数:
a.target(str):所指定子模块的名称,是字符串类型;
10.load_state_dict(state_dict, strict=True, assign=False)
(1)用法:
将参数state_dict所包含的参数值加载到该模型中;
返回值是可能缺失的参数名称或模型不包含的参数名称组成的列表
(2)参数:
a.state_dict(dict):包含参数的名称和对应状态(值、类型、所属设备)的字典;
b.strict(bool, optional):参数state_dict所包含的键的名称是否需要严格遵守该模型state_dict()返回的值,默认值为True;
11.modules()
(1)用法:
获取该网络模型中的全部module组成的迭代器
12.parameters(recurse=True)
(1)用法:
获取该网络模型中的全部参数组成的迭代器;通常传递给optimizer
(2)参数:
a.recurse(bool):为True时递归的返回该网络模型中的全部参数,包括子模块和本模块;
否则只返回该网络模型中的直接参数而不返回子参数;
13.requires_grad_(requires_grad=True)
(1)用法:
用于修改autograd操作是否记录该module中参数上的梯度
该方法可以用于设置参数的requires_grad属性
可以用于指定梯度更新时变化的参数范围
(2)参数:
a.requires_grad(bool):autograd操作是否记录该module中的参数上的梯度值;默认值为true
14.state_dict(*, prefix: str = '', keep_vars: bool = False)
(1)用法:
返回一个包含该module全部状态的字典;包括该module的参数;其中key为参数的名称;
返回值类型为Dict[str, Any]
(2)主要参数:
a.prefix(str, optional):添加到输出字典的key的前缀字符串,默认值为空字符串. 
15.to(device, dtype, non_blocking) 
(1)用法:
将该module的参数转换为指定的类型并迁移到指定的设备上
(2)主要参数:
a.device(torch.device):需要迁移的设备名称
b.dtype (torch.dtype):需要转换的数据类型;只能接受浮点或复数类型;
16.train(mode=True)
(1)用法:
将模型设置为train状态;只对某些特定的模块有实际效果;
(2)主要参数:
a.mode(bool):是否将模型设置为训练模式

c.使用实例:

python 复制代码
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5) # submodule: Conv2d
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
       x = F.relu(self.conv1(x))
       return F.relu(self.conv2(x))
相关推荐
FreakStudio1 小时前
全网最适合入门的面向对象编程教程:56 Python字符串与序列化-正则表达式和re模块应用
python·单片机·嵌入式·面向对象·电子diy
丶21361 小时前
【CUDA】【PyTorch】安装 PyTorch 与 CUDA 11.7 的详细步骤
人工智能·pytorch·python
_.Switch2 小时前
Python Web 应用中的 API 网关集成与优化
开发语言·前端·后端·python·架构·log4j
一个闪现必杀技2 小时前
Python入门--函数
开发语言·python·青少年编程·pycharm
小鹿( ﹡ˆoˆ﹡ )3 小时前
探索IP协议的神秘面纱:Python中的网络通信
python·tcp/ip·php
卷心菜小温3 小时前
【BUG】P-tuningv2微调ChatGLM2-6B时所踩的坑
python·深度学习·语言模型·nlp·bug
陈苏同学3 小时前
4. 将pycharm本地项目同步到(Linux)服务器上——深度学习·科研实践·从0到1
linux·服务器·ide·人工智能·python·深度学习·pycharm
唐家小妹3 小时前
介绍一款开源的 Modern GUI PySide6 / PyQt6的使用
python·pyqt
羊小猪~~4 小时前
深度学习项目----用LSTM模型预测股价(包含LSTM网络简介,代码数据均可下载)
pytorch·python·rnn·深度学习·机器学习·数据分析·lstm
Marst Code4 小时前
(Django)初步使用
后端·python·django