1.download and save to 'resnet18.pth' file:
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
def main():
print('cuda device count: ', torch.cuda.device_count())
net = torchvision.models.resnet18(pretrained=True)
#net.fc = nn.Linear(512, 2)
net = net.to('cuda:0')
net.eval()
print(net)
tmp = torch.ones(2, 3, 224, 224).to('cuda:0')
out = net(tmp)
print('resnet18 out:', out.shape)
torch.save(net, "resnet18.pth")
if __name__ == '__main__':
main()
this 'resnet18.pth' file contains the model structure and weights.
2.load the .pth file and transform it to ONNX format:
import torch
def main():
model = torch.load('resnet18.pth')
# model.eval()
inputs = torch.randn(1,3,224,224)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = inputs.to(device)
torch.onnx.export(model,inputs, 'resnet18_trtpose.onnx',training=2)
if __name__ == '__main__':
main()
3.load and read the .pth file, extract the weights of the model to a .wts file
import torch
from torch import nn
import torchvision
import os
import struct
from torchsummary import summary
def main():
print('cuda device count: ', torch.cuda.device_count())
net = torch.load('resnet18.pth')
net = net.to('cuda:0')
net.eval()
print('model: ', net)
#print('state dict: ', net.state_dict().keys())
tmp = torch.ones(1, 3, 224, 224).to('cuda:0')
print('input: ', tmp)
out = net(tmp)
print('output:', out)
summary(net, (3,224,224))
#return
f = open("resnet18.wts", 'w')
f.write("{}\n".format(len(net.state_dict().keys())))
for k,v in net.state_dict().items():
print('key: ', k)
print('value: ', v.shape)
vr = v.reshape(-1).cpu().numpy()
f.write("{} {}".format(k, len(vr)))
for vv in vr:
f.write(" ")
f.write(struct.pack(">f", float(vv)).hex())
f.write("\n")
if __name__ == '__main__':
main()