PyTorch常用工具(2)预训练模型

文章目录

  • 前言
  • [2 预训练模型](#2 预训练模型)

前言

在训练神经网络的过程中需要用到很多的工具,最重要的是数据处理、可视化和GPU加速。本章主要介绍PyTorch在这些方面常用的工具模块,合理使用这些工具可以极大地提高编程效率。

由于内容较多,本文分成了五篇文章(1)数据处理(2)预训练模型(3)TensorBoard(4)Visdom(5)CUDA与小结。

整体结构如下:

  • 1 数据处理
    • 1.1 Dataset
    • 1.2 DataLoader
  • 2 预训练模型
  • 3 可视化工具
  • 3.1 TensorBoard
  • 3.2 Visdom
  • 4 使用GPU加速:CUDA
  • 5 小结

全文链接:

  1. PyTorch中常用的工具(1)数据处理
  2. PyTorch常用工具(2)预训练模型
  3. PyTorch中常用的工具(3)TensorBoard
  4. PyTorch中常用的工具(4)Visdom
  5. PyTorch中常用的工具(5)使用GPU加速:CUDA

2 预训练模型

除了加载数据,并对数据进行预处理之外,torchvision还提供了深度学习中各种经典的网络结构以及预训练模型。这些模型封装在torchvision.models中,包括经典的分类模型:VGG、ResNet、DenseNet及MobileNet等,语义分割模型:FCN及DeepLabV3等,目标检测模型:Faster RCNN以及实例分割模型:Mask RCNN等。读者可以通过下述代码使用这些已经封装好的网络结构与模型,也可以在此基础上根据需求对网络结构进行修改:

python 复制代码
from torchvision import models
# 仅使用网络结构,参数权重随机初始化
mobilenet_v2 = models.mobilenet_v2()
# 加载预训练权重
deeplab = models.segmentation.deeplabv3_resnet50(pretrained=True)

下面使用torchvision中预训练好的实例分割模型Mask RCNN进行一次简单的实例分割:

python 复制代码
In: from torchvision import models
    from torchvision import transforms as T
    from torch import nn
    from PIL import Image
    import numpy as np
    import random
    import cv2

    # 加载预训练好的模型,不存在的话会自动下载
    # 预训练好的模型保存在 ~/.torch/models/下面
    detection = models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    detection.eval()
    def predict(img_path, threshold):
        # 数据预处理,标准化至[-1, 1],规定均值和标准差
        img = Image.open(img_path)
        transform = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
        ])
        img = transform(img)
        # 对图像进行预测
        pred = detection([img])
        # 对预测结果进行后处理:得到mask与bbox
        score = list(pred[0]['scores'].detach().numpy())
        t = [score.index(x) for x in score if x > threshold][-1]
        mask = (pred[0]['masks'] > 0.5).squeeze().detach().cpu().numpy()
        pred_boxes = [[(i[0], i[1]), (i[2], i[3])] \
                      for i in list(pred[0]['boxes'].detach().numpy())]
        pred_masks = mask[:t+1]
        boxes = pred_boxes[:t+1]
        return pred_masks, boxes

Transforms中涵盖了大部分对Tensor和PIL Image的常用处理,这些已在上文提到,本节不再详细介绍。需要注意的是转换分为两步,第一步:构建转换操作,例如transf = transforms.Normalize(mean=x, std=y);第二步:执行转换操作,例如output = transf(input)。另外还可以将多个处理操作用Compose拼接起来,构成一个处理转换流程。

python 复制代码
In: # 随机颜色,以便可视化
    def color(image):
        colours = [[0, 255, 255], [0, 0, 255], [255, 0, 0]]
        R = np.zeros_like(image).astype(np.uint8)
        G = np.zeros_like(image).astype(np.uint8)
        B = np.zeros_like(image).astype(np.uint8)
        R[image==1], G[image==1], B[image==1] = colours[random.randrange(0,3)]
        color_mask = np.stack([R,G,B],axis=2)
        return color_mask
python 复制代码
In: # 对mask与bounding box进行可视化
    def result(img_path, threshold=0.9, rect_th=1, text_size=1, text_th=2):
        masks, boxes = predict(img_path, threshold)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        for i in range(len(masks)):
            color_mask = color(masks[i])
            img = cv2.addWeighted(img, 1, color_mask, 0.5, 0)
            cv2.rectangle(img, boxes[i][0], boxes[i][1], color=(255,0,0), thickness=rect_th)
        return img
python 复制代码
In: from matplotlib import pyplot as plt
    img=result('data/demo.jpg')
    plt.figure(figsize=(10, 10))
    plt.axis('off')
    img_result = plt.imshow(img)

上述代码完成了一个简单的实例分割任务。如上图所示,Mask RCNN能够分割出该图像中的部分实例,读者可考虑对预训练模型进行微调,以适应不同场景下的不同任务。注意:上述代码均在CPU上进行,速度较慢,读者可以考虑将数据与模型转移至GPU上,具体操作可以参考第4节。

相关推荐
凤枭香4 分钟前
Python OpenCV 傅里叶变换
开发语言·图像处理·python·opencv
CSDN云计算5 分钟前
如何以开源加速AI企业落地,红帽带来新解法
人工智能·开源·openshift·红帽·instructlab
测试杂货铺12 分钟前
外包干了2年,快要废了。。
自动化测试·软件测试·python·功能测试·测试工具·面试·职场和发展
艾派森16 分钟前
大数据分析案例-基于随机森林算法的智能手机价格预测模型
人工智能·python·随机森林·机器学习·数据挖掘
hairenjing112318 分钟前
在 Android 手机上从SD 卡恢复数据的 6 个有效应用程序
android·人工智能·windows·macos·智能手机
小蜗子22 分钟前
Multi‐modal knowledge graph inference via media convergenceand logic rule
人工智能·知识图谱
SpikeKing35 分钟前
LLM - 使用 LLaMA-Factory 微调大模型 环境配置与训练推理 教程 (1)
人工智能·llm·大语言模型·llama·环境配置·llamafactory·训练框架
小码的头发丝、42 分钟前
Django中ListView 和 DetailView类的区别
数据库·python·django
黄焖鸡能干四碗1 小时前
信息化运维方案,实施方案,开发方案,信息中心安全运维资料(软件资料word)
大数据·人工智能·软件需求·设计规范·规格说明书
1 小时前
开源竞争-数据驱动成长-11/05-大专生的思考
人工智能·笔记·学习·算法·机器学习