基于PyTorch的深度学习——迁移学习2

现在将迁移学习的特征提取应用于CIFAR-10

python 复制代码
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torchvision.datasets import ImageFolder
from datetime import datetime

加载数据

python 复制代码
import torch
import torchvision
from torchvision import transforms

# =============== 1. 定义数据变换 ===============
# 训练集:增强 + 归一化(适配 ImageNet 预训练模型)
train_transform = transforms.Compose([
    transforms.Resize(256),               # 将 32x32 放大到 256x256
    transforms.RandomCrop(224),           # 随机裁剪出 224x224 区域
    transforms.RandomHorizontalFlip(),    # 随机水平翻转(数据增强)
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],       # ImageNet 的均值
        std=[0.229, 0.224, 0.225]         # ImageNet 的标准差
    )
])

# 测试集:不增强,只做确定性变换
test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),           # 中心裁剪(非随机)
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# =============== 2. 加载数据集 ===============
trainset = torchvision.datasets.CIFAR10(
    root="/data",
    train=True,
    download=False,      # 若未下载过,可设为 True(首次运行)
    transform=train_transform
)

testset = torchvision.datasets.CIFAR10(
    root="/data",
    train=False,
    download=False,
    transform=test_transform
)

# =============== 3. 创建 DataLoader ===============
trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=64,
    shuffle=True,
    num_workers=2
)

testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=64,
    shuffle=False,       # 测试时通常不打乱
    num_workers=2
)

# =============== 4. 类别标签(可选)==============
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

接下来,下载预训练模型,冻结模型参数使得反向传播时不更新,修改最后一层输出类别(512x1000改成512x10)

python 复制代码
import torch
import torch.nn as nn
import torchvision.models as models

# =============== 1. 加载预训练的 ResNet18 模型 ===============
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

# =============== 2. 冻结所有参数 ===============
for param in model.parameters():
    param.requires_grad = False

# =============== 3. 替换最后一层 ===============
# 原始的最后一层是 nn.Linear(512, 1000),现在我们将其改为 nn.Linear(512, 10)
num_ftrs = model.fc.in_features  # 获取原始最后一层输入特征的数量
model.fc = nn.Linear(num_ftrs, 10)  # 替换成新的全连接层,输出为 10 类别

# =============== 4. 确认只有新添加的层可训练 ===============
# 可选:打印模型中需要梯度计算的参数
for name, param in model.named_parameters():
    print(name, param.requires_grad)

# =============== 5. 创建损失函数和优化器 ===============
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

# 只优化 model.fc 这一层的参数
# 注意:这里仅展示了如何定义损失函数和优化器,
# 实际训练过程还需要结合 DataLoader 进行迭代训练。
相关推荐
聆风吟º5 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
User_芊芊君子6 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
智驱力人工智能6 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
人工不智能5776 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
h64648564h7 小时前
CANN 性能剖析与调优全指南:从 Profiling 到 Kernel 级优化
人工智能·深度学习
心疼你的一切7 小时前
解密CANN仓库:AIGC的算力底座、关键应用与API实战解析
数据仓库·深度学习·aigc·cann
学电子她就能回来吗9 小时前
深度学习速成:损失函数与反向传播
人工智能·深度学习·学习·计算机视觉·github
Coder_Boy_9 小时前
TensorFlow小白科普
人工智能·深度学习·tensorflow·neo4j
大模型玩家七七9 小时前
梯度累积真的省显存吗?它换走的是什么成本
java·javascript·数据库·人工智能·深度学习
kkzhang10 小时前
Concept Bottleneck Models-概念瓶颈模型用于可解释决策:进展、分类体系 与未来方向综述
深度学习