一. 封装Pytorch的Model 加载pre-trianed Model
python
import torch
import torchvision.models as models
from torchvision import transforms
# 1. 下载并加载预训练模型
model = models.resnet18(pretrained=False) # 设置pretrained=False,表示不加载预训练权重
# 2. 下载预训练权重文件并加载
pretrained_dict = torch.load("path/to/resnet18-5c106cde.pth") # 替换为实际的权重文件路径
model.load_state_dict(pretrained_dict)
# 3. 将模型设置为评估模式
model.eval()
# 4. 示例:将模型应用于输入数据
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_data = transform(Image.open("path/to/your/image.jpg")).unsqueeze(0) # 替换为实际的图像路径
output = model(input_data)
print(output)
二. 自定义Pytorch模型加载 Pre-trained Model
python
# 导入 PyTorch
import torch
# 初始化你的模型
model = faster_vit_0_224()
python
# 加载预训练权重
checkpoint = torch.load('/home/loads/vit_0_224_1k.pth.tar')
# 或者
checkpoint = torch.load('/home/loads/vit_0_224_1k.pth')
当完成这个加载以后, 可以考虑打开 checkpoint , 看看该模型保存时,包含哪些dict keys.
python
print(checkpoint.keys())
结果: dict_keys(['epoch', 'arch', 'state_dict',
'optimizer', 'version', 'args', 'amp_scaler', 'metric'])
根据上面的 state_dict, 可以接下来用来将权重赋予模型Model
python
# 将权重赋值给模型
model.load_state_dict(checkpoint['state_dict'])
注意:如果 在checkpoint 的dict_keys 中不是"state_dict", 是"model_state_dict", 则需要把checkpoint["state_dict"] 改成 checkpoinbt["model_state_dict"] 读取所下载的的模型的权重,并将其赋予给模型。