模型架构与数据流
FS-SAM2 由 4 个子模型串联构成 few-shot 视频分割 pipeline:
plain
image_encoder ──┬──> memory_encoder ──> memory_attention ──> image_decoder ──> pred_mask
│ ↑ (support mask) ↑ (memory) ↑ (prompt)
└─────────┴────────────────────────┘
| 子模型 | 功能 | FP32 大小 | 输入 | 输出 |
|---|---|---|---|---|
| image_encoder | 图像特征提取 (Hiera backbone) | 265 MB | image [1,3,1024,1024] | pix_feat, high_res_feats, vision_feats, vision_pos_embed |
| memory_encoder | 编码 support memory | 5.4 MB | mask_for_mem [1,1,1024,1024], pix_feat [1,256,64,64] | maskmem_features, maskmem_pos_enc, temporal_code |
| memory_attention | 融合当前帧特征与 memory | 31 MB | current_vision_feat, pos_embed, memory_0/1, memory_pos_embed | image_embed [1,256,64,64] |
| image_decoder | 生成分割 mask | 21 MB | point_coords, point_labels, image_embed, high_res_feats | obj_ptr, mask_for_mem, pred_mask |
FS-SAM2 训练策略(LoRA 微调版)
背景与目标
- 面向少样本 / 单类别分割,最小化训练开销同时提升泛化。
- 仅训练少量 LoRA 参数,保持 SAM2 主干稳定,避免小数据过拟合。
LoRA 设计与可训练参数
- 作用模块
model.image_encoder:qkv,projmodel.memory_attention:q_proj,v_proj,k_proj,out_projmodel.memory_encoder:out_proj
- 配置(见
train.py)- image encoder:
r=4, alpha=16, dropout=0.1, bias=none - memory attention / encoder:
r=32, alpha=16, dropout=0.1, bias=none
- image encoder:
- 冻结策略
- 除上述 LoRA 注入的权重外,其余全部
requires_grad=False。
- 除上述 LoRA 注入的权重外,其余全部
- 规模
- 可训练参数约 75 万,其余主干保持冻结。
数据与 Episode 采样
- 数据格式
- COCO 风格,
train2014/val2014,mask 为 PNG。 - 单类时:mask 非零视为前景(兼容 0/255)。
- COCO 风格,
- 类与 fold
- 保留 4-fold 结构以兼容 COCO-20i;单类数据各 fold 列表相同。
- 单类时自动
nclass=1,类 id 固定为 0。
- Episode 流程
- 采样 1 个类 → 1 张 query +
k张 support(默认kshot=1)。 - 数据增强:训练使用随机旋转、高斯模糊、翻转等(见
data/dataset.py)。
- 采样 1 个类 → 1 张 query +
- 尺寸
- 输入原始 1280x720,训练统一 resize 到
img_size=1024(可调)。
- 输入原始 1280x720,训练统一 resize 到
前向与损失
- 支持集前向:逐张 support (图+mask) 送入,累积 memory features。
- 查询前向:复用 memory,预测 query mask。
- 损失函数:
BCEWithLogits + Dice(对 logits 与 GT mask)。 - 自动混合精度:
torch.autocast(device_type='cuda', dtype=torch.bfloat16)。
训练配置(默认)
- 优化器:
AdamW- 两组参数:多维权重带 weight decay,1 维/bias 不 decay。
lr=1e-4,weight_decay=1e-6,betas=(0.9, 0.999).
- 学习率调度:
CosineAnnealingLR(T_max = epochs * steps_per_epoch). - 批大小:
bsz=4(单卡,可按显存调整)。 - 轮数:
epochs=50。 - 精度:
bfloat16混合精度。 - 设备:单 GPU。若使用 DDP,自行设置
torch.distributed.run,并保证 fold 与 sampler 对齐。
日志与模型保存
- 记录指标:Avg Loss、mIoU(前景)、FB-IoU(前景/背景 IoU 均值)、mF1。
- 按验证集 mIoU 最高保存
best_model.pt。 - TensorBoard:loss / mIoU / FB-IoU 曲线。
单类数据集的特别处理
- mask 二值化:非零即前景(0/255 → 0/1)。
- 类集合:检测到仅 1 类时,
class_ids=[0],避免 80 类抽样导致空类或 KeyError。 - 数据列表:
data_list_*/sub_class_file_list_*需与文件名一致,可用generate_data_list.py自动生成。
运行示例
bash
python train.py \
--datapath ./dataset_blood \
--benchmark coco \
--fold 0 \
--bsz 4 \
--epochs 50
单类场景下 fold 取 0/1/2/3 结果一致。
关键收益
- 训练开销小:仅 ~75 万可训练参数。
- 稳定性高:主干冻结,降低过拟合风险。
- Few-shot 泛化:episode + memory 机制,mIoU/FB-IoU 直接反映跨图像的类泛化能力。
onnx拆分导出方法
将 SAM2 导出为 ONNX 的方法,供本项目参考。
一、整体架构
将 SAM2 视频分割模型拆分为 4 个 ONNX 子模块:
| 序号 | 模块名 | ONNX 文件 | 功能 |
|---|---|---|---|
| 1 | ImageEncoder | image_encoder.onnx |
图像编码,输出 backbone 特征与位置编码 |
| 2 | ImageDecoder | image_decoder.onnx |
点 prompt + mask decoder 一体化,输出 mask、obj_ptr 等 |
| 3 | MemAttention | memory_attention.onnx |
记忆注意力,融合当前帧与历史记忆 |
| 4 | MemEncoder | memory_encoder.onnx |
记忆编码,将 mask 编码为 memory 特征 |
1.1 为什么要拆成这 4 个 ONNX?
从 SAM2 的原始架构来看,视频分割主干可以自然地拆成 4 个阶段:
- ImageEncoder(图像编码)
- 对每一帧做 backbone + FPN 编码,得到多尺度特征(pix_feat、high_res_feats_0/1、vision_feats、vision_pos_embed)。
- 计算量大、结构规则,最适合作为 单独的 TRT INT8/FP16 引擎。
- MemEncoder(记忆编码)
- 把 support/历史帧上的 mask + pix_feat 编码为 memory_0/memory_1(maskmem_features 等)。
- 输入输出固定,接口简单,适合作为 独立的小引擎。
- MemAttention(记忆注意力)
- 将当前帧特征 + 历史 memory 做 cross-attention,输出 image_embed。
- 本质是一个带 RoPE 的 cross-attention block,算子相对集中,但有 RoPE 复杂度(需额外补丁,见后文)。
- ImageDecoder(解码 + prompt)
- 内部打包了 prompt encoder(点/框)+ mask decoder,一次性输出 mask_for_mem + pred_mask + obj_ptr。
- 这样推理 pipeline 只需传入点坐标/标签和 encoder/memory 的输出,而不用再单独处理 prompt_encoder。
拆成这 4 个子图的好处:
- 完全覆盖 SAM2 推理主路径 :
image_encoder → mem_encoder → mem_attention → image_decoder就是视频分割 pipeline 的骨架。 - 易于部署与量化:每个子模型的输入输出都清晰、规模适中,便于做 INT8/FP16 量化和独立 TRT 编译。
- 避免"超大单图"问题:若把所有模块导成一个巨型 ONNX,不利于调试、量化、TRT tactic 搜索,也不方便后续做分模块优化。
二、核心思路
2.1 模块封装(src/Module.py)
通过 Wrapper 类 将 SAM2 原始子模块封装为可单独导出的 nn.Module,每个 Wrapper 只暴露必要的输入输出,便于 ONNX 导出。
2.2 模型加载
python
from sam2.build_sam import build_sam2
sam2_model = build_sam2(args.config, args.checkpoint, device="cpu")
- 使用 Hydra 配置 + 实例化
- 支持 tiny / small / large / base+ 四种配置
2.3 导出顺序
plain
1. export_image_encoder(image_encoder, outdir)
2. export_image_decoder(image_decoder, outdir)
3. export_memory_attention(mem_attention, outdir)
4. export_memory_encoder(mem_encoder, outdir)
三、各模块导出详情
3.1 ImageEncoder
输入:
image:[1, 3, 1024, 1024]float
输出:
pix_feat:[1, 256, 64, 64]像素特征high_res_feat0:[1, 32, 256, 256]高分辨率特征 0high_res_feat1:[1, 64, 128, 128]高分辨率特征 1vision_feats: 视觉特征(与 pix_feat 相关)vision_pos_embed:[4096, 1, 256]位置编码
实现要点:
- ImageEncoder 内部:
image_encoder→_prepare_backbone_features→ conv_s0/conv_s1 处理 - 使用
do_constant_folding=True,opset_version=17 - 不使用 onnxsim(注释说明 simplify 会简化掉部分输出)
3.2 ImageDecoder
输入:
point_coords:[num_labels, num_points, 2]点坐标point_labels:[num_labels, num_points]点标签image_embed:[1, 256, 64, 64]来自 memory_attention 的输出high_res_feats_0:[1, 32, 256, 256]high_res_feats_1:[1, 64, 128, 128]
输出:
obj_ptr:[1, 256]对象指针 tokenmask_for_mem:[1, 1, 1024, 1024]供 memory encoder 使用的 maskpred_mask:[1, 1, 1024, 1024]最终预测 mask
实现要点:
- 内部调用
_forward_sam_heads,包含 prompt encoder + mask decoder - 对
high_res_masks做 sigmoid + scale + bias,再fill_holes_in_mask_scores - 使用
dynamic_axes支持 point_coords / point_labels 的 batch 维度 - 使用 onnxsim 简化
3.3 MemAttention
输入:
current_vision_feat:[1, 256, 64, 64]当前帧视觉特征current_vision_pos_embed:[4096, 1, 256]当前帧位置编码memory_0:[num_obj_ptr, 256]对象指针 tokenmemory_1:[buff_size, 64, 64, 64]历史帧 mask 特征memory_pos_embed:[dynamic, 1, 64]记忆位置编码
输出:
image_embed:[1, 256, 64, 64]融合记忆后的图像 embedding
实现要点:
- 内部 reshape:
memory_0→[4*num_obj_ptr, 1, 64],memory_1→[buff_size*4096, 1, 64] - memory 与 curr 做 cross-attention,输出 reshape 为
[1, 256, 64, 64] dynamic_axes:memory_0(num),memory_1(buff_size),memory_pos_embed(dynamic)- 使用 onnxsim 简化
3.4 MemEncoder
输入:
mask_for_mem:[1, 1, 1024, 1024]高分辨率 maskpix_feat:[1, 256, 64, 64]像素特征(来自 image_encoder)
输出:
maskmem_features: 记忆特征maskmem_pos_enc: 记忆位置编码temporal_code: 时间编码
实现要点:
- 内部调用
_encode_new_memory,传入pred_masks_high_res、current_vision_feats等 - 使用
opset_version=17,do_constant_folding=True - 不启用 onnxsim
五、目录结构
plain
ONNXExport/
├── export_onnx.py # 主导出脚本
├── src/
│ └── Module.py # ImageEncoder, MemAttention, MemEncoder, ImageDecoder
├── sam2/
│ ├── build_sam.py # 模型构建
│ ├── sam2_video_predictor.py
│ ├── sam2_image_predictor.py
│ └── modeling/
│ ├── sam2_base.py
│ ├── memory_attention.py
│ ├── memory_encoder.py
│ └── ...
├── sam2_configs/
│ ├── sam2_hiera_tiny.yaml
│ ├── sam2_hiera_small.yaml
│ ├── sam2_hiera_large.yaml
│ └── sam2_hiera_base+.yaml
└── checkpoints/
└── download_ckpts.sh
六、使用方式
bash
## 1. 下载 checkpoint
cd checkpoints && ./download_ckpts.sh
mkdir base+ large small tiny
## 2. 导出
python export_onnx.py --outdir checkpoints/base+/ --config sam2_configs/sam2_hiera_base+.yaml --checkpoint checkpoints/sam2_hiera_base_plus.pt
七、关键要点
- Wrapper 封装:每个导出模块需独立封装,forward 只保留必要输入输出,便于 ONNX 追踪。
- onnxsim 使用:image_encoder 与 memory_encoder 不 simplify,避免输出被简化;image_decoder 与 memory_attention 使用 simplify。
- dynamic_axes:memory_attention 支持可变 num_obj_ptr、buff_size;image_decoder 支持可变 point 数量。
- dummy 输入 :使用
torch.randn构造 dummy 输入,保证 shape 正确,便于 ONNX 导出时正确推断。
八、LoRA 导出路径:model_loader_lora.py 的作用与原理
当使用 --lora 参数时,export_onnx.py 会走 FS-SAM2(LoRA 微调) 的导出路径:
bash
python ONNXExport/export_onnx.py \
--lora logs/coco/0000/fold0.log/best_model.pt \
--base-ckpt ./sam2.1_hiera_base_plus.pt \
--outdir ./onnx_export/fs_sam2_output
底层逻辑由 model_loader_lora.py 提供:
- 加载 base SAM2.1
python
model = build_sam2(config_file, base_checkpoint, device=device)
- 在关键子模块上挂 LoRA
image_encoder:对 self-attention 的qkv、proj加 LoRAmemory_attention:对 cross-attention 的q_proj/k_proj/v_proj/out_proj加 LoRAmemory_encoder:对out_proj加 LoRA
→ 这样 LoRA 只作用于注意力层,既能调优性能,又不破坏总体结构。
- 加载 LoRA checkpoint
- 将训练好的 LoRA 权重(如
best_model.pt)按 key 重映射到上述 LoRA 模块上。 - 处理
state_dict中的model.前缀,忽略与当前结构无关的多余 key。
- 将训练好的 LoRA 权重(如
- merge_and_unload(推荐)
- 默认
merge_lora=True,会执行:
- 默认
python
model.image_encoder = model.image_encoder.merge_and_unload()
model.memory_attention = model.memory_attention.merge_and_unload()
model.memory_encoder = model.memory_encoder.merge_and_unload()
- 数学上即 (W' = W + \Delta W),将 LoRA 增量权重"烤进" base 参数里,得到一个**纯静态的 FS-SAM2 模型**。
这么做的原因:
- ONNX / TensorRT 对动态 LoRA 分支(A/B 矩阵、Adapter 等)不友好,kernel 融合困难,部署复杂。
- merge 后的模型在结构上与原生 SAM2 一致,只是权重不同,最适合做 ONNX 导出和 TRT 量化/编译。
总结 :model_loader_lora.py 负责把"base SAM2.1 + LoRA checkpoint"变成一个 已合并 LoRA 的 FS-SAM2 模型 ,然后再交给 ImageEncoder / MemAttention / MemEncoder / ImageDecoder 这 4 个 Wrapper 导出为 ONNX。
九、RoPE 补丁:rope_onnx_patch.py 的原因与原理
9.1 为什么需要 RoPE 补丁?
SAM2 的 MemAttention 内部使用了 RoPE(Rotary Positional Embedding),原实现里 RoPE 使用 ComplexFloat(复数) 存储频率表 freqs_cis。
但是:
- ONNX 标准对 complex 类型支持非常有限(很多后端完全不支持);
- TensorRT 目前不支持 ComplexFloat,含 complex 的 ONNX 图在 TRT 中无法正常解析。
因此,在导出 memory_attention.onnx 前,必须把 complex RoPE 表达式改写为纯实数形式(cos/sin buffer)。
9.2 patch_rope_for_onnx
export_onnx.py 中的关键几行:
python
from rope_onnx_patch import patch_rope_for_onnx
patch_rope_for_onnx(model.memory_attention)
rope_onnx_patch.py 的主要步骤:
- 预计算 cos/sin buffer
- 对每个
RoPEAttention实例,读取原有freqs_cis(complex),分离出实部/虚部:
- 对每个
python
cos = fc.real.clone().detach().float()
sin = fc.imag.clone().detach().float()
mod.register_buffer("cos_cis", cos)
mod.register_buffer("sin_cis", sin)
- 删除 complex 版本的 `freqs_cis`,避免进入 ONNX 图。
- 定义纯实数版 RoPE 应用函数
_rope_single_real:把最后一维拆成 2 通道(x0, x1),按 cos/sin 做二维旋转,还原到原 shape;_apply_rotary_enc_real:对 q/k 应用上一步逻辑,并处理rope_k_repeat、dynamic seq 等情况。
- 替换
RoPEAttention.forward为 ONNX 友好版_rope_attn_forward_onnx内部:- 做 q/k/v projection 和多头拆分;
- 用
_apply_rotary_enc_real对 q/k 应用 cos/sin 版 RoPE; - 使用
torch.nn.functional.scaled_dot_product_attention做注意力; - 还原 heads、投影输出。
- 最后通过
partial把这个 forward 绑定到每个RoPEAttention上。
- 提供动态 shape 下的 cos/sin 计算
- 当序列长度变化时,会用
compute_axial_cis_real动态计算 cos/sin,以保持与原实现数学等价。
- 当序列长度变化时,会用
结果:memory_attention.onnx 中不再出现 complex 类型,只包含 FP16/FP32 的 cos/sin buffer 与常规 attention 算子,完全兼容 ONNX Runtime 和 TensorRT。
FS-SAM2 性能优化:从 ONNX 导出到 TensorRT 部署
1. 概述
本文档记录了 FS-SAM2(Few-Shot SAM2)模型从 PyTorch 训练产物到 TensorRT 高性能部署的完整优化过程,包括 ONNX 导出、INT8 量化、TRT 编译及推理,以及过程中遇到的关键问题与解决方案。
测试环境:NVIDIA H100 GPU,视频
kidney_1_001347_001405.mp4(540 帧,1920x1080)
最终量化方案
| 子模型 | 精度 | TRT 编译方式 | 说明 |
|---|---|---|---|
| image_encoder | INT8 | --stronglyTyped |
质量与 FP32 一致 |
| memory_encoder | INT8 | --stronglyTyped |
质量与 FP32 一致 |
| memory_attention | FP16 | --fp16 (从 FP32 ONNX) |
INT8 导致 mask 空洞 + 速度退化 |
| image_decoder | FP16 | --fp16 (从 FP32 ONNX) |
INT8 导致 mask 完全丢失 |
2. 模型架构与数据流
FS-SAM2 由 4 个子模型串联构成 few-shot 视频分割 pipeline:
plain
image_encoder ──┬──> memory_encoder ──> memory_attention ──> image_decoder ──> pred_mask
│ ↑ (support mask) ↑ (memory) ↑ (prompt)
└─────────┴────────────────────────┘
| 子模型 | 功能 | FP32 大小 | 输入 | 输出 |
|---|---|---|---|---|
| image_encoder | 图像特征提取 (Hiera backbone) | 265 MB | image [1,3,1024,1024] | pix_feat, high_res_feats, vision_feats, vision_pos_embed |
| memory_encoder | 编码 support memory | 5.4 MB | mask_for_mem [1,1,1024,1024], pix_feat [1,256,64,64] | maskmem_features, maskmem_pos_enc, temporal_code |
| memory_attention | 融合当前帧特征与 memory | 31 MB | current_vision_feat, pos_embed, memory_0/1, memory_pos_embed | image_embed [1,256,64,64] |
| image_decoder | 生成分割 mask | 21 MB | point_coords, point_labels, image_embed, high_res_feats | obj_ptr, mask_for_mem, pred_mask |
3. Step 1 --- ONNX 导出
3.1 导出流程
执行导出,得到 4 个 ONNX:
bash
python export_onnx.py \
--outdir /path/to/output \
--config sam2_configs/sam2_hiera_base+.yaml \
--checkpoint checkpoints/sam2_hiera_base_plus.pt
导出顺序:image_encoder → image_decoder → memory_attention → memory_encoder。
导出完成后,将生成的 4 个 .onnx 放到本项目的 onnx_export/fs_sam2_output/(或任意目录),供后续校准、量化与推理使用。
3.2 导出产物(4 个 ONNX)
在本项目中,这 4 个文件通常放在:
plain
onnx_export/fs_sam2_output/
├── image_encoder.onnx (265 MB)
├── memory_encoder.onnx (5.4 MB)
├── memory_attention.onnx (31 MB)
└── image_decoder.onnx (21 MB)
3.3 导出中的关键决策
- 4 模块划分:image_encoder、image_decoder、memory_attention、memory_encoder
- memory_attention 导出为 ONNX:便于与其余 3 个 ONNX 一起做 TRT 部署,无需单独加载 pth。
4. Step 2 --- INT8 量化
4.1 工具链
- nvidia-modelopt 0.37.0:NVIDIA 官方 ONNX PTQ 工具包
- ONNX Runtime 1.24.2:校准数据生成时的推理后端
- 校准方法:entropy(128-bin 直方图 KL 散度)
4.2 校准数据生成(Step 1)
执行的脚本 :onnx_quant/prepare_calibration_data.py
作用 :跑 FP32 的 4 子模型 pipeline,把每个子模型真实看到的输入缓存成 4 份 .npz,供后续量化使用。
建议 :在 modelopt 环境下运行;非交互式 shell 需先 export PS1=x 再 source ~/.bashrc 以加载 CUDA 等环境;--datapath/--onnx_dir 建议用绝对路径,避免脚本内相对路径解析到错误目录。
完整命令示例 (工作目录为项目根目录时,先 cd onnx_quant):
bash
python prepare_calibration_data.py \
--datapath /path/to/aie-kidney-all-blood/dataset_blood \
--onnx_dir /path/to/aie-kidney-all-blood/onnx_export/fs_sam2_output \
--output_dir build \
--calib_size 128 \
--device cuda
核心原则:1 套源校准集 + 4 份派生校准输入缓存
不是对每个子模型用随机数据校准,而是用 FP32 pipeline 跑完整推理流程,把每个子模型实际看到的输入 缓存为 .npz 文件:
plain
build/
├── calib_image_encoder.npz (304 MB) --- 128 张 [3,1024,1024] 图像
├── calib_memory_encoder.npz (474 MB) --- 128 对 (mask, pix_feat)
├── calib_memory_attention.npz (788 MB) --- 128 组 5 路 memory 输入
└── calib_image_decoder.npz (1.9 GB) --- 128 组 decoder 输入
4.3 量化执行(Step 2)
执行的脚本 :onnx_quant/quantize_onnx_int8.py
作用 :用 modelopt 对 4 个 FP32 ONNX 做 INT8 PTQ,读取上一步生成的 build/calib_*.npz,输出到 build/quantized_int8/。image_encoder 最慢(约 36 分钟),4 个模型合计约 40 分钟。
完整命令示例 (工作目录为 onnx_quant):
bash
python quantize_onnx_int8.py \
--onnx_dir /path/to/aie-kidney-all-blood/onnx_export/fs_sam2_output \
--calib_dir build \
--output_dir build/quantized_int8
量化结果:
| 子模型 | 总节点数 | INT8 量化节点 | 量化比例 | dq_only |
|---|---|---|---|---|
| image_encoder | 5603 | 266 | 4.7% | False (QDQ) |
| memory_encoder | 300 | 16 | 5.3% | True (DQ-only) |
| memory_attention | 908 | 62 | 6.8% | True (DQ-only) |
| image_decoder | 898 | 88 | 9.8% | True (DQ-only) |
4.4 量化过程中遇到的问题与解决方案
问题 1:image_encoder Resize 节点不兼容 DQ-only 模式
plain
[ERROR] Unsupported op_type for real weight quantization: Resize
原因 :modelopt 的 qdq_to_dq 转换步骤在处理 image_encoder 的 pos_embed Resize 节点时失败,无法将 QDQ 格式转换为 DQ-only 格式。
解决方案 :对 image_encoder 保持 QDQ 格式 (dq_only=False),其余模型正常使用 DQ-only:
python
use_dq_only = False if model_name == "image_encoder" else True
问题 2:memory_encoder pix_feat 维度不匹配
plain
Invalid rank for input: pix_feat Got: 5 Expected: 4
原因 :FP32 pipeline 中 pix_feat 在某些情况下为 5D 张量 [1,1,256,64,64],但 ONNX 模型期望 4D [1,256,64,64]。
解决方案 :在 prepare_calibration_data.py 中添加 _ensure_pix_feat_4d() 辅助函数:
python
def _ensure_pix_feat_4d(arr: np.ndarray) -> np.ndarray:
if arr.ndim == 4:
return arr
# squeeze 到 4D
while arr.ndim > 4 and arr.shape[0] == 1:
arr = arr.squeeze(0)
return arr
问题 3:modelopt symbolic shape inference 失败
原因:动态 shape 模型(memory_attention、image_decoder)的 symbolic shape inference 无法推断所有中间张量形状,导致 GEMV 检测失败。
解决方案 :Monkey-patch modelopt 的 find_nodes_from_matmul_to_exclude 函数,在 symbolic inference 失败时自动回退到 inference-based GEMV 检测:
python
def _patched_find(...):
try:
nodes_to_exclude = _graph_utils._exclude_matmuls_by_symbolic_inference(...)
except (AssertionError, Exception) as e:
logger.warning("Symbolic shape inference failed, falling back to inference-based.")
nodes_to_exclude = _graph_utils._exclude_matmuls_by_inference(...)
return nodes_to_exclude
问题 4:TensorRT stronglyTyped output dtype 不一致
原因 :量化后 graph output 声明的 dtype 可能与 inferred dtype 不一致,导致 trtexec --stronglyTyped 报错。
解决方案 :添加 sanitize_onnx_for_trt() 后处理,强制所有 graph output 声明为 FP32:
python
def sanitize_onnx_for_trt(onnx_path, force_output_fp32=True):
# 对每个 graph.output,将 elem_type 统一设置为 FLOAT
问题 5:onnx.helper.float32_to_bfloat16 缺失
原因:新版 onnx 库移除了该函数,但 modelopt 内部仍在引用。
解决方案:启动时注入兼容实现:
python
if not hasattr(onnx.helper, "float32_to_bfloat16"):
def float32_to_bfloat16(fval):
ival = struct.unpack("=I", struct.pack("=f", float(fval)))[0]
return ival >> 16
onnx.helper.float32_to_bfloat16 = float32_to_bfloat16
5. Step 3 --- 量化精度诊断
5.1 问题发现
全部 4 个子模型 INT8 量化后推理,输出视频完全没有出血 mask。
5.2 逐模型排查方法
构建 5 个诊断目录,每个目录中只有 1 个子模型使用 INT8 量化版本,其余 3 个使用 FP32 原始版本(通过符号链接重命名实现):
bash
diag_int8_tests/
├── test_all_fp32/ # 全 FP32 baseline
├── test_image_encoder/ # 仅 image_encoder 为 INT8
├── test_memory_encoder/ # 仅 memory_encoder 为 INT8
├── test_memory_attention/ # 仅 memory_attention 为 INT8
└── test_image_decoder/ # 仅 image_decoder 为 INT8
分别运行推理并对比输出视频中的 mask 质量。
5.3 诊断结果
| 测试配置 | mask 质量 | 速度 (FPS) | 结论 |
|---|---|---|---|
| all_fp32 (baseline) | 正常 | 19.3 | 基准线 |
| image_encoder=INT8 | 与 FP32 一致 | 15.3 | 可量化 |
| memory_encoder=INT8 | 与 FP32 一致 | 20.4 | 可量化 |
| memory_attention=INT8 | mask 有空洞 | 4.3 (慢 18x) | 不可量化 |
| image_decoder=INT8 | 无 mask | 20.9 | 不可量化 |
5.4 分析
- image_decoder INT8 致命:decoder 是最终输出 mask 的模型,INT8 精度损失直接导致 sigmoid 后 mask 概率全部低于阈值,mask 完全消失。这类模型对精度极敏感。
- memory_attention INT8 有害:不仅质量下降(空洞),ORT 对 DQ-only INT8 节点的执行效率极差,mem_attn 推理时间从 7ms 飙升至 144ms。可能是因为 ORT 没有为该子图的 INT8 模式做 kernel 融合。
- image_encoder / memory_encoder 安全:这两个模型以特征提取为主,对 INT8 量化容忍度高。
5.5 确定最终方案
只量化 image_encoder + memory_encoder,memory_attention + image_decoder 保持 FP32/FP16。
6. Step 4 --- TensorRT 引擎编译
6.1 混合编译脚本 trt.sh
bash
## INT8 量化模型(QDQ 图):
## 推荐用 --fp16 --int8(让 TRT 按 layer 自动选择最优精度),避免部分 GPU 上
## 因 kernel 覆盖不足导致插入大量 reformat 节点而变慢。
trtexec --onnx=quantized_int8/image_encoder.int8.quant.onnx \
--saveEngine=trt_engines/image_encoder.int8.quant.engine \
--fp16 --int8
trtexec --onnx=quantized_int8/memory_encoder.int8.quant.onnx \
--saveEngine=trt_engines/memory_encoder.int8.quant.engine \
--fp16 --int8
## FP32 原始模型 -> FP16(TRT 自动精度优化)
trtexec --onnx=fs_sam2_output/memory_attention.onnx \
--saveEngine=trt_engines/memory_attention.int8.quant.engine \
--fp16 \
--minShapes=... --optShapes=... --maxShapes=...
trtexec --onnx=fs_sam2_output/image_decoder.onnx \
--saveEngine=trt_engines/image_decoder.int8.quant.engine \
--fp16 \
--minShapes=... --optShapes=... --maxShapes=...
6.1.1 常见坑:A100 上编译+推理只有 ~10 FPS(而 H100 有 ~23 FPS)
现象:同一套 ONNX/QDQ 量化图,在 H100 上 TRT 推理可达 ~23 FPS,但在 A100 上只有 ~10 FPS。
根因(高概率):**对 QDQ 量化图使用 **--stronglyTyped 会强制 TRT 严格按 Q/DQ 标记执行 INT8;当某些 layer 在 A100 上缺少合适的 INT8 tactic / kernel 时,TRT 会插入大量 reformat / type-cast(INT8↔FP16/FP32)来满足强类型约束,导致整体吞吐显著下降。H100(更新架构)对 INT8 kernel 覆盖更完整,因此更不容易触发该退化。
修复:对 QDQ INT8 模型优先使用 --fp16 --int8(而不是 --stronglyTyped),让 TRT 允许对没有高效 INT8 kernel 的 layer 自动回退到 FP16,从而避免 reformat 风暴。
验证建议:重新在 A100 上运行 trt.sh 生成 engine,并用 infer_video_trt.py 复测 FPS;如果 FPS 显著回升(通常接近 FP32/FP16 TRT 的水平),说明退化来自强类型 INT8 的 tactic 覆盖问题。
6.2 动态 shape 配置
memory_attention 和 image_decoder 有动态维度,需指定 min/opt/max shapes:
memory_attention(memory 数量可变):
memory_0: [1~16, 256]memory_1: [1~16, 64, 64, 64]memory_pos_embed: [1~65600, 1, 64]
image_decoder(prompt 点数可变):
point_coords: [1, 2~10, 2]point_labels: [1, 2~10]
6.3 编译产物
plain
trt_engines/
├── image_encoder.int8.quant.engine (73 MB) --- INT8
├── memory_encoder.int8.quant.engine (3.7 MB) --- INT8
├── memory_attention.int8.quant.engine (29 MB) --- FP16
└── image_decoder.int8.quant.engine (16 MB) --- FP16
总计 ~122 MB,相比 FP32 ONNX 的 322 MB 减少了 62%。
8. 文件结构总览
plain
aie-kidney-all-blood/
├── onnx_export/ # 本仓库:存放 ONNX 产物与 FS-SAM2 导出脚本
│ ├── fs_sam2_output/ # 4 个 FP32 ONNX
│ │ ├── image_encoder.onnx (265 MB)
│ │ ├── memory_encoder.onnx (5.4 MB)
│ │ ├── memory_attention.onnx (31 MB)
│ │ └── image_decoder.onnx (21 MB)
│ └── README.md # 说明 onnx_export 的 FS-SAM2/LoRA 导出方式(与当前 4 模块 pipeline 不同)
│
├── onnx_quant/ # INT8 量化
│ ├── prepare_calibration_data.py # 校准数据生成(FP32 pipeline 派生)
│ ├── quantize_onnx_int8.py # 量化主脚本(modelopt)
│ ├── evaluate_onnx.py # 量化评估(精度+延迟对比)
│ ├── run_all.sh # 全流程编排
│ └── build/
│ ├── calib_*.npz # 校准数据缓存
│ ├── quantized_int8/ # 全量 INT8 模型(诊断用)
│ └── quantized_mixed/ # 混合方案(IE+ME=INT8, MA+ID=FP32 symlink)
│
├── trt.sh # TRT 编译脚本(混合 INT8/FP16)
├── trt_engines/ # 编译产物
│ ├── image_encoder.int8.quant.engine (73 MB, INT8)
│ ├── memory_encoder.int8.quant.engine (3.7 MB, INT8)
│ ├── memory_attention.int8.quant.engine (29 MB, FP16)
│ └── image_decoder.int8.quant.engine (16 MB, FP16)
│
├── infer_video_sam2export.py # ONNX 推理脚本(FP32 / 量化混合)
└── infer_video_trt.py # TRT 推理脚本
9. 复现指南
9.1 环境
plain
conda env: modelopt
Python 3.12
nvidia-modelopt 0.37.0
onnxruntime-gpu 1.24.2
TensorRT 10.14.1.48
CUDA 12.x + cuDNN 9
9.2 完整流程
bash
## 0. 进入环境
source ~/.bashrc
conda activate modelopt
## 1. 导出 ONNX
## python export_onnx.py --outdir <outdir> --config sam2_configs/sam2_hiera_base+.yaml --checkpoint <ckpt>
## 再将生成的 4 个 .onnx 拷贝到本项目 onnx_export/fs_sam2_output/
## 2. 生成校准数据(~5 分钟)
cd onnx_quant
python prepare_calibration_data.py \
--datapath /path/to/dataset_blood \
--onnx_dir /path/to/onnx_export/fs_sam2_output \
--output_dir build --calib_size 128 --device cuda
## 3. INT8 量化(~43 分钟,image_encoder 最慢)
python quantize_onnx_int8.py \
--onnx_dir /path/to/onnx_export/fs_sam2_output \
--calib_dir build --output_dir build/quantized_int8
## 4. 编译 TRT 引擎(~7 分钟)
cd ..
bash trt.sh
## 5. TRT 推理
python infer_video_trt.py \
--video kidney_1_001347_001405.mp4 \
--engine_dir trt_engines \
--output output_trt.mp4 \
--support_dir support_images \
--support_mask_dir support_masks