resnet18下载与保存,转换为ONNX模型,导出 .wts 格式的权重文件

1.download and save to 'resnet18.pth' file:

复制代码
import torch
from torch import nn
from torch.nn import functional as F
import torchvision

def main():
    print('cuda device count: ', torch.cuda.device_count())
    net = torchvision.models.resnet18(pretrained=True)
    #net.fc = nn.Linear(512, 2)
    net = net.to('cuda:0')
    net.eval()
    print(net)
    tmp = torch.ones(2, 3, 224, 224).to('cuda:0')
    out = net(tmp)
    print('resnet18 out:', out.shape)
    torch.save(net, "resnet18.pth")

if __name__ == '__main__':
    main()

this 'resnet18.pth' file contains the model structure and weights.

2.load the .pth file and transform it to ONNX format:

复制代码
import torch

def main():
    
    model = torch.load('resnet18.pth')
    # model.eval()
    inputs = torch.randn(1,3,224,224)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inputs = inputs.to(device)
    torch.onnx.export(model,inputs, 'resnet18_trtpose.onnx',training=2)
    
if __name__ == '__main__':
    main()

3.load and read the .pth file, extract the weights of the model to a .wts file

复制代码
import torch
from torch import nn
import torchvision
import os
import struct
from torchsummary import summary

def main():
    print('cuda device count: ', torch.cuda.device_count())
    net = torch.load('resnet18.pth')
    net = net.to('cuda:0')
    net.eval()
    print('model: ', net)
    #print('state dict: ', net.state_dict().keys())
    tmp = torch.ones(1, 3, 224, 224).to('cuda:0')
    print('input: ', tmp)
    out = net(tmp)
    print('output:', out)

    summary(net, (3,224,224))
    #return
    f = open("resnet18.wts", 'w')
    f.write("{}\n".format(len(net.state_dict().keys())))
    for k,v in net.state_dict().items():
        print('key: ', k)
        print('value: ', v.shape)
        vr = v.reshape(-1).cpu().numpy()
        f.write("{} {}".format(k, len(vr)))
        for vv in vr:
            f.write(" ")
            f.write(struct.pack(">f", float(vv)).hex())
        f.write("\n")

if __name__ == '__main__':
    main()
相关推荐
无风听海6 分钟前
神经网络之交叉熵与 Softmax 的梯度计算
人工智能·深度学习·神经网络
java1234_小锋9 分钟前
TensorFlow2 Python深度学习 - TensorFlow2框架入门 - 神经网络基础原理
python·深度学习·tensorflow·tensorflow2
JJJJ_iii10 分钟前
【深度学习03】神经网络基本骨架、卷积、池化、非线性激活、线性层、搭建网络
网络·人工智能·pytorch·笔记·python·深度学习·神经网络
玉石观沧海15 分钟前
高压变频器故障代码解析F67 F68
运维·经验分享·笔记·分布式·深度学习
JJJJ_iii19 分钟前
【深度学习05】PyTorch:完整的模型训练套路
人工智能·pytorch·python·深度学习
DP+GISer40 分钟前
自己制作遥感深度学习数据集进行遥感深度学习地物分类-试读
人工智能·深度学习·分类
paid槮1 小时前
《深度学习》【项目】自然语言处理——情感分析 <上>
深度学习·自然语言处理·easyui
程序员小远1 小时前
常用的测试用例
自动化测试·软件测试·python·功能测试·测试工具·职场和发展·测试用例
IT学长编程1 小时前
计算机毕业设计 基于EChants的海洋气象数据可视化平台设计与实现 Python 大数据毕业设计 Hadoop毕业设计选题【附源码+文档报告+安装调试】
大数据·hadoop·python·毕业设计·课程设计·毕业论文·海洋气象数据可视化平台
辣椒http_出海辣椒1 小时前
Python 数据抓取实战:从基础到反爬策略的完整指南
python