【Pytorch】模型权重保存与上传

1.模型权重保存 torch.save

python 复制代码
model_name = args.model
if model_name == "ResNet18" or model_name == "ResNet34":
    from models.ResNet1 import BasicBlock
    from models.ResNet1 import ResNet as PATCHMODEL
if model_name == "ResNet18":
	net = PATCHMODEL(BasicBlock, [2, 2, 2, 2], num_classes=num_classes).cuda()
	
torch.save(net.state_dict(), weights_dir + '/' + model_name + '_train_loss_min_numCls{}.pth'.format(num_classes))

2.模型权重上传 load_state_dict

python 复制代码
model_name = args.model
if model_name == "ResNet18" or model_name == "ResNet34":
    from models.ResNet1 import BasicBlock
    from models.ResNet1 import ResNet as PATCHMODEL
if model_name == "ResNet18":
    model = PATCHMODEL(BasicBlock, [2, 2, 2, 2], num_classes=num_classes).cuda()
    
model.load_state_dict(torch.load(model_path), strict=False)
相关推荐
逄逄不是胖胖41 分钟前
《动手学深度学习》-69预训练bert数据集实现
人工智能·深度学习·bert
CoovallyAIHub1 小时前
2.5GB 塞进浏览器:Mistral 开源实时语音识别,延迟不到半秒
深度学习·算法·计算机视觉
mygugu1 小时前
详细分析swanlab集成mmengine底层实现机制--源码分析
python·深度学习·可视化
Hello.Reader1 小时前
词语没有位置感?用“音乐节拍“给 Transformer 装上时钟——Positional Encoding 图解
人工智能·深度学习·transformer
Rorsion1 小时前
CNN经典神经网络架构
人工智能·深度学习·cnn
Neptune12 小时前
大模型入门:从 TOKEN 到 Agent,搞懂 AI 的底层逻辑(上)
人工智能·深度学习
scott1985122 小时前
扩散模型之(十六)像素空间生成模型
人工智能·深度学习·计算机视觉·生成式
no_work2 小时前
yolo摄像头下的目标检测识别集合
人工智能·深度学习·yolo·目标检测·计算机视觉
咚咚王者2 小时前
人工智能之语言领域 自然语言处理 第十九章 深度学习框架
人工智能·深度学习·自然语言处理
OpenBayes贝式计算2 小时前
教程上新丨基于 GPU 部署 OpenClaw,轻松接入飞书/Discord 等社交软件
人工智能·深度学习·机器学习