libtorch ITK 部署 nnUNetV2 模型

1.1 PyTorch 版本严格匹配

  • 为何如此重要?

    PyTorch 和 libtorch 在内部实现和数据结构上高度相关,版本不一致会导致二进制不兼容,模型加载失败或推理异常。

  • 如何操作?

    • • 记录训练时使用的 PyTorch 版本。

    • • 前往 PyTorch官网 下载对应版本的 libtorch。

    • • 注意操作系统及编译器版本匹配,确保环境一致。

1.2 CUDA 版本同步(仅 GPU 推理需关注)

  • • 训练和推理端的 CUDA、cuDNN 版本必须保持一致,防止驱动或库不匹配导致错误。

  • • 下载带 CUDA 支持的 libtorch 包。

1.3 开发环境准备

  • • 确保 C++17 及以上标准支持,推荐使用 CMake 管理项目。

  • • 配置好编译器,能正确链接 libtorch。

  • • 保持 Python 环境,用于模型训练和导出。


2. 模型导出:Python 端将 nnUNetV2 转换为 TorchScript

PyTorch 原生模型只能在 Python 环境使用,而 libtorch 需要的是经过 TorchScript 转换的模型文件。

2.1 为什么导出 TorchScript?

  • • TorchScript 是 PyTorch 的静态计算图中间表示,支持序列化和跨语言调用。

  • • C++ 端通过加载 TorchScript 模型,无需依赖 Python 环境,实现高效推理。

2.2 导出示例流程

复制代码
import json
from pathlib import Path
import torch
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer

# 参数配置
dataset_id = 101
configuration = "3d_fullres"
fold = 0

# 路径配置
plans_file = Path(f"E:/nnunet/nnUNet_preprocessed/Dataset{dataset_id}_Temporal/nnUNetPlans.json")
dataset_json_file = Path(f"E:/nnunet/nnUNet_preprocessed/Dataset{dataset_id}_Temporal/dataset.json")
checkpoint_file = Path(f"E:/nnunet/nnUNet_results/Dataset{dataset_id}_Temporal/nnUNetTrainer__nnUNetPlans__{configuration}/fold_{fold}/checkpoint_best.pth")

# 加载配置文件
withopen(plans_file) as f:
    plans = json.load(f)
withopen(dataset_json_file) as f:
    dataset_json = json.load(f)

# 初始化 Trainer
trainer = nnUNetTrainer(plans, configuration, fold, dataset_json, device=torch.device("cuda"))
trainer.initialize()

# 加载模型权重
trainer.load_checkpoint(str(checkpoint_file))

# 导出 TorchScript 模型
model = trainer.network
model.eval()

patch_size = plans["configurations"][configuration]["patch_size"]
dummy_input = torch.randn([1, 1] + patch_size).cuda()

traced = torch.jit.trace(model, dummy_input)
traced.save("nnunet_model_traced.pt")

print("✅ TorchScript 模型导出完成: nnunet_model_traced.pt")

注意事项

  • • 若模型中有动态控制流,建议使用 torch.jit.script

  • • 输入尺寸需严格与训练阶段保持一致。

  • • 可以导出多版本模型,便于不同场景测试。


3. libtorch C++ 推理:从加载到输出的完整流程

导出 TorchScript 模型后,即可使用 libtorch 在 C++ 中进行推理。

3.1 基础推理示范代码

复制代码
try {
  std::cout << "开始加载 TorchScript 模型..." << std::endl;
auto start_load = std::chrono::high_resolution_clock::now();

  torch::jit::script::Module model = torch::jit::load("model.pt");

auto end_load = std::chrono::high_resolution_clock::now();
  std::chrono::duration<double> load_time = end_load - start_load;
  std::cout << "模型加载成功,耗时: " << load_time.count() << " 秒" << std::endl;

  model.eval();
  model.to(torch::kCPU);  // 如需GPU,请切换到 torch::kCUDA

// 构造输入张量,尺寸需与训练时相同
int imageSize[3];
  imageSize[0] = transformed->GetLargestPossibleRegion().GetSize()[0];
  imageSize[1] = transformed->GetLargestPossibleRegion().GetSize()[1];
  imageSize[2] = transformed->GetLargestPossibleRegion().GetSize()[2];

  torch::Tensor input_tensor = torch::from_blob(
    transformed->GetBufferPointer(),
    {1, 1, (int)imageSize[2], (int)imageSize[1], (int)imageSize[0]}
  ).toType(torch::kFloat32);

  std::cout << "\n输入张量信息:" << std::endl
            << "- 形状: " << input_tensor.sizes() << std::endl
            << "- 类型: " << input_tensor.dtype() << std::endl
            << "- 设备: " << input_tensor.device() << std::endl;

  std::cout << "\n开始推理..." << std::endl;
auto start_infer = std::chrono::high_resolution_clock::now();

  std::vector<torch::jit::IValue> inputs;
  inputs.push_back(input_tensor);

  torch::jit::IValue output = model.forward(inputs);

auto end_infer = std::chrono::high_resolution_clock::now();
  std::chrono::duration<double> infer_time = end_infer - start_infer;
  std::cout << "推理完成,耗时: " << infer_time.count() << " 秒" << std::endl;

// 解析输出
  torch::Tensor output_tensor;
if (output.isTensorList()) {
    auto output_list = output.toTensorList();
    output_tensor = output_list[0];
  } elseif (output.isTuple()) {
    auto output_tuple = output.toTuple();
    output_tensor = output_tuple->elements()[0].toTensor();
  } elseif (output.isTensor()) {
    output_tensor = output.toTensor();
  } else {
    std::cerr << "未知输出类型!" << std::endl;
    return-1;
  }

  std::cout << "\n输出张量信息:" << std::endl
            << "- 形状: " << output_tensor.sizes() << std::endl
            << "- 类型: " << output_tensor.dtype() << std::endl
            << "- 设备: " << output_tensor.device() << std::endl;

// 打印部分输出值
if (output_tensor.numel() > 0) {
    std::cout << "\n输出张量前10个值: ";
    auto flattened = output_tensor.flatten();
    for (int i = 0; i < std::min(10, int(flattened.size(0))); ++i) {
      std::cout << flattened[i].item<float>() << " ";
    }
    std::cout << std::endl;
  }

  torch::Tensor pred = torch::argmax(output_tensor, 1).toType(torch::kInt16);
  pred = pred.squeeze(0);
  std::cout << "预测结果尺寸: " << pred.sizes() << std::endl;

} catch (const c10::Error& e) {
  std::cerr << "LibTorch 错误: " << e.what() << std::endl;
return-1;
} catch (const std::exception& e) {
  std::cerr << "异常: " << e.what() << std::endl;
return-1;
}

