基于PyTorch的深度学习——迁移学习1

声明一下内容来源于基于PyTorch的深度学习

迁移学习是一种机器学习方法,简单来说,就是把任务A开发的模型作为初始点,重新使用在任务B中,比如,A任务可以是识别图像中车辆,而B任务可以是识别卡车、识别轿车、识别公交车等

在神经网络迁移学习中,主要有两个应用场景:特征提取和微调。

• 特征提取(Feature Extraction):冻结除最终完全连接层之外的所有网络的权重。最后一个全连接层被替换为具有随机权重的新层,并且仅训练该层。

• 微调(Fine Tuning):使用预训练网络初始化网络,而不是随机初始化。用新数据训练部分或整个网络。

先来讲特征提取部分

python 复制代码
import torchvision.models as models

resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
  • 模型结构被创建 ,但所有参数(权重)是随机初始化的。
  • 相当于从零开始训练一个新模型(scratch training)。
  • 适用于:
    • 你不打算用迁移学习;
    • 或你有大量自有数据,想完全重新训练;
    • 或只是做模型结构测试。

如果要获取预训练模型应该这么做

python 复制代码
resnet18 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
alexnet = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
  • 不仅创建了模型结构,还自动下载并加载了在 ImageNet 数据集上训练好的权重
  • 这些权重是官方提供的、经过充分训练的高质量参数。
  • 模型开箱即用 ,可直接用于:
    • 图像分类推理(对自然图像效果很好);
    • 迁移学习(如特征提取或微调)。

✅ 这是迁移学习的标准起点

注意,有的模型训练和测试阶段用到了不同的模块,比如说batch normalization, dropout层等。使用model.train()或model.eval()可以切换到相应的模式。

所有的预训练模型都要求输入图片以相同的方式进行标准化,即:小批(Mini-Batch)3通道RGB格式(3×H×W),其中H和W应小于224。图片加载时像素值的范围应在0,1内,然后通过指定mean=0.485,0.456,0.406和std=0.229,0.224,0.225进行标准化,例如:

python 复制代码
normalize = transforms.Normalize(mean=[0.485,0.456,0.406]
                                 std=[0.229,0.224,0.225])

如何冻结某些层?

如果需要冻结最后一层之外的所有层,设置requires_grad==False, 反向传播中不计算梯度了。

python 复制代码
model = torchvision.models.resnet18(pretrained=True)
# 冻结所有参数
for param in resnet18.parameters():
    param.requires_grad = False

# 替换最后的分类层(适配自己的任务,比如 10 类)
resnet18.fc = torch.nn.Linear(resnet18.fc.in_features, 10)
相关推荐
大江东去浪淘尽千古风流人物1 小时前
【HaMeR】全Transformer架构的单目3D手部网格重建:ViT-H骨干+跨注意力MANO解码器源码深度解析
深度学习·3d·transformer·vit·手部重建·mano
MRDONG11 小时前
从机器学习到大语言模型:一文讲清 AI、Transformer、Embedding 和向量数据库
人工智能·机器学习·语言模型
钓了猫的鱼儿1 小时前
基于深度学习+AI的红外电力设备故障目标检测与预警系统(Python源码+数据集+UI可视化界面+YOLOv11训练结果)
人工智能·深度学习·目标检测
LaughingZhu1 小时前
Product Hunt 每日热榜 | 2026-05-30
人工智能·经验分享·深度学习·神经网络·产品运营
蒟蒻的贤2 小时前
深度学习底层核心原理:损失函数、梯度与参数更新
人工智能·深度学习
谷哥的小弟2 小时前
大模型核心基础知识(14)—神经网络的结构
人工智能·深度学习·神经网络·大模型·大语言模型
城事漫游Molly2 小时前
AI与质性研究的融合(三):AI赋能质性数据分析——从编码到理论构建的新范式
大数据·人工智能·机器学习·prompt·ai for science·智能体·定性研究
大模型最新论文速读2 小时前
SkillOpt:把 skill 文档当成模型权重来训练
论文阅读·人工智能·深度学习·机器学习·自然语言处理
Omics Pro3 小时前
基因泰克:检测级虚拟细胞基准!大语言模型+智能体
大数据·数据库·人工智能·机器学习·语言模型·自然语言处理·r语言
z小猫不吃鱼3 小时前
15 InstructGPT 论文精读:SFT + RLHF 如何让模型听懂指令?
人工智能·深度学习·算法·机器学习·语言模型·自然语言处理·gpt-3