基于PyTorch的深度学习——迁移学习4

微调 = 在预训练模型的基础上,继续训练(更新)部分或全部原有参数 + 新加的层,以适应新任务。此外预先训练的网络参数也会被更新,但会使用较小的学习率以防止预先训练好的参数发生较大的改变。

python 复制代码
# 使用预训练模型
net = models.resnet18(pretrained=True)
print(net)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, ...)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, ...)
  (layer1): Sequential(...)   # 2 个 BasicBlock
  (layer2): Sequential(...)   # 2 个 BasicBlock
  (layer3): Sequential(...)   # 2 个 BasicBlock
  (layer4): Sequential(...)   # 2 个 BasicBlock
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

net.fc = torch.nn.Linear(512, 10)  # 改为 10 类

常用的方法是固定底层的参数,调整一些顶层或具体层的参数。这样做的好处是可以减少训练参数的数量,同时也有助于克服过拟合现象的发生。

python 复制代码
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

# ==============================
# 1. 加载预训练模型(推荐新写法)
# ==============================
net = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

# ==============================
# 2. 修改最后的分类层(CIFAR-10 有 10 类)
# ==============================
num_classes = 10
net.fc = nn.Linear(net.fc.in_features, num_classes)  # 原来是 512 → 1000,现在改为 512 → 10

# ==============================
# 3. 【可选但推荐】冻结底层,只微调高层(减少过拟合,加快训练)
# ==============================
# 冻结前几层(通用特征提取器)
for param in net.conv1.parameters():
    param.requires_grad = False
for param in net.bn1.parameters():
    param.requires_grad = False
for param in net.layer1.parameters():
    param.requires_grad = False
for param in net.layer2.parameters():
    param.requires_grad = False

# layer3、layer4 和 fc 默认保持 requires_grad=True(可训练)

# ==============================
# 4. 定义数据变换(适配 ResNet 输入 224x224)
# ==============================
transform_train = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# ==============================
# 5. 加载 CIFAR-10 数据集
# ==============================
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)

# ==============================
# 6. 定义损失函数和优化器(分层学习率)
# ==============================
criterion = nn.CrossEntropyLoss()

# 对不同层使用不同的学习率:新层用大 lr,预训练层用小 lr
optimizer = torch.optim.SGD([
    {'params': net.fc.parameters(),       'lr': 1e-3},   # 新分类头,学习率大些
    {'params': net.layer4.parameters(),   'lr': 1e-4},   # 高层微调
    {'params': net.layer3.parameters(),   'lr': 1e-5},   # 中层微调(更小心)
], momentum=0.9, weight_decay=1e-4)

# ==============================
# 7. 训练循环(简化版)
# ==============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)

num_epochs = 10
for epoch in range(num_epochs):
    net.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_acc = 100. * correct / total
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(trainloader):.4f}, Train Acc: {train_acc:.2f}%')

    # 验证阶段
    net.eval()
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            _, predicted = outputs.max(1)
            total_test += labels.size(0)
            correct_test += predicted.eq(labels).sum().item()
    
    test_acc = 100. * correct_test / total_test
    print(f'>>> Test Accuracy: {test_acc:.2f}%')

print('Finished Fine-tuning!')
相关推荐
AI即插即用3 小时前
即插即用系列 | 2025 RestorMixer:融合 CNN、Mamba 与 Transformer 的高效图像复原的集大成者!
人工智能·深度学习·神经网络·目标检测·计算机视觉·cnn·transformer
All The Way North-3 小时前
PyTorch StepLR:等间隔学习率衰减的原理与实战
pytorch·深度学习·steplr学习率优化算法·学习率优化算法
Wis4e4 小时前
基于PyTorch的深度学习——迁移学习1
pytorch·深度学习·机器学习
北山小恐龙4 小时前
针对性模型压缩:YOLOv8n安全帽检测模型剪枝方案
人工智能·深度学习·算法·计算机视觉·剪枝
Wis4e4 小时前
基于PyTorch的深度学习——迁移学习2
pytorch·深度学习·迁移学习
从负无穷开始的三次元代码生活4 小时前
深度学习知识点概念速通——人工智能专业考试基础知识点
人工智能·深度学习
BB_CC_DD14 小时前
超简单搭建AI去水印和图像修复算法lama-cleaner二
人工智能·深度学习
高洁0115 小时前
DNN案例一步步构建深层神经网络(二)
人工智能·python·深度学习·算法·机器学习
Coding茶水间16 小时前
基于深度学习的螺栓螺母检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·机器学习·计算机视觉