13 计算机视觉-代码详解

13.2 微调

为了防止在训练集上过拟合,有两种办法,第一种是扩大训练集数量,但是需要大量的成本;第二种就是应用迁移学习,将源数据学习到的知识迁移到目标数据集,即在把在源数据训练好的参数和模型(除去输出层)直接复制到目标数据集训练。

python 复制代码
# IPython魔法函数,可以不用执行plt .show()
%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

13.2.1 获取数据集

python 复制代码
#@save
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
                         'fba480ffa8aa7e0febbb511d181409f899b9baa5')

data_dir = d2l.download_extract('hotdog')
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i-1][0] for i in range(8)]
# 展示2行8列矩阵的图片,共16张
d2l.show_images(hotdogs+not_hotdogs,2,8,scale=1.5)
python 复制代码
# 使用RGB通道的均值和标准差,以标准化每个通道
normalize = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# 图像增广
train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    normalize])
test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize([256, 256]),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    normalize])

13.2.2 初始化模型

python 复制代码
# 自动下载网上的训练模型
finetune_net = torchvision.models.resnet18(pretrained=True)
# 输入张量的形状还是源输入张量大小,输入张量大小改为2
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight);

13.2.3 微调模型

python 复制代码
# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
# 如果param_group=False,输出层中模型参数为随机值
# 训练模型
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,
                      param_group=True):
    train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train'), transform=train_augs),
        batch_size=batch_size, shuffle=True)
    test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'test'), transform=test_augs),
        batch_size=batch_size)
    devices = d2l.try_all_gpus()
    loss = nn.CrossEntropyLoss(reduction="none")
    if param_group:
        params_1x = [param for name, param in net.named_parameters()
             if name not in ["fc.weight", "fc.bias"]]
        # params_1x的参数使用learning_rate学习率, net.fc.parameters()的参数使用0.001的学习率
        trainer = torch.optim.SGD([{'params': params_1x},
                                   {'params': net.fc.parameters(),
                                    'lr': learning_rate * 10}],
                                lr=learning_rate, weight_decay=0.001)
    else:
        trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
                                  weight_decay=0.001)
    d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
                   devices)
train_fine_tuning(finetune_net, 5e-5)

13.3 目标检测和边界框

有时候不仅要识别图像的类别,还需要识别图像的位置。在计算机视觉中叫做目标识别或者目标检测。这小节是介绍目标检测的深度学习方法。

python 复制代码
%matplotlib inline
import torch
from d2l import torch as d2l
#@save
def box_corner_to_center(boxes):
    """从(左上,右下)转换到(中间,宽度,高度)"""
    x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    # cx,xy,w,h的维度是n
    cx = (x1 + x2) / 2
    cy = (y1 + y2) / 2
    w = x2 - x1
    h = y2 - y1
    # torch.stack()沿着新维度对张量进行链接。boxes最开始维度是(n,4),axis=-1表示倒数第一个维度
    # torch.stack()将(cx, cy, w, h)的维度n将其沿着倒数第一个维度拼接在一起,又是(n,4)
    boxes = torch.stack((cx, cy, w, h), axis=-1)
    return boxes

#@save
def box_center_to_corner(boxes):
    """从(中间,宽度,高度)转换到(左上,右下)"""
    cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    x1 = cx - 0.5 * w
    y1 = cy - 0.5 * h
    x2 = cx + 0.5 * w
    y2 = cy + 0.5 * h
    boxes = torch.stack((x1, y1, x2, y2), axis=-1)
    return boxes
相关推荐
BulingQAQ1 小时前
论文阅读:PET/CT Cross-modal medical image fusion of lung tumors based on DCIF-GAN
论文阅读·深度学习·生成对抗网络·计算机视觉·gan
hsling松子1 小时前
使用PaddleHub智能生成,献上浓情国庆福
人工智能·算法·机器学习·语言模型·paddlepaddle
正在走向自律2 小时前
机器学习框架
人工智能·机器学习
好吃番茄2 小时前
U mamba配置问题;‘KeyError: ‘file_ending‘
人工智能·机器学习
CV-King3 小时前
opencv实战项目(三十):使用傅里叶变换进行图像边缘检测
人工智能·opencv·算法·计算机视觉
禁默3 小时前
2024年计算机视觉与艺术研讨会(CVA 2024)
人工智能·计算机视觉
whaosoft-1434 小时前
大模型~合集3
人工智能
Dream-Y.ocean4 小时前
文心智能体平台AgenBuilder | 搭建智能体:情感顾问叶晴
人工智能·智能体
丶21364 小时前
【CUDA】【PyTorch】安装 PyTorch 与 CUDA 11.7 的详细步骤
人工智能·pytorch·python
春末的南方城市5 小时前
FLUX的ID保持项目也来了! 字节开源PuLID-FLUX-v0.9.0,开启一致性风格写真新纪元!
人工智能·计算机视觉·stable diffusion·aigc·图像生成