【Pytorch】模型权重保存与上传

1.模型权重保存 torch.save

python 复制代码
model_name = args.model
if model_name == "ResNet18" or model_name == "ResNet34":
    from models.ResNet1 import BasicBlock
    from models.ResNet1 import ResNet as PATCHMODEL
if model_name == "ResNet18":
	net = PATCHMODEL(BasicBlock, [2, 2, 2, 2], num_classes=num_classes).cuda()
	
torch.save(net.state_dict(), weights_dir + '/' + model_name + '_train_loss_min_numCls{}.pth'.format(num_classes))

2.模型权重上传 load_state_dict

python 复制代码
model_name = args.model
if model_name == "ResNet18" or model_name == "ResNet34":
    from models.ResNet1 import BasicBlock
    from models.ResNet1 import ResNet as PATCHMODEL
if model_name == "ResNet18":
    model = PATCHMODEL(BasicBlock, [2, 2, 2, 2], num_classes=num_classes).cuda()
    
model.load_state_dict(torch.load(model_path), strict=False)
相关推荐
lqqjuly2 分钟前
推荐系统技术解析(Recommendation Systems)
深度学习·推荐算法
老鱼说AI22 分钟前
统计学习方法第八章:Boosting
人工智能·深度学习·神经网络·机器学习·学习方法·集成学习·boosting
钓了猫的鱼儿24 分钟前
基于深度学习+AI的无人机森林火灾目标检测与预警系统(Python源码+数据集+UI可视化界面+YOLOv11训练结果)
人工智能·深度学习·无人机
无负今日_tq1 小时前
【无标题】
人工智能·深度学习·条纹
爱吃肉的鹏1 小时前
基于深度学习的电缆异常检测
人工智能·深度学习
钓了猫的鱼儿1 小时前
基于深度学习+AI的茶叶病害目标检测与预警系统(Python源码+数据集+UI可视化界面+YOLOv11训练结果)
人工智能·深度学习·目标检测
云和数据.ChenGuang1 小时前
深度学习在鲲鹏HPC下的学习
人工智能·深度学习·学习·机器学习·数据挖掘
YOLO数据集集合1 小时前
无人机航拍+深度学习落地智慧农业:作物出苗率目标检测开源数据集工程详解|YOLO作物计数、田间苗期AI监测、农情数字化训练资源
人工智能·深度学习·yolo·目标检测·计算机视觉·无人机
烬羽2 小时前
从零搭建AIGC应用:英伟达NIM + Node.js实战
深度学习
lqqjuly2 小时前
神经网络架构设计解析(Neural Network Architecture Design)
人工智能·深度学习·神经网络