05_yolox_s的后处理截断并导出onnx

目的是得到以下模型:

1、

官方yolox_s的源码和yolox_s.pth获取

https://github.com/Megvii-BaseDetection/YOLOX

2、

修改yolo_head.py的forward,替换为以下

python 复制代码
    def forward(self, xin, labels=None, imgs=None):
        outputs = []

        for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
                zip(self.cls_convs, self.reg_convs, self.strides, xin)
        ):
            x = self.stems[k](x)

            cls_feat = cls_conv(x)
            reg_feat = reg_conv(x)

            cls_output = self.cls_preds[k](cls_feat)  # [B, C, H, W]
            reg_output = self.reg_preds[k](reg_feat)  # [B, 4, H, W]
            obj_output = self.obj_preds[k](reg_feat)  # [B, 1, H, W]

            # 🚨 关键:不要 decode,不要 concat
            outputs.append(reg_output)
            outputs.append(obj_output)
            outputs.append(cls_output)

        return outputs

3、

修改export_onnx.py的main()为以下

python 复制代码
def main():
    args = make_parser().parse_args()
    logger.info("args value: {}".format(args))
    exp = get_exp(args.exp_file, args.name)
    exp.merge(args.opts)

    if not args.experiment_name:
        args.experiment_name = exp.exp_name

    model = exp.get_model()
    if args.ckpt is None:
        file_name = os.path.join(exp.output_dir, args.experiment_name)
        ckpt_file = os.path.join(file_name, "best_ckpt.pth")
    else:
        ckpt_file = args.ckpt

    # load the model state dict
    ckpt = torch.load(ckpt_file, map_location="cpu")

    model.eval()
    if "model" in ckpt:
        ckpt = ckpt["model"]
    model.load_state_dict(ckpt)
    model = replace_module(model, nn.SiLU, SiLU)
    model.head.decode_in_inference = False

    logger.info("loading checkpoint done.")
    dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])

    output_names = []

    output_names = [
        "reg1", "obj1", "cls1",
        "reg2", "obj2", "cls2",
        "reg3", "obj3", "cls3",
    ]
    torch.onnx._export(
        model,
        dummy_input,
        args.output_name,
        input_names=[args.input],
        output_names=output_names,
        dynamic_axes={args.input: {0: 'batch'},
                      **{name: {0: 'batch'} for name in output_names}} if args.dynamic else None,
        opset_version=args.opset,
    )
    logger.info("generated onnx model named {}".format(args.output_name))

    if not args.no_onnxsim:
        import onnx
        from onnxsim import simplify

        # use onnx-simplifier to reduce reduent model.
        onnx_model = onnx.load(args.output_name)
        model_simp, check = simplify(onnx_model)
        assert check, "Simplified ONNX model could not be validated"
        onnx.save(model_simp, args.output_name)
        logger.info("generated simplified onnx model named {}".format(args.output_name))

4、

导出指令

bash 复制代码
python tools/export_onnx.py  -f exps/default/yolox_s.py  -c yolox_s.pth  --output-name yolox_s.onnx  --opset 12 --output .

上述完成就可得到需要的onnx

相关推荐
放下华子我只抽RuiKe59 分钟前
React 从入门到生产(四):自定义 Hook
前端·javascript·人工智能·深度学习·react.js·自然语言处理·前端框架
涛声依旧-底层原理研究所11 分钟前
残差连接与层归一化通俗易懂的详解
人工智能·python·神经网络·transformer
AI算法沐枫1 小时前
深度学习python代码处理科研测序数据
数据结构·人工智能·python·深度学习·决策树·机器学习·线性回归
初心未改HD2 小时前
深度学习之Attention注意力机制详解
人工智能·深度学习
code_pgf2 小时前
模态生成器:原理详解与推荐开源项目
人工智能·深度学习·开源
文歌子3 小时前
DeepEarth 深度解析:AI 如何理解地球的时空规律
深度学习
初心未改HD3 小时前
深度学习之Transformer架构详解
人工智能·深度学习·transformer
malog_3 小时前
大语言模型后训练全解析
人工智能·深度学习·机器学习·ai·语言模型
初心未改HD4 小时前
深度学习之LSTM与GRU门控循环单元详解
深度学习·gru·lstm