Pytorch导出FP16 ONNX模型

一般Pytorch导出ONNX时默认都是用的FP32,但有时需要导出FP16的ONNX模型,这样在部署时能够方便的将计算以及IO改成FP16,并且ONNX文件体积也会更小。想导出FP16的ONNX模型也比较简单,一般情况下只需要在导出FP32 ONNX的基础上调用下model.half()将模型相关权重转为FP16,然后输入的Tensor也改成FP16即可,具体操作可参考如下示例代码。这里需要注意下,当前Pytorch要导出FP16的ONNX必须将模型以及输入Tensor的device设置成GPU,否则会报很多算子不支持FP16计算的提示。

python 复制代码
import torch
from torchvision.models import resnet50


def main():
    export_fp16 = True
    export_onnx_path = f"resnet50_fp{16 if export_fp16 else 32}.onnx"
    device = torch.device("cuda:0")

    model = resnet50()
    model.eval()
    model.to(device)
    if export_fp16:
        model.half()

    with torch.inference_mode():
        dtype = torch.float16 if export_fp16 else torch.float32
        x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)
        torch.onnx.export(model=model,
                          args=(x,),
                          f=export_onnx_path,
                          input_names=["image"],
                          output_names=["output"],
                          dynamic_axes={"image": {2: "width", 3: "height"}},
                          opset_version=17)


if __name__ == '__main__':
    main()

通过Netron可视化工具可以看到导出的FP16 ONNX的输入/输出的tensor类型都是float16

并且通过对比可以看到,FP16的ONNX模型比FP32的文件更小(48.6MB vs 97.3MB )。

大多数情况可以按照上述操作进行正常转换,但也有一些比较头大的场景,因为你永远无法知道拿到的模型会有多奇葩,例如下面示例:
错误导出FP16 ONNX示例

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(3, 1, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.conv(x)

        kernel = torch.tensor([[0.1, 0.1, 0.1],
                               [0.1, 0.1, 0.1],
                               [0.1, 0.1, 0.1]], dtype=torch.float32, device=x.device).reshape([1, 1, 3, 3])
        x = F.conv2d(x, weight=kernel, bias=None, stride=1)

        return x


def main():
    export_fp16 = True
    export_onnx_path = f"my_model_fp{16 if export_fp16 else 32}.onnx"
    device = torch.device("cuda:0")

    model = MyModel()
    model.eval()
    model.to(device)
    if export_fp16:
        model.half()

    with torch.inference_mode():
        dtype = torch.float16 if export_fp16 else torch.float32
        x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)
        model(x)
        torch.onnx.export(model=model,
                          args=(x,),
                          f=export_onnx_path,
                          input_names=["image"],
                          output_names=["output"],
                          dynamic_axes={"image": {2: "width", 3: "height"}},
                          opset_version=17)


if __name__ == '__main__':
    main()

执行以上代码后会报如下错误信息:

/src/ATen/native/cudnn/Conv_v8.cpp:80.)
  return F.conv2d(input, weight, bias, self.stride,
Traceback (most recent call last):
  File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 47, in <module>
    main()
  File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 36, in main
    model(x)
  File "/home/wz/miniconda3/envs/torch2.0.1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 17, in forward
    x = F.conv2d(x, weight=kernel, bias=None, stride=1)
  RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

简单来说就是在推理过程中遇到两种不同类型的数据要计算,torch.cuda.HalfTensor(FP16) 和torch.cuda.FloatTensor(FP32)。遇到这种情况一般常见有两种解法:

  • 一种是找到数据类型与我们预期不一致的地方,然后改成我们要想的dtype,例如上面示例是将kernel的dtype写死成了torch.float32,我们可以改成torch.float16或者写成x.dtype(这种会比较通用,会根据输入的Tensor类型自动切换)。这种方法有个弊端,如果代码里写死dtype的位置很多,改起来会比较头大。
  • 另一种是使用torch.autocast上下文管理器,该上下文管理器能够实现推理过程中自动进行混合精度计算,例如遇到能进行float16/bfloat16计算的场景会自动切换。具体使用方法可以查看官方文档。下面示例代码就是用torch.autocast上下文管理器来做自动转换。
python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(3, 1, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.conv(x)

        kernel = torch.tensor([[0.1, 0.1, 0.1],
                               [0.1, 0.1, 0.1],
                               [0.1, 0.1, 0.1]], dtype=torch.float32, device=x.device).reshape([1, 1, 3, 3])
        x = F.conv2d(x, weight=kernel, bias=None, stride=1)

        return x


def main():
    export_fp16 = True
    export_onnx_path = f"my_model_fp{16 if export_fp16 else 32}.onnx"
    device = torch.device("cuda:0")

    model = MyModel()
    model.eval()
    model.to(device)
    if export_fp16:
        model.half()

    with torch.autocast(device_type="cuda", dtype=torch.float16):
        with torch.inference_mode():
            dtype = torch.float16 if export_fp16 else torch.float32
            x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)
            model(x)
            torch.onnx.export(model=model,
                              args=(x,),
                              f=export_onnx_path,
                              input_names=["image"],
                              output_names=["output"],
                              dynamic_axes={"image": {2: "width", 3: "height"}},
                              opset_version=17)


if __name__ == '__main__':
    main()

使用上述代码能够正常导出ONNX模型,并且使用Netron可视化后可以看到导出的FP16 ONNX模型是符合预期的。

相关推荐
天天代码码天天6 分钟前
C# OpenCvSharp 部署表格检测
人工智能·目标检测·表格检测
姓学名生7 分钟前
李沐vscode配置+github管理+FFmpeg视频搬运+百度API添加翻译字幕
vscode·python·深度学习·ffmpeg·github·视频
斯多葛的信徒11 分钟前
看看你的电脑可以跑 AI 模型吗?
人工智能·语言模型·电脑·llama
正在走向自律11 分钟前
AI 写作(六):核心技术与多元应用(6/10)
人工智能·aigc·ai写作
AI科技大本营12 分钟前
Anthropic四大专家“会诊”:实现深度思考不一定需要多智能体,AI完美对齐比失控更可怕!...
人工智能·深度学习
Cc不爱吃洋葱12 分钟前
如何本地部署AI智能体平台,带你手搓一个AI Agent
人工智能·大语言模型·agent·ai大模型·ai agent·智能体·ai智能体
网安打工仔12 分钟前
斯坦福李飞飞最新巨著《AI Agent综述》
人工智能·自然语言处理·大模型·llm·agent·ai大模型·大模型入门
AGI学习社13 分钟前
2024中国排名前十AI大模型进展、应用案例与发展趋势
linux·服务器·人工智能·华为·llama
AI_Tool13 分钟前
纳米AI搜索官网 - 新一代智能答案引擎
人工智能·搜索引擎
Damon小智13 分钟前
合合信息DocFlow产品解析与体验:人人可搭建的AI自动化单据处理工作流
图像处理·人工智能·深度学习·机器学习·ai·自动化·docflow