libtorch ITK 部署 nnUNetV2 模型

1.1 PyTorch 版本严格匹配

  • 为何如此重要?

    PyTorch 和 libtorch 在内部实现和数据结构上高度相关,版本不一致会导致二进制不兼容,模型加载失败或推理异常。

  • 如何操作?

    • • 记录训练时使用的 PyTorch 版本。

    • • 前往 PyTorch官网 下载对应版本的 libtorch。

    • • 注意操作系统及编译器版本匹配,确保环境一致。

1.2 CUDA 版本同步(仅 GPU 推理需关注)

  • • 训练和推理端的 CUDA、cuDNN 版本必须保持一致,防止驱动或库不匹配导致错误。

  • • 下载带 CUDA 支持的 libtorch 包。

1.3 开发环境准备

  • • 确保 C++17 及以上标准支持,推荐使用 CMake 管理项目。

  • • 配置好编译器,能正确链接 libtorch。

  • • 保持 Python 环境,用于模型训练和导出。


2. 模型导出:Python 端将 nnUNetV2 转换为 TorchScript

PyTorch 原生模型只能在 Python 环境使用,而 libtorch 需要的是经过 TorchScript 转换的模型文件。

2.1 为什么导出 TorchScript?

  • • TorchScript 是 PyTorch 的静态计算图中间表示,支持序列化和跨语言调用。

  • • C++ 端通过加载 TorchScript 模型,无需依赖 Python 环境,实现高效推理。

2.2 导出示例流程

复制代码
import json
from pathlib import Path
import torch
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer

# 参数配置
dataset_id = 101
configuration = "3d_fullres"
fold = 0

# 路径配置
plans_file = Path(f"E:/nnunet/nnUNet_preprocessed/Dataset{dataset_id}_Temporal/nnUNetPlans.json")
dataset_json_file = Path(f"E:/nnunet/nnUNet_preprocessed/Dataset{dataset_id}_Temporal/dataset.json")
checkpoint_file = Path(f"E:/nnunet/nnUNet_results/Dataset{dataset_id}_Temporal/nnUNetTrainer__nnUNetPlans__{configuration}/fold_{fold}/checkpoint_best.pth")

# 加载配置文件
withopen(plans_file) as f:
    plans = json.load(f)
withopen(dataset_json_file) as f:
    dataset_json = json.load(f)

# 初始化 Trainer
trainer = nnUNetTrainer(plans, configuration, fold, dataset_json, device=torch.device("cuda"))
trainer.initialize()

# 加载模型权重
trainer.load_checkpoint(str(checkpoint_file))

# 导出 TorchScript 模型
model = trainer.network
model.eval()

patch_size = plans["configurations"][configuration]["patch_size"]
dummy_input = torch.randn([1, 1] + patch_size).cuda()

traced = torch.jit.trace(model, dummy_input)
traced.save("nnunet_model_traced.pt")

print("✅ TorchScript 模型导出完成: nnunet_model_traced.pt")

注意事项

  • • 若模型中有动态控制流,建议使用 torch.jit.script

  • • 输入尺寸需严格与训练阶段保持一致。

  • • 可以导出多版本模型,便于不同场景测试。


3. libtorch C++ 推理:从加载到输出的完整流程

导出 TorchScript 模型后,即可使用 libtorch 在 C++ 中进行推理。

3.1 基础推理示范代码

复制代码
try {
  std::cout << "开始加载 TorchScript 模型..." << std::endl;
auto start_load = std::chrono::high_resolution_clock::now();

  torch::jit::script::Module model = torch::jit::load("model.pt");

auto end_load = std::chrono::high_resolution_clock::now();
  std::chrono::duration<double> load_time = end_load - start_load;
  std::cout << "模型加载成功,耗时: " << load_time.count() << " 秒" << std::endl;

  model.eval();
  model.to(torch::kCPU);  // 如需GPU,请切换到 torch::kCUDA

// 构造输入张量,尺寸需与训练时相同
int imageSize[3];
  imageSize[0] = transformed->GetLargestPossibleRegion().GetSize()[0];
  imageSize[1] = transformed->GetLargestPossibleRegion().GetSize()[1];
  imageSize[2] = transformed->GetLargestPossibleRegion().GetSize()[2];

  torch::Tensor input_tensor = torch::from_blob(
    transformed->GetBufferPointer(),
    {1, 1, (int)imageSize[2], (int)imageSize[1], (int)imageSize[0]}
  ).toType(torch::kFloat32);

  std::cout << "\n输入张量信息:" << std::endl
            << "- 形状: " << input_tensor.sizes() << std::endl
            << "- 类型: " << input_tensor.dtype() << std::endl
            << "- 设备: " << input_tensor.device() << std::endl;

  std::cout << "\n开始推理..." << std::endl;
auto start_infer = std::chrono::high_resolution_clock::now();

  std::vector<torch::jit::IValue> inputs;
  inputs.push_back(input_tensor);

  torch::jit::IValue output = model.forward(inputs);

auto end_infer = std::chrono::high_resolution_clock::now();
  std::chrono::duration<double> infer_time = end_infer - start_infer;
  std::cout << "推理完成,耗时: " << infer_time.count() << " 秒" << std::endl;

// 解析输出
  torch::Tensor output_tensor;
if (output.isTensorList()) {
    auto output_list = output.toTensorList();
    output_tensor = output_list[0];
  } elseif (output.isTuple()) {
    auto output_tuple = output.toTuple();
    output_tensor = output_tuple->elements()[0].toTensor();
  } elseif (output.isTensor()) {
    output_tensor = output.toTensor();
  } else {
    std::cerr << "未知输出类型!" << std::endl;
    return-1;
  }

  std::cout << "\n输出张量信息:" << std::endl
            << "- 形状: " << output_tensor.sizes() << std::endl
            << "- 类型: " << output_tensor.dtype() << std::endl
            << "- 设备: " << output_tensor.device() << std::endl;

// 打印部分输出值
if (output_tensor.numel() > 0) {
    std::cout << "\n输出张量前10个值: ";
    auto flattened = output_tensor.flatten();
    for (int i = 0; i < std::min(10, int(flattened.size(0))); ++i) {
      std::cout << flattened[i].item<float>() << " ";
    }
    std::cout << std::endl;
  }

  torch::Tensor pred = torch::argmax(output_tensor, 1).toType(torch::kInt16);
  pred = pred.squeeze(0);
  std::cout << "预测结果尺寸: " << pred.sizes() << std::endl;

} catch (const c10::Error& e) {
  std::cerr << "LibTorch 错误: " << e.what() << std::endl;
return-1;
} catch (const std::exception& e) {
  std::cerr << "异常: " << e.what() << std::endl;
return-1;
}

