Pytorch加载模型

一、使用torchvision加载模型

1.使用说明:

torchvision是PyTorch生态系统中的一个包,专门用于计算机视觉任务。它提供了一系列用于加载、处理和预处理图像和视频数据的工具,以及常用的计算机视觉模型。

torchvision.models模块包含许多常用的预训练计算机视觉模型,例如ResNet、AlexNet、VGG等分类、分割等模型。

2.使用方法: Models and pre-trained weights --- Torchvision 0.18 documentation

(1)查看torchvision.models提供的全部模型

python 复制代码
from torchvision.models import list_models
all_models = list_models()
classification_models = list_models(module=torchvision.models)
返回值是torchvisoin.models中包含的全部模型名称,是一个字符串组成的列表

(2)torchvision.models中包含的模型如下:

a.分类(classification):

AlexNet,ConvNeXt,DenseNet,EfficientNet,EfficientNetV2,GoogLeNet,Inception-V3, MaxVit, MNASNet,MobileNet-v2,MobileNet-v3,RegNet,ResNet,ResNeXt,VGG,VisionTransformer

b.语义分割(Semantic Segmentation):

DeepLabV3,FCN,LRASPP

c.目标检测(Object Detection):

Faster R-CNN,FCOS,RetinaNet,SSD,SSDlite

(3)不带有预训练权重的模型加载:

在torchvision选择对应的模型,即可查看实例化该模型的方法,以ResNet为例:

|--------------------------------------|---------------------------------------------------------------|
| 模型实例方法 | 模型说明 |
| resnet18(*[, weights, progress]) | ResNet-18 from Deep Residual Learning for Image Recognition. |
| resnet34(*[, weights, progress]) | ResNet-34 from Deep Residual Learning for Image Recognition. |
| resnet50(*[, weights, progress]) | ResNet-50 from Deep Residual Learning for Image Recognition. |
| resnet101(*[, weights, progress]) | ResNet-101 from Deep Residual Learning for Image Recognition. |
| resnet152(*[, weights, progress]) | ResNet-152 from Deep Residual Learning for Image Recognition. |

python 复制代码
from torchvision.models import resnet50
model = resnet50(weights=None)

(4)带有预训练权重的模型加载:

在torchvision手册中可查看每个模型包含的预训练权重,以ResNet为例:

|---------------------------------|-------|--------|
| 权重全称 | 参数量 | GFLOPS |
| ResNet101_Weights.IMAGENET1K_V1 | 44.5M | 7.8 |
| ResNet101_Weights.IMAGENET1K_V2 | 44.5M | 7.8 |
| ResNet152_Weights.IMAGENET1K_V1 | 60.2M | 11.51 |
| ResNet152_Weights.IMAGENET1K_V2 | 60.2M | 11.51 |
| ResNet18_Weights.IMAGENET1K_V1 | 11.7M | 1.81 |
| ResNet34_Weights.IMAGENET1K_V1 | 21.8M | 3.66 |
| ResNet50_Weights.IMAGENET1K_V1 | 25.6M | 4.09 |
| ResNet50_Weights.IMAGENET1K_V2 | 25.6M | 4.09 |

具体使用:

python 复制代码
from torchvision.models import resnet50, ResNet50_Weights
# Using pretrained weights:
model1 = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model2 = resnet50(weights="IMAGENET1K_V1")

二、使用torch.nn模块

1.torch.nn模块中的Containers

(1)torch.nn.Module

a.介绍:

pytorch里面一切自定义操作基本上都是继承nn.Module类来实现的;在定义自已的网络的时候需要继承nn.Module类,并重新实现构造函数**__init__构造函数** 和forward这两个方法;

一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中;当然也可以把不具有可学习参数的层也放在里面,但是通常情况下在forward方法里面可以使用nn.functional来代替;

forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心;所有放在构造函数__init__里面的层的都是这个模型的固有属性

b.主要方法和成员变量:

