Pytorch导出FP16 ONNX模型

一般Pytorch导出ONNX时默认都是用的FP32,但有时需要导出FP16的ONNX模型,这样在部署时能够方便的将计算以及IO改成FP16,并且ONNX文件体积也会更小。想导出FP16的ONNX模型也比较简单,一般情况下只需要在导出FP32 ONNX的基础上调用下model.half()将模型相关权重转为FP16,然后输入的Tensor也改成FP16即可,具体操作可参考如下示例代码。这里需要注意下,当前Pytorch要导出FP16的ONNX必须将模型以及输入Tensor的device设置成GPU,否则会报很多算子不支持FP16计算的提示。

python 复制代码
import torch
from torchvision.models import resnet50


def main():
    export_fp16 = True
    export_onnx_path = f"resnet50_fp{16 if export_fp16 else 32}.onnx"
    device = torch.device("cuda:0")

    model = resnet50()
    model.eval()
    model.to(device)
    if export_fp16:
        model.half()

    with torch.inference_mode():
        dtype = torch.float16 if export_fp16 else torch.float32
        x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)
        torch.onnx.export(model=model,
                          args=(x,),
                          f=export_onnx_path,
                          input_names=["image"],
                          output_names=["output"],
                          dynamic_axes={"image": {2: "width", 3: "height"}},
                          opset_version=17)


if __name__ == '__main__':
    main()

通过Netron可视化工具可以看到导出的FP16 ONNX的输入/输出的tensor类型都是float16

并且通过对比可以看到,FP16的ONNX模型比FP32的文件更小(48.6MB vs 97.3MB )。

大多数情况可以按照上述操作进行正常转换,但也有一些比较头大的场景,因为你永远无法知道拿到的模型会有多奇葩,例如下面示例:
错误导出FP16 ONNX示例

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(3, 1, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.conv(x)

        kernel = torch.tensor([[0.1, 0.1, 0.1],
                               [0.1, 0.1, 0.1],
                               [0.1, 0.1, 0.1]], dtype=torch.float32, device=x.device).reshape([1, 1, 3, 3])
        x = F.conv2d(x, weight=kernel, bias=None, stride=1)

        return x


def main():
    export_fp16 = True
    export_onnx_path = f"my_model_fp{16 if export_fp16 else 32}.onnx"
    device = torch.device("cuda:0")

    model = MyModel()
    model.eval()
    model.to(device)
    if export_fp16:
        model.half()

    with torch.inference_mode():
        dtype = torch.float16 if export_fp16 else torch.float32
        x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)
        model(x)
        torch.onnx.export(model=model,
                          args=(x,),
                          f=export_onnx_path,
                          input_names=["image"],
                          output_names=["output"],
                          dynamic_axes={"image": {2: "width", 3: "height"}},
                          opset_version=17)


if __name__ == '__main__':
    main()

执行以上代码后会报如下错误信息:

/src/ATen/native/cudnn/Conv_v8.cpp:80.)
  return F.conv2d(input, weight, bias, self.stride,
Traceback (most recent call last):
  File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 47, in <module>
    main()
  File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 36, in main
    model(x)
  File "/home/wz/miniconda3/envs/torch2.0.1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 17, in forward
    x = F.conv2d(x, weight=kernel, bias=None, stride=1)
  RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

简单来说就是在推理过程中遇到两种不同类型的数据要计算,torch.cuda.HalfTensor(FP16) 和torch.cuda.FloatTensor(FP32)。遇到这种情况一般常见有两种解法:

  • 一种是找到数据类型与我们预期不一致的地方,然后改成我们要想的dtype,例如上面示例是将kernel的dtype写死成了torch.float32,我们可以改成torch.float16或者写成x.dtype(这种会比较通用,会根据输入的Tensor类型自动切换)。这种方法有个弊端,如果代码里写死dtype的位置很多,改起来会比较头大。
  • 另一种是使用torch.autocast上下文管理器,该上下文管理器能够实现推理过程中自动进行混合精度计算,例如遇到能进行float16/bfloat16计算的场景会自动切换。具体使用方法可以查看官方文档。下面示例代码就是用torch.autocast上下文管理器来做自动转换。
python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(3, 1, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.conv(x)

        kernel = torch.tensor([[0.1, 0.1, 0.1],
                               [0.1, 0.1, 0.1],
                               [0.1, 0.1, 0.1]], dtype=torch.float32, device=x.device).reshape([1, 1, 3, 3])
        x = F.conv2d(x, weight=kernel, bias=None, stride=1)

        return x


def main():
    export_fp16 = True
    export_onnx_path = f"my_model_fp{16 if export_fp16 else 32}.onnx"
    device = torch.device("cuda:0")

    model = MyModel()
    model.eval()
    model.to(device)
    if export_fp16:
        model.half()

    with torch.autocast(device_type="cuda", dtype=torch.float16):
        with torch.inference_mode():
            dtype = torch.float16 if export_fp16 else torch.float32
            x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)
            model(x)
            torch.onnx.export(model=model,
                              args=(x,),
                              f=export_onnx_path,
                              input_names=["image"],
                              output_names=["output"],
                              dynamic_axes={"image": {2: "width", 3: "height"}},
                              opset_version=17)


if __name__ == '__main__':
    main()

使用上述代码能够正常导出ONNX模型,并且使用Netron可视化后可以看到导出的FP16 ONNX模型是符合预期的。

相关推荐
sp_fyf_202417 分钟前
【大语言模型】ACL2024论文-19 SportsMetrics: 融合文本和数值数据以理解大型语言模型中的信息融合
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理
CoderIsArt19 分钟前
基于 BP 神经网络整定的 PID 控制
人工智能·深度学习·神经网络
编程修仙28 分钟前
Collections工具类
linux·windows·python
开源社33 分钟前
一场开源视角的AI会议即将在南京举办
人工智能·开源
FreeIPCC33 分钟前
谈一下开源生态对 AI人工智能大模型的促进作用
大数据·人工智能·机器人·开源
芝麻团坚果44 分钟前
对subprocess启动的子进程使用VSCode python debugger
linux·ide·python·subprocess·vscode debugger
机器之心1 小时前
全球十亿级轨迹点驱动,首个轨迹基础大模型来了
人工智能·后端
z千鑫1 小时前
【人工智能】PyTorch、TensorFlow 和 Keras 全面解析与对比:深度学习框架的终极指南
人工智能·pytorch·深度学习·aigc·tensorflow·keras·codemoss
EterNity_TiMe_1 小时前
【论文复现】神经网络的公式推导与代码实现
人工智能·python·深度学习·神经网络·数据分析·特征分析
Stara05111 小时前
Git推送+拉去+uwsgi+Nginx服务器部署项目
git·python·mysql·nginx·gitee·github·uwsgi