微调 = 在预训练模型的基础上,继续训练(更新)部分或全部原有参数 + 新加的层,以适应新任务。此外预先训练的网络参数也会被更新,但会使用较小的学习率以防止预先训练好的参数发生较大的改变。
python
# 使用预训练模型
net = models.resnet18(pretrained=True)
print(net)
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, ...)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, ...)
(layer1): Sequential(...) # 2 个 BasicBlock
(layer2): Sequential(...) # 2 个 BasicBlock
(layer3): Sequential(...) # 2 个 BasicBlock
(layer4): Sequential(...) # 2 个 BasicBlock
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
net.fc = torch.nn.Linear(512, 10) # 改为 10 类
常用的方法是固定底层的参数,调整一些顶层或具体层的参数。这样做的好处是可以减少训练参数的数量,同时也有助于克服过拟合现象的发生。
python
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
# ==============================
# 1. 加载预训练模型(推荐新写法)
# ==============================
net = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
# ==============================
# 2. 修改最后的分类层(CIFAR-10 有 10 类)
# ==============================
num_classes = 10
net.fc = nn.Linear(net.fc.in_features, num_classes) # 原来是 512 → 1000,现在改为 512 → 10
# ==============================
# 3. 【可选但推荐】冻结底层,只微调高层(减少过拟合,加快训练)
# ==============================
# 冻结前几层(通用特征提取器)
for param in net.conv1.parameters():
param.requires_grad = False
for param in net.bn1.parameters():
param.requires_grad = False
for param in net.layer1.parameters():
param.requires_grad = False
for param in net.layer2.parameters():
param.requires_grad = False
# layer3、layer4 和 fc 默认保持 requires_grad=True(可训练)
# ==============================
# 4. 定义数据变换(适配 ResNet 输入 224x224)
# ==============================
transform_train = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# ==============================
# 5. 加载 CIFAR-10 数据集
# ==============================
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)
# ==============================
# 6. 定义损失函数和优化器(分层学习率)
# ==============================
criterion = nn.CrossEntropyLoss()
# 对不同层使用不同的学习率:新层用大 lr,预训练层用小 lr
optimizer = torch.optim.SGD([
{'params': net.fc.parameters(), 'lr': 1e-3}, # 新分类头,学习率大些
{'params': net.layer4.parameters(), 'lr': 1e-4}, # 高层微调
{'params': net.layer3.parameters(), 'lr': 1e-5}, # 中层微调(更小心)
], momentum=0.9, weight_decay=1e-4)
# ==============================
# 7. 训练循环(简化版)
# ==============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
num_epochs = 10
for epoch in range(num_epochs):
net.train()
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in trainloader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
train_acc = 100. * correct / total
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(trainloader):.4f}, Train Acc: {train_acc:.2f}%')
# 验证阶段
net.eval()
correct_test = 0
total_test = 0
with torch.no_grad():
for inputs, labels in testloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
_, predicted = outputs.max(1)
total_test += labels.size(0)
correct_test += predicted.eq(labels).sum().item()
test_acc = 100. * correct_test / total_test
print(f'>>> Test Accuracy: {test_acc:.2f}%')
print('Finished Fine-tuning!')