迁移学习|代码实现

还记得我们之前实现的猫狗分类器 吗?在哪里,我们设计了一个网络,这个网络接受一张图片,最后输出这张图片属于猫还是狗。实现分类器的过程比较复杂,准备的数据也比较少。所以我们是否可以使用一种方法,在数据很少的情况下仍然可以训练出较好的模型。

借助已经训练好的模型是个不错的想法。因此我们将学习如何使用预训练好的模型来构建只需要很少数据的先进的猫狗图像分类器。

首先,加载一个预训练的模型,例如ResNet18。

借助torchvision库,我们很容易获得一组已经训练好的模型。这些模型大多数接受一个称为pretrained的参数,当这个参数为True时,它会下载为ImageNet分类问题调整好的权重。就像这样:

复制代码
from torchvision import modelsnetwork1=models.resnet18(pretrained=True)

当代码第一次运行时,需要一点时间...

接着,我们需要冻结所有层,所有权重不会随训练而更新。​​​​​​​

复制代码
for param in network1.parameters():    param.requires_grad=False

当然,这个模型并不是针对2分类问题,所以,我们需要将其最后一层的输出特征从1000改为2

首先我们要知道最后一层的名字:​​​​​​​

复制代码
network1ResNet(  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)...  ...  ...  (fc): Linear(in_features=512, out_features=1000, bias=True)

最后一层是个全连接层,名为fc。

所以,我们就可将最后一层替换为输出特征为2的全连接层

复制代码
network1.fc=nn.Linear(512,2)

注:此时,因为该层为新的层,所以其requires_grad=True,这样整个网络仅有这一层可以更新权重

打印网络​​​​​​​

复制代码
network1ResNet(  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  (relu): ReLU(inplace=True)  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)...  ...  ...  (fc): Linear(in_features=512, out_features=2, bias=True)

此时,network1就是一个符合猫狗分类问题的模型

最后,既然我们只对最后一层训练,那么我们只需要将最后一层的参数传入优化器

复制代码
optimizer=optim.SGD(network1.fc.parameters(),lr=...,momentnum=...)

总结一下代码:​​​​​​​

复制代码
from torchvision import modelsimport torch.nn as nnimport torch.optim as optim#网络搭建network1=models.resnet18(pretrained=True)
for param in network1.parameters():    param.requires_grad=False
network1.fc=nn.Linear(512,2)#损失函数criterion=nn.CrossEntropyLoss()#优化器optimizer=optim.SGD(network1.fc.parameters(),lr=...,momentnum=...)

其实,我们就是利用已经训练好的模型的主要目的就是它已经能够提取出非常好的特征 ,最后一层接受前面层提取的特征,然后误差反向传播,仅更新这一层的权重,不断迭代,最后达到一个非常好的效果。

我们这里只对最后一层进行了调整 ,只训练这一层,主要原因就是数据太少 ;如果数据较多 ,可以把预训练的前面一些层权重固定住,后面层不固定,修改最后一层以满足任务,然后训练;如果数据很多,算力充沛,那么可以对所有层进行精调,只把预训练的模型的参数作为初始化参数。

相关推荐
2501_9411491124 分钟前
人工智能驱动下的边缘物联网革新,打造未来全球智能互联新格局
人工智能·物联网
没头脑的男大27 分钟前
Unet+Transformer脑肿瘤分割检测
人工智能·深度学习·transformer
AI即插即用33 分钟前
即插即用涨点系列(十四)2025 SOTA | Efficient ViM:基于“隐状态混合SSD”与“多阶段融合”的轻量级视觉 Mamba 新标杆
人工智能·pytorch·深度学习·计算机视觉·视觉检测·transformer
1***81531 小时前
免费的自然语言处理教程,NLP入门
人工智能·自然语言处理
算家计算1 小时前
Gemini 3.0重磅发布!技术全面突破:百万上下文、全模态推理与开发者生态重构
人工智能·资讯·gemini
说私域1 小时前
“开源链动2+1模式AI智能名片S2B2C商城小程序”赋能同城自媒体商家营销创新研究
人工智能·小程序·开源
m0_635129261 小时前
内外具身智能VLA模型深度解析
人工智能·机器学习
zhougoo1 小时前
AI驱动代码开之Vs Code Cline插件集成
人工智能
minhuan2 小时前
构建AI智能体:九十五、YOLO视觉大模型入门指南:从零开始掌握目标检测
人工智能·yolo·目标检测·计算机视觉·视觉大模型
双翌视觉2 小时前
机器视觉的车载显示器玻璃覆膜应用
人工智能·机器学习·计算机外设