【Pytorch 第一讲】 如何加载预训练模型

一. 封装Pytorch的Model 加载pre-trianed Model

python 复制代码
import torch
import torchvision.models as models
from torchvision import transforms

# 1. 下载并加载预训练模型
model = models.resnet18(pretrained=False)  # 设置pretrained=False,表示不加载预训练权重

# 2. 下载预训练权重文件并加载
pretrained_dict = torch.load("path/to/resnet18-5c106cde.pth")  # 替换为实际的权重文件路径
model.load_state_dict(pretrained_dict)

# 3. 将模型设置为评估模式
model.eval()

# 4. 示例:将模型应用于输入数据
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_data = transform(Image.open("path/to/your/image.jpg")).unsqueeze(0)  # 替换为实际的图像路径
output = model(input_data)
print(output)

二. 自定义Pytorch模型加载 Pre-trained Model

python 复制代码
# 导入 PyTorch
import torch

# 初始化你的模型
model = faster_vit_0_224()
python 复制代码
# 加载预训练权重
checkpoint = torch.load('/home/loads/vit_0_224_1k.pth.tar')

# 或者

checkpoint = torch.load('/home/loads/vit_0_224_1k.pth')

当完成这个加载以后, 可以考虑打开 checkpoint , 看看该模型保存时,包含哪些dict keys.

python 复制代码
print(checkpoint.keys())

结果: dict_keys(['epoch', 'arch', 'state_dict', 

'optimizer', 'version', 'args', 'amp_scaler', 'metric'])

根据上面的 state_dict, 可以接下来用来将权重赋予模型Model

python 复制代码
# 将权重赋值给模型

model.load_state_dict(checkpoint['state_dict'])

注意:如果 在checkpoint 的dict_keys 中不是"state_dict", 是"model_state_dict", 则需要把checkpoint["state_dict"] 改成 checkpoinbt["model_state_dict"] 读取所下载的的模型的权重,并将其赋予给模型。

相关推荐
一点媛艺16 分钟前
Kotlin函数由易到难
开发语言·python·kotlin
qzhqbb1 小时前
基于统计方法的语言模型
人工智能·语言模型·easyui
冷眼看人间恩怨1 小时前
【话题讨论】AI大模型重塑软件开发:定义、应用、优势与挑战
人工智能·ai编程·软件开发
2401_883041081 小时前
新锐品牌电商代运营公司都有哪些?
大数据·人工智能
魔道不误砍柴功1 小时前
Java 中如何巧妙应用 Function 让方法复用性更强
java·开发语言·python
_.Switch2 小时前
高级Python自动化运维:容器安全与网络策略的深度解析
运维·网络·python·安全·自动化·devops
AI极客菌2 小时前
Controlnet作者新作IC-light V2:基于FLUX训练,支持处理风格化图像,细节远高于SD1.5。
人工智能·计算机视觉·ai作画·stable diffusion·aigc·flux·人工智能作画
阿_旭2 小时前
一文读懂| 自注意力与交叉注意力机制在计算机视觉中作用与基本原理
人工智能·深度学习·计算机视觉·cross-attention·self-attention
王哈哈^_^2 小时前
【数据集】【YOLO】【目标检测】交通事故识别数据集 8939 张,YOLO道路事故目标检测实战训练教程!
前端·人工智能·深度学习·yolo·目标检测·计算机视觉·pyqt
测开小菜鸟3 小时前
使用python向钉钉群聊发送消息
java·python·钉钉