使用 pytorch 运行预训练模型的框架

PyTorch 简介:

PyTorch 是一个 Python 程序库,我们可以使用 PyTorch 来构建深度学习项目。

PyTorch 的两个特点:

  1. PyTorch 的核心数据结构是张量,张量是一个多维数组,与 NumPy 数组有许多相似之处。
  2. PyTorch 提供了在专用硬件上执行加速数学操作的特性,这使得神经网络结构设计以及在单机或并行计算资源上训练它们变得很方便。

因此,我们可以将 PyTorch 描述为一个在 Python 中为科学计算提供优化支持的高性能库。

PyTorch 大部分是用 C++ 和 CUDA 编写的,CUDA 是一种来自英伟达的类 C++的语言,可以被编译并在 GPU 上以并行方式运行。

使用 pytorch 运行预训练模型的框架

复制代码
import torch
  1. 定义模型类 1.1 自定义模型类 1.2 从 torchvision 模块加载模型: from torchvision import models

  2. 实例化模型类

    resnet101 = models.resnet101()

  3. 给实例化的模型类加载预训练好的参数 3.1 实例化模型类和加载预训练好的权重同时进行(这种情况可以省略第 2 步)

    resnet101 = models.resnet101(pretrained=True) # pretrained=True 指示函数下载 resnet101 在 ImageNet数据集上训练好的权重

3.2 使用模型的 load_state_dict() 方法将预训练权重加载到 resnet101 中

复制代码
model_path = '......'
model_data = torch.load(model_path)
resnet101.load_state_dict(model_data)

3.3 使用 torch.hub 从 github 加载模型(这种情况可以省略第 1、2 步)

复制代码
from torch import hub
resnet101 = hub.load('pytorch/vision:main', 'resnet101', pretrained=True)  # 第一项是 GitHub 存储库的名称和分支,第二项是入口点函数的名称

以上代码将 pytorch/vision 主分支的快照及其权重默认下载到本地的 C:\Users\username.cache\torch\hub 目录下,然后运行 resnet101 入口点函数返回实例化的模型,参数 pretrained=true 会从 ImageNet 获得预训练权重,并加载到 resnet101 中。

  1. 使用 Python 图像操作模块 Pillow 从本地文件系统加载一幅图像

    from PIL import Image # PIL 指的是 pillow
    img = Image.open(".../xxx.jpg")

  2. 使用 TorchVision 模块提供的 transforms 定义一个对输入图像进行预处理的管道

    from torchvision import transforms
    preprocess = transforms.Compose([transforms.Resize(256), # 将输入图像缩放到 256× 256 个像素
    transforms.ToTensor(), # 转换为一个张量
    ])

  3. 使用预处理管道 preprocess 对图像 img 进行预处理

    img_t = preprocess(img)

  4. 给数据添加一个新的维度:批次维度

    batch_t = torch.unsqueeze(img_t, 0)

  5. 进行推理时,我们需要将神经网络置于 eval 模式

    resnet.eval()

  6. eval 模式设置好之后,进行推理

    out = resnet101(batch_t)
    out

......

本文由mdnice多平台发布

相关推荐
2401_895521347 小时前
SpringBoot Maven快速上手
spring boot·后端·maven
disgare7 小时前
关于 spring 工程中添加 traceID 实践
java·后端·spring
ictI CABL7 小时前
Spring Boot与MyBatis
spring boot·后端·mybatis
小江的记录本9 小时前
【Linux】《Linux常用命令汇总表》
linux·运维·服务器·前端·windows·后端·macos
yhole12 小时前
springboot三层架构详细讲解
spring boot·后端·架构
香香甜甜的辣椒炒肉12 小时前
Spring(1)基本概念+开发的基本步骤
java·后端·spring
白毛大侠13 小时前
Go Goroutine 与用户态是进程级
开发语言·后端·golang
ForteScarlet13 小时前
从 Kotlin 编译器 API 的变化开始: 2.3.20
android·开发语言·后端·ios·开源·kotlin
大阿明14 小时前
SpringBoot - Cookie & Session 用户登录及登录状态保持功能实现
java·spring boot·后端
Binary-Jeff14 小时前
Spring 创建 Bean 的关键流程
java·开发语言·前端·spring boot·后端·spring·学习方法