迁移学习实战:基于 ResNet18 的食物分类

一、迁移学习简介

迁移学习是一种高效的机器学习方法,它利用在大规模数据集上预训练好的模型,在新的任务上进行微调。这样做的优势十分显著:

  • 加速训练:无需从零开始训练模型,节省大量时间。
  • 提升性能:预训练模型已经学习到了通用的特征表示,能为新任务提供良好的基础。
  • 数据高效:在新任务数据稀缺时,也能取得不错的效果。

二、迁移学习步骤

1. 选择预训练模型和适当的层

通常会选择在大规模图像数据集(如 ImageNet)上预训练的模型,像 VGG、ResNet 等。对于不同的任务,选择的层也有所不同:

  • 若任务是低级特征提取(如边缘检测),适合使用浅层模型的层。
  • 若任务是高级特征相关(如分类),则应选择更深层次的模型。

2. 冻结预训练模型的参数

保持预训练模型的权重不变,只训练新增加的层或者微调部分层。这样做是为了避免预训练模型在新数据集上过度拟合,同时也能减少计算量。

3. 在新数据集上训练新增加的层

在冻结预训练模型参数的情况下,训练新增加的层,使新模型能够适应新的任务,从而提升性能。

4. 微调预训练模型的层

在新层训练完成后,解冻一些已经训练过的层并进行微调,进一步提高模型在新数据集上的性能。

5. 评估和测试

训练完成后,使用测试集对模型进行评估。若模型性能不佳,可调整超参数或更改微调层。

三、基于 ResNet18 的食物分类实战

使用上节课所说的残差网络的18层结构来对其进行微调,该残差网络结构如下图所示:

此时我们可以发现输入图像的特征大小为3*224*224,输出特征图格式为512*1*1,然后将其进行全连接层处理后变成输入512张特征图,输出1000个预测结果,这个结果的种类太多,我们不需要使用这么多的预测类别,所以当下需要对其微调,调整最后输出时的全连接层输出结果个数及其全连接层中的权重参数。

1. 导入预训练模型

我们选择在 ImageNet 上预训练好的 ResNet18 模型,代码如下:

python 复制代码
import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np

# 导入预训练的ResNet18模型
resent_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

2. 冻结预训练模型参数

通过设置参数的requires_grad属性为False,冻结预训练模型的参数,使其在训练过程中不参与梯度更新:

python 复制代码
for param in resent_model.parameters():
    param.requires_grad = False  # 冻结所有预训练模型参数

3. 修改全连接层

原 ResNet18 模型是为 ImageNet 的 1000 类分类任务设计的,我们要将其适配为 20 类食物分类任务,所以需要修改全连接层,并收集需要训练的参数:

python 复制代码
in_features = resent_model.fc.in_features  # 获取原全连接层的输入特征数
resent_model.fc = nn.Linear(in_features, 20)  # 替换为输出为20类的全连接层

param_to_update = []  # 收集需要训练的参数(仅新的全连接层)
for param in resent_model.parameters():
    if param.requires_grad:
        param_to_update.append(param)

4. 自定义数据集类与数据增强

创建food_dataset类来加载食物图像数据,并通过数据增强来提升模型的泛化能力:

python 复制代码
class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        self.imgs = []
        self.labels = []
        self.transform = transform
        with open(self.file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        image = Image.open(self.imgs[idx])
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        label = torch.from_numpy(np.array(label, dtype=np.int64))
        return image, label

# 数据增强与预处理
data_transforms = {
    'train':
        transforms.Compose([
            transforms.Resize([300, 300]),
            transforms.RandomRotation(45),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomGrayscale(p=0.1),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    'test':
        transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
}

# 加载训练集和测试集
train_data = food_dataset(file_path=r'train.1txt', transform=data_transforms['train'])
test_data = food_dataset(file_path=r'test.1txt', transform=data_transforms['test'])

# 创建数据加载器
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

train.1txt,test.1txt如下:

5. 定义训练和测试函数

python 复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    batch_size_num = 1
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        pred = model.forward(x)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_value = loss.item()
        if batch_size_num % 40 == 0:
            print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")
        batch_size_num += 1

best_acc = 0
acc_s = []
loss_s = []

def test(dataloader, model, loss_fn):
    global best_acc
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            pred = model.forward(x)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}\n")
    acc_s.append(correct)
    loss_s.append(test_loss)
    if correct > best_acc:
        best_acc = correct

6. 模型设备部署与优化器设置

python 复制代码
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
model = resent_model.to(device)

loss_fn = nn.CrossEntropyLoss()  # 多分类损失函数
optimizer = torch.optim.Adam(param_to_update, lr=0.001)  # 仅优化新全连接层参数
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # 学习率调度器

7. 训练与测试

python 复制代码
epochs = 10
for t in range(epochs):
    print(f"Epoch {t + 1}\n--------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    scheduler.step()
    test(test_dataloader, model, loss_fn)
print('最优测试结果为:', best_acc)

训练结果如下:

相关推荐
THMAIL4 小时前
深度学习从入门到精通 - LSTM与GRU深度剖析:破解长序列记忆遗忘困境
人工智能·python·深度学习·算法·机器学习·逻辑回归·lstm
悠哉悠哉愿意4 小时前
【数学建模学习笔记】机器学习分类:随机森林分类
学习·机器学习·数学建模
玉木子4 小时前
机器学习(七)决策树-分类
决策树·机器学习·分类
悠哉悠哉愿意5 小时前
【数学建模学习笔记】机器学习分类:KNN分类
学习·机器学习·数学建模
ningmengjing_5 小时前
理解损失函数:机器学习的指南针与裁判
人工智能·深度学习·机器学习
荒野饮冰室5 小时前
分类、目标检测、实例分割的评估指标
目标检测·计算机视觉·分类·实例分割
nju_spy6 小时前
Kaggle - LLM Science Exam 大模型做科学选择题
人工智能·机器学习·大模型·rag·南京大学·gpu分布计算·wikipedia 维基百科
中國龍在廣州6 小时前
GPT-5冷酷操盘,游戏狼人杀一战封神!七大LLM狂飙演技,人类玩家看完沉默
人工智能·gpt·深度学习·机器学习·计算机视觉·机器人
THMAIL7 小时前
深度学习从入门到精通 - 神经网络核心原理:从生物神经元到数学模型蜕变
人工智能·python·深度学习·神经网络·算法·机器学习·逻辑回归