++直接在pytorch中,用torch.save保存的张量,可能因格式差异无法在C++中加载。++
以下是一个最简单的例子,展示如何在 Pytorch中保存张量到 TorchScript 模块,并在 C++ 中使用 LibTorch 加载。
Python 代码 (save_tensor.py)
import torch
# 定义一个简单的 TorchScript 模块来包装张量
class TensorWrapper(torch.jit.ScriptModule):
def __init__(self, tensor):
super().__init__()
self.tensor = torch.jit.Attribute(tensor, torch.Tensor)
# 创建一个张量
tensor = torch.randn(2, 3)
# 包装张量到模块
module = TensorWrapper(tensor)
# 保存模块到文件
torch.jit.save(module, "tensor.pt")
C++ 代码 (load_tensor.cpp)
#include <torch/script.h>
#include <iostream>
int main() {
// 加载 TorchScript 模块
torch::jit::script::Module module = torch::jit::load("tensor.pt");
// 获取张量(假设我们知道属性名为 tensor)
torch::Tensor tensor = module.attr("tensor").toTensor();
// 打印张量
std::cout << tensor << std::endl;
return 0;
}
这种方法可靠,因为 TorchScript 提供了跨语言的序列化支持,保证张量数据一致性。