在处理图像识别任务时,我们不需要总是从零开始搭建网络。PyTorch 的 torchvision.models 库提供了大量已经在 ImageNet 大规模数据集上训练好的经典模型。本文将教你如何获取这些模型,并根据自己的需求(如将 1000 分类改为 10 分类)进行灵活修改。
1. 现有网络模型的获取
官方模型库提供了两种获取方式,区别在于是否加载预训练好的参数:
- pretrained=False:只下载网络结构,参数是随机初始化的(默认方式)。
- pretrained=True:不仅下载结构,还下载已经在 ImageNet 上训练好的参数。这通常用于迁移学习,能极大地加快收敛速度。
代码实现:
2. 为什么要修改网络模型?
像 VGG16 这样的模型,其输出层通常是为 ImageNet 设计的(输出 1000 个类别)。但如果我们处理的是 CIFAR-10(只需输出 10 个类别),我们就需要对模型的结构进行微调。
3. 实战:修改模型的两种常用方法
文件展示了如何通过"添加层"或"替换层"来适配 10 分类任务:
方法一:在现有结构后添加层 (add_module)
我们可以保持原有的 classifier 结构不变,在其最后追加一个新的线性层。
import torchvision
from torch import nn
dataset = torchvision.datasets.CIFAR10("./dataset",train=True,transform=torchvision.transforms.ToTensor(),download=True)
vgg16_true = torchvision.models.vgg16(pretrained=True) # 下载卷积层对应的参数是多少、池化层对应的参数时多少,这些参数时ImageNet训练好了的
vgg16_true.add_module('add_linear',nn.Linear(1000,10)) # 在VGG16后面添加一个线性层,使得输出为适应CIFAR10的输出,CIFAR10需要输出10个种类
print(vgg16_true)
方法二:直接修改/替换现有层
如果你觉得多加一层太麻烦,可以直接修改 classifier 中的最后一个子模块。
import torchvision
from torch import nn
vgg16_false = torchvision.models.vgg16(pretrained=False) # 没有预训练的参数
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096,10)
print(vgg16_false)
4. 迁移学习的意义
通过修改现有模型,我们实际上是在进行迁移学习(Transfer Learning):
- 特征提取器:利用 VGG 前半部分强大的特征提取能力(这部分在 ImageNet 上学到了识别线条、形状、纹理的通用能力)。
- 自定义分类器:只针对我们特定的数据集训练最后的几层全连接层。
这种方法在数据集较小时效果尤为显著,能避免过拟合且大幅节省算力。
5. 总结
分析该文件后,我们可以掌握以下技巧:
- 加载模型 :利用
torchvision.models快速调用经典结构。 - 查看结构 :直接
print(model)找到需要修改的层名称或索引。 - 动态修改 :使用
add_module或直接索引赋值来改变网络层级,使其适配你的任务。
💡 学习小结
学会修改官方模型是迈向中高级开发者的重要一步。你不再受限于简单的 3 层卷积,而是可以自由调用 ResNet、VGG、MobileNet 等工业级模型来解决复杂的视觉问题。