(动手学习深度学习)第13章 计算机视觉---微调

文章目录

微调






总结

  • 微调通过使用在大数据上的恶道的预训练好的模型来初始化模型权重来完成提升精度。
  • 预训练模型质量很重要
  • 微调通常速度更快、精确度更高

微调代码实现

  1. 导入相关库
python 复制代码
%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import matplotlib as plt
  1. 获取数据集
python 复制代码
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
                         'fba480ffa8aa7e0febbb511d181409f899b9baa5')

data_dir = d2l.download_extract('hotdog')
python 复制代码
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'test'))
python 复制代码
print(train_imgs)
print(train_imgs[0])
train_imgs[0][0]

查看数据集中图像的形状

python 复制代码
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs= [train_imgs[-i-1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2 ,8, scale=1.4)
  1. 数据增强
python 复制代码
# 图像增广
normalize = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224,0.225]
)
train_augs = torchvision.transforms.Compose(  # 训练集数据增强
    [torchvision.transforms.RandomResizedCrop(224),
     torchvision.transforms.RandomHorizontalFlip(),
     torchvision.transforms.ToTensor(),
     normalize]
)
test_augs = torchvision.transforms.Compose(  # 验证集不做数据增强
    [torchvision.transforms.Resize(256),
     torchvision.transforms.CenterCrop(224),
     torchvision.transforms.ToTensor(),
     normalize]
)
  1. 定义和初始化模型
python 复制代码
# 下载resnet18,
# 老:pretrain=True: 也下载预训练的模型参数
# 新:weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1
pretrained_net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
python 复制代码
print(pretrained_net.fc)
  1. 微调模型
  • (1)直接修改网络层(如最后全连接层:512--->1000,改成512--->2)
  • (2)在增加一层分类层(如:512--->1000, 改成512--->1000, 1000--->2)

本次选择(1):将resnet18最后全连接层的输出,改成自己训练集的类别,并初始化最后全连接层的权重参数

python 复制代码
finetune_net = pretrained_net
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight)
python 复制代码
print(finetune_net)
  1. 训练模型
  • 特征提取层(预训练层):使用较小的学习率
  • 输出全连接层(微调层):使用较大的学习率
python 复制代码
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=10, param_group=True):
    train_iter = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder(
            os.path.join(data_dir,'train'), transform=train_augs
        ),
        batch_size=batch_size,
        shuffle=True
    )
    test_iter = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder(
            os.path.join(data_dir, 'test'), transform=test_augs
        ),
        batch_size=batch_size
    )
    device = d2l.try_all_gpus()
    loss = nn.CrossEntropyLoss(reduction='none')
    if param_group:
        params_1x = [param for name, param in net.named_parameters()
                     if name not in ['fc.weight', 'fc.bias']]
        trainer = torch.optim.SGD(
            [{'params': params_1x}, {'params': net.fc.parameters(), 'lr': learning_rate * 10}],
            lr=learning_rate, weight_decay=0.001
        )
    else:
        trainer = torch.optim.SGD(
            net.parameters(),
            lr=learning_rate,weight_decay=0.001
        )
    d2l.train_ch13(net, train_iter, test_iter, loss,trainer, num_epochs, device)

训练模型

python 复制代码
import time

# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以

train_fine_tuning(finetune_net, 5e-5, 128, 10)

# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以

# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f} s')

直接训练:整个模型都使用相同的学习率,重新训练

python 复制代码
scracth_net = torchvision.models.resnet18()
scracth_net.fc = nn.Linear(scracth_net.fc.in_features, 2)

import time

# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以

train_fine_tuning(scracth_net, 5e-4, param_group=False)

# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以

# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f} s')
相关推荐
小言从不摸鱼14 分钟前
【AI大模型】探索GPT模型的奥秘:引领自然语言处理的新纪元
人工智能·gpt·深度学习·语言模型·自然语言处理·transformer
sp_fyf_202421 分钟前
【大语言模型】ACL2024论文-36 利用NLI和ChatGPT及编码簿知识进行零样本政治关系分类
深度学习·神经网络·机器学习·语言模型·chatgpt·分类·数据挖掘
张铁铁是个小胖子2 小时前
微服务学习
java·学习·微服务
sp_fyf_20243 小时前
【大语言模型】ACL2024论文-35 WAV2GLOSS:从语音生成插值注解文本
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理·数据挖掘
AITIME论道3 小时前
论文解读 | EMNLP2024 一种用于大语言模型版本更新的学习率路径切换训练范式
人工智能·深度学习·学习·机器学习·语言模型
明明真系叻4 小时前
第二十六周机器学习笔记:PINN求正反解求PDE文献阅读——正问题
人工智能·笔记·深度学习·机器学习·1024程序员节
XianxinMao5 小时前
Transformer 架构对比:Dense、MoE 与 Hybrid-MoE 的优劣分析
深度学习·架构·transformer
青春男大6 小时前
java栈--数据结构
java·开发语言·数据结构·学习·eclipse
HyperAI超神经6 小时前
未来具身智能的触觉革命!TactEdge传感器让机器人具备精细触觉感知,实现织物缺陷检测、灵巧操作控制
人工智能·深度学习·机器人·触觉传感器·中国地质大学·机器人智能感知·具身触觉
一勺汤6 小时前
YOLO11改进-注意力-引入多尺度卷积注意力模块MSCAM
yolo·目标检测·计算机视觉·改进·魔改·yolov11·yolov11改进