Pytorch新版本如何使用现有网络模型,怎么修改现有模型

来自 千问

在 PyTorch 的新版本(特别是 torchvision 0.13+)中,使用和修改现有网络模型的标准方法发生了一些重要变化。主要的变化在于加载预训练权重 的方式,以及如何灵活地修改模型结构以适应你的任务(如迁移学习)。

以下是基于最新版本的详细指南:

1. 如何使用现有网络模型

在新版本中,pretrained 参数已被弃用,推荐使用 weights 参数来明确指定要加载的权重类型。

语法:

python 复制代码
import torchvision.models as models

# 1. 加载预训练模型 (推荐写法)
model = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

# 2. 或者,如果你不想加载预训练权重 (从头开始训练)
model = models.vgg16(weights=None)

# 3. (旧写法,已弃用) 
# model = models.vgg16(pretrained=True) # 不推荐,未来版本会报错
  • weights=models.VGG16_Weights.DEFAULT: 加载在 ImageNet 上预训练的权重。这是目前最标准的做法。
  • weights=None: 不加载预训练权重,模型参数会被随机初始化。

2. 如何修改现有模型

加载模型后,通常需要根据你的数据集(如 CIFAR-10 只有 10 个类别,而 ImageNet 有 1000 个)来修改模型结构。主要有两种常用方法:

方法一:替换层(最常用)

直接替换模型的最后一层(通常是全连接层),以改变输出维度。

示例:修改 VGG16 的输出为 10 类

python 复制代码
import torchvision.models as models
import torch.nn as nn

# 加载模型
model = models.vgg16(weights=None)

# 查看原模型结构,你会发现最后一层是 classifier[6]
# Linear(in_features=4096, out_features=1000, bias=True)

# 修改:将最后一层替换为输出 10 维的层
model.classifier[6] = nn.Linear(4096, 10)

print(model.classifier[6]) # 检查是否修改成功

示例:修改 ResNet 的输出为 10 类

python 复制代码
model = models.resnet18(weights=None)

# ResNet 的最后一层通常命名为 'fc'
# 获取原全连接层的输入特征数
num_ftrs = model.fc.in_features 

# 替换为新的全连接层
model.fc = nn.Linear(num_ftrs, 10)
方法二:添加新层

如果你不想直接替换原有层,而是想在网络末尾追加新的层(例如添加一个特定的分类头)。

python 复制代码
# 加载模型
model = models.vgg16(weights=None)

# 在 classifier 序列的末尾添加一个新层
# 参数:(名称, 层结构)
model.classifier.add_module('new_layer', nn.Linear(1000, 10))

# 或者更灵活的方式是直接重新定义 classifier
# 例如:在原有基础上加一层 Dropout 和新的 Linear
model.classifier = nn.Sequential(
    *list(model.classifier.children())[:-1], # 保留除了最后一层外的所有层
    nn.Dropout(0.5), # 可选:添加正则化
    nn.Linear(4096, 10) # 新的输出层
)

3. 完整的训练与保存流程示例

结合上述内容,这是一个从加载、修改到保存的标准流程:

python 复制代码
import torch
import torch.nn as nn
import torchvision.models as models

# 1. 加载预训练模型
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

# 2. 修改模型以适应新任务 (例如 10 分类)
num_classes = 10
model.fc = nn.Linear(model.fc.in_features, num_classes)

# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# --- 训练循环 (伪代码) ---
# for inputs, labels in dataloader:
#     optimizer.zero_grad()
#     outputs = model(inputs)
#     loss = criterion(outputs, labels)
#     loss.backward()
#     optimizer.step()

# 4. 保存模型 (推荐方式:仅保存参数)
# 训练完成后,保存模型参数
torch.save(model.state_dict(), 'my_model_weights.pth')

# 5. 加载模型 (推理时)
# 必须先重建完全相同的结构
loaded_model = models.resnet18(weights=None) # 注意:推理时通常不需要预训练权重,因为我们有自己的
loaded_model.fc = nn.Linear(loaded_model.fc.in_features, num_classes)
loaded_model.load_state_dict(torch.load('my_model_weights.pth'))
loaded_model.eval() # 切换到评估模式

关键变化总结

功能 旧写法 (已弃用) 新写法 (推荐)
加载预训练 models.vgg16(pretrained=True) models.vgg16(weights=models.VGG16_Weights.DEFAULT)
不加载权重 models.vgg16(pretrained=False) models.vgg16(weights=None)
访问权重 隐式 显式通过 Weights 枚举类访问

建议: 在新项目中,请务必使用 weights 参数,以确保代码的兼容性和可维护性。

相关推荐
行走__Wz11 小时前
【刘二大人】《PyTorch深度学习实践》——PyTorch实现线性回归代码(自用)
pytorch·深度学习·线性回归
一碗白开水一11 小时前
【工具相关】OpenClaw 配置使用飞书:打造智能飞书助手全流程指南(亲测有效,放心享用)
人工智能·深度学习·算法·飞书
小程故事多_8011 小时前
Vibe Coding的致命隐患,你必须知道的技术债务和扩展性危机
大数据·人工智能·aigc
童话名剑11 小时前
YOLO v3(学习笔记)
人工智能·深度学习·yolo·目标检测
康康的AI博客11 小时前
农业工业变革:如何通过DMXAPI中转提升自动化效率
运维·人工智能·自动化
实在智能RPA11 小时前
从API集成到意图驱动:深度解析实在Agent在复杂ERP/OA环境下的非标接口处理架构
人工智能·ai·架构
北京耐用通信11 小时前
协议融合的工业钥匙:耐达讯自动化网关如何打通CC-Link IE转DeviceNet的通信壁垒
人工智能·物联网·网络协议·自动化·信息与通信
EasyGBS11 小时前
GB35114+GB28181:EasyGBS视频融合平台如何构建视频监控 “联网+安全” 双重保障体系
网络·人工智能·国标gb28181·gb35114
只说证事12 小时前
中专计算机专业必考的证书清单有哪些?
人工智能
臭东西的学习笔记12 小时前
论文学习——通过蛋白质片段-环境比对实现自我监督口袋预训练
人工智能