【NLP】一个使用PyTorch实现图像分类的迁移学习实例

一个使用PyTorch实现图像分类的迁移学习实例

  • [1. 导入模块](#1. 导入模块)
  • [2. 加载数据](#2. 加载数据)
  • [3. 模型处理](#3. 模型处理)
  • [4. 训练及验证模型](#4. 训练及验证模型)
  • [5. 微调](#5. 微调)
  • [6. 其他代码](#6. 其他代码)

在特征提取中,可以在预先训练好的网络结构后修改或添加一个简单的分类器,然后将源任务上预先训练好的网络作为另一个目标任务的特征提取器,只对最后增加的分类器参数重新学习,而预先训练好的网络参数不被修改或冻结。

在完成新任务的特征提取时使用的是源任务中学习到的参数,而不用重新学习所有参数。下面的示例用一个实例具体说明如何通过特征提取的方法进行图像分类。

1. 导入模块

python 复制代码
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torchvision import models

2. 加载数据

这里需要事先将CIFAR10数据下载到本地,因为比较耗时,因此,将download=False。除此之外,还增加了一些预处理功能,比如数据标准化、对图片进行裁剪等。

python 复制代码
def load_data(data, batch_size=64, num_workers=2, mean=None, std=None):
    if std is None:
        std = [0.229, 0.224, 0.225]
    if mean is None:
        mean = [0.485, 0.456, 0.406]
    trans_train = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    trans_valid = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
                                      transforms.Normalize(mean=mean, std=std)])

    train_set = torchvision.datasets.CIFAR10(root=data, train=True, download=True, transform=trans_train)
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    test_set = torchvision.datasets.CIFAR10(root=data, train=False, download=True, transform=trans_valid)
    testloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return trainloader, testloader

3. 模型处理

这个部分包含三个操作:

  • 下载预训练模型:使用的预训练模型为resnet18,且已经在ImageNet大数据集上训练好了
  • 冻结模型参数:使其在反向传播时,不会更新
  • 修改最后一层的输出类别数:该数据集中有1000个类别,即原始输出为512×1000,现将其修改为512×10,因为这里使用的新数据集有10个类别
python 复制代码
def freeze_net(num_class=10):
    # 下载预训练模型
    net = models.resnet18(pretrained=True)
    # 冻结模型参数
    for params in net.parameters():
        params.requires_grad = False
    # 修改最后一层的输出类别数
    net.fc = nn.Linear(512, num_class)
    # 查看冻结前后的参数情况
    total_params = sum(p.numel() for p in net.parameters())
    print(f'原总参数个数:{total_params}')
    total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print(f'需训练参数个数:{total_trainable_params}')
    return net

原总参数个数:11181642

需训练参数个数:5130

从输出上可知,如果不冻结,需要更新的参数太多了,冻结之后只需要更新全连接层的参数即可。

4. 训练及验证模型

这里选用交叉熵作为损失函数,使用SGD作为优化器,学习率为1e-3,权重衰减设为1e-3,代码如下:

python 复制代码
# 训练及验证模型
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
    prev_time = datetime.now()
    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        net = net.train()
        for im, label in train_data:
            im = im.to(device)  # (bs, 3, h, w)
            label = label.to(device)  # (bs, h, w)
            # forward
            output = net(im)
            loss = criterion(output, label)
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_acc += get_acc(output, label)

        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        if valid_data is not None:
            valid_loss = 0
            valid_acc = 0
            net = net.eval()
            for im, label in valid_data:
                im = im.to(device)  # (bs, 3, h, w)
                label = label.to(device)  # (bs, h, w)
                output = net(im)
                loss = criterion(output, label)
                valid_loss += loss.item()
                valid_acc += get_acc(output, label)
            epoch_str = (
                    "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
                    % (epoch, train_loss / len(train_data),
                       train_acc / len(train_data), valid_loss / len(valid_data),
                       valid_acc / len(valid_data)))
        else:
            epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
                         (epoch, train_loss / len(train_data),
                          train_acc / len(train_data)))
        prev_time = cur_time
        print(epoch_str + time_str)

运行结果:

Epoch 0. Train Loss: 1.474121, Train Acc: 0.498322, Valid Loss: 0.901339, Valid Acc: 0.713177, Time 00:03:26

Epoch 1. Train Loss: 1.222752, Train Acc: 0.576946, Valid Loss: 0.818926, Valid Acc: 0.730494, Time 00:04:35

Epoch 2. Train Loss: 1.172832, Train Acc: 0.592651, Valid Loss: 0.777265, Valid Acc: 0.737759, Time 00:04:23

Epoch 3. Train Loss: 1.158157, Train Acc: 0.596228, Valid Loss: 0.761969, Valid Acc: 0.746517, Time 00:04:28

Epoch 4. Train Loss: 1.143113, Train Acc: 0.600643, Valid Loss: 0.757134, Valid Acc: 0.742138, Time 00:04:24

Epoch 5. Train Loss: 1.128991, Train Acc: 0.607797, Valid Loss: 0.745840, Valid Acc: 0.747014, Time 00:04:24

Epoch 6. Train Loss: 1.131602, Train Acc: 0.603561, Valid Loss: 0.740176, Valid Acc: 0.748109, Time 00:04:21

Epoch 7. Train Loss: 1.127840, Train Acc: 0.608336, Valid Loss: 0.738235, Valid Acc: 0.751990, Time 00:04:19

Epoch 8. Train Loss: 1.122831, Train Acc: 0.609275, Valid Loss: 0.730571, Valid Acc: 0.751692, Time 00:04:18

Epoch 9. Train Loss: 1.118955, Train Acc: 0.609715, Valid Loss: 0.731084, Valid Acc: 0.751692, Time 00:04:13

Epoch 10. Train Loss: 1.111291, Train Acc: 0.612052, Valid Loss: 0.728281, Valid Acc: 0.749602, Time 00:04:09

Epoch 11. Train Loss: 1.108454, Train Acc: 0.612712, Valid Loss: 0.719465, Valid Acc: 0.752787, Time 00:04:15

Epoch 12. Train Loss: 1.111189, Train Acc: 0.612012, Valid Loss: 0.726525, Valid Acc: 0.751294, Time 00:04:09

Epoch 13. Train Loss: 1.114475, Train Acc: 0.610594, Valid Loss: 0.717852, Valid Acc: 0.754080, Time 00:04:06

Epoch 14. Train Loss: 1.112658, Train Acc: 0.608596, Valid Loss: 0.723336, Valid Acc: 0.751393, Time 00:04:14

Epoch 15. Train Loss: 1.109367, Train Acc: 0.614950, Valid Loss: 0.721230, Valid Acc: 0.752588, Time 00:04:06

Epoch 16. Train Loss: 1.107644, Train Acc: 0.614230, Valid Loss: 0.711586, Valid Acc: 0.755275, Time 00:04:08

Epoch 17. Train Loss: 1.100239, Train Acc: 0.613411, Valid Loss: 0.722191, Valid Acc: 0.749303, Time 00:04:11

Epoch 18. Train Loss: 1.108576, Train Acc: 0.611013, Valid Loss: 0.721263, Valid Acc: 0.753483, Time 00:04:08

Epoch 19. Train Loss: 1.098069, Train Acc: 0.618027, Valid Loss: 0.705413, Valid Acc: 0.757962, Time 00:04:06

从结果上看,验证集的准确率达到75%左右。下面采用微调+数据增强的方法继续提升准确率。

5. 微调

微调允许修改预训练好的网络参数来学习目标任务,所以训练时间要比特征抽取方法长,但精度更高。微调的大致过程是再预训练的网络上添加新的随机初始化层,此外预训练的网络参数也会被更新,但会使用较小的学习率以防止预训练好的参数发生较大改变

常用的方法是固定底层的参数,调整一些顶层或具体层的参数。这样可以减少训练参数的数量,也可以避免过拟合的发生。尤其是针对目标任务的数据量不够大的时候,该方法会很有效。

实际上,微调优于特征提取,因为它能对迁移过来的预训练网络参数进行优化,使其更加适合新的任务。

(1)数据预处理

对训练数据添加了几种数据增强方法,比如图片裁剪、旋转、颜色改变等方法。测试数据与特征提取的方法一样。

python 复制代码
    if fine_tuning is False:
        trans_train = transforms.Compose([transforms.RandomResizedCrop(224),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=mean, std=std)])
    else:
        trans_train = transforms.Compose([transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
                                          transforms.RandomRotation(degrees=15),
                                          transforms.ColorJitter(),
                                          transforms.RandomResizedCrop(224),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=mean, std=std)])

(2)修改模型的分类器层

修改最后全连接层,把类别数由原来的1000改为10。

python 复制代码
def freeze_net(num_class=10, fine_tuning=False):
    # 下载预训练模型
    net = models.resnet18(pretrained=True)
    print(net)
    if fine_tuning is False:
        # 冻结模型参数
        for params in net.parameters():
            params.requires_grad = False
    # 修改最后一层的输出类别数
    net.fc = nn.Linear(512, num_class)
    # 查看冻结前后的参数情况
    total_params = sum(p.numel() for p in net.parameters())
    print(f'原总参数个数:{total_params}')
    total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print(f'需训练参数个数:{total_trainable_params}')
    # 打印出第一层的权重
    print(f'第一层的权重:{net.conv1.weight.type()}')
    return net

训练结果:

Epoch 0. Train Loss: 1.455535, Train Acc: 0.488460, Valid Loss: 0.832547, Valid Acc: 0.721400, Time 00:14:48

Epoch 1. Train Loss: 1.342625, Train Acc: 0.530280, Valid Loss: 0.815430, Valid Acc: 0.723500, Time 10:31:48

Epoch 2. Train Loss: 1.319122, Train Acc: 0.535680, Valid Loss: 0.866512, Valid Acc: 0.699000, Time 00:12:02

Epoch 3. Train Loss: 1.310949, Train Acc: 0.541700, Valid Loss: 0.789511, Valid Acc: 0.728000, Time 00:12:03

Epoch 4. Train Loss: 1.313486, Train Acc: 0.538500, Valid Loss: 0.762553, Valid Acc: 0.741300, Time 00:12:19

Epoch 5. Train Loss: 1.309776, Train Acc: 0.540680, Valid Loss: 0.777906, Valid Acc: 0.736100, Time 00:11:43

Epoch 6. Train Loss: 1.302117, Train Acc: 0.541780, Valid Loss: 0.779318, Valid Acc: 0.737200, Time 00:12:00

Epoch 7. Train Loss: 1.304539, Train Acc: 0.544320, Valid Loss: 0.795917, Valid Acc: 0.726500, Time 00:13:16

Epoch 8. Train Loss: 1.311748, Train Acc: 0.542400, Valid Loss: 0.785983, Valid Acc: 0.728000, Time 00:14:48

Epoch 9. Train Loss: 1.302069, Train Acc: 0.544820, Valid Loss: 0.781665, Valid Acc: 0.734700, Time 00:14:15

Epoch 10. Train Loss: 1.298019, Train Acc: 0.547040, Valid Loss: 0.771555, Valid Acc: 0.742200, Time 00:16:11

Epoch 11. Train Loss: 1.310127, Train Acc: 0.538700, Valid Loss: 0.764313, Valid Acc: 0.739300, Time 00:17:33

Epoch 12. Train Loss: 1.300172, Train Acc: 0.544720, Valid Loss: 0.765881, Valid Acc: 0.734200, Time 00:12:04

Epoch 13. Train Loss: 1.289607, Train Acc: 0.546980, Valid Loss: 0.753371, Valid Acc: 0.742500, Time 00:11:49

Epoch 14. Train Loss: 1.295938, Train Acc: 0.546280, Valid Loss: 0.821099, Valid Acc: 0.721900, Time 00:11:43

...

使用微调训练方式的时间明显大于使用特征提取方式的时间,但是验证集上的准确率并没有提高,这是因为由于GPU内存限制,这里将batch_size设为了16。

6. 其他代码

python 复制代码
if __name__ == '__main__':
    data_path = './data'
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'forg', 'horse', 'ship', 'truck')
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
        torch.cuda.empty_cache()
    else:
        device = torch.device('cpu')
    # 加载数据
    train_loader, test_loader = load_data(data=data_path, fine_tuning=True)
    # 随机获取部分训练数据
    data_iter = iter(train_loader)
    images, labels = data_iter.next()
    # 显示图像
    imshow(torchvision.utils.make_grid(images[:4]))
    # 打印标签
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
    # 加载模型
    net = freeze_net(num_class=len(classes), fine_tuning=True)
    net = net.to(device)

    # 定义损失函数及优化器
    criterion = nn.CrossEntropyLoss()
    # 只需要优化最后一层参数
    optimizer = torch.optim.SGD(net.fc.parameters(), lr=1e-3, weight_decay=1e-3, momentum=0.9)
    # 训练及验证模型
    train(net, train_loader, test_loader, 20, optimizer, criterion)
