目的是得到以下模型:

1、
官方yolox_s的源码和yolox_s.pth获取
https://github.com/Megvii-BaseDetection/YOLOX

2、
修改yolo_head.py的forward,替换为以下
python
def forward(self, xin, labels=None, imgs=None):
outputs = []
for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
zip(self.cls_convs, self.reg_convs, self.strides, xin)
):
x = self.stems[k](x)
cls_feat = cls_conv(x)
reg_feat = reg_conv(x)
cls_output = self.cls_preds[k](cls_feat) # [B, C, H, W]
reg_output = self.reg_preds[k](reg_feat) # [B, 4, H, W]
obj_output = self.obj_preds[k](reg_feat) # [B, 1, H, W]
# 🚨 关键:不要 decode,不要 concat
outputs.append(reg_output)
outputs.append(obj_output)
outputs.append(cls_output)
return outputs
3、
修改export_onnx.py的main()为以下
python
def main():
args = make_parser().parse_args()
logger.info("args value: {}".format(args))
exp = get_exp(args.exp_file, args.name)
exp.merge(args.opts)
if not args.experiment_name:
args.experiment_name = exp.exp_name
model = exp.get_model()
if args.ckpt is None:
file_name = os.path.join(exp.output_dir, args.experiment_name)
ckpt_file = os.path.join(file_name, "best_ckpt.pth")
else:
ckpt_file = args.ckpt
# load the model state dict
ckpt = torch.load(ckpt_file, map_location="cpu")
model.eval()
if "model" in ckpt:
ckpt = ckpt["model"]
model.load_state_dict(ckpt)
model = replace_module(model, nn.SiLU, SiLU)
model.head.decode_in_inference = False
logger.info("loading checkpoint done.")
dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])
output_names = []
output_names = [
"reg1", "obj1", "cls1",
"reg2", "obj2", "cls2",
"reg3", "obj3", "cls3",
]
torch.onnx._export(
model,
dummy_input,
args.output_name,
input_names=[args.input],
output_names=output_names,
dynamic_axes={args.input: {0: 'batch'},
**{name: {0: 'batch'} for name in output_names}} if args.dynamic else None,
opset_version=args.opset,
)
logger.info("generated onnx model named {}".format(args.output_name))
if not args.no_onnxsim:
import onnx
from onnxsim import simplify
# use onnx-simplifier to reduce reduent model.
onnx_model = onnx.load(args.output_name)
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, args.output_name)
logger.info("generated simplified onnx model named {}".format(args.output_name))
4、
导出指令
bash
python tools/export_onnx.py -f exps/default/yolox_s.py -c yolox_s.pth --output-name yolox_s.onnx --opset 12 --output .
上述完成就可得到需要的onnx