FS-SAM2微调和推理加速

模型架构与数据流

FS-SAM2 由 4 个子模型串联构成 few-shot 视频分割 pipeline:

plain 复制代码
image_encoder ──┬──> memory_encoder ──> memory_attention ──> image_decoder ──> pred_mask
                │         ↑ (support mask)         ↑ (memory)        ↑ (prompt)
                └─────────┴────────────────────────┘
子模型 功能 FP32 大小 输入 输出
image_encoder 图像特征提取 (Hiera backbone) 265 MB image [1,3,1024,1024] pix_feat, high_res_feats, vision_feats, vision_pos_embed
memory_encoder 编码 support memory 5.4 MB mask_for_mem [1,1,1024,1024], pix_feat [1,256,64,64] maskmem_features, maskmem_pos_enc, temporal_code
memory_attention 融合当前帧特征与 memory 31 MB current_vision_feat, pos_embed, memory_0/1, memory_pos_embed image_embed [1,256,64,64]
image_decoder 生成分割 mask 21 MB point_coords, point_labels, image_embed, high_res_feats obj_ptr, mask_for_mem, pred_mask

FS-SAM2 训练策略(LoRA 微调版)

背景与目标

  • 面向少样本 / 单类别分割,最小化训练开销同时提升泛化。
  • 仅训练少量 LoRA 参数,保持 SAM2 主干稳定,避免小数据过拟合。