3.2 GPU 推理

启用 GPU 推理非常简单:

复制代码
model.to(torch::kCUDA);
input_tensor = input_tensor.to(torch::kCUDA);

请确保:

  • • 使用的是支持 CUDA 的 libtorch。

  • • CUDA 和驱动版本匹配训练环境。

  • • GPU 可用且资源充足。

3.3 预处理与后处理关键点

预处理:

  • • 必须完全复刻训练时的数据处理流程(归一化、裁剪、重采样、格式转换)。

  • • 以 nnUNet 的 plan.json 为准,确认 patch size、spacing 和归一化参数。

  • • 例如 CT 图像归一化示例:

    "patch_size": [56,112,112],
    "spacing":[1.0,1.0,1.0],
    "normalization_schemes":["CTNormalization"],
    "foreground_intensity_properties_per_channel":{
    "0":{
    "mean":655.56,
    "std":736.38,
    "percentile_00_5":-249.0,
    "percentile_99_5":2202.0
    }
    }

在 C++ 使用 ITK 做归一化:

复制代码
// 读取图像
auto reader = itk::ImageFileReader<FloatImage>::New();
reader->SetFileName("Test001_0000.nii");
reader->SetImageIO(itk::NiftiImageIO::New());
reader->Update();

// 重采样与填充
double spacing[3]{1.0,1.0,1.0};
int roi_size[3]{112,112,56};
auto transformed = GeneralTransform::SpatialPad(
  GeneralTransform::Orientation(
    GeneralTransform::Spacing(reader->GetOutput(), spacing),
    itk::SpatialOrientation::ITK_COORDINATE_ORIENTATION_RAI),
  roi_size, -1000);

// 归一化参数
double mean_intensity = 655.56;
double std_intensity = 736.38;
double p00_5 = -249.0;
double p99_5 = 2202.0;

// 归一化处理
itk::ImageRegionIterator<FloatImage> it(transformed, transformed->GetLargestPossibleRegion());
for (it.GoToBegin(); !it.IsAtEnd(); ++it) {
double v = it.Get();
  v = std::min(std::max(v, p00_5), p99_5);      // 裁剪
  v = (v - mean_intensity) / std::max(std_intensity, 1e-8);  // 标准化
  it.Set(static_cast<float>(v));
}

后处理:

  • • 阈值分割将模型输出概率转换为掩码。

  • • 连通域分析剔除小块噪声。

  • • 根据实际需求映射不同标签类别。


4. 常见问题及解决方案

问题 解决方案
模型加载失败 确认 libtorch 版本与训练 PyTorch 版本是否匹配
推理结果全零或异常 检查输入数据预处理是否严格一致,输入尺寸是否正确
GPU 推理时报错 确认 CUDA 和驱动版本匹配,使用支持 CUDA 的 libtorch
编译链接失败 确认 CMakeLists.txt 配置正确,库路径完整

https://mp.weixin.qq.com/s/9fjmydyFsnOwOAkuRoyGyw

相关推荐
asyxchenchong8886 小时前
OpenLCA、GREET、R语言的生命周期评价方法、模型构建
开发语言·r语言
没有梦想的咸鱼185-1037-16636 小时前
【生命周期评价(LCA)】基于OpenLCA、GREET、R语言的生命周期评价方法、模型构建
开发语言·数据分析·r语言
程序猿20237 小时前
Python每日一练---第三天:删除有序数组中的重复项
开发语言·python
一只游鱼7 小时前
Springboot+BannerBanner(启动横幅)
java·开发语言·数据库
一只游鱼7 小时前
抖音上的用python实现激励弹窗
开发语言·python
行走在电子领域的工匠7 小时前
2.2 常用控件
开发语言·python
散峰而望7 小时前
Dev-C++一些问题的处理
c语言·开发语言·数据库·c++·编辑器
进击的大海贼7 小时前
QT/C++ 消息定时管理器
开发语言·c++·qt
lly2024067 小时前
TypeScript 基础类型
开发语言