【NLP】一个使用PyTorch实现图像分类的迁移学习实例

一个使用PyTorch实现图像分类的迁移学习实例

  • [1. 导入模块](#1. 导入模块)
  • [2. 加载数据](#2. 加载数据)
  • [3. 模型处理](#3. 模型处理)
  • [4. 训练及验证模型](#4. 训练及验证模型)
  • [5. 微调](#5. 微调)
  • [6. 其他代码](#6. 其他代码)

在特征提取中,可以在预先训练好的网络结构后修改或添加一个简单的分类器,然后将源任务上预先训练好的网络作为另一个目标任务的特征提取器,只对最后增加的分类器参数重新学习,而预先训练好的网络参数不被修改或冻结。

在完成新任务的特征提取时使用的是源任务中学习到的参数,而不用重新学习所有参数。下面的示例用一个实例具体说明如何通过特征提取的方法进行图像分类。

1. 导入模块

python 复制代码
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torchvision import models

2. 加载数据

这里需要事先将CIFAR10数据下载到本地,因为比较耗时,因此,将download=False。除此之外,还增加了一些预处理功能,比如数据标准化、对图片进行裁剪等。

python 复制代码
def load_data(data, batch_size=64, num_workers=2, mean=None, std=None):
    if std is None:
        std = [0.229, 0.224, 0.225]
    if mean is None:
        mean = [0.485, 0.456, 0.406]
    trans_train = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    trans_valid = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
                                      transforms.Normalize(mean=mean, std=std)])

    train_set = torchvision.datasets.CIFAR10(root=data, train=True, download=True, transform=trans_train)
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    test_set = torchvision.datasets.CIFAR10(root=data, train=False, download=True, transform=trans_valid)
    testloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return trainloader, testloader

3. 模型处理

这个部分包含三个操作:

  • 下载预训练模型:使用的预训练模型为resnet18,且已经在ImageNet大数据集上训练好了
  • 冻结模型参数:使其在反向传播时,不会更新
  • 修改最后一层的输出类别数:该数据集中有1000个类别,即原始输出为512×1000,现将其修改为512×10,因为这里使用的新数据集有10个类别
python 复制代码
def freeze_net(num_class=10):
    # 下载预训练模型
    net = models.resnet18(pretrained=True)
    # 冻结模型参数
    for params in net.parameters():
        params.requires_grad = False
    # 修改最后一层的输出类别数
    net.fc = nn.Linear(512, num_class)
    # 查看冻结前后的参数情况
    total_params = sum(p.numel() for p in net.parameters())
    print(f'原总参数个数:{total_params}')
    total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print(f'需训练参数个数:{total_trainable_params}')
    return net

原总参数个数:11181642

需训练参数个数:5130

从输出上可知,如果不冻结,需要更新的参数太多了,冻结之后只需要更新全连接层的参数即可。

4. 训练及验证模型

这里选用交叉熵作为损失函数,使用SGD作为优化器,学习率为1e-3,权重衰减设为1e-3,代码如下:

python 复制代码
# 训练及验证模型
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
    prev_time = datetime.now()
    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        net = net.train()
        for im, label in train_data:
            im = im.to(device)  # (bs, 3, h, w)
            label = label.to(device)  # (bs, h, w)
            # forward
            output = net(im)
            loss = criterion(output, label)
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_acc += get_acc(output, label)

        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        if valid_data is not None:
            valid_loss = 0
            valid_acc = 0
            net = net.eval()
            for im, label in valid_data:
                im = im.to(device)  # (bs, 3, h, w)
                label = label.to(device)  # (bs, h, w)
                output = net(im)
                loss = criterion(output, label)
                valid_loss += loss.item()
                valid_acc += get_acc(output, label)
            epoch_str = (
                    "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
                    % (epoch, train_loss / len(train_data),
                       train_acc / len(train_data), valid_loss / len(valid_data),
                       valid_acc / len(valid_data)))
        else:
            epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
                         (epoch, train_loss / len(train_data),
                          train_acc / len(train_data)))
        prev_time = cur_time
        print(epoch_str + time_str)

运行结果:

Epoch 0. Train Loss: 1.474121, Train Acc: 0.498322, Valid Loss: 0.901339, Valid Acc: 0.713177, Time 00:03:26

Epoch 1. Train Loss: 1.222752, Train Acc: 0.576946, Valid Loss: 0.818926, Valid Acc: 0.730494, Time 00:04:35

Epoch 2. Train Loss: 1.172832, Train Acc: 0.592651, Valid Loss: 0.777265, Valid Acc: 0.737759, Time 00:04:23

Epoch 3. Train Loss: 1.158157, Train Acc: 0.596228, Valid Loss: 0.761969, Valid Acc: 0.746517, Time 00:04:28

Epoch 4. Train Loss: 1.143113, Train Acc: 0.600643, Valid Loss: 0.757134, Valid Acc: 0.742138, Time 00:04:24

Epoch 5. Train Loss: 1.128991, Train Acc: 0.607797, Valid Loss: 0.745840, Valid Acc: 0.747014, Time 00:04:24

Epoch 6. Train Loss: 1.131602, Train Acc: 0.603561, Valid Loss: 0.740176, Valid Acc: 0.748109, Time 00:04:21

Epoch 7. Train Loss: 1.127840, Train Acc: 0.608336, Valid Loss: 0.738235, Valid Acc: 0.751990, Time 00:04:19

Epoch 8. Train Loss: 1.122831, Train Acc: 0.609275, Valid Loss: 0.730571, Valid Acc: 0.751692, Time 00:04:18

Epoch 9. Train Loss: 1.118955, Train Acc: 0.609715, Valid Loss: 0.731084, Valid Acc: 0.751692, Time 00:04:13

Epoch 10. Train Loss: 1.111291, Train Acc: 0.612052, Valid Loss: 0.728281, Valid Acc: 0.749602, Time 00:04:09

Epoch 11. Train Loss: 1.108454, Train Acc: 0.612712, Valid Loss: 0.719465, Valid Acc: 0.752787, Time 00:04:15

Epoch 12. Train Loss: 1.111189, Train Acc: 0.612012, Valid Loss: 0.726525, Valid Acc: 0.751294, Time 00:04:09

Epoch 13. Train Loss: 1.114475, Train Acc: 0.610594, Valid Loss: 0.717852, Valid Acc: 0.754080, Time 00:04:06

Epoch 14. Train Loss: 1.112658, Train Acc: 0.608596, Valid Loss: 0.723336, Valid Acc: 0.751393, Time 00:04:14

Epoch 15. Train Loss: 1.109367, Train Acc: 0.614950, Valid Loss: 0.721230, Valid Acc: 0.752588, Time 00:04:06

Epoch 16. Train Loss: 1.107644, Train Acc: 0.614230, Valid Loss: 0.711586, Valid Acc: 0.755275, Time 00:04:08

Epoch 17. Train Loss: 1.100239, Train Acc: 0.613411, Valid Loss: 0.722191, Valid Acc: 0.749303, Time 00:04:11

Epoch 18. Train Loss: 1.108576, Train Acc: 0.611013, Valid Loss: 0.721263, Valid Acc: 0.753483, Time 00:04:08

Epoch 19. Train Loss: 1.098069, Train Acc: 0.618027, Valid Loss: 0.705413, Valid Acc: 0.757962, Time 00:04:06

从结果上看,验证集的准确率达到75%左右。下面采用微调+数据增强的方法继续提升准确率。

5. 微调

微调允许修改预训练好的网络参数来学习目标任务,所以训练时间要比特征抽取方法长,但精度更高。微调的大致过程是再预训练的网络上添加新的随机初始化层,此外预训练的网络参数也会被更新,但会使用较小的学习率以防止预训练好的参数发生较大改变

常用的方法是固定底层的参数,调整一些顶层或具体层的参数。这样可以减少训练参数的数量,也可以避免过拟合的发生。尤其是针对目标任务的数据量不够大的时候,该方法会很有效。

实际上,微调优于特征提取,因为它能对迁移过来的预训练网络参数进行优化,使其更加适合新的任务。

(1)数据预处理

对训练数据添加了几种数据增强方法,比如图片裁剪、旋转、颜色改变等方法。测试数据与特征提取的方法一样。

python 复制代码
    if fine_tuning is False:
        trans_train = transforms.Compose([transforms.RandomResizedCrop(224),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=mean, std=std)])
    else:
        trans_train = transforms.Compose([transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
                                          transforms.RandomRotation(degrees=15),
                                          transforms.ColorJitter(),
                                          transforms.RandomResizedCrop(224),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=mean, std=std)])

(2)修改模型的分类器层

修改最后全连接层,把类别数由原来的1000改为10。

python 复制代码
def freeze_net(num_class=10, fine_tuning=False):
    # 下载预训练模型
    net = models.resnet18(pretrained=True)
    print(net)
    if fine_tuning is False:
        # 冻结模型参数
        for params in net.parameters():
            params.requires_grad = False
    # 修改最后一层的输出类别数
    net.fc = nn.Linear(512, num_class)
    # 查看冻结前后的参数情况
    total_params = sum(p.numel() for p in net.parameters())
    print(f'原总参数个数:{total_params}')
    total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print(f'需训练参数个数:{total_trainable_params}')
    # 打印出第一层的权重
    print(f'第一层的权重:{net.conv1.weight.type()}')
    return net

训练结果:

Epoch 0. Train Loss: 1.455535, Train Acc: 0.488460, Valid Loss: 0.832547, Valid Acc: 0.721400, Time 00:14:48

Epoch 1. Train Loss: 1.342625, Train Acc: 0.530280, Valid Loss: 0.815430, Valid Acc: 0.723500, Time 10:31:48

Epoch 2. Train Loss: 1.319122, Train Acc: 0.535680, Valid Loss: 0.866512, Valid Acc: 0.699000, Time 00:12:02

Epoch 3. Train Loss: 1.310949, Train Acc: 0.541700, Valid Loss: 0.789511, Valid Acc: 0.728000, Time 00:12:03

Epoch 4. Train Loss: 1.313486, Train Acc: 0.538500, Valid Loss: 0.762553, Valid Acc: 0.741300, Time 00:12:19

Epoch 5. Train Loss: 1.309776, Train Acc: 0.540680, Valid Loss: 0.777906, Valid Acc: 0.736100, Time 00:11:43

Epoch 6. Train Loss: 1.302117, Train Acc: 0.541780, Valid Loss: 0.779318, Valid Acc: 0.737200, Time 00:12:00

Epoch 7. Train Loss: 1.304539, Train Acc: 0.544320, Valid Loss: 0.795917, Valid Acc: 0.726500, Time 00:13:16

Epoch 8. Train Loss: 1.311748, Train Acc: 0.542400, Valid Loss: 0.785983, Valid Acc: 0.728000, Time 00:14:48

Epoch 9. Train Loss: 1.302069, Train Acc: 0.544820, Valid Loss: 0.781665, Valid Acc: 0.734700, Time 00:14:15

Epoch 10. Train Loss: 1.298019, Train Acc: 0.547040, Valid Loss: 0.771555, Valid Acc: 0.742200, Time 00:16:11

Epoch 11. Train Loss: 1.310127, Train Acc: 0.538700, Valid Loss: 0.764313, Valid Acc: 0.739300, Time 00:17:33

Epoch 12. Train Loss: 1.300172, Train Acc: 0.544720, Valid Loss: 0.765881, Valid Acc: 0.734200, Time 00:12:04

Epoch 13. Train Loss: 1.289607, Train Acc: 0.546980, Valid Loss: 0.753371, Valid Acc: 0.742500, Time 00:11:49

Epoch 14. Train Loss: 1.295938, Train Acc: 0.546280, Valid Loss: 0.821099, Valid Acc: 0.721900, Time 00:11:43

...

使用微调训练方式的时间明显大于使用特征提取方式的时间,但是验证集上的准确率并没有提高,这是因为由于GPU内存限制,这里将batch_size设为了16。

6. 其他代码

python 复制代码
if __name__ == '__main__':
    data_path = './data'
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'forg', 'horse', 'ship', 'truck')
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
        torch.cuda.empty_cache()
    else:
        device = torch.device('cpu')
    # 加载数据
    train_loader, test_loader = load_data(data=data_path, fine_tuning=True)
    # 随机获取部分训练数据
    data_iter = iter(train_loader)
    images, labels = data_iter.next()
    # 显示图像
    imshow(torchvision.utils.make_grid(images[:4]))
    # 打印标签
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
    # 加载模型
    net = freeze_net(num_class=len(classes), fine_tuning=True)
    net = net.to(device)

    # 定义损失函数及优化器
    criterion = nn.CrossEntropyLoss()
    # 只需要优化最后一层参数
    optimizer = torch.optim.SGD(net.fc.parameters(), lr=1e-3, weight_decay=1e-3, momentum=0.9)
    # 训练及验证模型
    train(net, train_loader, test_loader, 20, optimizer, criterion)
