05_yolox_s的后处理截断并导出onnx

目的是得到以下模型:

1、

官方yolox_s的源码和yolox_s.pth获取

https://github.com/Megvii-BaseDetection/YOLOX

2、

修改yolo_head.py的forward,替换为以下

python 复制代码
    def forward(self, xin, labels=None, imgs=None):
        outputs = []

        for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
                zip(self.cls_convs, self.reg_convs, self.strides, xin)
        ):
            x = self.stems[k](x)

            cls_feat = cls_conv(x)
            reg_feat = reg_conv(x)

            cls_output = self.cls_preds[k](cls_feat)  # [B, C, H, W]
            reg_output = self.reg_preds[k](reg_feat)  # [B, 4, H, W]
            obj_output = self.obj_preds[k](reg_feat)  # [B, 1, H, W]

            # 🚨 关键:不要 decode,不要 concat
            outputs.append(reg_output)
            outputs.append(obj_output)
            outputs.append(cls_output)

        return outputs

3、

修改export_onnx.py的main()为以下

python 复制代码
def main():
    args = make_parser().parse_args()
    logger.info("args value: {}".format(args))
    exp = get_exp(args.exp_file, args.name)
    exp.merge(args.opts)

    if not args.experiment_name:
        args.experiment_name = exp.exp_name

    model = exp.get_model()
    if args.ckpt is None:
        file_name = os.path.join(exp.output_dir, args.experiment_name)
        ckpt_file = os.path.join(file_name, "best_ckpt.pth")
    else:
        ckpt_file = args.ckpt

    # load the model state dict
    ckpt = torch.load(ckpt_file, map_location="cpu")

    model.eval()
    if "model" in ckpt:
        ckpt = ckpt["model"]
    model.load_state_dict(ckpt)
    model = replace_module(model, nn.SiLU, SiLU)
    model.head.decode_in_inference = False

    logger.info("loading checkpoint done.")
    dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])

    output_names = []

    output_names = [
        "reg1", "obj1", "cls1",
        "reg2", "obj2", "cls2",
        "reg3", "obj3", "cls3",
    ]
    torch.onnx._export(
        model,
        dummy_input,
        args.output_name,
        input_names=[args.input],
        output_names=output_names,
        dynamic_axes={args.input: {0: 'batch'},
                      **{name: {0: 'batch'} for name in output_names}} if args.dynamic else None,
        opset_version=args.opset,
    )
    logger.info("generated onnx model named {}".format(args.output_name))

    if not args.no_onnxsim:
        import onnx
        from onnxsim import simplify

        # use onnx-simplifier to reduce reduent model.
        onnx_model = onnx.load(args.output_name)
        model_simp, check = simplify(onnx_model)
        assert check, "Simplified ONNX model could not be validated"
        onnx.save(model_simp, args.output_name)
        logger.info("generated simplified onnx model named {}".format(args.output_name))

4、

导出指令

bash 复制代码
python tools/export_onnx.py  -f exps/default/yolox_s.py  -c yolox_s.pth  --output-name yolox_s.onnx  --opset 12 --output .

上述完成就可得到需要的onnx

相关推荐
云上码厂1 小时前
2023年之前物理信息神经网络PINN papers
人工智能·深度学习·神经网络
A尘埃2 小时前
深度学习之神经网络简介(FNN+CNN+RNN+LSTM+GRU+GAN+GNN+Transformer)
深度学习·神经网络
纪伊路上盛名在2 小时前
Accurate structure prediction of biomolecular interactions with AlphaFold 3
深度学习·阅读·文献·结构·蛋白质
β添砖java4 小时前
深度学习(11)数值稳定+模型初始化、激活函数
人工智能·深度学习
九成宫4 小时前
动手学深度学习PyTorch版初步安装过程
人工智能·pytorch·深度学习
lwf0061645 小时前
DeepFM 学习日记
深度学习·机器学习
Narrastory5 小时前
Note:强化学习(六)
人工智能·深度学习·强化学习
Luca_kill5 小时前
GPT Image 2 深度评测:当 AI 图像生成跨越“图灵测试”,它如何重塑开发者工作流?
人工智能·深度学习·openai·ai图像生成·gpt image 2
小糖学代码6 小时前
LLM系列:1.python入门:16.正则表达式与文本处理 (re)
人工智能·pytorch·python·深度学习·神经网络·正则表达式