05_yolox_s的后处理截断并导出onnx

目的是得到以下模型:

1、

官方yolox_s的源码和yolox_s.pth获取

https://github.com/Megvii-BaseDetection/YOLOX

2、

修改yolo_head.py的forward,替换为以下

python 复制代码
    def forward(self, xin, labels=None, imgs=None):
        outputs = []

        for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
                zip(self.cls_convs, self.reg_convs, self.strides, xin)
        ):
            x = self.stems[k](x)

            cls_feat = cls_conv(x)
            reg_feat = reg_conv(x)

            cls_output = self.cls_preds[k](cls_feat)  # [B, C, H, W]
            reg_output = self.reg_preds[k](reg_feat)  # [B, 4, H, W]
            obj_output = self.obj_preds[k](reg_feat)  # [B, 1, H, W]

            # 🚨 关键:不要 decode,不要 concat
            outputs.append(reg_output)
            outputs.append(obj_output)
            outputs.append(cls_output)

        return outputs

3、

修改export_onnx.py的main()为以下

python 复制代码
def main():
    args = make_parser().parse_args()
    logger.info("args value: {}".format(args))
    exp = get_exp(args.exp_file, args.name)
    exp.merge(args.opts)

    if not args.experiment_name:
        args.experiment_name = exp.exp_name

    model = exp.get_model()
    if args.ckpt is None:
        file_name = os.path.join(exp.output_dir, args.experiment_name)
        ckpt_file = os.path.join(file_name, "best_ckpt.pth")
    else:
        ckpt_file = args.ckpt

    # load the model state dict
    ckpt = torch.load(ckpt_file, map_location="cpu")

    model.eval()
    if "model" in ckpt:
        ckpt = ckpt["model"]
    model.load_state_dict(ckpt)
    model = replace_module(model, nn.SiLU, SiLU)
    model.head.decode_in_inference = False

    logger.info("loading checkpoint done.")
    dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])

    output_names = []

    output_names = [
        "reg1", "obj1", "cls1",
        "reg2", "obj2", "cls2",
        "reg3", "obj3", "cls3",
    ]
    torch.onnx._export(
        model,
        dummy_input,
        args.output_name,
        input_names=[args.input],
        output_names=output_names,
        dynamic_axes={args.input: {0: 'batch'},
                      **{name: {0: 'batch'} for name in output_names}} if args.dynamic else None,
        opset_version=args.opset,
    )
    logger.info("generated onnx model named {}".format(args.output_name))

    if not args.no_onnxsim:
        import onnx
        from onnxsim import simplify

        # use onnx-simplifier to reduce reduent model.
        onnx_model = onnx.load(args.output_name)
        model_simp, check = simplify(onnx_model)
        assert check, "Simplified ONNX model could not be validated"
        onnx.save(model_simp, args.output_name)
        logger.info("generated simplified onnx model named {}".format(args.output_name))

4、

导出指令

bash 复制代码
python tools/export_onnx.py  -f exps/default/yolox_s.py  -c yolox_s.pth  --output-name yolox_s.onnx  --opset 12 --output .

上述完成就可得到需要的onnx

相关推荐
戴西软件21 分钟前
戴西 DLM 许可授权管理系统:破解无网络环境下工业软件授权难题,助力制造企业降本增效
网络·人工智能·python·深度学习·程序人生·算法·制造
Black蜡笔小新33 分钟前
制造业AI质检工作站/企业AI算力工作站DLTM助力制造业质检智能化升级
人工智能·深度学习·机器学习
渡之3 小时前
GRiM-Net 深度解析 | 无人机 GNSS 拒止场景下两阶段跨视角视觉定位框架
深度学习·算法·动态规划·无人机
code_pgf4 小时前
mllm训练过程中有效地利用辅助监督信号来减少幻觉的方法
人工智能·深度学习·计算机视觉
装不满的克莱因瓶4 小时前
自然语言处理常见任务——从文本理解到生成式AI的完整任务体系
人工智能·pytorch·python·深度学习·ai·自然语言处理
炎武丶航4 小时前
LeNet-5深度学习详解:从手写数字识别到代码实战
人工智能·python·深度学习·机器学习·ai·cnn·lenet
湘美书院--湘美谈教育5 小时前
湘美谈教育AI赋能系列经验集锦:学好唐诗宋词的点滴心得体会
大数据·人工智能·深度学习·神经网络·机器学习
卷Java6 小时前
混合检索让RAG召回率从62%干到89%
深度学习
装不满的克莱因瓶6 小时前
掌握生成对抗网络(GAN)的优化目标与评估指标——从博弈函数到生成质量衡量体系
人工智能·python·深度学习·算法·机器学习