相关推荐
mm_exploration9 小时前
torch常用函数
pytorch·python·深度学习
剑圣土豆9 小时前
大模型LLM面试常见算法题-包括Attention和Transformer常见面试题
人工智能·深度学习·算法·自然语言处理·面试·nlp·transformer
过于真实呢10 小时前
1-4 NLP发展历史与我的工作感悟
人工智能·自然语言处理
姚家湾11 小时前
使用工业自动化的功能块实现大语言模型应用
人工智能·语言模型·自然语言处理·工业控制
liyiersan12313 小时前
如何使用C++调用Pytorch模型进行推理测试:使用libtorch库
开发语言·c++·pytorch
啥都亿点点的研究生15 小时前
OnnxRuntime c#找不到模块HRESULT: 0x8007007E
pytorch·深度学习·opencv·计算机视觉·cnn
只是有点小怂15 小时前
【chatgpt】 PyTorch 中view方法改变张量的形状,-1是特殊参数,用于自动推断维度的大小
人工智能·pytorch·python
哦豁灬17 小时前
pytorch 源码阅读(1)——torch.complie
人工智能·pytorch·python
xiandong2017 小时前
240705_昇思学习打卡-Day17-基于 MindSpore 实现 BERT 对话情绪识别
人工智能·深度学习·学习·自然语言处理·bert·transformer
tjl521314_2118 小时前
N5 使用Gensim库训练Word2Vec模型
人工智能·自然语言处理·word2vec