深度学习系列72:torch-tensorrt入门

1. 安装

坑非常多,清华源阿里源都不行。使用官网源下载,这里的121可以改成你需要的东西:

python -m pip install torch torch-tensorrt tensorrt --extra-index-url https://download.pytorch.org/whl/cu121

2. 原理

我们来看一个实例:这是一个用于支持 torchscript 到 TensorRT 转换的项目。上面的代码用于将 addmm 运算展开成数个算子,方便后续映射 TensorRT 算子。

复制代码
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
  // TensorRT implicitly adds a flatten layer in front of FC layers if necessary
  // 用于匹配的模式
  std::string addmm_pattern = R"IR(
    graph(%b, %x, %w, %beta, %alpha):
      %out: Tensor = aten::addmm(%b, %x, %w, %beta, %alpha)
      return (%out))IR";
  // 用于替换的模式
  std::string mm_add_pattern = R"IR(
    graph(%b, %x, %w, %beta, %alpha):
      %mm: Tensor = aten::matmul(%x, %w)
      %bias: Tensor = aten::mul(%b, %beta)
      %out: Tensor = aten::add(%bias, %mm, %alpha)
      return (%out))IR";

  // 创建子图重写器并注册匹配模式和替换模式
  torch::jit::SubgraphRewriter unpack_addmm;
  unpack_addmm.RegisterRewritePattern(addmm_pattern, mm_add_pattern);
  // 遍历graph,完成重写
  unpack_addmm.runOnGraph(graph);
  LOG_GRAPH("Post unpack addmm: " << *graph);
}

3. 简单例子

复制代码
import torch
def origin_func(x):
    x = x**2
    x = x**3
    return x

x = torch.rand(1, 2, 3, 4)
jit_model = torch.jit.trace(origin_func, x)
print(jit_model.graph)

# 匹配用的子图定义,注意常量必须为[value=2]属性
pattern = """
    graph(%x):
        %const_2 = prim::Constant[value=2]()
        %out = aten::pow(%x, %const_2)
        return (%out)
"""
# 替换用的子图定义
replacement = """
    graph(%x):
        %out = aten::mul(%x, %x)
        return (%out)
"""
torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,jit_model.graph)
print(jit_model.graph)
相关推荐
大数据张老师10 分钟前
用 AI 做数据分析:从“数字”里挖“规律”
大数据·人工智能
音视频牛哥38 分钟前
如何打造毫秒级响应的RTSP播放器:架构拆解与实战优化指南
人工智能·机器人·音视频开发
张较瘦_1 小时前
[论文阅读] 人工智能 + 软件工程 | NoCode-bench:评估LLM无代码功能添加能力的新基准
论文阅读·人工智能·软件工程
go54631584651 小时前
Python点阵字生成与优化:从基础实现到高级渲染技术
开发语言·人工智能·python·深度学习·分类·数据挖掘
Coovally AI模型快速验证1 小时前
避开算力坑!无人机桥梁检测场景下YOLO模型选型指南
人工智能·深度学习·yolo·计算机视觉·目标跟踪·无人机
巫婆理发2221 小时前
神经网络(第二课第一周)
人工智能·深度学习·神经网络
欧阳小猜2 小时前
OpenCV-图像预处理➁【图像插值方法、边缘填充策略、图像矫正、掩膜应用、水印添加,图像的噪点消除】
人工智能·opencv·计算机视觉
旭日东升的xu.2 小时前
OpenCV(04)梯度处理,边缘检测,绘制轮廓,凸包特征检测,轮廓特征查找
人工智能·opencv·计算机视觉
liliangcsdn2 小时前
mac测试ollama llamaindex
数据仓库·人工智能·prompt·llama
qyhua2 小时前
Windows 平台源码部署 Dify教程(不依赖 Docker)
人工智能·windows·python