解决低版本pytorch和onnx组合时torch.atan2()不被onnx支持的问题

解决这个问题,最简单的当然是升级pytorch和onnx到比较高的版本,例如有人验证过的组合: pytorch=2.1.1+cu118, onnxruntime=1.16.3

但是因为你的模型或cuda环境等约束,不能安装这么高的版本的pytorch和onnx组合时(例如我的环境是pytorch1.12,onnxruntime=1.19.2,即使onnxruntime版本比较高但是Pytorch的版本底也照样报这个错: tan2 to ONNX opset version 16 is not supported),那就只能考虑自己基于torch.atan()实现torch.atan2()的功能了。

torch.atan()因为不能区分坐标落在哪个像限,如果直接用来计算物体的朝向角的话,因为可能这个缺陷导致计算出来的物体的朝向是完全相反的方向,或者朝向沿着x轴做对称翻转了,例如

torch.atan(-1/10)和torch.atan(1/(-10))是没有区别的,atan2()就是为了解决这种问题的,atan2()的实现原理大致如下图所示:

有人基于atan()对atan2()做了如下近似实现:

torch.atan((rot_sine / (rot_cosine + 1e-8)).sigmoid())                
+ ((1 - torch.sign(rot_cosine)) / 2) * torch.sign(rot_sine) * torch.pi

显然后半部分根据x和y的正负做加/减torch.pi是正确的,但是前半部分对(y/x)做sigmoid()运算把值一律转到不带符号的(0,1)之间在有的情况下是有一定误差的。

网上没找到其他更好的实现,于是我基于上面图中的计算规则做了如下实现:

            rot_sine = bboxes[..., 6:7]
            rot_cosine = bboxes[..., 7:8]
           
            idx = torch.where(rot_cosine == 0)
            rot_cosine[idx] = 1e-6

            rot = torch.atan(rot_sine / rot_cosine)

            mask_yp = (rot_sine >= 0) & (rot_cosine < 0)
            mask_yn = (rot_sine < 0) & (rot_cosine < 0)
            idx_yp = torch.where(mask_yp)
            idx_yn = torch.where(mask_yn)
            rot[idx_yp] += torch.pi
            rot[idx_yn] -= torch.pi

用数据测试发现上述计算步骤计算出的结果和torch.atan2()计算出来是一致的,仅当x==0(或者说上面的rot_cosine==0)时,用小量1e-6代替0,计算出的角度和正负torch.pi/2可能有点很细微差异而已,这基本不影响物体朝向的正确性。

将上述实现封装成函数替代调用处的torch.atan2(),导出onnx就可以顺利成功了。

相关推荐
z千鑫1 小时前
【人工智能】PyTorch、TensorFlow 和 Keras 全面解析与对比:深度学习框架的终极指南
人工智能·pytorch·深度学习·aigc·tensorflow·keras·codemoss
学不会lostfound3 小时前
三、计算机视觉_05MTCNN人脸检测
pytorch·深度学习·计算机视觉·mtcnn·p-net·r-net·o-net
Mr.谢尔比4 小时前
李宏毅机器学习课程知识点摘要(1-5集)
人工智能·pytorch·深度学习·神经网络·算法·机器学习·计算机视觉
做程序员的第一天5 小时前
在PyTorch中,钩子(hook)是什么?在神经网络中扮演什么角色?
pytorch·python·深度学习
Nerinic5 小时前
PyTorch基础2
pytorch·python
曼城周杰伦6 小时前
自然语言处理:第六十二章 KAG 超越GraphRAG的图谱框架
人工智能·pytorch·神经网络·自然语言处理·chatgpt·nlp·gpt-3
Joyner20186 小时前
pytorch训练的双卡,一个显卡占有20GB,另一个卡占有8GB,怎么均衡?
人工智能·pytorch·python
AI视觉网奇6 小时前
pytorch3d linux安装
linux·人工智能·pytorch
诚威_lol_中大努力中7 小时前
pytorch多个GPU并行使用示例:from zhouyifan
pytorch
Jamence1 天前
torch.utils.data.dataset 的数据组织形式——python list、dict、tuple内存消耗量
开发语言·人工智能·pytorch·python