LoRA 设计与可训练参数

  • 作用模块
    • model.image_encoderqkv, proj
    • model.memory_attentionq_proj, v_proj, k_proj, out_proj
    • model.memory_encoderout_proj
  • 配置(见 train.py
    • image encoder:r=4, alpha=16, dropout=0.1, bias=none
    • memory attention / encoder:r=32, alpha=16, dropout=0.1, bias=none
  • 冻结策略
    • 除上述 LoRA 注入的权重外,其余全部 requires_grad=False
  • 规模
    • 可训练参数约 75 万,其余主干保持冻结。

数据与 Episode 采样

  • 数据格式
    • COCO 风格,train2014 / val2014,mask 为 PNG。
    • 单类时:mask 非零视为前景(兼容 0/255)。
  • 类与 fold
    • 保留 4-fold 结构以兼容 COCO-20i;单类数据各 fold 列表相同。
    • 单类时自动 nclass=1,类 id 固定为 0。
  • Episode 流程
    • 采样 1 个类 → 1 张 query + k 张 support(默认 kshot=1)。
    • 数据增强:训练使用随机旋转、高斯模糊、翻转等(见 data/dataset.py)。
  • 尺寸
    • 输入原始 1280x720,训练统一 resize 到 img_size=1024(可调)。

前向与损失

  1. 支持集前向:逐张 support (图+mask) 送入,累积 memory features。
  2. 查询前向:复用 memory,预测 query mask。
  3. 损失函数:BCEWithLogits + Dice(对 logits 与 GT mask)。
  4. 自动混合精度:torch.autocast(device_type='cuda', dtype=torch.bfloat16)

训练配置(默认)

  • 优化器:AdamW
    • 两组参数:多维权重带 weight decay,1 维/bias 不 decay。
    • lr=1e-4, weight_decay=1e-6, betas=(0.9, 0.999).
  • 学习率调度:CosineAnnealingLR(T_max = epochs * steps_per_epoch).
  • 批大小:bsz=4(单卡,可按显存调整)。
  • 轮数:epochs=50
  • 精度:bfloat16 混合精度。
  • 设备:单 GPU。若使用 DDP,自行设置 torch.distributed.run,并保证 fold 与 sampler 对齐。

日志与模型保存

  • 记录指标:Avg Loss、mIoU(前景)、FB-IoU(前景/背景 IoU 均值)、mF1。
  • 按验证集 mIoU 最高保存 best_model.pt
  • TensorBoard:loss / mIoU / FB-IoU 曲线。

单类数据集的特别处理

  • mask 二值化:非零即前景(0/255 → 0/1)。
  • 类集合:检测到仅 1 类时,class_ids=[0],避免 80 类抽样导致空类或 KeyError。
  • 数据列表:data_list_* / sub_class_file_list_* 需与文件名一致,可用 generate_data_list.py 自动生成。

运行示例

bash 复制代码
python train.py \
  --datapath ./dataset_blood \
  --benchmark coco \
  --fold 0 \
  --bsz 4 \
  --epochs 50

单类场景下 fold 取 0/1/2/3 结果一致。

关键收益

  • 训练开销小:仅 ~75 万可训练参数。
  • 稳定性高:主干冻结,降低过拟合风险。
  • Few-shot 泛化:episode + memory 机制,mIoU/FB-IoU 直接反映跨图像的类泛化能力。

onnx拆分导出方法

将 SAM2 导出为 ONNX 的方法,供本项目参考。


一、整体架构

将 SAM2 视频分割模型拆分为 4 个 ONNX 子模块

序号 模块名 ONNX 文件 功能
1 ImageEncoder image_encoder.onnx 图像编码,输出 backbone 特征与位置编码
2 ImageDecoder image_decoder.onnx 点 prompt + mask decoder 一体化,输出 mask、obj_ptr 等
3 MemAttention memory_attention.onnx 记忆注意力,融合当前帧与历史记忆
4 MemEncoder memory_encoder.onnx 记忆编码,将 mask 编码为 memory 特征
1.1 为什么要拆成这 4 个 ONNX?

从 SAM2 的原始架构来看,视频分割主干可以自然地拆成 4 个阶段:

  1. ImageEncoder(图像编码)
    • 对每一帧做 backbone + FPN 编码,得到多尺度特征(pix_feat、high_res_feats_0/1、vision_feats、vision_pos_embed)。
    • 计算量大、结构规则,最适合作为 单独的 TRT INT8/FP16 引擎
  2. MemEncoder(记忆编码)
    • 把 support/历史帧上的 mask + pix_feat 编码为 memory_0/memory_1(maskmem_features 等)。
    • 输入输出固定,接口简单,适合作为 独立的小引擎
  3. MemAttention(记忆注意力)
    • 将当前帧特征 + 历史 memory 做 cross-attention,输出 image_embed。
    • 本质是一个带 RoPE 的 cross-attention block,算子相对集中,但有 RoPE 复杂度(需额外补丁,见后文)。
  4. ImageDecoder(解码 + prompt)
    • 内部打包了 prompt encoder(点/框)+ mask decoder,一次性输出 mask_for_mem + pred_mask + obj_ptr。
    • 这样推理 pipeline 只需传入点坐标/标签和 encoder/memory 的输出,而不用再单独处理 prompt_encoder。

拆成这 4 个子图的好处:

  • 完全覆盖 SAM2 推理主路径image_encoder → mem_encoder → mem_attention → image_decoder 就是视频分割 pipeline 的骨架。
  • 易于部署与量化:每个子模型的输入输出都清晰、规模适中,便于做 INT8/FP16 量化和独立 TRT 编译。
  • 避免"超大单图"问题:若把所有模块导成一个巨型 ONNX,不利于调试、量化、TRT tactic 搜索,也不方便后续做分模块优化。

二、核心思路

2.1 模块封装(src/Module.py)

通过 Wrapper 类 将 SAM2 原始子模块封装为可单独导出的 nn.Module,每个 Wrapper 只暴露必要的输入输出,便于 ONNX 导出。

2.2 模型加载
python 复制代码
from sam2.build_sam import build_sam2

sam2_model = build_sam2(args.config, args.checkpoint, device="cpu")
  • 使用 Hydra 配置 + 实例化
  • 支持 tiny / small / large / base+ 四种配置
2.3 导出顺序
plain 复制代码
1. export_image_encoder(image_encoder, outdir)
2. export_image_decoder(image_decoder, outdir)
3. export_memory_attention(mem_attention, outdir)
4. export_memory_encoder(mem_encoder, outdir)

三、各模块导出详情

3.1 ImageEncoder

输入:

  • image: [1, 3, 1024, 1024] float

输出:

  • pix_feat: [1, 256, 64, 64] 像素特征
  • high_res_feat0: [1, 32, 256, 256] 高分辨率特征 0
  • high_res_feat1: [1, 64, 128, 128] 高分辨率特征 1
  • vision_feats: 视觉特征(与 pix_feat 相关)
  • vision_pos_embed: [4096, 1, 256] 位置编码

实现要点:

  • ImageEncoder 内部:image_encoder_prepare_backbone_features → conv_s0/conv_s1 处理
  • 使用 do_constant_folding=Trueopset_version=17
  • 使用 onnxsim(注释说明 simplify 会简化掉部分输出)

3.2 ImageDecoder

输入:

  • point_coords: [num_labels, num_points, 2] 点坐标
  • point_labels: [num_labels, num_points] 点标签
  • image_embed: [1, 256, 64, 64] 来自 memory_attention 的输出
  • high_res_feats_0: [1, 32, 256, 256]
  • high_res_feats_1: [1, 64, 128, 128]

输出:

  • obj_ptr: [1, 256] 对象指针 token
  • mask_for_mem: [1, 1, 1024, 1024] 供 memory encoder 使用的 mask
  • pred_mask: [1, 1, 1024, 1024] 最终预测 mask

实现要点:

  • 内部调用 _forward_sam_heads,包含 prompt encoder + mask decoder
  • high_res_masks 做 sigmoid + scale + bias,再 fill_holes_in_mask_scores
  • 使用 dynamic_axes 支持 point_coords / point_labels 的 batch 维度
  • 使用 onnxsim 简化

3.3 MemAttention

输入:

  • current_vision_feat: [1, 256, 64, 64] 当前帧视觉特征
  • current_vision_pos_embed: [4096, 1, 256] 当前帧位置编码
  • memory_0: [num_obj_ptr, 256] 对象指针 token
  • memory_1: [buff_size, 64, 64, 64] 历史帧 mask 特征
  • memory_pos_embed: [dynamic, 1, 64] 记忆位置编码

输出:

  • image_embed: [1, 256, 64, 64] 融合记忆后的图像 embedding

实现要点:

  • 内部 reshape:memory_0[4*num_obj_ptr, 1, 64]memory_1[buff_size*4096, 1, 64]
  • memory 与 curr 做 cross-attention,输出 reshape 为 [1, 256, 64, 64]
  • dynamic_axes: memory_0(num), memory_1(buff_size), memory_pos_embed(dynamic)
  • 使用 onnxsim 简化

3.4 MemEncoder

输入:

  • mask_for_mem: [1, 1, 1024, 1024] 高分辨率 mask
  • pix_feat: [1, 256, 64, 64] 像素特征(来自 image_encoder)

输出:

  • maskmem_features: 记忆特征
  • maskmem_pos_enc: 记忆位置编码
  • temporal_code: 时间编码

实现要点:

  • 内部调用 _encode_new_memory,传入 pred_masks_high_rescurrent_vision_feats
  • 使用 opset_version=17do_constant_folding=True
  • 不启用 onnxsim

五、目录结构

plain 复制代码
ONNXExport/
├── export_onnx.py          # 主导出脚本
├── src/
│   └── Module.py           # ImageEncoder, MemAttention, MemEncoder, ImageDecoder
├── sam2/
│   ├── build_sam.py        # 模型构建
│   ├── sam2_video_predictor.py
│   ├── sam2_image_predictor.py
│   └── modeling/
│       ├── sam2_base.py
│       ├── memory_attention.py
│       ├── memory_encoder.py
│       └── ...
├── sam2_configs/
│   ├── sam2_hiera_tiny.yaml
│   ├── sam2_hiera_small.yaml
│   ├── sam2_hiera_large.yaml
│   └── sam2_hiera_base+.yaml
└── checkpoints/
    └── download_ckpts.sh

六、使用方式

bash 复制代码
## 1. 下载 checkpoint
cd checkpoints && ./download_ckpts.sh
mkdir base+ large small tiny

## 2. 导出
python export_onnx.py --outdir checkpoints/base+/ --config sam2_configs/sam2_hiera_base+.yaml --checkpoint checkpoints/sam2_hiera_base_plus.pt

七、关键要点

  1. Wrapper 封装:每个导出模块需独立封装,forward 只保留必要输入输出,便于 ONNX 追踪。
  2. onnxsim 使用:image_encoder 与 memory_encoder 不 simplify,避免输出被简化;image_decoder 与 memory_attention 使用 simplify。
  3. dynamic_axes:memory_attention 支持可变 num_obj_ptr、buff_size;image_decoder 支持可变 point 数量。
  4. dummy 输入 :使用 torch.randn 构造 dummy 输入,保证 shape 正确,便于 ONNX 导出时正确推断。

八、LoRA 导出路径:model_loader_lora.py 的作用与原理

当使用 --lora 参数时,export_onnx.py 会走 FS-SAM2(LoRA 微调) 的导出路径:

bash 复制代码
python ONNXExport/export_onnx.py \
  --lora logs/coco/0000/fold0.log/best_model.pt \
  --base-ckpt ./sam2.1_hiera_base_plus.pt \
  --outdir ./onnx_export/fs_sam2_output

底层逻辑由 model_loader_lora.py 提供:

  1. 加载 base SAM2.1
python 复制代码
model = build_sam2(config_file, base_checkpoint, device=device)
  1. 在关键子模块上挂 LoRA
    • image_encoder:对 self-attention 的 qkvproj 加 LoRA
    • memory_attention:对 cross-attention 的 q_proj/k_proj/v_proj/out_proj 加 LoRA
    • memory_encoder:对 out_proj 加 LoRA
      → 这样 LoRA 只作用于注意力层,既能调优性能,又不破坏总体结构。
  2. 加载 LoRA checkpoint
    • 将训练好的 LoRA 权重(如 best_model.pt)按 key 重映射到上述 LoRA 模块上。
    • 处理 state_dict 中的 model. 前缀,忽略与当前结构无关的多余 key。
  3. merge_and_unload(推荐)
    • 默认 merge_lora=True,会执行:
python 复制代码
model.image_encoder = model.image_encoder.merge_and_unload()
model.memory_attention = model.memory_attention.merge_and_unload()
model.memory_encoder = model.memory_encoder.merge_and_unload()
复制代码
- 数学上即 (W' = W + \Delta W),将 LoRA 增量权重"烤进" base 参数里,得到一个**纯静态的 FS-SAM2 模型**。

这么做的原因:

  • ONNX / TensorRT 对动态 LoRA 分支(A/B 矩阵、Adapter 等)不友好,kernel 融合困难,部署复杂。
  • merge 后的模型在结构上与原生 SAM2 一致,只是权重不同,最适合做 ONNX 导出和 TRT 量化/编译。

总结model_loader_lora.py 负责把"base SAM2.1 + LoRA checkpoint"变成一个 已合并 LoRA 的 FS-SAM2 模型 ,然后再交给 ImageEncoder / MemAttention / MemEncoder / ImageDecoder 这 4 个 Wrapper 导出为 ONNX。


九、RoPE 补丁:rope_onnx_patch.py 的原因与原理

9.1 为什么需要 RoPE 补丁?

SAM2 的 MemAttention 内部使用了 RoPE(Rotary Positional Embedding),原实现里 RoPE 使用 ComplexFloat(复数) 存储频率表 freqs_cis

但是:

  • ONNX 标准对 complex 类型支持非常有限(很多后端完全不支持);
  • TensorRT 目前不支持 ComplexFloat,含 complex 的 ONNX 图在 TRT 中无法正常解析。

因此,在导出 memory_attention.onnx 前,必须把 complex RoPE 表达式改写为纯实数形式(cos/sin buffer)

9.2 patch_rope_for_onnx

export_onnx.py 中的关键几行:

python 复制代码
from rope_onnx_patch import patch_rope_for_onnx

patch_rope_for_onnx(model.memory_attention)

rope_onnx_patch.py 的主要步骤:

  1. 预计算 cos/sin buffer
    • 对每个 RoPEAttention 实例,读取原有 freqs_cis(complex),分离出实部/虚部:
python 复制代码
cos = fc.real.clone().detach().float()
sin = fc.imag.clone().detach().float()
mod.register_buffer("cos_cis", cos)
mod.register_buffer("sin_cis", sin)
复制代码
- 删除 complex 版本的 `freqs_cis`,避免进入 ONNX 图。
  1. 定义纯实数版 RoPE 应用函数
    • _rope_single_real:把最后一维拆成 2 通道(x0, x1),按 cos/sin 做二维旋转,还原到原 shape;
    • _apply_rotary_enc_real:对 q/k 应用上一步逻辑,并处理 rope_k_repeat、dynamic seq 等情况。
  2. 替换 RoPEAttention.forward 为 ONNX 友好版
    • _rope_attn_forward_onnx 内部:
      1. 做 q/k/v projection 和多头拆分;
      2. _apply_rotary_enc_real 对 q/k 应用 cos/sin 版 RoPE;
      3. 使用 torch.nn.functional.scaled_dot_product_attention 做注意力;
      4. 还原 heads、投影输出。
    • 最后通过 partial 把这个 forward 绑定到每个 RoPEAttention 上。
  3. 提供动态 shape 下的 cos/sin 计算
    • 当序列长度变化时,会用 compute_axial_cis_real 动态计算 cos/sin,以保持与原实现数学等价。

结果:memory_attention.onnx 中不再出现 complex 类型,只包含 FP16/FP32 的 cos/sin buffer 与常规 attention 算子,完全兼容 ONNX Runtime 和 TensorRT


FS-SAM2 性能优化:从 ONNX 导出到 TensorRT 部署

1. 概述

本文档记录了 FS-SAM2(Few-Shot SAM2)模型从 PyTorch 训练产物到 TensorRT 高性能部署的完整优化过程,包括 ONNX 导出、INT8 量化、TRT 编译及推理,以及过程中遇到的关键问题与解决方案。

测试环境:NVIDIA H100 GPU,视频 kidney_1_001347_001405.mp4(540 帧,1920x1080)

最终量化方案
子模型 精度 TRT 编译方式 说明
image_encoder INT8 --stronglyTyped 质量与 FP32 一致
memory_encoder INT8 --stronglyTyped 质量与 FP32 一致
memory_attention FP16 --fp16 (从 FP32 ONNX) INT8 导致 mask 空洞 + 速度退化
image_decoder FP16 --fp16 (从 FP32 ONNX) INT8 导致 mask 完全丢失

2. 模型架构与数据流

FS-SAM2 由 4 个子模型串联构成 few-shot 视频分割 pipeline:

plain 复制代码
image_encoder ──┬──> memory_encoder ──> memory_attention ──> image_decoder ──> pred_mask
                │         ↑ (support mask)         ↑ (memory)        ↑ (prompt)
                └─────────┴────────────────────────┘
子模型 功能 FP32 大小 输入 输出
image_encoder 图像特征提取 (Hiera backbone) 265 MB image [1,3,1024,1024] pix_feat, high_res_feats, vision_feats, vision_pos_embed
memory_encoder 编码 support memory 5.4 MB mask_for_mem [1,1,1024,1024], pix_feat [1,256,64,64] maskmem_features, maskmem_pos_enc, temporal_code
memory_attention 融合当前帧特征与 memory 31 MB current_vision_feat, pos_embed, memory_0/1, memory_pos_embed image_embed [1,256,64,64]
image_decoder 生成分割 mask 21 MB point_coords, point_labels, image_embed, high_res_feats obj_ptr, mask_for_mem, pred_mask

3. Step 1 --- ONNX 导出

3.1 导出流程

执行导出,得到 4 个 ONNX:

bash 复制代码
python export_onnx.py \
  --outdir /path/to/output \
  --config sam2_configs/sam2_hiera_base+.yaml \
  --checkpoint checkpoints/sam2_hiera_base_plus.pt

导出顺序:image_encoder → image_decoder → memory_attention → memory_encoder。

导出完成后,将生成的 4 个 .onnx 放到本项目的 onnx_export/fs_sam2_output/(或任意目录),供后续校准、量化与推理使用。

3.2 导出产物(4 个 ONNX)

在本项目中,这 4 个文件通常放在:

plain 复制代码
onnx_export/fs_sam2_output/
├── image_encoder.onnx     (265 MB)
├── memory_encoder.onnx    (5.4 MB)
├── memory_attention.onnx  (31 MB)
└── image_decoder.onnx     (21 MB)
3.3 导出中的关键决策
  • 4 模块划分:image_encoder、image_decoder、memory_attention、memory_encoder
  • memory_attention 导出为 ONNX:便于与其余 3 个 ONNX 一起做 TRT 部署,无需单独加载 pth。

4. Step 2 --- INT8 量化

4.1 工具链
  • nvidia-modelopt 0.37.0:NVIDIA 官方 ONNX PTQ 工具包
  • ONNX Runtime 1.24.2:校准数据生成时的推理后端
  • 校准方法:entropy(128-bin 直方图 KL 散度)
4.2 校准数据生成(Step 1)

执行的脚本onnx_quant/prepare_calibration_data.py
作用 :跑 FP32 的 4 子模型 pipeline,把每个子模型真实看到的输入缓存成 4 份 .npz,供后续量化使用。
建议 :在 modelopt 环境下运行;非交互式 shell 需先 export PS1=xsource ~/.bashrc 以加载 CUDA 等环境;--datapath/--onnx_dir 建议用绝对路径,避免脚本内相对路径解析到错误目录。

完整命令示例 (工作目录为项目根目录时,先 cd onnx_quant):

bash 复制代码
python prepare_calibration_data.py \
  --datapath /path/to/aie-kidney-all-blood/dataset_blood \
  --onnx_dir /path/to/aie-kidney-all-blood/onnx_export/fs_sam2_output \
  --output_dir build \
  --calib_size 128 \
  --device cuda

核心原则:1 套源校准集 + 4 份派生校准输入缓存

不是对每个子模型用随机数据校准,而是用 FP32 pipeline 跑完整推理流程,把每个子模型实际看到的输入 缓存为 .npz 文件:

plain 复制代码
build/
├── calib_image_encoder.npz     (304 MB) --- 128 张 [3,1024,1024] 图像
├── calib_memory_encoder.npz    (474 MB) --- 128 对 (mask, pix_feat)
├── calib_memory_attention.npz  (788 MB) --- 128 组 5 路 memory 输入
└── calib_image_decoder.npz     (1.9 GB) --- 128 组 decoder 输入
4.3 量化执行(Step 2)

执行的脚本onnx_quant/quantize_onnx_int8.py
作用 :用 modelopt 对 4 个 FP32 ONNX 做 INT8 PTQ,读取上一步生成的 build/calib_*.npz,输出到 build/quantized_int8/。image_encoder 最慢(约 36 分钟),4 个模型合计约 40 分钟。

完整命令示例 (工作目录为 onnx_quant):

bash 复制代码
python quantize_onnx_int8.py \
  --onnx_dir /path/to/aie-kidney-all-blood/onnx_export/fs_sam2_output \
  --calib_dir build \
  --output_dir build/quantized_int8

量化结果:

子模型 总节点数 INT8 量化节点 量化比例 dq_only
image_encoder 5603 266 4.7% False (QDQ)
memory_encoder 300 16 5.3% True (DQ-only)
memory_attention 908 62 6.8% True (DQ-only)
image_decoder 898 88 9.8% True (DQ-only)
4.4 量化过程中遇到的问题与解决方案
问题 1:image_encoder Resize 节点不兼容 DQ-only 模式
plain 复制代码
[ERROR] Unsupported op_type for real weight quantization: Resize

原因 :modelopt 的 qdq_to_dq 转换步骤在处理 image_encoder 的 pos_embed Resize 节点时失败,无法将 QDQ 格式转换为 DQ-only 格式。

解决方案 :对 image_encoder 保持 QDQ 格式 (dq_only=False),其余模型正常使用 DQ-only:

python 复制代码
use_dq_only = False if model_name == "image_encoder" else True
问题 2:memory_encoder pix_feat 维度不匹配
plain 复制代码
Invalid rank for input: pix_feat Got: 5 Expected: 4

原因 :FP32 pipeline 中 pix_feat 在某些情况下为 5D 张量 [1,1,256,64,64],但 ONNX 模型期望 4D [1,256,64,64]

解决方案 :在 prepare_calibration_data.py 中添加 _ensure_pix_feat_4d() 辅助函数:

python 复制代码
def _ensure_pix_feat_4d(arr: np.ndarray) -> np.ndarray:
    if arr.ndim == 4:
        return arr
    # squeeze 到 4D
    while arr.ndim > 4 and arr.shape[0] == 1:
        arr = arr.squeeze(0)
    return arr
问题 3:modelopt symbolic shape inference 失败

原因:动态 shape 模型(memory_attention、image_decoder)的 symbolic shape inference 无法推断所有中间张量形状,导致 GEMV 检测失败。

解决方案 :Monkey-patch modelopt 的 find_nodes_from_matmul_to_exclude 函数,在 symbolic inference 失败时自动回退到 inference-based GEMV 检测:

python 复制代码
def _patched_find(...):
    try:
        nodes_to_exclude = _graph_utils._exclude_matmuls_by_symbolic_inference(...)
    except (AssertionError, Exception) as e:
        logger.warning("Symbolic shape inference failed, falling back to inference-based.")
        nodes_to_exclude = _graph_utils._exclude_matmuls_by_inference(...)
    return nodes_to_exclude
问题 4:TensorRT stronglyTyped output dtype 不一致

原因 :量化后 graph output 声明的 dtype 可能与 inferred dtype 不一致,导致 trtexec --stronglyTyped 报错。

解决方案 :添加 sanitize_onnx_for_trt() 后处理,强制所有 graph output 声明为 FP32:

python 复制代码
def sanitize_onnx_for_trt(onnx_path, force_output_fp32=True):
    # 对每个 graph.output,将 elem_type 统一设置为 FLOAT
问题 5:onnx.helper.float32_to_bfloat16 缺失

原因:新版 onnx 库移除了该函数,但 modelopt 内部仍在引用。

解决方案:启动时注入兼容实现:

python 复制代码
if not hasattr(onnx.helper, "float32_to_bfloat16"):
    def float32_to_bfloat16(fval):
        ival = struct.unpack("=I", struct.pack("=f", float(fval)))[0]
        return ival >> 16
    onnx.helper.float32_to_bfloat16 = float32_to_bfloat16

5. Step 3 --- 量化精度诊断

5.1 问题发现

全部 4 个子模型 INT8 量化后推理,输出视频完全没有出血 mask

5.2 逐模型排查方法

构建 5 个诊断目录,每个目录中只有 1 个子模型使用 INT8 量化版本,其余 3 个使用 FP32 原始版本(通过符号链接重命名实现):

bash 复制代码
diag_int8_tests/
├── test_all_fp32/         # 全 FP32 baseline
├── test_image_encoder/    # 仅 image_encoder 为 INT8
├── test_memory_encoder/   # 仅 memory_encoder 为 INT8
├── test_memory_attention/ # 仅 memory_attention 为 INT8
└── test_image_decoder/    # 仅 image_decoder 为 INT8

分别运行推理并对比输出视频中的 mask 质量。

5.3 诊断结果
测试配置 mask 质量 速度 (FPS) 结论
all_fp32 (baseline) 正常 19.3 基准线
image_encoder=INT8 与 FP32 一致 15.3 可量化
memory_encoder=INT8 与 FP32 一致 20.4 可量化
memory_attention=INT8 mask 有空洞 4.3 (慢 18x) 不可量化
image_decoder=INT8 无 mask 20.9 不可量化
5.4 分析
  • image_decoder INT8 致命:decoder 是最终输出 mask 的模型,INT8 精度损失直接导致 sigmoid 后 mask 概率全部低于阈值,mask 完全消失。这类模型对精度极敏感。
  • memory_attention INT8 有害:不仅质量下降(空洞),ORT 对 DQ-only INT8 节点的执行效率极差,mem_attn 推理时间从 7ms 飙升至 144ms。可能是因为 ORT 没有为该子图的 INT8 模式做 kernel 融合。
  • image_encoder / memory_encoder 安全:这两个模型以特征提取为主,对 INT8 量化容忍度高。
5.5 确定最终方案

只量化 image_encoder + memory_encoder,memory_attention + image_decoder 保持 FP32/FP16。

6. Step 4 --- TensorRT 引擎编译

6.1 混合编译脚本 trt.sh
bash 复制代码
## INT8 量化模型(QDQ 图):
## 推荐用 --fp16 --int8(让 TRT 按 layer 自动选择最优精度),避免部分 GPU 上
## 因 kernel 覆盖不足导致插入大量 reformat 节点而变慢。
trtexec --onnx=quantized_int8/image_encoder.int8.quant.onnx \
        --saveEngine=trt_engines/image_encoder.int8.quant.engine \
        --fp16 --int8

trtexec --onnx=quantized_int8/memory_encoder.int8.quant.onnx \
        --saveEngine=trt_engines/memory_encoder.int8.quant.engine \
        --fp16 --int8

## FP32 原始模型 -> FP16(TRT 自动精度优化)
trtexec --onnx=fs_sam2_output/memory_attention.onnx \
        --saveEngine=trt_engines/memory_attention.int8.quant.engine \
        --fp16 \
        --minShapes=... --optShapes=... --maxShapes=...

trtexec --onnx=fs_sam2_output/image_decoder.onnx \
        --saveEngine=trt_engines/image_decoder.int8.quant.engine \
        --fp16 \
        --minShapes=... --optShapes=... --maxShapes=...
6.1.1 常见坑:A100 上编译+推理只有 ~10 FPS(而 H100 有 ~23 FPS)

现象:同一套 ONNX/QDQ 量化图,在 H100 上 TRT 推理可达 ~23 FPS,但在 A100 上只有 ~10 FPS。

根因(高概率):**对 QDQ 量化图使用 **--stronglyTyped 会强制 TRT 严格按 Q/DQ 标记执行 INT8;当某些 layer 在 A100 上缺少合适的 INT8 tactic / kernel 时,TRT 会插入大量 reformat / type-cast(INT8↔FP16/FP32)来满足强类型约束,导致整体吞吐显著下降。H100(更新架构)对 INT8 kernel 覆盖更完整,因此更不容易触发该退化。

修复:对 QDQ INT8 模型优先使用 --fp16 --int8(而不是 --stronglyTyped),让 TRT 允许对没有高效 INT8 kernel 的 layer 自动回退到 FP16,从而避免 reformat 风暴。

验证建议:重新在 A100 上运行 trt.sh 生成 engine,并用 infer_video_trt.py 复测 FPS;如果 FPS 显著回升(通常接近 FP32/FP16 TRT 的水平),说明退化来自强类型 INT8 的 tactic 覆盖问题。

6.2 动态 shape 配置

memory_attention 和 image_decoder 有动态维度,需指定 min/opt/max shapes:

memory_attention(memory 数量可变):

  • memory_0: [1~16, 256]
  • memory_1: [1~16, 64, 64, 64]
  • memory_pos_embed: [1~65600, 1, 64]

image_decoder(prompt 点数可变):

  • point_coords: [1, 2~10, 2]
  • point_labels: [1, 2~10]
6.3 编译产物
plain 复制代码
trt_engines/
├── image_encoder.int8.quant.engine     (73 MB)  --- INT8
├── memory_encoder.int8.quant.engine    (3.7 MB) --- INT8
├── memory_attention.int8.quant.engine  (29 MB)  --- FP16
└── image_decoder.int8.quant.engine     (16 MB)  --- FP16

总计 ~122 MB,相比 FP32 ONNX 的 322 MB 减少了 62%。

8. 文件结构总览

plain 复制代码
aie-kidney-all-blood/
├── onnx_export/                          # 本仓库:存放 ONNX 产物与 FS-SAM2 导出脚本
│   ├── fs_sam2_output/                   # 4 个 FP32 ONNX
│   │   ├── image_encoder.onnx   (265 MB)
│   │   ├── memory_encoder.onnx  (5.4 MB)
│   │   ├── memory_attention.onnx (31 MB)
│   │   └── image_decoder.onnx   (21 MB)
│   └── README.md                          # 说明 onnx_export 的 FS-SAM2/LoRA 导出方式(与当前 4 模块 pipeline 不同)
│
├── onnx_quant/                           # INT8 量化
│   ├── prepare_calibration_data.py       # 校准数据生成(FP32 pipeline 派生)
│   ├── quantize_onnx_int8.py            # 量化主脚本(modelopt)
│   ├── evaluate_onnx.py                  # 量化评估(精度+延迟对比)
│   ├── run_all.sh                        # 全流程编排
│   └── build/
│       ├── calib_*.npz                   # 校准数据缓存
│       ├── quantized_int8/               # 全量 INT8 模型(诊断用)
│       └── quantized_mixed/              # 混合方案(IE+ME=INT8, MA+ID=FP32 symlink)
│
├── trt.sh                                # TRT 编译脚本(混合 INT8/FP16)
├── trt_engines/                          # 编译产物
│   ├── image_encoder.int8.quant.engine   (73 MB, INT8)
│   ├── memory_encoder.int8.quant.engine  (3.7 MB, INT8)
│   ├── memory_attention.int8.quant.engine (29 MB, FP16)
│   └── image_decoder.int8.quant.engine   (16 MB, FP16)
│
├── infer_video_sam2export.py             # ONNX 推理脚本(FP32 / 量化混合)
└── infer_video_trt.py                    # TRT 推理脚本

9. 复现指南

9.1 环境
plain 复制代码
conda env: modelopt
Python 3.12
nvidia-modelopt 0.37.0
onnxruntime-gpu 1.24.2
TensorRT 10.14.1.48
CUDA 12.x + cuDNN 9
9.2 完整流程
bash 复制代码
## 0. 进入环境
source ~/.bashrc
conda activate modelopt

## 1. 导出 ONNX
##   python export_onnx.py --outdir <outdir> --config sam2_configs/sam2_hiera_base+.yaml --checkpoint <ckpt>
## 再将生成的 4 个 .onnx 拷贝到本项目 onnx_export/fs_sam2_output/

## 2. 生成校准数据(~5 分钟)
cd onnx_quant
python prepare_calibration_data.py \
  --datapath /path/to/dataset_blood \
  --onnx_dir /path/to/onnx_export/fs_sam2_output \
  --output_dir build --calib_size 128 --device cuda

## 3. INT8 量化(~43 分钟,image_encoder 最慢)
python quantize_onnx_int8.py \
  --onnx_dir /path/to/onnx_export/fs_sam2_output \
  --calib_dir build --output_dir build/quantized_int8

## 4. 编译 TRT 引擎(~7 分钟)
cd ..
bash trt.sh

## 5. TRT 推理
python infer_video_trt.py \
  --video kidney_1_001347_001405.mp4 \
  --engine_dir trt_engines \
  --output output_trt.mp4 \
  --support_dir support_images \
  --support_mask_dir support_masks

相关推荐
lifallen2 小时前
Flink Agent:ActionTask 与可续跑状态机 (Coroutine/Continuation)
java·大数据·人工智能·语言模型·flink
数据分析能量站2 小时前
Harnessing Claude 打造高效、低成本、可进化的 AI 应用
人工智能
枫叶林FYL2 小时前
【Python高级工程与架构实战】项目三:实时数据管道(Kafka + Polars + Delta Lake)(一)
人工智能·机器学习
q_35488851532 小时前
计算机毕业设计:Python居民出行规律可视化分析系统 Django框架 可视化 数据分析 PyEcharts 交通 深度学习(建议收藏)✅
人工智能·python·数据分析·车载系统·django·汽车·课程设计
ai生成式引擎优化技术2 小时前
TSPR-WEB-LLM-HIC (TWLH四元结构)AI 生成式引擎(GEO)
人工智能
云上码厂2 小时前
大模型数学库DeepSeek-Math-V2
人工智能
wxl7812272 小时前
驾驭工程 (Harness Engineering):AI Agent 时代的软件工程新范式
人工智能·软件工程
程序员Shawn2 小时前
【深度学习 | 第四篇】- 循环神经网络
人工智能·rnn·深度学习
33三 三like2 小时前
BERT-BiLSTM-CRF 养老需求实体抽取模型解析与实践:从口语文本到结构化知识
人工智能·深度学习·bert