1. 安装
坑非常多,清华源阿里源都不行。使用官网源下载,这里的121可以改成你需要的东西:
python -m pip install torch torch-tensorrt tensorrt --extra-index-url https://download.pytorch.org/whl/cu121
2. 原理
我们来看一个实例:这是一个用于支持 torchscript 到 TensorRT 转换的项目。上面的代码用于将 addmm 运算展开成数个算子,方便后续映射 TensorRT 算子。
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
// TensorRT implicitly adds a flatten layer in front of FC layers if necessary
// 用于匹配的模式
std::string addmm_pattern = R"IR(
graph(%b, %x, %w, %beta, %alpha):
%out: Tensor = aten::addmm(%b, %x, %w, %beta, %alpha)
return (%out))IR";
// 用于替换的模式
std::string mm_add_pattern = R"IR(
graph(%b, %x, %w, %beta, %alpha):
%mm: Tensor = aten::matmul(%x, %w)
%bias: Tensor = aten::mul(%b, %beta)
%out: Tensor = aten::add(%bias, %mm, %alpha)
return (%out))IR";
// 创建子图重写器并注册匹配模式和替换模式
torch::jit::SubgraphRewriter unpack_addmm;
unpack_addmm.RegisterRewritePattern(addmm_pattern, mm_add_pattern);
// 遍历graph,完成重写
unpack_addmm.runOnGraph(graph);
LOG_GRAPH("Post unpack addmm: " << *graph);
}
3. 简单例子
import torch
def origin_func(x):
x = x**2
x = x**3
return x
x = torch.rand(1, 2, 3, 4)
jit_model = torch.jit.trace(origin_func, x)
print(jit_model.graph)
# 匹配用的子图定义,注意常量必须为[value=2]属性
pattern = """
graph(%x):
%const_2 = prim::Constant[value=2]()
%out = aten::pow(%x, %const_2)
return (%out)
"""
# 替换用的子图定义
replacement = """
graph(%x):
%out = aten::mul(%x, %x)
return (%out)
"""
torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,jit_model.graph)
print(jit_model.graph)