113_站在巨人的肩膀上:PyTorch 经典模型(VGG16)的获取与自定义修改

在处理图像识别任务时,我们不需要总是从零开始搭建网络。PyTorch 的 torchvision.models 库提供了大量已经在 ImageNet 大规模数据集上训练好的经典模型。本文将教你如何获取这些模型,并根据自己的需求(如将 1000 分类改为 10 分类)进行灵活修改。

1. 现有网络模型的获取

官方模型库提供了两种获取方式,区别在于是否加载预训练好的参数:

  • pretrained=False:只下载网络结构,参数是随机初始化的(默认方式)。
  • pretrained=True:不仅下载结构,还下载已经在 ImageNet 上训练好的参数。这通常用于迁移学习,能极大地加快收敛速度。

代码实现:


2. 为什么要修改网络模型?

像 VGG16 这样的模型,其输出层通常是为 ImageNet 设计的(输出 1000 个类别)。但如果我们处理的是 CIFAR-10(只需输出 10 个类别),我们就需要对模型的结构进行微调。


3. 实战:修改模型的两种常用方法

文件展示了如何通过"添加层"或"替换层"来适配 10 分类任务:

方法一:在现有结构后添加层 (add_module)

我们可以保持原有的 classifier 结构不变,在其最后追加一个新的线性层。

复制代码
import torchvision
from torch import nn

dataset = torchvision.datasets.CIFAR10("./dataset",train=True,transform=torchvision.transforms.ToTensor(),download=True)       
vgg16_true = torchvision.models.vgg16(pretrained=True) # 下载卷积层对应的参数是多少、池化层对应的参数时多少,这些参数时ImageNet训练好了的
vgg16_true.add_module('add_linear',nn.Linear(1000,10)) # 在VGG16后面添加一个线性层,使得输出为适应CIFAR10的输出,CIFAR10需要输出10个种类

print(vgg16_true)

方法二:直接修改/替换现有层

如果你觉得多加一层太麻烦,可以直接修改 classifier 中的最后一个子模块。

复制代码
import torchvision
from torch import nn

vgg16_false = torchvision.models.vgg16(pretrained=False) # 没有预训练的参数     
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096,10)
print(vgg16_false)

4. 迁移学习的意义

通过修改现有模型,我们实际上是在进行迁移学习(Transfer Learning)

  1. 特征提取器:利用 VGG 前半部分强大的特征提取能力(这部分在 ImageNet 上学到了识别线条、形状、纹理的通用能力)。
  2. 自定义分类器:只针对我们特定的数据集训练最后的几层全连接层。

这种方法在数据集较小时效果尤为显著,能避免过拟合且大幅节省算力。


5. 总结

分析该文件后,我们可以掌握以下技巧:

  1. 加载模型 :利用 torchvision.models 快速调用经典结构。
  2. 查看结构 :直接 print(model) 找到需要修改的层名称或索引。
  3. 动态修改 :使用 add_module 或直接索引赋值来改变网络层级,使其适配你的任务。

💡 学习小结

学会修改官方模型是迈向中高级开发者的重要一步。你不再受限于简单的 3 层卷积,而是可以自由调用 ResNet、VGG、MobileNet 等工业级模型来解决复杂的视觉问题。

相关推荐
SilentSamsara几秒前
命令行工具开发:Click/Typer + 打包为独立二进制
linux·服务器·开发语言·前端·python·青少年编程·fastapi
Ulyanov2 分钟前
深入QML滑块与进度控制:构建动态数据可视化界面:QML+PySide6现代开发入门(六)
开发语言·python·算法·ui·信息可视化·雷达电子对抗仿真
扫地僧9853 分钟前
一个基于 PyTorch 手语翻译模型Xuanmen_Net
人工智能·pytorch·python
zyl837213 分钟前
Python 函数、模块、异常处理 超详细入门教程
开发语言·windows·python
搬砖的小码农_Sky3 分钟前
Windows环境下OpenClaw本地部署完整指南
人工智能·windows·ai·人机交互·agi
风舞雪凌月7 分钟前
【总结】国产AI大模型公司汇总
人工智能
Hali_Botebie8 分钟前
【光流】自动驾驶光流任务 DeFlow: Decoder of Scene Flow Network in Autonomous Driving
人工智能·机器学习·自动驾驶
IT_陈寒11 分钟前
被Vite的HMR坑惨了,原来这样配置才能用对!
前端·人工智能·后端
“码”力全开14 分钟前
解耦安防碎片化:基于 Docker 与边缘计算的 AI 视频中台架构设计(支持 GB28181/RTSP 与源码交付)
人工智能·docker·边缘计算
sali-tec15 分钟前
C# 基于OpenCv的视觉工作流-章80-长短脚
图像处理·人工智能·opencv·算法·计算机视觉