使用 pytorch 运行预训练模型的框架

PyTorch 简介:

PyTorch 是一个 Python 程序库,我们可以使用 PyTorch 来构建深度学习项目。

PyTorch 的两个特点:

  1. PyTorch 的核心数据结构是张量,张量是一个多维数组,与 NumPy 数组有许多相似之处。
  2. PyTorch 提供了在专用硬件上执行加速数学操作的特性,这使得神经网络结构设计以及在单机或并行计算资源上训练它们变得很方便。

因此,我们可以将 PyTorch 描述为一个在 Python 中为科学计算提供优化支持的高性能库。

PyTorch 大部分是用 C++ 和 CUDA 编写的,CUDA 是一种来自英伟达的类 C++的语言,可以被编译并在 GPU 上以并行方式运行。

使用 pytorch 运行预训练模型的框架

复制代码
import torch
  1. 定义模型类 1.1 自定义模型类 1.2 从 torchvision 模块加载模型: from torchvision import models

  2. 实例化模型类

    resnet101 = models.resnet101()

  3. 给实例化的模型类加载预训练好的参数 3.1 实例化模型类和加载预训练好的权重同时进行(这种情况可以省略第 2 步)

    resnet101 = models.resnet101(pretrained=True) # pretrained=True 指示函数下载 resnet101 在 ImageNet数据集上训练好的权重

3.2 使用模型的 load_state_dict() 方法将预训练权重加载到 resnet101 中

复制代码
model_path = '......'
model_data = torch.load(model_path)
resnet101.load_state_dict(model_data)

3.3 使用 torch.hub 从 github 加载模型(这种情况可以省略第 1、2 步)

复制代码
from torch import hub
resnet101 = hub.load('pytorch/vision:main', 'resnet101', pretrained=True)  # 第一项是 GitHub 存储库的名称和分支,第二项是入口点函数的名称

以上代码将 pytorch/vision 主分支的快照及其权重默认下载到本地的 C:\Users\username.cache\torch\hub 目录下,然后运行 resnet101 入口点函数返回实例化的模型,参数 pretrained=true 会从 ImageNet 获得预训练权重,并加载到 resnet101 中。

  1. 使用 Python 图像操作模块 Pillow 从本地文件系统加载一幅图像

    from PIL import Image # PIL 指的是 pillow
    img = Image.open(".../xxx.jpg")

  2. 使用 TorchVision 模块提供的 transforms 定义一个对输入图像进行预处理的管道

    from torchvision import transforms
    preprocess = transforms.Compose([transforms.Resize(256), # 将输入图像缩放到 256× 256 个像素
    transforms.ToTensor(), # 转换为一个张量
    ])

  3. 使用预处理管道 preprocess 对图像 img 进行预处理

    img_t = preprocess(img)

  4. 给数据添加一个新的维度:批次维度

    batch_t = torch.unsqueeze(img_t, 0)

  5. 进行推理时,我们需要将神经网络置于 eval 模式

    resnet.eval()

  6. eval 模式设置好之后,进行推理

    out = resnet101(batch_t)
    out

......

本文由mdnice多平台发布

相关推荐
橙子家7 小时前
WebAPI 项目通过 CI/CD 自动化部署到 Linux 服务器(docker-compose)
后端
钟离墨笺8 小时前
Go语言--2go基础-->基本数据类型
开发语言·前端·后端·golang
飞Link10 小时前
【Django】Django的静态文件相关配置与操作
后端·python·django
钟离墨笺10 小时前
Go语言--2go基础-->map
开发语言·后端·golang
Tony Bai11 小时前
Go 语言的“魔法”时刻:如何用 -toolexec 实现零侵入式自动插桩?
开发语言·后端·golang
qq_124987075313 小时前
基于小程序中医食谱推荐系统的设计(源码+论文+部署+安装)
java·spring boot·后端·微信小程序·小程序·毕业设计·计算机毕业设计
Marktowin14 小时前
SpringBoot项目的国际化流程
java·后端·springboot
程序员泠零澪回家种桔子14 小时前
RAG中的Embedding技术
人工智能·后端·ai·embedding
汤姆yu14 小时前
基于springboot的直播管理系统
java·spring boot·后端
a努力。14 小时前
虾皮Java面试被问:分布式Top K问题的解决方案
java·后端·云原生·面试·rpc·架构