pytorch从python转 c++涉及到的数据保存加载问题;libtorch

pytorch 从python 转 c++涉及到的数据保存加载问题

1. torch.nn.Module 保存 state_dict 无法被 c++访问,只能转化为 python字典

python代码

python 复制代码
model = ThreeLayer_FCNN_Net()
model.load_state_dict(ret_load)
w = {k: v for k, v in model.state_dict().items()}
torch.save(w, "data.pkl")

c++代码

cpp 复制代码
std::vector<char> get_the_bytes(std::string filename) {
    std::ifstream input(filename, std::ios::binary);
    std::vector<char> bytes(
        (std::istreambuf_iterator<char>(input)),
        (std::istreambuf_iterator<char>()));
    input.close();
    return bytes;
}
cpp 复制代码
std::vector<char> f = get_the_bytes("../48_channal_guozi_2.pkl");
torch::IValue x = torch::pickle_load(f);
auto dict = x.toGenericDict();
std::cout << "dict: " << dict << std::endl;

2. 保存 Tensor张量,C++加载 Tensor张量

python代码

python 复制代码
data = torch.ones(3, 4)
torch.save(data, "data.pkl")

c++代码

cpp 复制代码
std::vector<char> f = get_the_bytes("../48_channal_guozi_2.pkl");
torch::IValue x = torch::pickle_load(f);
auto my_tensor = x.toTensor();
std::cout << "my_tensor: " << my_tensor << std::endl;
相关推荐
wljy17 分钟前
二、进制状态转换
linux·运维·服务器·c语言·c++
云泽80827 分钟前
笔试算法 -位运算篇(二):从唯一字符到消失数字
c++·算法·位运算
charlee4428 分钟前
《GIS基础原理与技术实践》配套案例(Python版)
python·conda·numpy·gis·环境配置
枫叶林FYL28 分钟前
项目十:事件溯源仓储管理系统(WMS)仿真实现
开发语言·python
繁华落尽,倾城殇?44 分钟前
[C++11] : atomic,nullptr,default/delete,enum class
开发语言·c++·c++11·nullptr·atomic·enum class·default/delete
代码村新手1 小时前
C++-二叉搜索树
开发语言·c++
智者知已应修善业2 小时前
【51单片机8位数码管动态显示日期小数点风格】2023-11-13
c++·经验分享·笔记·算法·51单片机
智者知已应修善业2 小时前
【51单片机有三个LED 分别第一个灯闪三下 再到第二个灯又闪三下 再到第三个灯又闪三下 就这样循环程序】2023-11-16
c++·经验分享·笔记·算法·51单片机
渣渣xiong3 小时前
从零开始:前端转型AI agent直到就业第五十七天-第五十八天
前端·人工智能·python
小L~~~3 小时前
基于贪心策略的混合遗传算法求解01背包问题
python·算法