3.2 GPU 推理

启用 GPU 推理非常简单:

复制代码
model.to(torch::kCUDA);
input_tensor = input_tensor.to(torch::kCUDA);

请确保:

  • • 使用的是支持 CUDA 的 libtorch。

  • • CUDA 和驱动版本匹配训练环境。

  • • GPU 可用且资源充足。

3.3 预处理与后处理关键点

预处理:

  • • 必须完全复刻训练时的数据处理流程(归一化、裁剪、重采样、格式转换)。

  • • 以 nnUNet 的 plan.json 为准,确认 patch size、spacing 和归一化参数。

  • • 例如 CT 图像归一化示例:

    "patch_size": [56,112,112],
    "spacing":[1.0,1.0,1.0],
    "normalization_schemes":["CTNormalization"],
    "foreground_intensity_properties_per_channel":{
    "0":{
    "mean":655.56,
    "std":736.38,
    "percentile_00_5":-249.0,
    "percentile_99_5":2202.0
    }
    }

在 C++ 使用 ITK 做归一化:

复制代码
// 读取图像
auto reader = itk::ImageFileReader<FloatImage>::New();
reader->SetFileName("Test001_0000.nii");
reader->SetImageIO(itk::NiftiImageIO::New());
reader->Update();

// 重采样与填充
double spacing[3]{1.0,1.0,1.0};
int roi_size[3]{112,112,56};
auto transformed = GeneralTransform::SpatialPad(
  GeneralTransform::Orientation(
    GeneralTransform::Spacing(reader->GetOutput(), spacing),
    itk::SpatialOrientation::ITK_COORDINATE_ORIENTATION_RAI),
  roi_size, -1000);

// 归一化参数
double mean_intensity = 655.56;
double std_intensity = 736.38;
double p00_5 = -249.0;
double p99_5 = 2202.0;

// 归一化处理
itk::ImageRegionIterator<FloatImage> it(transformed, transformed->GetLargestPossibleRegion());
for (it.GoToBegin(); !it.IsAtEnd(); ++it) {
double v = it.Get();
  v = std::min(std::max(v, p00_5), p99_5);      // 裁剪
  v = (v - mean_intensity) / std::max(std_intensity, 1e-8);  // 标准化
  it.Set(static_cast<float>(v));
}

后处理:

  • • 阈值分割将模型输出概率转换为掩码。

  • • 连通域分析剔除小块噪声。

  • • 根据实际需求映射不同标签类别。


4. 常见问题及解决方案

问题 解决方案
模型加载失败 确认 libtorch 版本与训练 PyTorch 版本是否匹配
推理结果全零或异常 检查输入数据预处理是否严格一致,输入尺寸是否正确
GPU 推理时报错 确认 CUDA 和驱动版本匹配,使用支持 CUDA 的 libtorch
编译链接失败 确认 CMakeLists.txt 配置正确,库路径完整

https://mp.weixin.qq.com/s/9fjmydyFsnOwOAkuRoyGyw

相关推荐
m0_748708058 分钟前
C++中的观察者模式实战
开发语言·c++·算法
qq_5375626720 分钟前
跨语言调用C++接口
开发语言·c++·算法
wjs202430 分钟前
DOM CDATA
开发语言
Tingjct32 分钟前
【初阶数据结构-二叉树】
c语言·开发语言·数据结构·算法
猷咪1 小时前
C++基础
开发语言·c++
IT·小灰灰1 小时前
30行PHP,利用硅基流动API,网页客服瞬间上线
开发语言·人工智能·aigc·php
快点好好学习吧1 小时前
phpize 依赖 php-config 获取 PHP 信息的庖丁解牛
android·开发语言·php
秦老师Q1 小时前
php入门教程(超详细,一篇就够了!!!)
开发语言·mysql·php·db
烟锁池塘柳01 小时前
解决Google Scholar “We‘re sorry... but your computer or network may be sending automated queries.”的问题
开发语言
是誰萆微了承諾1 小时前
php 对接deepseek
android·开发语言·php