【TVM 教程】编译 PyTorch 目标检测模型

本文介绍如何用 Relay VM 部署 PyTorch 目标检测模型。

首先应安装 PyTorch。此外,还应安装 TorchVision,并将其作为模型合集(model zoo)。

可通过 pip 快速安装:

pip install torch
pip install torchvision

或参考官网:pytorch.org/get-started...

PyTorch 版本应该和 TorchVision 版本兼容。

目前 TVM 支持 PyTorch 1.7 和 1.4,其他版本可能不稳定。

python 复制代码
import tvm
from tvm import relay
from tvm import relay
from tvm.runtime.vm import VirtualMachine
from tvm.contrib.download import download_testdata

import numpy as np
import cv2

# PyTorch 导入
import torch
import torchvision

从 TorchVision 加载预训练的 MaskRCNN 并进行跟踪

scss 复制代码
in_size = 300
input_shape = (1, 3, in_size, in_size)

def do_trace(model, inp):
    model_trace = torch.jit.trace(model, inp)
    model_trace.eval()
    return model_trace

def dict_to_tuple(out_dict):
    if "masks" in out_dict.keys():
        return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
    return out_dict["boxes"], out_dict["scores"], out_dict["labels"]

class TraceWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, inp):
        out = self.model(inp)
        return dict_to_tuple(out[0])

model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
model = TraceWrapper(model_func(pretrained=True))

model.eval()
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))

with torch.no_grad():
    out = model(inp)
    script_module = do_trace(model, inp)

输出结果:

css 复制代码
Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /workspace/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth

  0%|          | 0.00/170M [00:00<?, ?B/s]
  9%|9         | 15.3M/170M [00:00<00:01, 160MB/s]
 19%|#8        | 32.1M/170M [00:00<00:00, 170MB/s]
 29%|##9       | 49.7M/170M [00:00<00:00, 176MB/s]
 40%|####      | 68.8M/170M [00:00<00:00, 185MB/s]
 51%|#####     | 86.4M/170M [00:00<00:00, 175MB/s]
 61%|######1   | 104M/170M [00:00<00:00, 178MB/s]
 71%|#######1  | 121M/170M [00:00<00:00, 169MB/s]
 86%|########6 | 147M/170M [00:00<00:00, 199MB/s]
100%|##########| 170M/170M [00:00<00:00, 193MB/s]
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3878: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  for i in range(dim)
/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/anchor_utils.py:127: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  for g in grid_sizes
/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/anchor_utils.py:127: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  for g in grid_sizes
/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/rpn.py:73: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  A = Ax4 // 4
/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/rpn.py:74: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  C = AxC // A
/usr/local/lib/python3.7/dist-packages/torchvision/ops/boxes.py:156: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
/usr/local/lib/python3.7/dist-packages/torchvision/ops/boxes.py:158: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/transform.py:293: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  for s, s_orig in zip(new_size, original_size)
/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/roi_heads.py:387: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)

下载测试图像并进行预处理

ini 复制代码
img_url = (
    "/img/docs/dmlc/web-data/master/gluoncv/detection/street_small.jpg"
)
img_path = download_testdata(img_url, "test_street_small.jpg", module="data")

img = cv2.imread(img_path).astype("float32")
img = cv2.resize(img, (in_size, in_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img / 255.0, [2, 0, 1])
img = np.expand_dims(img, axis=0)

将计算图导入 Relay

ini 复制代码
input_name = "input0"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(script_module, shape_list)

输出结果:

sql 复制代码
/workspace/python/tvm/relay/build_module.py:411: DeprecationWarning: Please use input parameter mod (tvm.IRModule) instead of deprecated parameter mod (tvm.relay.function.Function)
  DeprecationWarning,

使用 Relay VM 编译

注意:目前仅支持 CPU target。对于 x86 target,因为 TorchVision RCNN 模型中存在大型密集算子,为取得最佳性能,强烈推荐使用 Intel MKL 和 Intel OpenMP 来构建 TVM。

ini 复制代码
# 在 x86 target上添加"-libs=mkl"以获得最佳性能。
# 对于支持 AVX512 的 x86 机器,完整 target 是
# "llvm -mcpu=skylake-avx512 -libs=mkl"
target = "llvm"

with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
    vm_exec = relay.vm.compile(mod, target=target, params=params)

输出结果:

python 复制代码
/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  "target_host parameter is going to be deprecated. "

使用 Relay VM 进行推理

ini 复制代码
dev = tvm.cpu()
vm = VirtualMachine(vm_exec, dev)
vm.set_input("main", **{input_name: img})
tvm_res = vm.run()

获取 score 大于 0.9 的 box

scss 复制代码
score_threshold = 0.9
boxes = tvm_res[0].numpy().tolist()
valid_boxes = []
for i, score in enumerate(tvm_res[1].numpy().tolist()):
    if score > score_threshold:
        valid_boxes.append(boxes[i])
    else:
        break

print("Get {} valid boxes".format(len(valid_boxes)))

输出结果:

sql 复制代码
Get 9 valid boxes

脚本总运行时长: (2 分 57.278 秒)

下载 Python 源代码:deploy_object_detection_pytorch.py

下载 Jupyter Notebook:deploy_object_detection_pytorch.ipynb

相关推荐
随心Coding9 分钟前
【MySQL】存储引擎有哪些?区别是什么?
数据库·mysql
hunter20620611 分钟前
用opencv生成视频流,然后用rtsp进行拉流显示
人工智能·python·opencv
Daphnis_z13 分钟前
大模型应用编排工具Dify之常用编排组件
人工智能·chatgpt·prompt
yuanbenshidiaos1 小时前
【大数据】机器学习----------强化学习机器学习阶段尾声
人工智能·机器学习
m0_748237051 小时前
sql实战解析-sum()over(partition by xx order by xx)
数据库·sql
dal118网工任子仪2 小时前
61,【1】BUUCTF WEB BUU XSS COURSE 11
前端·数据库·xss
萌小丹Fighting3 小时前
【Postgres_Python】使用python脚本批量创建和导入多个PG数据库
数据库
青灯文案13 小时前
Oracle 数据库常见字段类型大全及详细解析
数据库·oracle
羊小猪~~4 小时前
MYSQL学习笔记(四):多表关系、多表查询(交叉连接、内连接、外连接、自连接)、七种JSONS、集合
数据库·笔记·后端·sql·学习·mysql·考研
好评笔记6 小时前
AIGC视频生成模型:Stability AI的SVD(Stable Video Diffusion)模型
论文阅读·人工智能·深度学习·机器学习·计算机视觉·面试·aigc