一、Caffe模型转换源码下载
GitHub:https://github.com/xxradon/PytorchToCaffe
GitHub上拉取PytorchToCaffe的源码,将Caffe
文件夹和pytorch_to_caffe.py
文件放到项目根目录
二、将VGG的PyTorch模型转为Caffe模型
cpp
import sys
sys.path.insert(0,'.')
import torch
from torch.autograd import Variable
from torchvision.models.vgg import vgg11_bn
import pytorch_to_caffe
if __name__=='__main__':
name='vgg11_bn'
net=vgg11_bn(True)
input=Variable(torch.ones([1,3,224,224]))
pytorch_to_caffe.trans_net(net,input,name)
pytorch_to_caffe.save_prototxt('{}.prototxt'.format(name))
pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(name))