05_yolox_s的后处理截断并导出onnx

目的是得到以下模型:

1、

官方yolox_s的源码和yolox_s.pth获取

https://github.com/Megvii-BaseDetection/YOLOX

2、

修改yolo_head.py的forward,替换为以下

python 复制代码
    def forward(self, xin, labels=None, imgs=None):
        outputs = []

        for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
                zip(self.cls_convs, self.reg_convs, self.strides, xin)
        ):
            x = self.stems[k](x)

            cls_feat = cls_conv(x)
            reg_feat = reg_conv(x)

            cls_output = self.cls_preds[k](cls_feat)  # [B, C, H, W]
            reg_output = self.reg_preds[k](reg_feat)  # [B, 4, H, W]
            obj_output = self.obj_preds[k](reg_feat)  # [B, 1, H, W]

            # 🚨 关键:不要 decode,不要 concat
            outputs.append(reg_output)
            outputs.append(obj_output)
            outputs.append(cls_output)

        return outputs

3、

修改export_onnx.py的main()为以下

python 复制代码
def main():
    args = make_parser().parse_args()
    logger.info("args value: {}".format(args))
    exp = get_exp(args.exp_file, args.name)
    exp.merge(args.opts)

    if not args.experiment_name:
        args.experiment_name = exp.exp_name

    model = exp.get_model()
    if args.ckpt is None:
        file_name = os.path.join(exp.output_dir, args.experiment_name)
        ckpt_file = os.path.join(file_name, "best_ckpt.pth")
    else:
        ckpt_file = args.ckpt

    # load the model state dict
    ckpt = torch.load(ckpt_file, map_location="cpu")

    model.eval()
    if "model" in ckpt:
        ckpt = ckpt["model"]
    model.load_state_dict(ckpt)
    model = replace_module(model, nn.SiLU, SiLU)
    model.head.decode_in_inference = False

    logger.info("loading checkpoint done.")
    dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])

    output_names = []

    output_names = [
        "reg1", "obj1", "cls1",
        "reg2", "obj2", "cls2",
        "reg3", "obj3", "cls3",
    ]
    torch.onnx._export(
        model,
        dummy_input,
        args.output_name,
        input_names=[args.input],
        output_names=output_names,
        dynamic_axes={args.input: {0: 'batch'},
                      **{name: {0: 'batch'} for name in output_names}} if args.dynamic else None,
        opset_version=args.opset,
    )
    logger.info("generated onnx model named {}".format(args.output_name))

    if not args.no_onnxsim:
        import onnx
        from onnxsim import simplify

        # use onnx-simplifier to reduce reduent model.
        onnx_model = onnx.load(args.output_name)
        model_simp, check = simplify(onnx_model)
        assert check, "Simplified ONNX model could not be validated"
        onnx.save(model_simp, args.output_name)
        logger.info("generated simplified onnx model named {}".format(args.output_name))

4、

导出指令

bash 复制代码
python tools/export_onnx.py  -f exps/default/yolox_s.py  -c yolox_s.pth  --output-name yolox_s.onnx  --opset 12 --output .

上述完成就可得到需要的onnx

相关推荐
Lihua奏4 天前
从单核到多核:CPU为什么不能再只靠提频变快
深度学习
拾年2754 天前
大模型的"聪明"从哪来?聊聊 AI 数据集的那些事儿
人工智能·深度学习·机器学习
hboot4 天前
AI工程师第四课 - 深度学习入门
pytorch·python·神经网络
饼干哥哥8 天前
开源Skills|搭建亚马逊动态关键词库系统,每天抓SSS级机会词
人工智能·深度学习·数据分析
武子康10 天前
调查研究-191 SenseVoice 不只是 ASR:把语音从“转文字“升级成“理解状态“
人工智能·深度学习·openai
武子康11 天前
调查研究-189 Kronos 调研:金融 K 线基础模型,是真突破,还是量化圈的新玩具?
人工智能·深度学习·openai
xiao5kou4chang6kai417 天前
MATLAB机器学习、深度学习--从数据预处理到模型训练
深度学习·机器学习·matlab·数据预处理
renhongxia117 天前
世界模型作为AGI落地底层底座的作用
人工智能·深度学习·生成对抗网络·自然语言处理·知识图谱·agi
计算机科研狗@OUC17 天前
(cvpr26) AIMDepth: Asymmetric Image-Event Mamba for Monocular Depth Estimation
人工智能·深度学习·计算机视觉
β添砖java17 天前
深度学习(22)网络中的网络NiN
人工智能·深度学习