相关推荐
AI算法工程师Moxi1 小时前
什么是迁移学习(transfer learning)
人工智能·机器学习·迁移学习
love you joyfully2 小时前
循环神经网络——pytorch实现循环神经网络(RNN、GRU、LSTM)
人工智能·pytorch·rnn·深度学习·gru·循环神经网络
大千AI助手7 小时前
InstructGPT:使用人类反馈训练语言模型以遵循指令
人工智能·gpt·语言模型·自然语言处理·rlhf·指令微调·模型对齐
躺柒9 小时前
读大语言模型08计算基础设施
人工智能·ai·语言模型·自然语言处理·大语言模型·大语言
金井PRATHAMA9 小时前
破译心智密码:神经科学如何为下一代自然语言处理绘制语义理解的蓝图
人工智能·自然语言处理
DogDaoDao1 天前
用PyTorch实现多类图像分类:从原理到实际操作
图像处理·人工智能·pytorch·python·深度学习·分类·图像分类
金井PRATHAMA1 天前
大脑的藏宝图——神经科学如何为自然语言处理(NLP)的深度语义理解绘制新航线
人工智能·自然语言处理
max5006001 天前
北京大学MuMo多模态肿瘤分类模型复现与迁移学习
人工智能·python·机器学习·分类·数据挖掘·迁移学习
樱花的浪漫1 天前
CUDA的编译与调试
人工智能·深度学习·语言模型·自然语言处理
勤劳的进取家2 天前
论文阅读:Gorilla: Large Language Model Connected with Massive APIs
论文阅读·人工智能·语言模型·自然语言处理·prompt