python 复制代码
一、成员变量:
1.training(bool):判断当前的模型处于训练(true)还是推断(false)模式
二、成员方法:
1.add_module(name, module):
(1)用法:
将一个子module添加到当前的module中,并赋予相应的名字name,该子module便可以像调用属性一样被调用
(2)参数:
a.name(str):子模型被赋予的名称
b.module(Module):被添加的子模型
2.apply(fn)
(1)用法:
将函数fn作用在该模型的每一个子模型和该模型自己上,常用于参数的初始化
(2)参数:
a.fn(Module -> None):作用于每一个模块的函数名;
用于初始化参数;接受参数类型是一个Module类型的参数;需要提前定义
3.children()
(1)用法:
返回当前Module的子module组成的列表
4.cpu()
(1)用法:
将当前Module和其子Module的参数迁移到CPU上
5.cuda(device=None)
(1)用法:
将当前Module和其子Module的参数迁移到指定的GPU设备上
(2)参数:
a.device(int, optional):指定需要迁移的GPU设备编号
6.eval():
(1)用法:
将模型设置为eval状态;只对某些特定的模块有实际效果;效果等同于self.train(False)
7.forward(*input)
(1)用法:
定义了每次调用时指定的计算方法;应该被每个子类所重写,以代表不同的操作;
调用时直接通过实例名称调用即可,例如:output = model1(input);而不需要通过.forward的方法调用
(2)参数:
a.input(tensor):需要进行操作的张量
8.get_parameter(target)
(1)用法:
获取指定参数target所确定的参数,返回类型为torch.nn.Parameter
(2)参数:
a.target(str):所指定参数的名称,字符串类型
9.get_submodule(target)
(1)用法:
获取指定参数target所确定的子模块,返回类型为torch.nn.Module
(2)参数:
a.target(str):所指定子模块的名称,是字符串类型;
10.load_state_dict(state_dict, strict=True, assign=False)
(1)用法:
将参数state_dict所包含的参数值加载到该模型中;
返回值是可能缺失的参数名称或模型不包含的参数名称组成的列表
(2)参数:
a.state_dict(dict):包含参数的名称和对应状态(值、类型、所属设备)的字典;
b.strict(bool, optional):参数state_dict所包含的键的名称是否需要严格遵守该模型state_dict()返回的值,默认值为True;
11.modules()
(1)用法:
获取该网络模型中的全部module组成的迭代器
12.parameters(recurse=True)
(1)用法:
获取该网络模型中的全部参数组成的迭代器;通常传递给optimizer
(2)参数:
a.recurse(bool):为True时递归的返回该网络模型中的全部参数,包括子模块和本模块;
否则只返回该网络模型中的直接参数而不返回子参数;
13.requires_grad_(requires_grad=True)
(1)用法:
用于修改autograd操作是否记录该module中参数上的梯度
该方法可以用于设置参数的requires_grad属性
可以用于指定梯度更新时变化的参数范围
(2)参数:
a.requires_grad(bool):autograd操作是否记录该module中的参数上的梯度值;默认值为true
14.state_dict(*, prefix: str = '', keep_vars: bool = False)
(1)用法:
返回一个包含该module全部状态的字典;包括该module的参数;其中key为参数的名称;
返回值类型为Dict[str, Any]
(2)主要参数:
a.prefix(str, optional):添加到输出字典的key的前缀字符串,默认值为空字符串. 
15.to(device, dtype, non_blocking) 
(1)用法:
将该module的参数转换为指定的类型并迁移到指定的设备上
(2)主要参数:
a.device(torch.device):需要迁移的设备名称
b.dtype (torch.dtype):需要转换的数据类型;只能接受浮点或复数类型;
16.train(mode=True)
(1)用法:
将模型设置为train状态;只对某些特定的模块有实际效果;
(2)主要参数:
a.mode(bool):是否将模型设置为训练模式

c.使用实例:

python 复制代码
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5) # submodule: Conv2d
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
       x = F.relu(self.conv1(x))
       return F.relu(self.conv2(x))
相关推荐
阿斯卡码1 小时前
jupyter添加、删除、查看内核
ide·python·jupyter
埃菲尔铁塔_CV算法4 小时前
图像算法之 OCR 识别算法:原理与应用场景
图像处理·python·计算机视觉
封步宇AIGC4 小时前
量化交易系统开发-实时行情自动化交易-3.4.2.Okex行情交易数据
人工智能·python·机器学习·数据挖掘
封步宇AIGC4 小时前
量化交易系统开发-实时行情自动化交易-2.技术栈
人工智能·python·机器学习·数据挖掘
love_and_hope5 小时前
Pytorch学习--神经网络--完整的模型训练套路
人工智能·pytorch·python·深度学习·神经网络·学习
在人间负债^6 小时前
基于标签相关性的多标签学习
人工智能·python·chatgpt·大模型·图像类型
python1567 小时前
使用YOLOv9进行图像与视频检测
开发语言·python·音视频
狂奔solar7 小时前
DQN强化训练agent玩是男人就下xx层小游戏
python·pygame·dqn 强化
互联网杂货铺7 小时前
软件测试之白盒测试(超详细总结)
自动化测试·软件测试·python·测试工具·职场和发展·单元测试·测试用例
Z pz7 小时前
网络编程——Python简单TCP通信功能代码实践
网络·python·tcp/ip