13 计算机视觉-代码详解

13.2 微调

为了防止在训练集上过拟合,有两种办法,第一种是扩大训练集数量,但是需要大量的成本;第二种就是应用迁移学习,将源数据学习到的知识迁移到目标数据集,即在把在源数据训练好的参数和模型(除去输出层)直接复制到目标数据集训练。

python 复制代码
# IPython魔法函数,可以不用执行plt .show()
%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

13.2.1 获取数据集

python 复制代码
#@save
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
                         'fba480ffa8aa7e0febbb511d181409f899b9baa5')

data_dir = d2l.download_extract('hotdog')
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i-1][0] for i in range(8)]
# 展示2行8列矩阵的图片,共16张
d2l.show_images(hotdogs+not_hotdogs,2,8,scale=1.5)
python 复制代码
# 使用RGB通道的均值和标准差,以标准化每个通道
normalize = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# 图像增广
train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    normalize])
test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize([256, 256]),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    normalize])

13.2.2 初始化模型

python 复制代码
# 自动下载网上的训练模型
finetune_net = torchvision.models.resnet18(pretrained=True)
# 输入张量的形状还是源输入张量大小,输入张量大小改为2
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight);

13.2.3 微调模型

python 复制代码
# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
# 如果param_group=False,输出层中模型参数为随机值
# 训练模型
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,
                      param_group=True):
    train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train'), transform=train_augs),
        batch_size=batch_size, shuffle=True)
    test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'test'), transform=test_augs),
        batch_size=batch_size)
    devices = d2l.try_all_gpus()
    loss = nn.CrossEntropyLoss(reduction="none")
    if param_group:
        params_1x = [param for name, param in net.named_parameters()
             if name not in ["fc.weight", "fc.bias"]]
        # params_1x的参数使用learning_rate学习率, net.fc.parameters()的参数使用0.001的学习率
        trainer = torch.optim.SGD([{'params': params_1x},
                                   {'params': net.fc.parameters(),
                                    'lr': learning_rate * 10}],
                                lr=learning_rate, weight_decay=0.001)
    else:
        trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
                                  weight_decay=0.001)
    d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
                   devices)
train_fine_tuning(finetune_net, 5e-5)

13.3 目标检测和边界框

有时候不仅要识别图像的类别,还需要识别图像的位置。在计算机视觉中叫做目标识别或者目标检测。这小节是介绍目标检测的深度学习方法。

python 复制代码
%matplotlib inline
import torch
from d2l import torch as d2l
#@save
def box_corner_to_center(boxes):
    """从(左上,右下)转换到(中间,宽度,高度)"""
    x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    # cx,xy,w,h的维度是n
    cx = (x1 + x2) / 2
    cy = (y1 + y2) / 2
    w = x2 - x1
    h = y2 - y1
    # torch.stack()沿着新维度对张量进行链接。boxes最开始维度是(n,4),axis=-1表示倒数第一个维度
    # torch.stack()将(cx, cy, w, h)的维度n将其沿着倒数第一个维度拼接在一起,又是(n,4)
    boxes = torch.stack((cx, cy, w, h), axis=-1)
    return boxes

#@save
def box_center_to_corner(boxes):
    """从(中间,宽度,高度)转换到(左上,右下)"""
    cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    x1 = cx - 0.5 * w
    y1 = cy - 0.5 * h
    x2 = cx + 0.5 * w
    y2 = cy + 0.5 * h
    boxes = torch.stack((x1, y1, x2, y2), axis=-1)
    return boxes
相关推荐
Keep_Trying_Go3 分钟前
LightningCLI教程 + 视频讲解
人工智能·pytorch·语言模型·大模型·多模态·lightning
1***s6323 分钟前
Java语音识别开发
人工智能·语音识别
模型启动机4 分钟前
DeepSeek OCR vs Qwen-3 VL vs Mistral OCR:谁更胜一筹?
人工智能·ai·大模型·ocr·deepseek
Chef_Chen10 分钟前
数据科学每日总结--Day26--数据挖掘
人工智能·数据挖掘
胡琦博客12 分钟前
21天开源鸿蒙训练营|Day1 拒绝环境配置焦虑:AI 辅助下的 OpenHarmony 跨平台环境搭建全实录
人工智能·开源·harmonyos
一泽Eze15 分钟前
飞书没走 AI Coding 路线,它做好了另一种 AI 应用模式
人工智能
大任视点15 分钟前
科技赋能健康未来,守护生命青春活力
大数据·人工智能·科技
光影341523 分钟前
微调检测页面操作
人工智能
虎头金猫35 分钟前
随时随地处理图片文档!Reubah 加cpolar的实用体验
linux·运维·人工智能·python·docker·开源·visual studio
九鼎创展科技1 小时前
九鼎创展发布X3588SCV4核心板,集成LPDDR5内存,提升RK3588S平台性能边界
android·人工智能·嵌入式硬件·硬件工程