深度学习23数据增广+微调+实战Kaggle+目标检测+锚框

1. 数据增广

1. 总结

2. 数据增广-----代码

python 复制代码
#让 matplotlib 绘制的图像直接内嵌显示在单元格输出下方,不会弹出独立图片窗口,是深度学习绘图标配。
%matplotlib inline
import torch                # 导入PyTorch深度学习框架,搭建、运行神经网络
import torchvision          # PyTorch配套视觉工具库,处理图像、数据集、预训练模型
from torch import nn        # 从torch单独导入神经网络模块(卷积层、全连接层、激活函数等)
from d2l import torch as d2l # 导入《动手学深度学习》配套工具包,封装了绘图、读图、训练辅助函数

d2l.set_figsize()   #调用 d2l 封装的接口,设置全局画布默认大小
img = d2l.Image.open('01_Data/02_cat.jpg') # 读取图片
d2l.plt.imshow(img) # 显示图片
复制代码
<matplotlib.image.AxesImage at 0x1d31b4730f0>
python 复制代码
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5): # 传入aug图片增广方法
    Y = [aug(img) for _ in range(num_rows * num_cols)] # 用aug方法对图片作用八次
    d2l.show_images(Y, num_rows, num_cols, scale=scale) # 生成结果用num_cols行,num_cols列展示  
python 复制代码
apply(img, torchvision.transforms.RandomHorizontalFlip()) # 水平随机翻转
python 复制代码
apply(img, torchvision.transforms.RandomVerticalFlip()) # 上下随机翻转
python 复制代码
# 随机剪裁,剪裁后的大小为(200,200)
# (0.1,1)使得随即剪裁原始图片的10%到100%区域里的大小,ratio=(0.5,2)使得高宽比为2:1,下面是显示时显示的1:1    
shape_aug = torchvision.transforms.RandomResizedCrop((200,200),scale=(0.1,1),ratio=(0.5,2))     
apply(img,shape_aug)
python 复制代码
# 随即更改图像的亮度                          亮度          对比度       饱和度      色相
apply(img,torchvision.transforms.ColorJitter(brightness=0.5,contrast=0,saturation=0,hue=0))      
python 复制代码
# 随即改变色调
apply(img,torchvision.transforms.ColorJitter(brightness=0,contrast=0,saturation=0,hue=0.5))   
python 复制代码
# 随机更改图像的亮度(brightness)、对比度(constrast)、饱和度(saturation)和色调(hue)
color_aug = torchvision.transforms.ColorJitter(brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5)
apply(img,color_aug)
python 复制代码
# 结合多种图像增广方法
# 先随即水平翻转,再做颜色增广,再做形状增广
augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),color_aug,shape_aug])   
apply(img,augs)
python 复制代码
# 下载图片,并显示部分图片
all_images = torchvision.datasets.CIFAR10(train=True, root='01_Data/03_CIFAR10', download=True)    
d2l.show_images([all_images[i][0] for i in range(32)], 4, 8, scale=0.8)
复制代码
Files already downloaded and verified
复制代码
array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>,
       <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>,
       <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>,
       <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>,
       <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>,
       <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>,
       <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>,
       <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
      dtype=object)
python 复制代码
# 只使用最简单的随机左右翻转
train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor()])

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()])     
python 复制代码
# 定义一个辅助函数,以便于读取图像和应用图像增广
def load_cifar10(is_train, augs, batch_size):
    dataset = torchvision.datasets.CIFAR10(root='01_Data/03_CIFAR10',train=is_train,
                                         transform=augs, download=True)
    dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=is_train,
                                            num_workers = 0)   
    return dataloader
python 复制代码
# 定义一个函数,使用多GPU模式进行训练和评估
def train_batch_ch13(net, X, y, loss, trainer, devices):
    if isinstance(X, list):
        X = [x.to(devices[0]) for x in X] # 如果X是一个list,则把数据一个接一个都挪到devices[0]上    
    else:
        X = X.to(devices[0]) # 如果X不是一个list,则把X挪到devices[0]上
    y = y.to(devices[0])
    net.train()
    trainer.zero_grad()
    pred = net(X)
    l = loss(pred, y)
    l.sum().backward()
    trainer.step()
    train_loss_sum = l.sum()
    train_acc_sum = d2l.accuracy(pred, y)
    return train_loss_sum, train_acc_sum

def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices=d2l.try_all_gpus()):
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],ylim=[0,1],
                           legend=['train loss', 'train acc', 'test acc'])
    # nn.DataParallel使用多GPU
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = train_batch_ch13(net,features,labels,loss,trainer,devices)   
            metric.add(l,acc,labels.shape[0],labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches -1:
                animator.add(
                    epoch + (i + 1) / num_batches,
                    (metric[0] / metric[2], metric[1] / metric[3], None))              
        test_acc = d2l.evaluate_accuracy_gpu(net,test_iter)
        animator.add(epoch+1,(None,None,test_acc))
    print(f'loss {metric[0] / metric[2]:.3f}, train acc'
         f' {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f' {metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
         f' {str(devices)}')  
python 复制代码
# 定义train_with_data_aug函数,使用图像增广来训练模型
batch_size, devices, net = 256, d2l.try_all_gpus(), d2l.resnet18(10,3)

def init_weights(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_uniform_(m.weight)
        
net.apply(init_weights)

def train_with_data_aug(train_augs, test_augs, net, lr=0.001):
    train_iter = load_cifar10(True, train_augs, batch_size)
    test_iter = load_cifar10(False, test_augs, batch_size)
    loss = nn.CrossEntropyLoss(reduction="none")
    # Adam优化器算是一个比较平滑的SGD,它对学习率调参不是很敏感
    trainer = torch.optim.Adam(net.parameters(),lr=lr)
    train_ch13(net, train_iter, test_iter, loss, trainer, 10, devices)
    
train_with_data_aug(train_augs, test_augs, net)
复制代码
loss 0.166, train acc 0.942, test acc 0.832
 1013.2 examples/sec on  [device(type='cuda', index=0)]

2. 微调

1. 总结

2. 微调---->代码

**目的:**自动下载热狗 / 非热狗二分类数据集,读取训练集、测试集,批量展示 8 张热狗图片 + 8 张非热狗图片

python 复制代码
%matplotlib inline
import os                 #处理文件路径、文件夹
import torch
import torchvision         #深度学习框架、图像数据集工具
from torch import nn
from d2l import torch as d2l  #《动手学深度学习》配套工具库,自带数据集下载、绘图工具
python 复制代码
#d2l.DATA_HUB 是一个字典,用来注册数据集信息:键名:hotdog(数据集名字)
#第一个值:数据集在线下载地址 hotdog.zip
#第二个值:文件校验码,用来判断文件有没有损坏、是否下载完整
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')    
data_dir = d2l.download_extract('hotdog')#data_dir 变量保存解压后数据集根文件夹路径
# ImageFolder 读取文件夹式数据集
#train_imgs 是整个训练数据集对象,可以用下标取单条样本
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'train'))   
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'test'))#
python 复制代码
# 图片的大小和纵横比各有不同
#列表推导式循环 8 次,得到包含 8 张热狗图片的列表 hotdogs
hotdogs = [train_imgs[i][0] for i in range(8)]
print(train_imgs[0]) # 图片和标签,合为一个元组
print(train_imgs[0][0]) # 元组第一个元素为图片
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)
复制代码
(<PIL.Image.Image image mode=RGB size=122x144 at 0x1F2CBDF9AC8>, 0)
<PIL.Image.Image image mode=RGB size=122x144 at 0x1F2CBDF9C18>
复制代码
array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>,
       <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>,
       <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>,
       <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
      dtype=object)

整体作用分别定义训练集、测试集两套图像预处理流水线,在喂进神经网络前统一图片格式、做数据增强防过拟合,最后标准化像素值。
执行顺序从上到下:

  1. RandomResizedCrop(224) 随机裁剪 + 缩放:随机在原图抠一块区域,再缩放到 224×224 正方形。 👉 数据增强:让模型学会关注物体不同局部位置,提升泛化能力。
  2. RandomHorizontalFlip() 随机水平翻转:默认 50% 概率左右镜像翻转图片,另一半保持原图不变。 👉 扩充样本多样性,防止模型死记图片朝向。
  3. ToTensor() PIL 图片 → PyTorch 张量;同时把像素值从 0~255 压缩到 0~1。 通道顺序从 (H,W,C) 转为模型要求的 (C,H,W)
  4. normalize 用上面定义的均值方差做标准化。
训练集 train_augs 测试集 test_augs
随机裁剪、随机翻转(数据增强) 居中裁剪、无任何随机操作
目的:扩充数据,防止过拟合 目的:统一输入尺寸,保证预测稳定可复现
python 复制代码
# 数据增广
#Normalize(均值列表, 标准差列表)
normalize = torchvision.transforms.Normalize([0.485,0.456,0.406],
                                            [0.229,0.224,0.225]) # 按该均值、方差做归一化 
#Compose = 把多个变换操作按顺序打包串联执行 

train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    normalize ])  

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    normalize ])

含义拆解

  1. torchvision.models.resnet18:调用 PyTorch 内置的 ResNet18 网络结构(18 层深度残差卷积神经网络,图像分类经典 backbone)
  2. pretrained=True
    • 自动下载在超大 ImageNet 数据集上训练完成的权重参数
      • 不是初始化随机参数,而是加载别人训练好、提取通用视觉特征的成熟权重
    • 该技巧叫迁移学习:用大数据学到的通用视觉能力,适配你自己的热狗二分类小数据集,收敛更快、精度更高

ResNet18 前面一大堆卷积层、残差块,用来提取图片边缘、纹理、轮廓、物体高级特征; 末尾 fc = fully connected 全连接层,是整个网络最后的分类输出头。

python 复制代码
# 定义和初始化模型
pretrained_net = torchvision.models.resnet18(pretrained=True) # 把模型和在ImageNet上定义好的参数拿过来   
pretrained_net.fc # full connection全连接层,最后一层,查看最后一层的输入和输出结构  
复制代码
Linear(in_features=512, out_features=1000, bias=True)

知识点讲解

  1. 前面卷积层是预训练好的成熟权重,不用重新初始化;
  2. 你刚刚新建替换的 fc 层是随机初始值,需要合理初始化;
  3. xavier_uniform_ 是经典权重初始化方案: 让每一层输入、输出的方差尽量稳定,避免训练时梯度爆炸 / 梯度消失,收敛更快更稳;
  4. 注意:只初始化了权重 weight,偏置 bias 默认 PyTorch 自带初始化,不用手动处理
python 复制代码
finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features,2) # 最后一层修改为输出类别数为2   
nn.init.xavier_uniform_(finetune_net.fc.weight) # 只对最后一层的weight做随即初始化   
复制代码
Parameter containing:
tensor([[ 0.0004, -0.0395, -0.0163,  ...,  0.0185, -0.0238,  0.0693],
        [ 0.0307,  0.0278,  0.0082,  ..., -0.0852,  0.0642, -0.0302]],
       requires_grad=True)

整体功能:针对迁移学习微调专门设计的训练函数,支持「主干小学习率、分类头大学习率」差异化训练,适配你的热狗二分类任务

参数说明:

  • net:你改造好的 ResNet18 模型 finetune_net

  • learning_rate:基础学习率

  • batch_size=128:批次大小,每次喂入模型 128 张图

  • num_epochs=5:完整遍历整个数据集 5 轮

  • param_group=True:是否开启分组学习率(微调核心技巧)

  • ImageFolder:读取你热狗数据集,自动打标签

  • transform=train_augs / test_augs:自动套用前面写好的数据增强 + 归一化预处理

  • shuffle=True:训练集打乱顺序,防止模型记顺序作弊;测试集不需要打乱

python 复制代码
# 微调座位
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5, param_group=True):
    train_iter = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder(os.path.join(data_dir,'train'),transform=train_augs),
        batch_size = batch_size,shuffle=True)
    test_iter = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder(os.path.join(data_dir,'test'),transform=test_augs),
        batch_size=batch_size) 
#动检测电脑有没有 GPU,有就用 GPU 加速训练,没有自动切 CPU   
    devices = d2l.try_all_gpus()
    loss = nn.CrossEntropyLoss(reduction="none")#交叉熵损失
    if param_group:          #params_lx:收集除最后全连接层 fc 以外所有卷积、残差层参数
        # 除了最后一层的learning rate外,用的是默认的learning rate
        # 最后一层的learning rate用的是十倍的learning rate
        params_lx = [
            param for name, param in net.named_parameters()
            if name not in ["fc.weight","fc.bias"] ]
        trainer = torch.optim.SGD([
            {'params': params_lx}, 
            {'params': net.fc.parameters(), 'lr': learning_rate * 10}],
            lr=learning_rate, weight_decay=0.001) #权重衰减(L2 正则),抑制过拟合
   #全部参数共用同一个学习率,适合全程从头大幅度训练,微调场景一般不用。
     else:
        trainer = torch.optim.SGD(net.parameters(),lr=learning_rate,weight_decay=0.001)
#调用 d2l 封装好的第 13 章训练循环:自动完成前向传播、反向传播、参数更新、打印训练 / 测试准确率、损失变化
    d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)   
python 复制代码
# 使用较小的学习率
train_fine_tuning(finetune_net,5e-5)
复制代码
loss 0.163, train acc 0.932, test acc 0.935
265.9 examples/sec on [device(type='cuda', index=0)]
python 复制代码
# 为了进行比较,所有模型参数初始化为随机值
scratch_net = torchvision.models.resnet18() # 这里没有pretrained=True,没有拿预训练的参数   
scratch_net.fc = nn.Linear(scratch_net.fc.in_features,2)
train_fine_tuning(scratch_net,5e-4,param_group=False) # param_group=False使得所有层的参数都为默认的学习率   
复制代码
loss 0.349, train acc 0.847, test acc 0.806
422.1 examples/sec on [device(type='cuda', index=0)]

为什么微调要用「两头不同学习率」?💡

  1. 前面卷积层:已经学会识别边缘、纹理、物体轮廓等通用视觉知识,只需要小幅修正适配新数据集,学习率必须很小,步子大容易冲乱原有优质权重
  2. 最后 fc 层:是你全新初始化的,完全不认识热狗分类,需要更快收敛,所以设置10 倍更大学习率加速学习

3. 实战Kaggle比赛图像分类CIFAR10

① 比赛的网址是 Checking your browser - reCAPTCHA

python 复制代码
import collections
import math
import os
import shutil
import pandas as pd
import torch 
import torchvision
from torch import nn
from d2l import torch as d2l
python 复制代码
# 我们提供包含前1000个训练图像和5个随即测试图像的数据集的小规模样本
# cifar10_tiny是cifar10中每一个类把前面一千个训练图片拿出来,测试是每一个类挑五个图片
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',
                               '2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')
demo = True

if demo:
    data_dir = d2l.download_extract('cifar10_tiny')
else:
    data_dir = '../data/cifar-10'
python 复制代码
# 整理数据集
def read_csv_labels(fname):
    """读取 'fname' 来给标签字典返回一个文件名。"""
    with open(fname, 'r') as f:
        lines = f.readlines()[1:] # 一行一行读进来,每一行为列表中一个元素
    tokens = [l.rstrip().split(',') for l in lines] # 遍历列表每一个元素,切分
    return dict(((name, label) for name, label in tokens))

labels = read_csv_labels(os.path.join(data_dir,'trainLabels.csv'))
labels
复制代码
{'1': 'frog',
 '2': 'truck',
 '3': 'truck',
 '4': 'deer',
 '5': 'automobile',
 '6': 'automobile',
 '7': 'bird',
 '8': 'horse',
 '9': 'ship',
 '10': 'cat',
 '11': 'deer',
 '12': 'horse',
 '13': 'horse',
 '14': 'bird',
 '15': 'truck',
 '16': 'truck',
 '17': 'truck',
 '18': 'cat',
 '19': 'bird',
 '20': 'frog',
 '21': 'deer',
 '22': 'cat',
 '23': 'frog',
 '24': 'frog',
 '25': 'bird',
 '26': 'frog',
 '27': 'cat',
 '28': 'dog',
 '29': 'deer',
 '30': 'airplane',
 '31': 'airplane',
 '32': 'truck',
 '33': 'automobile',
 '34': 'cat',
 '35': 'deer',
 '36': 'airplane',
 '37': 'cat',
 '38': 'horse',
 '39': 'cat',
 '40': 'cat',
 '41': 'dog',
 '42': 'bird',
 '43': 'bird',
 '44': 'horse',
 '45': 'automobile',
 '46': 'automobile',
 '47': 'automobile',
 '48': 'bird',
 '49': 'bird',
 '50': 'airplane',
 '51': 'truck',
 '52': 'dog',
 '53': 'horse',
 '54': 'truck',
 '55': 'bird',
 '56': 'bird',
 '57': 'dog',
 '58': 'bird',
 '59': 'deer',
 '60': 'cat',
 '61': 'automobile',
 '62': 'automobile',
 '63': 'ship',
 '64': 'bird',
 '65': 'automobile',
 '66': 'automobile',
 '67': 'deer',
 '68': 'truck',
 '69': 'horse',
 '70': 'ship',
 '71': 'dog',
 '72': 'truck',
 '73': 'frog',
 '74': 'horse',
 '75': 'cat',
 '76': 'automobile',
 '77': 'truck',
 '78': 'airplane',
 '79': 'cat',
 '80': 'automobile',
 '81': 'cat',
 '82': 'dog',
 '83': 'deer',
 '84': 'dog',
 '85': 'horse',
 '86': 'horse',
 '87': 'deer',
 '88': 'horse',
 '89': 'truck',
 '90': 'deer',
 '91': 'bird',
 '92': 'cat',
 '93': 'ship',
 '94': 'airplane',
 '95': 'automobile',
 '96': 'frog',
 '97': 'automobile',
 '98': 'automobile',
 '99': 'deer',
 '100': 'automobile',
 '101': 'ship',
 '102': 'cat',
 '103': 'truck',
 '104': 'frog',
 '105': 'frog',
 '106': 'automobile',
 '107': 'ship',
 '108': 'dog',
 '109': 'bird',
 '110': 'truck',
 '111': 'truck',
 '112': 'ship',
 '113': 'automobile',
 '114': 'horse',
 '115': 'horse',
 '116': 'airplane',
 '117': 'airplane',
 '118': 'frog',
 '119': 'truck',
 '120': 'automobile',
 '121': 'bird',
 '122': 'bird',
 '123': 'truck',
 '124': 'bird',
 '125': 'frog',
 '126': 'frog',
 '127': 'automobile',
 '128': 'truck',
 '129': 'dog',
 '130': 'airplane',
 '131': 'deer',
 '132': 'horse',
 '133': 'frog',
 '134': 'horse',
 '135': 'automobile',
 '136': 'ship',
 '137': 'automobile',
 '138': 'automobile',
 '139': 'bird',
 '140': 'ship',
 '141': 'automobile',
 '142': 'cat',
 '143': 'cat',
 '144': 'frog',
 '145': 'bird',
 '146': 'deer',
 '147': 'truck',
 '148': 'truck',
 '149': 'dog',
 '150': 'deer',
 '151': 'cat',
 '152': 'frog',
 '153': 'horse',
 '154': 'deer',
 '155': 'frog',
 '156': 'ship',
 '157': 'dog',
 '158': 'dog',
 '159': 'deer',
 '160': 'cat',
 '161': 'automobile',
 '162': 'ship',
 '163': 'deer',
 '164': 'horse',
 '165': 'frog',
 '166': 'airplane',
 '167': 'truck',
 '168': 'dog',
 '169': 'automobile',
 '170': 'cat',
 '171': 'ship',
 '172': 'bird',
 '173': 'horse',
 '174': 'dog',
 '175': 'cat',
 '176': 'deer',
 '177': 'automobile',
 '178': 'dog',
 '179': 'horse',
 '180': 'airplane',
 '181': 'deer',
 '182': 'horse',
 '183': 'dog',
 '184': 'dog',
 '185': 'automobile',
 '186': 'airplane',
 '187': 'truck',
 '188': 'frog',
 '189': 'truck',
 '190': 'airplane',
 '191': 'ship',
 '192': 'horse',
 '193': 'ship',
 '194': 'ship',
 '195': 'bird',
 '196': 'dog',
 '197': 'bird',
 '198': 'cat',
 '199': 'dog',
 '200': 'airplane',
 '201': 'frog',
 '202': 'automobile',
 '203': 'truck',
 '204': 'cat',
 '205': 'frog',
 '206': 'truck',
 '207': 'automobile',
 '208': 'cat',
 '209': 'truck',
 '210': 'frog',
 '211': 'frog',
 '212': 'horse',
 '213': 'automobile',
 '214': 'airplane',
 '215': 'truck',
 '216': 'dog',
 '217': 'ship',
 '218': 'dog',
 '219': 'bird',
 '220': 'truck',
 '221': 'airplane',
 '222': 'ship',
 '223': 'ship',
 '224': 'airplane',
 '225': 'frog',
 '226': 'truck',
 '227': 'automobile',
 '228': 'automobile',
 '229': 'frog',
 '230': 'cat',
 '231': 'horse',
 '232': 'frog',
 '233': 'frog',
 '234': 'airplane',
 '235': 'frog',
 '236': 'frog',
 '237': 'automobile',
 '238': 'horse',
 '239': 'automobile',
 '240': 'dog',
 '241': 'ship',
 '242': 'cat',
 '243': 'frog',
 '244': 'frog',
 '245': 'ship',
 '246': 'frog',
 '247': 'ship',
 '248': 'deer',
 '249': 'frog',
 '250': 'frog',
 '251': 'automobile',
 '252': 'cat',
 '253': 'ship',
 '254': 'cat',
 '255': 'deer',
 '256': 'automobile',
 '257': 'horse',
 '258': 'automobile',
 '259': 'cat',
 '260': 'ship',
 '261': 'dog',
 '262': 'automobile',
 '263': 'automobile',
 '264': 'deer',
 '265': 'airplane',
 '266': 'truck',
 '267': 'cat',
 '268': 'horse',
 '269': 'deer',
 '270': 'truck',
 '271': 'truck',
 '272': 'bird',
 '273': 'deer',
 '274': 'truck',
 '275': 'truck',
 '276': 'automobile',
 '277': 'airplane',
 '278': 'dog',
 '279': 'truck',
 '280': 'airplane',
 '281': 'ship',
 '282': 'bird',
 '283': 'automobile',
 '284': 'bird',
 '285': 'airplane',
 '286': 'dog',
 '287': 'frog',
 '288': 'cat',
 '289': 'bird',
 '290': 'horse',
 '291': 'ship',
 '292': 'ship',
 '293': 'frog',
 '294': 'airplane',
 '295': 'horse',
 '296': 'truck',
 '297': 'deer',
 '298': 'dog',
 '299': 'frog',
 '300': 'deer',
 '301': 'bird',
 '302': 'automobile',
 '303': 'automobile',
 '304': 'bird',
 '305': 'automobile',
 '306': 'dog',
 '307': 'truck',
 '308': 'truck',
 '309': 'airplane',
 '310': 'ship',
 '311': 'deer',
 '312': 'automobile',
 '313': 'automobile',
 '314': 'frog',
 '315': 'cat',
 '316': 'cat',
 '317': 'truck',
 '318': 'airplane',
 '319': 'horse',
 '320': 'truck',
 '321': 'horse',
 '322': 'horse',
 '323': 'truck',
 '324': 'automobile',
 '325': 'dog',
 '326': 'automobile',
 '327': 'frog',
 '328': 'frog',
 '329': 'ship',
 '330': 'horse',
 '331': 'automobile',
 '332': 'cat',
 '333': 'airplane',
 '334': 'cat',
 '335': 'cat',
 '336': 'bird',
 '337': 'deer',
 '338': 'dog',
 '339': 'horse',
 '340': 'dog',
 '341': 'truck',
 '342': 'airplane',
 '343': 'cat',
 '344': 'deer',
 '345': 'airplane',
 '346': 'deer',
 '347': 'deer',
 '348': 'frog',
 '349': 'airplane',
 '350': 'airplane',
 '351': 'frog',
 '352': 'frog',
 '353': 'airplane',
 '354': 'ship',
 '355': 'automobile',
 '356': 'frog',
 '357': 'bird',
 '358': 'truck',
 '359': 'bird',
 '360': 'dog',
 '361': 'truck',
 '362': 'frog',
 '363': 'horse',
 '364': 'deer',
 '365': 'automobile',
 '366': 'ship',
 '367': 'horse',
 '368': 'cat',
 '369': 'frog',
 '370': 'truck',
 '371': 'cat',
 '372': 'airplane',
 '373': 'deer',
 '374': 'airplane',
 '375': 'dog',
 '376': 'automobile',
 '377': 'airplane',
 '378': 'cat',
 '379': 'deer',
 '380': 'ship',
 '381': 'dog',
 '382': 'deer',
 '383': 'horse',
 '384': 'bird',
 '385': 'cat',
 '386': 'truck',
 '387': 'horse',
 '388': 'frog',
 '389': 'horse',
 '390': 'automobile',
 '391': 'deer',
 '392': 'horse',
 '393': 'airplane',
 '394': 'automobile',
 '395': 'horse',
 '396': 'cat',
 '397': 'automobile',
 '398': 'ship',
 '399': 'deer',
 '400': 'deer',
 '401': 'bird',
 '402': 'airplane',
 '403': 'bird',
 '404': 'bird',
 '405': 'airplane',
 '406': 'airplane',
 '407': 'truck',
 '408': 'airplane',
 '409': 'truck',
 '410': 'frog',
 '411': 'ship',
 '412': 'bird',
 '413': 'horse',
 '414': 'horse',
 '415': 'deer',
 '416': 'airplane',
 '417': 'cat',
 '418': 'airplane',
 '419': 'ship',
 '420': 'truck',
 '421': 'deer',
 '422': 'bird',
 '423': 'horse',
 '424': 'bird',
 '425': 'dog',
 '426': 'bird',
 '427': 'dog',
 '428': 'automobile',
 '429': 'truck',
 '430': 'deer',
 '431': 'ship',
 '432': 'dog',
 '433': 'automobile',
 '434': 'horse',
 '435': 'deer',
 '436': 'deer',
 '437': 'airplane',
 '438': 'frog',
 '439': 'truck',
 '440': 'airplane',
 '441': 'horse',
 '442': 'ship',
 '443': 'ship',
 '444': 'truck',
 '445': 'truck',
 '446': 'cat',
 '447': 'cat',
 '448': 'deer',
 '449': 'airplane',
 '450': 'deer',
 '451': 'dog',
 '452': 'frog',
 '453': 'frog',
 '454': 'airplane',
 '455': 'automobile',
 '456': 'airplane',
 '457': 'ship',
 '458': 'airplane',
 '459': 'deer',
 '460': 'ship',
 '461': 'ship',
 '462': 'automobile',
 '463': 'dog',
 '464': 'bird',
 '465': 'frog',
 '466': 'ship',
 '467': 'automobile',
 '468': 'airplane',
 '469': 'airplane',
 '470': 'horse',
 '471': 'horse',
 '472': 'dog',
 '473': 'truck',
 '474': 'frog',
 '475': 'bird',
 '476': 'ship',
 '477': 'cat',
 '478': 'deer',
 '479': 'horse',
 '480': 'cat',
 '481': 'truck',
 '482': 'airplane',
 '483': 'automobile',
 '484': 'bird',
 '485': 'deer',
 '486': 'ship',
 '487': 'automobile',
 '488': 'ship',
 '489': 'frog',
 '490': 'deer',
 '491': 'deer',
 '492': 'dog',
 '493': 'horse',
 '494': 'automobile',
 '495': 'cat',
 '496': 'truck',
 '497': 'ship',
 '498': 'airplane',
 '499': 'automobile',
 '500': 'horse',
 '501': 'dog',
 '502': 'ship',
 '503': 'bird',
 '504': 'ship',
 '505': 'airplane',
 '506': 'deer',
 '507': 'automobile',
 '508': 'ship',
 '509': 'truck',
 '510': 'ship',
 '511': 'bird',
 '512': 'truck',
 '513': 'truck',
 '514': 'bird',
 '515': 'horse',
 '516': 'dog',
 '517': 'horse',
 '518': 'cat',
 '519': 'ship',
 '520': 'ship',
 '521': 'deer',
 '522': 'deer',
 '523': 'bird',
 '524': 'horse',
 '525': 'automobile',
 '526': 'frog',
 '527': 'deer',
 '528': 'airplane',
 '529': 'deer',
 '530': 'frog',
 '531': 'truck',
 '532': 'horse',
 '533': 'frog',
 '534': 'bird',
 '535': 'dog',
 '536': 'dog',
 '537': 'automobile',
 '538': 'horse',
 '539': 'bird',
 '540': 'bird',
 '541': 'bird',
 '542': 'truck',
 '543': 'dog',
 '544': 'deer',
 '545': 'bird',
 '546': 'horse',
 '547': 'ship',
 '548': 'automobile',
 '549': 'cat',
 '550': 'deer',
 '551': 'cat',
 '552': 'horse',
 '553': 'frog',
 '554': 'truck',
 '555': 'ship',
 '556': 'airplane',
 '557': 'frog',
 '558': 'airplane',
 '559': 'bird',
 '560': 'bird',
 '561': 'bird',
 '562': 'automobile',
 '563': 'ship',
 '564': 'deer',
 '565': 'airplane',
 '566': 'automobile',
 '567': 'ship',
 '568': 'ship',
 '569': 'automobile',
 '570': 'dog',
 '571': 'horse',
 '572': 'frog',
 '573': 'deer',
 '574': 'dog',
 '575': 'ship',
 '576': 'horse',
 '577': 'automobile',
 '578': 'truck',
 '579': 'automobile',
 '580': 'truck',
 '581': 'ship',
 '582': 'deer',
 '583': 'horse',
 '584': 'cat',
 '585': 'ship',
 '586': 'ship',
 '587': 'bird',
 '588': 'frog',
 '589': 'frog',
 '590': 'horse',
 '591': 'automobile',
 '592': 'frog',
 '593': 'ship',
 '594': 'automobile',
 '595': 'truck',
 '596': 'horse',
 '597': 'ship',
 '598': 'cat',
 '599': 'airplane',
 '600': 'automobile',
 '601': 'airplane',
 '602': 'ship',
 '603': 'ship',
 '604': 'cat',
 '605': 'airplane',
 '606': 'airplane',
 '607': 'automobile',
 '608': 'dog',
 '609': 'airplane',
 '610': 'ship',
 '611': 'ship',
 '612': 'horse',
 '613': 'truck',
 '614': 'truck',
 '615': 'airplane',
 '616': 'truck',
 '617': 'deer',
 '618': 'automobile',
 '619': 'cat',
 '620': 'frog',
 '621': 'frog',
 '622': 'deer',
 '623': 'deer',
 '624': 'horse',
 '625': 'dog',
 '626': 'frog',
 '627': 'airplane',
 '628': 'ship',
 '629': 'airplane',
 '630': 'cat',
 '631': 'bird',
 '632': 'ship',
 '633': 'deer',
 '634': 'frog',
 '635': 'truck',
 '636': 'truck',
 '637': 'horse',
 '638': 'airplane',
 '639': 'cat',
 '640': 'cat',
 '641': 'frog',
 '642': 'horse',
 '643': 'deer',
 '644': 'truck',
 '645': 'automobile',
 '646': 'frog',
 '647': 'bird',
 '648': 'horse',
 '649': 'bird',
 '650': 'bird',
 '651': 'airplane',
 '652': 'frog',
 '653': 'horse',
 '654': 'dog',
 '655': 'horse',
 '656': 'frog',
 '657': 'ship',
 '658': 'truck',
 '659': 'airplane',
 '660': 'truck',
 '661': 'deer',
 '662': 'deer',
 '663': 'horse',
 '664': 'airplane',
 '665': 'truck',
 '666': 'deer',
 '667': 'truck',
 '668': 'frog',
 '669': 'truck',
 '670': 'deer',
 '671': 'dog',
 '672': 'horse',
 '673': 'truck',
 '674': 'bird',
 '675': 'deer',
 '676': 'dog',
 '677': 'automobile',
 '678': 'deer',
 '679': 'cat',
 '680': 'truck',
 '681': 'frog',
 '682': 'dog',
 '683': 'frog',
 '684': 'truck',
 '685': 'cat',
 '686': 'cat',
 '687': 'dog',
 '688': 'airplane',
 '689': 'horse',
 '690': 'bird',
 '691': 'automobile',
 '692': 'cat',
 '693': 'frog',
 '694': 'deer',
 '695': 'airplane',
 '696': 'airplane',
 '697': 'bird',
 '698': 'dog',
 '699': 'airplane',
 '700': 'automobile',
 '701': 'airplane',
 '702': 'bird',
 '703': 'cat',
 '704': 'truck',
 '705': 'ship',
 '706': 'deer',
 '707': 'truck',
 '708': 'ship',
 '709': 'airplane',
 '710': 'bird',
 '711': 'frog',
 '712': 'deer',
 '713': 'deer',
 '714': 'airplane',
 '715': 'automobile',
 '716': 'ship',
 '717': 'ship',
 '718': 'cat',
 '719': 'frog',
 '720': 'truck',
 '721': 'frog',
 '722': 'frog',
 '723': 'horse',
 '724': 'ship',
 '725': 'bird',
 '726': 'deer',
 '727': 'dog',
 '728': 'horse',
 '729': 'frog',
 '730': 'dog',
 '731': 'cat',
 '732': 'airplane',
 '733': 'dog',
 '734': 'airplane',
 '735': 'dog',
 '736': 'airplane',
 '737': 'ship',
 '738': 'bird',
 '739': 'frog',
 '740': 'horse',
 '741': 'cat',
 '742': 'ship',
 '743': 'bird',
 '744': 'automobile',
 '745': 'horse',
 '746': 'frog',
 '747': 'horse',
 '748': 'automobile',
 '749': 'airplane',
 '750': 'truck',
 '751': 'dog',
 '752': 'dog',
 '753': 'airplane',
 '754': 'automobile',
 '755': 'horse',
 '756': 'frog',
 '757': 'truck',
 '758': 'airplane',
 '759': 'deer',
 '760': 'horse',
 '761': 'horse',
 '762': 'automobile',
 '763': 'dog',
 '764': 'truck',
 '765': 'deer',
 '766': 'airplane',
 '767': 'ship',
 '768': 'dog',
 '769': 'truck',
 '770': 'truck',
 '771': 'frog',
 '772': 'horse',
 '773': 'automobile',
 '774': 'ship',
 '775': 'cat',
 '776': 'bird',
 '777': 'cat',
 '778': 'ship',
 '779': 'bird',
 '780': 'bird',
 '781': 'deer',
 '782': 'frog',
 '783': 'airplane',
 '784': 'airplane',
 '785': 'dog',
 '786': 'cat',
 '787': 'ship',
 '788': 'bird',
 '789': 'cat',
 '790': 'horse',
 '791': 'bird',
 '792': 'truck',
 '793': 'cat',
 '794': 'ship',
 '795': 'horse',
 '796': 'ship',
 '797': 'bird',
 '798': 'horse',
 '799': 'truck',
 '800': 'airplane',
 '801': 'bird',
 '802': 'cat',
 '803': 'bird',
 '804': 'bird',
 '805': 'bird',
 '806': 'cat',
 '807': 'cat',
 '808': 'frog',
 '809': 'bird',
 '810': 'cat',
 '811': 'bird',
 '812': 'ship',
 '813': 'airplane',
 '814': 'dog',
 '815': 'dog',
 '816': 'automobile',
 '817': 'deer',
 '818': 'dog',
 '819': 'frog',
 '820': 'frog',
 '821': 'bird',
 '822': 'horse',
 '823': 'airplane',
 '824': 'automobile',
 '825': 'horse',
 '826': 'horse',
 '827': 'ship',
 '828': 'bird',
 '829': 'truck',
 '830': 'bird',
 '831': 'bird',
 '832': 'deer',
 '833': 'bird',
 '834': 'automobile',
 '835': 'automobile',
 '836': 'automobile',
 '837': 'frog',
 '838': 'frog',
 '839': 'frog',
 '840': 'dog',
 '841': 'automobile',
 '842': 'automobile',
 '843': 'horse',
 '844': 'airplane',
 '845': 'deer',
 '846': 'cat',
 '847': 'cat',
 '848': 'horse',
 '849': 'automobile',
 '850': 'bird',
 '851': 'cat',
 '852': 'dog',
 '853': 'dog',
 '854': 'dog',
 '855': 'frog',
 '856': 'automobile',
 '857': 'deer',
 '858': 'cat',
 '859': 'horse',
 '860': 'ship',
 '861': 'ship',
 '862': 'cat',
 '863': 'frog',
 '864': 'frog',
 '865': 'bird',
 '866': 'cat',
 '867': 'airplane',
 '868': 'truck',
 '869': 'deer',
 '870': 'cat',
 '871': 'ship',
 '872': 'airplane',
 '873': 'airplane',
 '874': 'automobile',
 '875': 'automobile',
 '876': 'dog',
 '877': 'deer',
 '878': 'truck',
 '879': 'cat',
 '880': 'automobile',
 '881': 'ship',
 '882': 'truck',
 '883': 'cat',
 '884': 'truck',
 '885': 'truck',
 '886': 'bird',
 '887': 'truck',
 '888': 'deer',
 '889': 'ship',
 '890': 'bird',
 '891': 'truck',
 '892': 'ship',
 '893': 'ship',
 '894': 'automobile',
 '895': 'dog',
 '896': 'cat',
 '897': 'frog',
 '898': 'ship',
 '899': 'horse',
 '900': 'frog',
 '901': 'truck',
 '902': 'ship',
 '903': 'airplane',
 '904': 'frog',
 '905': 'deer',
 '906': 'airplane',
 '907': 'airplane',
 '908': 'bird',
 '909': 'dog',
 '910': 'ship',
 '911': 'bird',
 '912': 'airplane',
 '913': 'bird',
 '914': 'horse',
 '915': 'frog',
 '916': 'truck',
 '917': 'horse',
 '918': 'automobile',
 '919': 'dog',
 '920': 'dog',
 '921': 'frog',
 '922': 'frog',
 '923': 'cat',
 '924': 'frog',
 '925': 'bird',
 '926': 'deer',
 '927': 'horse',
 '928': 'airplane',
 '929': 'dog',
 '930': 'frog',
 '931': 'deer',
 '932': 'frog',
 '933': 'dog',
 '934': 'bird',
 '935': 'deer',
 '936': 'frog',
 '937': 'automobile',
 '938': 'frog',
 '939': 'airplane',
 '940': 'deer',
 '941': 'airplane',
 '942': 'cat',
 '943': 'automobile',
 '944': 'ship',
 '945': 'dog',
 '946': 'deer',
 '947': 'deer',
 '948': 'automobile',
 '949': 'horse',
 '950': 'cat',
 '951': 'truck',
 '952': 'deer',
 '953': 'horse',
 '954': 'truck',
 '955': 'horse',
 '956': 'cat',
 '957': 'horse',
 '958': 'bird',
 '959': 'ship',
 '960': 'deer',
 '961': 'frog',
 '962': 'frog',
 '963': 'automobile',
 '964': 'bird',
 '965': 'truck',
 '966': 'airplane',
 '967': 'deer',
 '968': 'ship',
 '969': 'horse',
 '970': 'cat',
 '971': 'truck',
 '972': 'ship',
 '973': 'horse',
 '974': 'horse',
 '975': 'airplane',
 '976': 'bird',
 '977': 'deer',
 '978': 'automobile',
 '979': 'automobile',
 '980': 'deer',
 '981': 'automobile',
 '982': 'dog',
 '983': 'deer',
 '984': 'airplane',
 '985': 'dog',
 '986': 'frog',
 '987': 'bird',
 '988': 'ship',
 '989': 'dog',
 '990': 'airplane',
 '991': 'bird',
 '992': 'automobile',
 '993': 'cat',
 '994': 'dog',
 '995': 'horse',
 '996': 'cat',
 '997': 'dog',
 '998': 'automobile',
 '999': 'cat',
 '1000': 'dog'}
python 复制代码
# 将验证集从原始的训练集中拆分出来
# train文件夹下有所有train的图片,test文件夹下有所有test图片
# 把train文件夹下所有类的图片创建一个类名文件夹,然后搬到对应文件夹下
def copyfile(filename, target_dir):
    """将文件复制到目标目录"""
    os.makedirs(target_dir, exist_ok=True)
    shutil.copy(filename, target_dir)
python 复制代码
def reorg_train_valid(data_dir, labels, valid_ratio):
    n = collections.Counter(labels.values()).most_common()[-1][1]
    n_valid_per_label = max(1,math.floor(n * valid_ratio))
    label_count = {}
    for train_file in os.listdir(os.path.join(data_dir,'train')):
        label = labels[train_file.split('.')[0]]
        fname = os.path.join(data_dir,'train',train_file)
        copyfile(fname,os.path.join(data_dir,'train_valid_test','train_valid',label)) 
        if label not in label_count or label_count[label] < n_valid_per_label:                           
            copyfile(fname,os.path.join(data_dir,'train_valid_test','valid',label))
            label_count[label] = label_count.get(label,0) + 1
        else:
            copyfile(fname,os.path.join(data_dir,'train_valid_test','train',label))  
    return n_valid_per_label
python 复制代码
# 在预测期间整理测试集,以方便读取
def reorg_test(data_dir):
    for test_file in os.listdir(os.path.join(data_dir,'test')):
        copyfile(os.path.join(data_dir,'test',test_file),
                os.path.join(data_dir,'train_valid_test','test','unknown')) # unknown为 test文件夹里面的一个文件夹   
python 复制代码
# 调用前面定义的函数,前面只是定义函数,这里是调用
def reorg_cifar10_data(data_dir,valid_ratio):
    labels = read_csv_labels(os.path.join(data_dir,'trainLabels.csv'))
    reorg_train_valid(data_dir,labels,valid_ratio)
    reorg_test(data_dir)
    
batch_size = 32 if demo else 128
valid_ratio = 0.1 # train 数据里面百分之九十用来训练,剩下百分之十用来验证  
reorg_cifar10_data(data_dir, valid_ratio)
python 复制代码
# 图像增广
transform_train = torchvision.transforms.Compose([
    torchvision.transforms.Resize(40),
    torchvision.transforms.RandomResizedCrop(32,scale=(0.64,1.0),ratio=(1.0,1.0)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.4914,0.4822,0.4465],
                                   [0.2023,0.1994,0.2010]) ])

transform_test = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.4914,0.4822,0.4465],
                                    [0.2023,0.1994,0.2010]) ])
python 复制代码
# 读取由原始图像组成的数据集
train_ds, train_valid_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir,'train_valid_test',folder),
        transform=transform_train) for folder in ['train','train_valid'] ]

valid_ds, test_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir,'train_valid_test',folder),
        transform=transform_test) for folder in ['valid','test'] ]
python 复制代码
# 指定上面定义的所有图像增广操作
train_iter, train_valid_iter = [
    torch.utils.data.DataLoader(dataset,batch_size,shuffle=True,drop_last=True)
    for dataset in (train_ds, train_valid_ds) ]

valid_iter = torch.utils.data.DataLoader(valid_ds,batch_size,shuffle=False,drop_last=True)    
test_iter = torch.utils.data.DataLoader(test_ds,batch_size,shuffle=False,drop_last=False)    
python 复制代码
# 模型
def get_net():
    num_classes = 10
    net = d2l.resnet18(num_classes,3) # 3表示数值三通道,彩色图片
    return net

loss = nn.CrossEntropyLoss(reduction="none") # reduction="none" 表示不要把loss加起来sum
python 复制代码
# 训练函数
def train(net, train_iter, valid_iter, num_epoch, lr, wd, devices, lr_period, lr_decay):  # 每隔一段时间的lr_period把学习率lr_decay降低点     
    trainer = torch.optim.SGD(net.parameters(),lr=lr,momentum=0.9,weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay) 
    num_batches, timer = len(train_iter), d2l.Timer()
    legend = ['train loss','train acc']
    if valid_iter is not None:
        legend.append('valid acc')
    animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=legend)   
    net = nn.DataParallel(net,device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        net.train()
        metric = d2l.Accumulator(3)
        for i,(features,labels) in enumerate(train_iter):
            timer.start()
            l, acc = d2l.train_batch_ch13(net,features,labels,loss,trainer,devices)    
            metric.add(l,acc,labels.shape[0])
            timer.stop()
            if (i+1) % (num_batches // 5) == 0 or i == num_batches -1:
                animator.add(epoch + (i + 1) / num_batches, (metric[0]/metric[2], metric[1]/metric[2],None))      
        if valid_iter is not None:
            valid_acc = d2l.evaluate_accuracy_gpu(net,valid_iter)
            animator.add(epoch+1,(None,None,valid_acc))
        scheduler.step()
    measures = (f'train loss {metric[0] / metric[2]:.3f},'
               f'train acc {metric[1] / metric[2]:.3f}')
    if valid_iter is not None:
        measures += f', valid acc {valid_acc:.3f}'
    print(measures + f'\n{metric[2] * num_epochs / timer.sum():.1f}'
         f' examples/sec on {str(devices)}')
python 复制代码
# 训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4
lr_period, lr_decay, net = 4, 0.9, get_net()
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)
复制代码
train loss 0.618,train acc 0.790, valid acc 0.359
623.1 examples/sec on [device(type='cuda', index=0)]
python 复制代码
# 对测试集进行分类并提交结果
net, preds = get_net(), []
train(net, train_valid_iter, None, num_epochs, lr, wd, devices, lr_period, lr_decay)    
for X, _ in test_iter:
    y_hat = net(X.to(devices[0]))
    preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())
sorted_ids = list(range(1,len(test_ds)+1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id':sorted_ids,'label':preds})
df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])
df.to_csv('submission.csv',index=False)
复制代码
train loss 0.560,train acc 0.805
859.2 examples/sec on [device(type='cuda', index=0)]

4. 实战Kaggle比赛狗的品种识别ImageNetDogs

整体项目目标:120 类狗狗图像分类比赛 ,用 ** 迁移学习(预训练 ResNet 冻结微调)** 训练,最终生成 Kaggle 可提交的预测 csv 文件;demo=True用小数据集快速调试,正式跑关闭 demo 用完整数据集。

① 比赛网址是 Checking your browser - reCAPTCHA

python 复制代码
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

配置小数据集下载(demo 调试模式

  • d2l.DATA_HUB:注册数据集地址 + MD5 校验码,防止下载文件损坏
  • demo=True:开启小样例数据集(数据量小、训练快,适合学习调试
  • download_extract:自动下载 zip、自动解压,返回解压后的文件夹路径data_dir
  • demo=False:使用你本地提前下载好的 Kaggle 完整比赛数据集
python 复制代码
d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip',
                           '0cb91d09b814ecdc07b50f31f8dcad3e816a86d')

demo = True
if demo:
    data_dir = d2l.download_extract('dog_tiny')
else:
    data_dir = os.path_join('..','data','dog_breed-identification')
复制代码
Downloading ..\data\kaggle_dog_tiny.zip from http://d2l-data.s3-accelerate.amazonaws.com/kaggle_dog_tiny.zip...

数据集整理函数 reorg_dog_data

1. 函数形参:data_dir

  • 含义:数据集根目录路径(文件夹地址字符串)

  • 来源:前面代码 data_dir = d2l.download_extract('dog_tiny') 得到

  • 举例:

    复制代码
    'C:/xxx/data/kaggle_dog_tiny'

    这个文件夹内部一开始有:labels.csvtrain文件夹(所有训练图片)、test文件夹(所有测试图片)

2. 函数形参:valid_ratio

  • 英文直译:validation ratio,验证集比例
  • 作用:把原始训练图片拆分两部分:训练集 + 验证集
  • 比如 valid_ratio=0.1 = 拿全部训练图片的 10% 做验证集,剩下 90% 做训练集

3. 函数内部变量:labels

  • 类型:Python 字典 {key: value}

  • key:图片文件名(不带后缀 .jpg

  • value:该图片对应的狗狗品种名字

  • 来源:从 labels.csv 表格读取而来

  • 示例内容:

    复制代码
    {
        '000bec180eb03b78b2436a80fead340f': 'boston_bull',
        '001513dfcb2ffafc82712f7dcca1c318d': 'dingo',
        ...
    }

4. 全局变量:batch_size

  • 含义:批次大小

  • 深度学习不能一次性把几万张图片全部塞进显卡,会显存爆炸;一次拿 batch_size 张图片算一次梯度、更新一次参数

  • 三目表达式解释:

    复制代码
    batch_size = 32 if demo else 128
    • demo=True(小数据集调试):batch_size = 32,一次读 32 张图片
    • demo=False(完整大数据集):batch_size = 128,一次读 128 张图片

5. 全局变量:valid_ratio = 0.1

给上面函数要用到的验证集比例赋值:划分 10% 数据为验证集

二、逐行拆解函数内部每一行

复制代码
def reorg_dog_data(data_dir, valid_ratio):
  1. def 代表定义自定义函数,函数名叫 reorg_dog_data
  2. 函数接收两个输入参数:
    • 参数 1:data_dir 数据集根路径
    • 参数 2:valid_ratio 验证集划分比例
  3. 这个函数整体功能:一键重构、整理整个数据集文件夹结构

第 1 行函数内代码

复制代码
labels = d2l.read_csv_labels(os.path.join(data_dir,'labels.csv'))

os.path.join(a, b)

路径拼接函数,自动拼接文件夹 + 文件名,避免手动写斜杠 / 反斜杠 \ 出错 示例: os.path.join('C:/dogdata', 'labels.csv') → 拼接出完整路径 C:/dogdata/labels.csv

d2l.read_csv_labels(文件路径)

d2l 是李沐动手学深度学习工具库自带函数,专门适配这个狗狗比赛的 csv 文件:

  • 打开 labels.csv 文件
  • csv 内部两列:id(图片名)、breed(品种)
  • 自动解析成字典,赋值给变量 labels

第 2 行函数内代码

复制代码
d2l.reorg_train_valid(data_dir,labels,valid_ratio)

调用 d2l 库内置函数,拆分 + 整理带标签的训练图片,详细过程:

  1. 读取原始 train 文件夹里所有图片
  2. 借助 labels 字典知道每张图片是什么品种
  3. 按照 valid_ratio 比例随机拆分:
    • 90% 图片 → 放到 train_valid_test/train/对应品种文件夹(训练集)
    • 10% 图片 → 放到 train_valid_test/valid/对应品种文件夹(验证集)
    • 全部原始训练图合并一份 → train_valid_test/train_valid/对应品种文件夹(全集,后期用来最终训练)
python 复制代码
# 整理数据集
def reorg_dog_data(data_dir, valid_ratio):
    labels = d2l.read_csv_labels(os.path.join(data_dir,'labels.csv'))
    d2l.reorg_train_valid(data_dir,labels,valid_ratio)
    d2l.reorg_test(data_dir)
    
batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_dog_data(data_dir, valid_ratio)

图像预处理 + 图像增广 ,分为训练集变换 transform_train、测试集变换 transform_test

  1. transform 是什么 图片不能直接丢进神经网络,必须统一尺寸、转成张量、归一化;transform 就是图片预处理流水线
  2. 为什么要分 train 和 test 两套 transform?
  • 训练集:要做随机增广(随机裁剪、翻转、调色),增加数据多样性,防止过拟合
  • 测试集:不能带任何随机操作,保证每次预测结果稳定

逐行拆解 transform_train(训练集预处理)

1. 外层容器 Compose

复制代码
transform_train = torchvision.transforms.Compose([操作1,操作2,操作3...])
  • 变量名:transform_train,专门给训练集图片使用的预处理规则
  • Compose([]):把括号里面多个图像处理步骤按顺序串行执行,打包成一个整体变换对象

操作 1:RandomResizedCrop 随机缩放裁剪

python

运行

复制代码
torchvision.transforms.RandomResizedCrop(224,scale=(0.08,1.0),ratio=(3.0/4.0, 4.0/3.0))

逐个参数拆解:

  1. 224:最终裁剪输出图片尺寸 224×224(ResNet 模型固定输入尺寸)
  2. scale=(0.08, 1.0): 先在原图上随机截取一块区域,这块区域面积是原图面积的 8% ~ 100% 截取小区域迫使模型学习局部特征,提升泛化能力
  3. ratio=(3/4, 4/3): 随机截取区域宽高比在 0.75 ~ 1.33 之间,避免裁剪出极端细长图片
  4. 整体流程:随机裁一块区域 → 拉伸缩放到 224×224

操作 2:RandomHorizontalFlip 随机水平翻转

python

运行

复制代码
torchvision.transforms.RandomHorizontalFlip()
  • 默认概率 0.5:每张图片有 50% 概率左右镜像翻转
  • 狗朝左、朝右都是同一个品种,扩充样本多样性,最简单有效的数据增广

操作 3:ColorJitter 随机色彩抖动

python

运行

复制代码
torchvision.transforms.ColorJitter(brightness=0.4,contrast=0.4,saturation=0.4)

随机轻微修改图片属性,模拟拍照光线不同场景:

  • brightness=0.4:亮度随机浮动 ±40%
  • contrast=0.4:对比度随机浮动 ±40%
  • saturation=0.4:饱和度随机浮动 ±40% 不修改色相(色调),避免改变狗狗本身颜色特征。

操作 4:ToTensor () 转张量

python

运行

复制代码
torchvision.transforms.ToTensor()
  1. 原始图片是 PIL 图片,像素范围 [0, 255] 整数
  2. 转为 PyTorch 张量 Tensor,维度 (通道C, 高度H, 宽度W)
  3. 像素值自动除以 255,归一到 [0, 1] 浮点数

维度变化举例:(H,W,C) → (C,H,W),符合卷积网络输入格式要求

操作 5:Normalize 标准化(最重要)

python

运行

复制代码
torchvision.transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])

公式: \(x_{out}=\frac{x-\mu}{\sigma}\)

  • 第一个列表 [0.485,0.456,0.406]:RGB 三通道均值 mean
  • 第二个列表 [0.229,0.224,0.225]:RGB 三通道标准差 std

为什么固定这组数值?

这是 ImageNet 数据集全局统计均值方差,我们要用预训练 ResNet,预训练时人家就是用这个归一化; 前后归一化规则必须一模一样,否则模型提取特征完全错乱,准确率暴跌。

python 复制代码
# 图像增广
transform_train = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224,scale=(0.08,1.0),ratio=(3.0/4.0, 4.0/3.0)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ColorJitter(brightness=0.4,contrast=0.4,saturation=0.4),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485,0.456,0.406],
                                    [0.229,0.224,0.225])])

transform_test = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485,0.456,0.406],
                                     [0.229,0.224,0.225])])
变量名 类型 用途 使用哪个 transform
train_ds Dataset 数据集对象 划分后的训练集 transform_train(带增广)
train_valid_ds Dataset 数据集对象 训练 + 验证全集 transform_train
valid_ds Dataset 数据集对象 验证集(看过拟合) transform_test(无增广)
test_ds Dataset 数据集对象 比赛测试集,用来预测 transform_test
train_iter DataLoader 迭代器 分批读取 train_ds -
train_valid_iter DataLoader 迭代器 分批读取全集 -
valid_iter DataLoader 迭代器 分批读取验证集 -
test_iter DataLoader 迭代器 分批读取测试集 -

创建数据集 Dataset + 数据加载器 DataLoader

python 复制代码
train_ds, train_valid_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir,'train_valid_test',folder),
        transform=transform_train) for folder in ['train','train_valid']]       

valid_ds, test_ds = [
    torchvision.datasets.ImageFolder(
    os.path.join(data_dir,'train_valid_test',folder),
    transform=transform_test) for folder in ['valid','test']]

train_iter, train_valid_iter = [
    torch.utils.data.DataLoader(dataset,batch_size,shuffle=True,drop_last=True) for dataset in (train_ds, train_valid_ds)]

valid_iter = torch.utils.data.DataLoader(valid_ds,batch_size,shuffle=False,drop_last=True)
test_iter = torch.utils.data.DataLoader(test_ds,batch_size,shuffle=False,drop_last=False)      
python 复制代码
# 微调预训练模型
# 除了最后一层外,前面的层固定住参数不变
def get_net(device):
    finetune_net = nn.Sequential()
    finetune_net.features = torchvision.models.resnet34(pretrained=True)
    print("finetune_net:", finetune_net)
    finetune_net.output_new = nn.Sequential(nn.Linear(1000,256),nn.ReLU(),nn.Linear(256,120)) #  在原始网络后又加了一层 
    print("finetune_net:", finetune_net)
    finetune_net = finetune_net.to(devices[0])
    for param in finetune_net.features.parameters(): # 遍历features的所有参数
        param.requires_grad = False  
    return finetune_net # 返回整个网络,这个网络中原始层的参数固定住了,保持不变
python 复制代码
# 计算损失
loss = nn.CrossEntropyLoss(reduction='none')

def evaluate_loss(data_iter, net, devices):
    l_sum, n = 0.0, 0
    for features, labels in data_iter:
        features, labels = features.to(devices[0]), labels.to(devices[0])
        outputs = net(features)
        l = loss(outputs, labels)
        l_sum += l.sum()
        n += labels.numel()
    return l_sum / n
python 复制代码
# 训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,lr_decay):
    net = nn.DataParallel(net,device_ids=devices).to(devices[0])
    trainer = torch.optim.SGD(
        (param for param in net.parameters() if param.requires_grad), 
        lr = lr, momentum = 0.9, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
    num_batches, timer = len(train_iter), d2l.Timer()
    legend = ['train loss']
    if valid_iter is not None:
        legend.append('valid loss')
    animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=legend)
    
    for epoch in range(num_epochs):
        metric = d2l.Accumulator(2)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            features, labels = features.to(devices[0]), labels.to(devices[0])
            trainer.zero_grad()
            output = net(features)
            l = loss(output, labels).sum()
            l.backward()
            trainer.step()
            metric.add(l,labels.shape[0])
            timer.stop()
            if (i+1) % (num_batches // 5) == 0 or i == num_batches -1:
                animator.add(epoch + (i+1) / num_batches,
                             (metric[0] / metric[1], None))
        measures = f'train loss {metric[0] / metric[1]:.3f}'
        if valid_iter is not None:
            valid_loss = evaluate_loss(valid_iter, net, devices)
            animator.add(epoch + 1, (None, valid_loss.detach()))
        scheduler.step()
    if valid_iter is not None:
        measures += f', valid loss {valid_loss:.3f}'
    print(measures + f'\n{metric[1] * num_epochs / timer.sum():.1f}'
         f' examples/sec on {str(devices)}')
python 复制代码
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 10, 1e-4, 1e-4
lr_period, lr_decay, net = 2, 0.9, get_net(devices)
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)  
复制代码
train loss 1.119, valid loss 1.561
569.0 examples/sec on [device(type='cuda', index=0)]
python 复制代码
net = get_net(devices)
train(net, train_valid_iter, None, num_epochs, lr, wd, devices, lr_period, lr_decay)  
复制代码
train loss 1.068
761.3 examples/sec on [device(type='cuda', index=0)]
python 复制代码
preds = []
for data, label in test_iter:
    # 计算每一个样本对每一类的概率是多少
    output = torch.nn.functional.softmax(net(data.to(devices[0])), dim=0)
    preds.extend(output.cpu().detach().numpy())
    print(len(preds))
ids = sorted(os.listdir(os.path.join(data_dir, 'train_valid_test', 'test', 'unknown')))    
with open('submission.csv','w') as f:
    f.write('id,' + ','.join(train_valid_ds.classes)+'\n')
    for i, output in zip(ids, preds):
        f.write(i.split('.')[0] + ',' + ','.join([str(num) for num in output]) + '\n')  
复制代码
10

5. 目标检测

1. 总结

2. 目标检测和边界框

python 复制代码
%matplotlib inline
import torch 
from d2l import torch as d2l

d2l.set_figsize()
img = d2l.plt.imread('01_Data/img/catdog.jpg')
d2l.plt.imshow(img)
复制代码
<matplotlib.image.AxesImage at 0x19a816c4a90>
python 复制代码
# 定义在这两种表示之间进行转换的函数
def box_corner_to_center(boxes):
    """从(左上,右下)转换到(中间,宽度,高度)"""
    x1, y1, x2, y2 = boxes[:,0], boxes[:,1], boxes[:,2], boxes[:,3]     
    cx = (x1 + x2) / 2
    cy = (y1 + y2) / 2
    w  = x2 - x1
    h  = y2 - y1
    boxes = torch.stack((cx,cy,w,h),axis = -1)
    return boxes

def box_center_to_corner(boxes):
    """从(中间,宽度,高度)转换到(左上,右下)"""
    cx, cy, w, h = boxes[:,0], boxes[:,1], boxes[:,2], boxes[:,3]
    x1 = cx - 0.5 * w
    y1 = cy - 0.5 * h
    x2 = cx + 0.5 * w 
    y2 = cy + 0.5 * h
    boxes = torch.stack((x1,y1,x2,y2),axis = -1)
    return boxes

构造猫狗两个边界框,测试坐标互转函数是否写对,来回转换后和原框一致,代表转换逻辑没问题。

python 复制代码
# 定义图像中狗和猫的边界框
dog_bbox, cat_bbox = [60.0, 45.0, 378.0, 516.0], [400.0, 112.0, 655.0, 493.0]  
#把两个框打包成嵌套列表,并把普通 Python 列表,转成 PyTorch 张量(神经网络专用数组)
boxes = torch.tensor((dog_bbox,cat_bbox))
# boxes 转中间表示,再转回来,等于自己
box_center_to_corner(box_corner_to_center(boxes)) == boxes 
复制代码
tensor([[True, True, True, True],
        [True, True, True, True]])

目标:在图片上画两个方框,蓝框圈狗、红框圈猫

fill=False fill = 填充;False = 矩形内部不填充颜色,只画空心边框 如果写 True,整个框内部会被颜色糊住,挡住图片里的猫狗。

  • fig.axes axes = 坐标轴区域,一张图的绘图区域,所有线条、方框都要添加到 axes 上才看得见。
  • patch matplotlib 里,线条、矩形、圆形这类几何图形都叫 patch(图形补丁),add_patch 就是往图上加几何图形。
python 复制代码
# 将边界框在图中画出
def bbox_to_rect(bbox,color):
    return d2l.plt.Rectangle(xy=(bbox[0],bbox[1]),width=bbox[2]-bbox[0],     
                            height=bbox[3] - bbox[1], fill=False,
                            edgecolor=color,linewidth=2)

fig = d2l.plt.imshow(img)
fig.axes.add_patch(bbox_to_rect(dog_bbox,'blue'))
fig.axes.add_patch(bbox_to_rect(cat_bbox,'red'))
复制代码
<matplotlib.patches.Rectangle at 0x19a8178f9b0>

3. 目标检测数据集

整体作用:注册香蕉目标检测数据集的下载地址 + 校验码,为后续自动下载数据集做准备

1、先搞懂 d2l.DATA_HUB 是什么

DATA_HUB 是 d2l 库内置的字典 格式:数据集名字 : (下载网址, MD5校验码) 作用:给数据集起别名,后面调用 d2l.download_extract('banana-detection') 时,程序自动去这个地址下载压缩包

2、拆分括号里两个元素

① 第一个元素:下载地址

plaintext

复制代码
d2l.DATA_URL + 'banana-detection.zip'
  • d2l.DATA_URL 是 d2l 预设好的云端根地址
  • 拼接后 = 香蕉数据集压缩包完整网络下载链接

② 第二个长字符串:MD5 校验码

'5de25c8fce5ccdea9f91267273465dc968d20d72'

  • 作用:校验下载文件完整性
  • 如果下载中途断网、文件损坏,MD5 比对不一致,d2l 会提示文件异常,避免用破损数据集训练报错

3、整行赋值含义

DATA_HUB 字典里新增一条记录:

  • 数据集别名:banana-detection
  • 对应资源:香蕉检测数据集压缩包 + 校验码
python 复制代码
%matplotlib inline
import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l

d2l.DATA_HUB['banana-detection'] = (d2l.DATA_URL + 'banana-detection.zip','5de25c8fce5ccdea9f91267273465dc968d20d72')

自动下载香蕉数据集,读取所有图片 + 每张图里香蕉的边界框标签,最后做格式整理、坐标归一化,返回给后面训练用。

download_extract 作用:

  • 先检查本地有没有下载解压好的数据集
  • 没有就自动下载 zip 包、自动解压
  • 返回解压后的总文件夹路径 ,存到变量 data_dir

**os.join(a,b,c):**自动拼接路径,避免斜杠正反写错

  • .iterrows():pandas 遍历表格每一行的方法
  • 每次循环得到两个变量:
    1. img_name:当前行索引 = 图片名字(比如 0.png1.png
    2. target:当前行剩下全部内容(类别 + x1,y1,x2,y2 五个数字)

返回两个东西:

  1. images:列表,每个元素是单张图片张量
  2. 处理好的标签张量,形状 [N, 1, 5],坐标归一化到 0~1

整段函数完整流程总结

  1. 根据 is_train 判断读训练集还是验证集
  2. 自动下载 / 定位数据集,找到标签 label.csv
  3. 读取表格,图片名和标签一一绑定
  4. 循环:读每张图片存列表、读对应标签存列表
  5. 打印一堆信息调试查看数据结构
  6. 标签转张量、增加维度适配检测网络、坐标除以 256 归一化
  7. 返回图片列表 + 规整好的标签张量

补充小疑问:为什么要 unsqueeze (1) 多一维?

如果形状 [N,5],只能代表 "每张图 5 个数字"; 改成 [N,1,5] 语义清晰:每张图有 1 个物体框,每个框 5 个参数 ,以后拓展到一张图多个物体(比如 3 根香蕉,写成 [N,3,5])代码不用大改,是目标检测通用格式。

python 复制代码
# 读取香蕉检测数据集
def read_data_bananas(is_train=True):
    """读取香蕉检测数据集中的图像和标签"""
#download_extract先检查本地有没有下载解压好的数据集,没有再自动下载
    data_dir = d2l.download_extract('banana-detection')
#csv_fname 就是标签文件 label.csv 的完整路径
    csv_fname = os.path.join(data_dir,
                            'bananas_train' if is_train else 'bananas_val',
                            'label.csv')
    csv_data = pd.read_csv(csv_fname)
    csv_data = csv_data.set_index('img_name')
    images, targets = [], []#创建两个空列表,存图片、存标签
    # 把图片、标号全部读到内存里面
    for img_name, target in csv_data.iterrows():#按行遍历
        #.append(...):把读好的图片放进 images 列表
        images.append(torchvision.io.read_image(os.path.join(data_dir,'bananas_train' if is_train else 'bananas_val','images',f'{img_name}')))
        targets.append(list(target))
    print("len(targets):",len(targets))#打印一共有多少张图片(多少组标签)
    print("len(targets[0]):",len(targets[0]))#打印第一张图片标签长度:固定是 5:[类别, x1, y1, x2, y2]
    #打印第一张标签五个数字分别是什么,直观看到类别、四个坐标
    print("targets[0][0]....targets[0][4]:",targets[0][0], targets[0][1], targets[0][2], targets[0][3], targets[0][4])    
    print("type(targets):",type(targets))    
#把嵌套列表转张量;.unsqueeze(1):在第1维新增一个长度为 1 的维度
print("torch.tensor(targets).unsqueeze(1).shape:",torch.tensor(targets).unsqueeze(1).shape) # unsqueeze函数在指定位置加上维数为一的维度   
#坐标归一化
    print("len(torch.tensor(targets).unsqueeze(1) / 256):", len(torch.tensor(targets).unsqueeze(1) / 256))   
    print("type(torch.tensor(targets).unsqueeze(1) / 256):", type(torch.tensor(targets).unsqueeze(1) / 256))   
    return images, torch.tensor(targets).unsqueeze(1) / 256 # 归一化使得收敛更快
python 复制代码
# 创建一个自定义Dataset实例
class BananasDataset(torch.utils.data.Dataset):
    """一个用于加载香蕉检测数据集的自定义数据集"""
    def __init__(self, is_train):
        self.features, self.labels = read_data_bananas(is_train)
        print('read ' + str(len(self.features)) + (f' training examples' if is_train else f'validation examples'))   
        
    def __getitem__(self, idx):
        return (self.features[idx].float(), self.labels[idx])
    
    def __len__(self):
        return len(self.features)
python 复制代码
# 为训练集和测试集返回两个数据加载器实例
def load_data_bananas(batch_size):
    """加载香蕉检测数据集"""
    train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
                                            batch_size, shuffle=True)
    val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
                                          batch_size)
    return train_iter, val_iter
python 复制代码
# 读取一个小批量,并打印其中的图像和标签的形状
batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))
# ([32,1,5]) 中的1是每张图片中有几种类别,这里只有一种香蕉要识别的类别    
# 5是类别标号、框的四个参数
batch[0].shape, batch[1].shape
复制代码
Downloading ..\data\banana-detection.zip from http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip...
len(targets): 1000
len(targets[0]): 5
targets[0][0]....targets[0][4]: 0 104 20 143 58
type(targets): <class 'list'>
torch.tensor(targets).unsqueeze(1).shape: torch.Size([1000, 1, 5])
len(torch.tensor(targets).unsqueeze(1) / 256): 1000
type(torch.tensor(targets).unsqueeze(1) / 256): <class 'torch.Tensor'>
read 1000 training examples
Downloading ..\data\banana-detection.zip from http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip...
len(targets): 100
len(targets[0]): 5
targets[0][0]....targets[0][4]: 0 183 63 241 112
type(targets): <class 'list'>
torch.tensor(targets).unsqueeze(1).shape: torch.Size([100, 1, 5])
len(torch.tensor(targets).unsqueeze(1) / 256): 100
type(torch.tensor(targets).unsqueeze(1) / 256): <class 'torch.Tensor'>
read 100validation examples
复制代码
(torch.Size([32, 3, 256, 256]), torch.Size([32, 1, 5]))
python 复制代码
# 示例
# pytorch里permute是改变参数维度的函数,
# Dataset里读的img维度是[batch_size, RGB, h, w],
# 但是plt画图的时候要求是[h, w, RGB],所以要调整一下

# 做图片的时候,一般是会用一个ToTensor()将图片归一化到【0, 1】,这样收敛更快
print("原始图片:\n", batch[0][0])
print("原始图片:\n", (batch[0][0:10].permute(0,2,3,1)))
print("归一化后图片:\n", (batch[0][0:10].permute(0,2,3,1)) / 255 )
imgs = (batch[0][0:10].permute(0,2,3,1)) / 255
#imgs = (batch[0][0:10].permute(0,2,3,1))
# d2l.show_images输入的imgs图片参数是归一化后的图片
axes = d2l.show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][0:10]):
    d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
复制代码
原始图片:
 tensor([[[248., 249., 250.,  ..., 193., 194., 193.],
         [245., 244., 243.,  ..., 195., 197., 196.],
         [243., 243., 241.,  ..., 197., 200., 201.],
         ...,
         [ 17.,  10.,  13.,  ...,  92., 112., 119.],
         [ 19.,  14.,  12.,  ..., 114., 115., 113.],
         [ 13.,  22.,  12.,  ...,  98., 104., 118.]],

        [[252., 253., 252.,  ..., 206., 207., 206.],
         [249., 248., 245.,  ..., 205., 207., 206.],
         [245., 245., 243.,  ..., 206., 209., 210.],
         ...,
         [ 12.,   5.,   8.,  ...,  82., 102., 109.],
         [ 14.,   9.,   7.,  ..., 105., 106., 104.],
         [  8.,  17.,   7.,  ...,  91.,  95., 109.]],

        [[251., 252., 251.,  ..., 215., 216., 215.],
         [248., 247., 244.,  ..., 214., 216., 215.],
         [244., 244., 242.,  ..., 213., 216., 217.],
         ...,
         [  6.,   0.,   2.,  ...,  72.,  92.,  99.],
         [  8.,   3.,   1.,  ...,  96.,  97.,  95.],
         [  2.,  11.,   1.,  ...,  81.,  86., 100.]]])
原始图片:
 tensor([[[[248., 252., 251.],
          [249., 253., 252.],
          [250., 252., 251.],
          ...,
          [193., 206., 215.],
          [194., 207., 216.],
          [193., 206., 215.]],

         [[245., 249., 248.],
          [244., 248., 247.],
          [243., 245., 244.],
          ...,
          [195., 205., 214.],
          [197., 207., 216.],
          [196., 206., 215.]],

         [[243., 245., 244.],
          [243., 245., 244.],
          [241., 243., 242.],
          ...,
          [197., 206., 213.],
          [200., 209., 216.],
          [201., 210., 217.]],

         ...,

         [[ 17.,  12.,   6.],
          [ 10.,   5.,   0.],
          [ 13.,   8.,   2.],
          ...,
          [ 92.,  82.,  72.],
          [112., 102.,  92.],
          [119., 109.,  99.]],

         [[ 19.,  14.,   8.],
          [ 14.,   9.,   3.],
          [ 12.,   7.,   1.],
          ...,
          [114., 105.,  96.],
          [115., 106.,  97.],
          [113., 104.,  95.]],

         [[ 13.,   8.,   2.],
          [ 22.,  17.,  11.],
          [ 12.,   7.,   1.],
          ...,
          [ 98.,  91.,  81.],
          [104.,  95.,  86.],
          [118., 109., 100.]]],


        [[[180., 167., 132.],
          [177., 163., 128.],
          [169., 153., 117.],
          ...,
          [172., 140., 102.],
          [168., 138., 100.],
          [165., 135.,  97.]],

         [[186., 173., 138.],
          [181., 167., 130.],
          [176., 158., 122.],
          ...,
          [171., 139., 101.],
          [170., 138., 100.],
          [166., 136.,  98.]],

         [[187., 173., 136.],
          [181., 167., 128.],
          [175., 157., 119.],
          ...,
          [172., 139., 104.],
          [172., 139., 104.],
          [170., 137., 102.]],

         ...,

         [[173., 148., 118.],
          [146., 121.,  91.],
          [173., 148., 117.],
          ...,
          [182., 151., 131.],
          [138., 106.,  83.],
          [142., 110.,  85.]],

         [[ 80.,  60.,  33.],
          [151., 132., 102.],
          [193., 174., 142.],
          ...,
          [215., 194., 175.],
          [117.,  95.,  71.],
          [139., 118.,  91.]],

         [[129., 113.,  87.],
          [110.,  95.,  66.],
          [119., 102.,  72.],
          ...,
          [141., 124., 104.],
          [164., 146., 122.],
          [181., 164., 136.]]],


        [[[169., 146.,  50.],
          [182., 157.,  64.],
          [187., 160.,  69.],
          ...,
          [ 85.,  68.,  40.],
          [145., 133., 107.],
          [253., 246., 220.]],

         [[162., 139.,  45.],
          [163., 138.,  46.],
          [169., 142.,  55.],
          ...,
          [127., 107.,  80.],
          [157., 145., 119.],
          [249., 242., 216.]],

         [[163., 137.,  50.],
          [160., 134.,  47.],
          [177., 149.,  66.],
          ...,
          [138., 117.,  90.],
          [156., 141., 118.],
          [254., 243., 221.]],

         ...,

         [[ 18.,  19.,  11.],
          [ 11.,  12.,   4.],
          [ 13.,  14.,   6.],
          ...,
          [ 49.,  48.,  17.],
          [ 90.,  88.,  63.],
          [248., 246., 225.]],

         [[ 13.,  16.,   5.],
          [ 11.,  14.,   3.],
          [ 18.,  20.,   9.],
          ...,
          [ 38.,  39.,   8.],
          [ 86.,  85.,  64.],
          [245., 244., 226.]],

         [[ 11.,  15.,   1.],
          [  8.,  12.,   0.],
          [ 18.,  20.,   9.],
          ...,
          [ 35.,  38.,   9.],
          [ 88.,  87.,  67.],
          [249., 247., 232.]]],


        ...,


        [[[158., 108.,  35.],
          [153., 108.,  43.],
          [101.,  67.,  22.],
          ...,
          [129., 125.,  87.],
          [189., 184., 164.],
          [226., 220., 208.]],

         [[164., 115.,  36.],
          [106.,  62.,   0.],
          [107.,  70.,  18.],
          ...,
          [118., 115.,  80.],
          [173., 168., 146.],
          [151., 148., 131.]],

         [[203., 154.,  62.],
          [184., 137.,  55.],
          [109.,  65.,   2.],
          ...,
          [176., 172., 145.],
          [195., 195., 169.],
          [116., 116.,  92.]],

         ...,

         [[ 99.,  47.,  10.],
          [134.,  87.,  57.],
          [ 64.,  27.,   9.],
          ...,
          [201., 140.,  57.],
          [146.,  87.,   7.],
          [167., 108.,  30.]],

         [[ 71.,  28.,   0.],
          [137.,  99.,  50.],
          [ 83.,  53.,  17.],
          ...,
          [214., 153.,  70.],
          [182., 122.,  34.],
          [168., 109.,  17.]],

         [[ 89.,  51.,   0.],
          [170., 135.,  77.],
          [134., 107.,  62.],
          ...,
          [195., 134.,  51.],
          [182., 123.,  31.],
          [209., 151.,  52.]]],


        [[[196., 198.,  97.],
          [178., 180.,  79.],
          [194., 194.,  98.],
          ...,
          [116.,  74.,  34.],
          [ 76.,  42.,   4.],
          [ 61.,  33.,   0.]],

         [[198., 201.,  98.],
          [190., 193.,  90.],
          [191., 193.,  92.],
          ...,
          [108.,  67.,  23.],
          [101.,  68.,  23.],
          [103.,  75.,  28.]],

         [[206., 209., 104.],
          [195., 198.,  91.],
          [181., 185.,  75.],
          ...,
          [123.,  84.,  29.],
          [171., 136.,  80.],
          [177., 144.,  90.]],

         ...,

         [[131., 127.,  64.],
          [130., 129.,  65.],
          [125., 126.,  60.],
          ...,
          [ 93., 112.,  20.],
          [ 93., 110.,  16.],
          [101., 118.,  24.]],

         [[130., 124.,  64.],
          [132., 128.,  65.],
          [126., 126.,  62.],
          ...,
          [ 96., 115.,  23.],
          [100., 119.,  27.],
          [104., 123.,  31.]],

         [[126., 120.,  60.],
          [129., 125.,  64.],
          [127., 126.,  62.],
          ...,
          [108., 127.,  35.],
          [112., 131.,  39.],
          [110., 129.,  37.]]],


        [[[ 57.,  82.,  40.],
          [ 62.,  87.,  45.],
          [ 39.,  65.,  26.],
          ...,
          [244., 253., 232.],
          [ 94., 108.,  83.],
          [133., 149., 122.]],

         [[ 55.,  80.,  38.],
          [ 63.,  88.,  46.],
          [ 33.,  60.,  19.],
          ...,
          [207., 216., 199.],
          [118., 132., 109.],
          [ 57.,  73.,  46.]],

         [[ 41.,  66.,  24.],
          [ 39.,  64.,  22.],
          [ 41.,  66.,  24.],
          ...,
          [235., 241., 231.],
          [ 86.,  99.,  79.],
          [ 48.,  66.,  40.]],

         ...,

         [[ 68.,  90.,  44.],
          [ 63.,  85.,  38.],
          [ 53.,  73.,  22.],
          ...,
          [ 56.,  71.,  42.],
          [ 52.,  65.,  39.],
          [ 37.,  50.,  24.]],

         [[ 46.,  67.,  24.],
          [ 63.,  82.,  37.],
          [ 50.,  67.,  22.],
          ...,
          [ 45.,  59.,  33.],
          [ 52.,  64.,  40.],
          [ 35.,  47.,  23.]],

         [[ 44.,  65.,  24.],
          [ 44.,  62.,  20.],
          [ 55.,  72.,  28.],
          ...,
          [ 49.,  63.,  37.],
          [ 40.,  52.,  30.],
          [ 39.,  51.,  29.]]]])
归一化后图片:
 tensor([[[[0.9725, 0.9882, 0.9843],
          [0.9765, 0.9922, 0.9882],
          [0.9804, 0.9882, 0.9843],
          ...,
          [0.7569, 0.8078, 0.8431],
          [0.7608, 0.8118, 0.8471],
          [0.7569, 0.8078, 0.8431]],

         [[0.9608, 0.9765, 0.9725],
          [0.9569, 0.9725, 0.9686],
          [0.9529, 0.9608, 0.9569],
          ...,
          [0.7647, 0.8039, 0.8392],
          [0.7725, 0.8118, 0.8471],
          [0.7686, 0.8078, 0.8431]],

         [[0.9529, 0.9608, 0.9569],
          [0.9529, 0.9608, 0.9569],
          [0.9451, 0.9529, 0.9490],
          ...,
          [0.7725, 0.8078, 0.8353],
          [0.7843, 0.8196, 0.8471],
          [0.7882, 0.8235, 0.8510]],

         ...,

         [[0.0667, 0.0471, 0.0235],
          [0.0392, 0.0196, 0.0000],
          [0.0510, 0.0314, 0.0078],
          ...,
          [0.3608, 0.3216, 0.2824],
          [0.4392, 0.4000, 0.3608],
          [0.4667, 0.4275, 0.3882]],

         [[0.0745, 0.0549, 0.0314],
          [0.0549, 0.0353, 0.0118],
          [0.0471, 0.0275, 0.0039],
          ...,
          [0.4471, 0.4118, 0.3765],
          [0.4510, 0.4157, 0.3804],
          [0.4431, 0.4078, 0.3725]],

         [[0.0510, 0.0314, 0.0078],
          [0.0863, 0.0667, 0.0431],
          [0.0471, 0.0275, 0.0039],
          ...,
          [0.3843, 0.3569, 0.3176],
          [0.4078, 0.3725, 0.3373],
          [0.4627, 0.4275, 0.3922]]],


        [[[0.7059, 0.6549, 0.5176],
          [0.6941, 0.6392, 0.5020],
          [0.6627, 0.6000, 0.4588],
          ...,
          [0.6745, 0.5490, 0.4000],
          [0.6588, 0.5412, 0.3922],
          [0.6471, 0.5294, 0.3804]],

         [[0.7294, 0.6784, 0.5412],
          [0.7098, 0.6549, 0.5098],
          [0.6902, 0.6196, 0.4784],
          ...,
          [0.6706, 0.5451, 0.3961],
          [0.6667, 0.5412, 0.3922],
          [0.6510, 0.5333, 0.3843]],

         [[0.7333, 0.6784, 0.5333],
          [0.7098, 0.6549, 0.5020],
          [0.6863, 0.6157, 0.4667],
          ...,
          [0.6745, 0.5451, 0.4078],
          [0.6745, 0.5451, 0.4078],
          [0.6667, 0.5373, 0.4000]],

         ...,

         [[0.6784, 0.5804, 0.4627],
          [0.5725, 0.4745, 0.3569],
          [0.6784, 0.5804, 0.4588],
          ...,
          [0.7137, 0.5922, 0.5137],
          [0.5412, 0.4157, 0.3255],
          [0.5569, 0.4314, 0.3333]],

         [[0.3137, 0.2353, 0.1294],
          [0.5922, 0.5176, 0.4000],
          [0.7569, 0.6824, 0.5569],
          ...,
          [0.8431, 0.7608, 0.6863],
          [0.4588, 0.3725, 0.2784],
          [0.5451, 0.4627, 0.3569]],

         [[0.5059, 0.4431, 0.3412],
          [0.4314, 0.3725, 0.2588],
          [0.4667, 0.4000, 0.2824],
          ...,
          [0.5529, 0.4863, 0.4078],
          [0.6431, 0.5725, 0.4784],
          [0.7098, 0.6431, 0.5333]]],


        [[[0.6627, 0.5725, 0.1961],
          [0.7137, 0.6157, 0.2510],
          [0.7333, 0.6275, 0.2706],
          ...,
          [0.3333, 0.2667, 0.1569],
          [0.5686, 0.5216, 0.4196],
          [0.9922, 0.9647, 0.8627]],

         [[0.6353, 0.5451, 0.1765],
          [0.6392, 0.5412, 0.1804],
          [0.6627, 0.5569, 0.2157],
          ...,
          [0.4980, 0.4196, 0.3137],
          [0.6157, 0.5686, 0.4667],
          [0.9765, 0.9490, 0.8471]],

         [[0.6392, 0.5373, 0.1961],
          [0.6275, 0.5255, 0.1843],
          [0.6941, 0.5843, 0.2588],
          ...,
          [0.5412, 0.4588, 0.3529],
          [0.6118, 0.5529, 0.4627],
          [0.9961, 0.9529, 0.8667]],

         ...,

         [[0.0706, 0.0745, 0.0431],
          [0.0431, 0.0471, 0.0157],
          [0.0510, 0.0549, 0.0235],
          ...,
          [0.1922, 0.1882, 0.0667],
          [0.3529, 0.3451, 0.2471],
          [0.9725, 0.9647, 0.8824]],

         [[0.0510, 0.0627, 0.0196],
          [0.0431, 0.0549, 0.0118],
          [0.0706, 0.0784, 0.0353],
          ...,
          [0.1490, 0.1529, 0.0314],
          [0.3373, 0.3333, 0.2510],
          [0.9608, 0.9569, 0.8863]],

         [[0.0431, 0.0588, 0.0039],
          [0.0314, 0.0471, 0.0000],
          [0.0706, 0.0784, 0.0353],
          ...,
          [0.1373, 0.1490, 0.0353],
          [0.3451, 0.3412, 0.2627],
          [0.9765, 0.9686, 0.9098]]],


        ...,


        [[[0.6196, 0.4235, 0.1373],
          [0.6000, 0.4235, 0.1686],
          [0.3961, 0.2627, 0.0863],
          ...,
          [0.5059, 0.4902, 0.3412],
          [0.7412, 0.7216, 0.6431],
          [0.8863, 0.8627, 0.8157]],

         [[0.6431, 0.4510, 0.1412],
          [0.4157, 0.2431, 0.0000],
          [0.4196, 0.2745, 0.0706],
          ...,
          [0.4627, 0.4510, 0.3137],
          [0.6784, 0.6588, 0.5725],
          [0.5922, 0.5804, 0.5137]],

         [[0.7961, 0.6039, 0.2431],
          [0.7216, 0.5373, 0.2157],
          [0.4275, 0.2549, 0.0078],
          ...,
          [0.6902, 0.6745, 0.5686],
          [0.7647, 0.7647, 0.6627],
          [0.4549, 0.4549, 0.3608]],

         ...,

         [[0.3882, 0.1843, 0.0392],
          [0.5255, 0.3412, 0.2235],
          [0.2510, 0.1059, 0.0353],
          ...,
          [0.7882, 0.5490, 0.2235],
          [0.5725, 0.3412, 0.0275],
          [0.6549, 0.4235, 0.1176]],

         [[0.2784, 0.1098, 0.0000],
          [0.5373, 0.3882, 0.1961],
          [0.3255, 0.2078, 0.0667],
          ...,
          [0.8392, 0.6000, 0.2745],
          [0.7137, 0.4784, 0.1333],
          [0.6588, 0.4275, 0.0667]],

         [[0.3490, 0.2000, 0.0000],
          [0.6667, 0.5294, 0.3020],
          [0.5255, 0.4196, 0.2431],
          ...,
          [0.7647, 0.5255, 0.2000],
          [0.7137, 0.4824, 0.1216],
          [0.8196, 0.5922, 0.2039]]],


        [[[0.7686, 0.7765, 0.3804],
          [0.6980, 0.7059, 0.3098],
          [0.7608, 0.7608, 0.3843],
          ...,
          [0.4549, 0.2902, 0.1333],
          [0.2980, 0.1647, 0.0157],
          [0.2392, 0.1294, 0.0000]],

         [[0.7765, 0.7882, 0.3843],
          [0.7451, 0.7569, 0.3529],
          [0.7490, 0.7569, 0.3608],
          ...,
          [0.4235, 0.2627, 0.0902],
          [0.3961, 0.2667, 0.0902],
          [0.4039, 0.2941, 0.1098]],

         [[0.8078, 0.8196, 0.4078],
          [0.7647, 0.7765, 0.3569],
          [0.7098, 0.7255, 0.2941],
          ...,
          [0.4824, 0.3294, 0.1137],
          [0.6706, 0.5333, 0.3137],
          [0.6941, 0.5647, 0.3529]],

         ...,

         [[0.5137, 0.4980, 0.2510],
          [0.5098, 0.5059, 0.2549],
          [0.4902, 0.4941, 0.2353],
          ...,
          [0.3647, 0.4392, 0.0784],
          [0.3647, 0.4314, 0.0627],
          [0.3961, 0.4627, 0.0941]],

         [[0.5098, 0.4863, 0.2510],
          [0.5176, 0.5020, 0.2549],
          [0.4941, 0.4941, 0.2431],
          ...,
          [0.3765, 0.4510, 0.0902],
          [0.3922, 0.4667, 0.1059],
          [0.4078, 0.4824, 0.1216]],

         [[0.4941, 0.4706, 0.2353],
          [0.5059, 0.4902, 0.2510],
          [0.4980, 0.4941, 0.2431],
          ...,
          [0.4235, 0.4980, 0.1373],
          [0.4392, 0.5137, 0.1529],
          [0.4314, 0.5059, 0.1451]]],


        [[[0.2235, 0.3216, 0.1569],
          [0.2431, 0.3412, 0.1765],
          [0.1529, 0.2549, 0.1020],
          ...,
          [0.9569, 0.9922, 0.9098],
          [0.3686, 0.4235, 0.3255],
          [0.5216, 0.5843, 0.4784]],

         [[0.2157, 0.3137, 0.1490],
          [0.2471, 0.3451, 0.1804],
          [0.1294, 0.2353, 0.0745],
          ...,
          [0.8118, 0.8471, 0.7804],
          [0.4627, 0.5176, 0.4275],
          [0.2235, 0.2863, 0.1804]],

         [[0.1608, 0.2588, 0.0941],
          [0.1529, 0.2510, 0.0863],
          [0.1608, 0.2588, 0.0941],
          ...,
          [0.9216, 0.9451, 0.9059],
          [0.3373, 0.3882, 0.3098],
          [0.1882, 0.2588, 0.1569]],

         ...,

         [[0.2667, 0.3529, 0.1725],
          [0.2471, 0.3333, 0.1490],
          [0.2078, 0.2863, 0.0863],
          ...,
          [0.2196, 0.2784, 0.1647],
          [0.2039, 0.2549, 0.1529],
          [0.1451, 0.1961, 0.0941]],

         [[0.1804, 0.2627, 0.0941],
          [0.2471, 0.3216, 0.1451],
          [0.1961, 0.2627, 0.0863],
          ...,
          [0.1765, 0.2314, 0.1294],
          [0.2039, 0.2510, 0.1569],
          [0.1373, 0.1843, 0.0902]],

         [[0.1725, 0.2549, 0.0941],
          [0.1725, 0.2431, 0.0784],
          [0.2157, 0.2824, 0.1098],
          ...,
          [0.1922, 0.2471, 0.1451],
          [0.1569, 0.2039, 0.1176],
          [0.1529, 0.2000, 0.1137]]]])

b'\r\n\r\n\r\n

6. 锚框

目标检测任务中,框(即锚框或候选框)的生成是基于预定义的一组尺度和比例进行的,通常是在图像的每个位置生成多个框。这些框的位置和尺度是固定的,但在不同的位置可能有不同的大小和形状。

这些框的生成是为了覆盖不同尺度和形状的目标物体。 然后,生成的每个框都会与真实目标框进行匹配,通过计算它们之间的IoU(交并比)来评估它们的相似度。IoU计算量化了两个框之间的重叠程度,可以判断它们是否匹配。

根据IoU的计算结果,可以进行以下判断和处理:

  • 如果某个框与任何一个真实目标框的IoU超过阈值(通常为0.5或0.7),则认为它与一个真实目标框匹配,被标记为正样本。

  • 如果某个框与所有真实目标框的IoU都小于阈值,则认为它与背景不匹配,被标记为负样本(背景样本)。

  • 如果某个框与某个真实目标框的IoU在阈值范围内,但与其他真实目标框的IoU也很接近,则可以将它忽略,不参与训练和评估。

根据这样的匹配和判断过程,可以确定哪些锚框是与真实目标框匹配的正样本,哪些是与背景不匹配的负样本,以及哪些可以被忽略。

通过这种方式,模型可以学习到目标物体的定位和分类。 因此,IoU在目标检测中起到计算相似度和筛选锚框的作用,用于匹配和分类框,以确定模型的训练目标和样本选择。

1. 总结

2. 锚框代码

python 复制代码
%matplotlib inline
import torch
from d2l import torch as d2l
#设置 PyTorch 张量打印输出格式
#传入参数 2 = precision=2,代表:打印浮点数时,统一保留小数点后 2 位
torch.set_printoptions(2)
复制代码
python 复制代码
help(torch.set_printoptions) # 将打印的张量的精度设置为2位小数
复制代码
Help on function set_printoptions in module torch._tensor_str:

set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None)
    Set options for printing. Items shamelessly taken from NumPy
    
    Args:
        precision: Number of digits of precision for floating point output
            (default = 4).
        threshold: Total number of array elements which trigger summarization
            rather than full `repr` (default = 1000).
        edgeitems: Number of array items in summary at beginning and end of
            each dimension (default = 3).
        linewidth: The number of characters per line for the purpose of
            inserting line breaks (default = 80). Thresholded matrices will
            ignore this parameter.
        profile: Sane defaults for pretty printing. Can override with any of
            the above options. (any one of `default`, `short`, `full`)
        sci_mode: Enable (True) or disable (False) scientific notation. If
            None (default) is specified, the value is defined by
            `torch._tensor_str._Formatter`. This value is automatically chosen
            by the framework.
复制代码
 Example::
    
        >>> torch.set_printoptions(precision=2)
        >>> torch.tensor([1.12345])
        tensor([1.12])
        >>> torch.set_printoptions(threshold=5)
        >>> torch.arange(10)
        tensor([0, 1, 2, ..., 7, 8, 9])

① 锚框的宽度和高度分别是和。我们只考虑组合:

② s表示锚框的大小,锚框占图片的百分之多少,r表示锚框的高宽比。

整体流程总概括(从头到尾干了啥)

  1. 拿到特征图 H、W、设备信息,计算单个像素锚框个数
  2. 算出特征图上所有像素归一化中心点坐标
  3. 根据 sizes、ratios 算出每种锚框对应的归一化宽、高
  4. 算出每个锚框相对中心的上下左右偏移
  5. 给每个像素中心点复制多份,匹配对应锚框偏移
  6. 中心坐标 + 偏移 = 所有锚框的左上角、右下角归一化坐标
  7. 扩充 batch 维度返回,给后续目标检测匹配真实框、计算损失使用
python 复制代码
def multibox_prior(data,sizes,ratios):
    """生成以每个像素为中心具有不同高宽度的锚框"""
    # data.shape的最后两个元素为宽和高,第一个元素为通道数  
    in_height, in_width = data.shape[-2:] 
    # 数据对应的设备、锚框占比个数、锚框高宽比个数      
    device, num_sizes, num_ratios = data.device, len(sizes), len(ratios) 
    # 计算每个像素点对应的锚框数量  
    boxes_per_pixel = (num_sizes + num_ratios - 1) 
    # 将锚框占比列表转为张量并将其移动到指定设备
    size_tensor = torch.tensor(sizes, device=device) 
    # 将宽高比列表转为张量并将其移动到指定设备
    ratio_tensor = torch.tensor(ratios, device=device) 
    
    # 定义锚框中心偏移量
    offset_h, offset_w = 0.5, 0.5 
    # 计算高度方向上的步长
    steps_h = 1.0 / in_height 
    # 计算宽度方向上的步长
    steps_w = 1.0 / in_width 
    
    # torch.arange(in_height, device=device)获得每一行像素
    # (torch.arange(in_height, device=device) + offset_h) 获得每一行像素的中心
    # (torch.arange(in_height, device=device) + offset_h) * steps_h 对每一行像素的中心坐标作归一化处理  
    
    # 生成归一化的高度和宽度方向上的像素点中心坐标
    center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h
    center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w   
    # 生成坐标网格
    shift_y, shift_x = torch.meshgrid(center_h, center_w) 
    # 将坐标网格平铺为一维
    shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1) 
    
    # 计算每个锚框的宽度和高度
    w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),
                  sizes[0] * torch.sqrt(ratio_tensor[1:]))) \
                    * in_height / in_width
    h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),
                  sizes[0] / torch.sqrt(ratio_tensor[1:])))
    
    # 计算锚框的左上角和右下角坐标(相对于锚框中心的偏移量)
    anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(in_height * in_width, 1) / 2
    
    # 计算所有锚框的中心坐标,每个像素对应boxes_per_pixel个锚框
    out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y], dim=1).repeat_interleave(boxes_per_pixel, dim=0)   
    
    # 通过中心坐标和偏移量计算所有锚框的左上角和右下角坐标
    output = out_grid + anchor_manipulations 
    
    # 增加一个维度并返回结果
    return output.unsqueeze(0) 
  • 读取一张猫狗图片,拿到图片高、宽像素尺寸
  • 构造一个和图片同分辨率的虚拟张量,用来给锚框函数提供尺寸参数
  • 设置 3 种缩放尺度、3 种长宽比,调用 multibox_prior每一个像素中心生成 5 个锚框
  • 输出所有锚框的整体张量形状,直观验证锚框总数量是否符合计算预期
python 复制代码
# 返回锚框变量Y的形状
img = d2l.plt.imread('01_Data/img/catdog.jpg')
print("img.shape:",img.shape) # 高561,宽72,3通道
h, w = img.shape[:2]
print(h,w)

X = torch.rand(size=(1,3,h,w)) # 批量大小为1,3通道
Y = multibox_prior(X, sizes=[0.75,0.5,0.25], ratios=[1,2,0.5]) # 占图片sizes尺寸的大小、高宽比ratios尺寸大小的锚框   
print(Y.shape) # 1 是批量大小,2042040是一张图片生成的锚框数量,4个元素时每个锚框对应的位置
复制代码
img.shape: (561, 728, 3)
561 728
torch.Size([1, 2042040, 4])
python 复制代码
# 访问以(250,250)为中心的第一个锚框
boxes = Y.reshape(h,w,5,4)  # 上面的sizes×sizes=3×3,3+3-1=5,故每个像素为中心生成五个锚框    
boxes[250,250,0,:] # 以250×250为中心的第一个锚框的坐标
复制代码
tensor([0.06, 0.07, 0.63, 0.82])

整体作用:

  1. 定义通用函数 show_bboxes:在图片画布上批量绘制锚框 / 真实目标框,附带文字标签
  2. 选取坐标 (250,250) 像素点对应的全部 5 个锚框,把归一化坐标还原成像素坐标,画在原图上直观查看锚框大小与形态

if labels and len(labels) > i:

  • 判断当前锚框是否配有标签,有就绘制文字
  • 文字颜色自适应:如果框是白色,文字用黑色;其余情况文字白色,保证看得清
  • rect.xy[0], rect.xy[1]:矩形左上角坐标
  • va='center', ha='center':文字居中摆放
  • 外层小色块背景和锚框同色,lw=0 去掉边框

总结 show_bboxes:通用边界框绘图工具,输入画布、一批框坐标、可选标签、可选颜色,自动批量画框 + 打标注。

python 复制代码
# 显示以图像中一个像素为中心的所有锚框
def show_bboxes(axes, bboxes, labels=None, colors=None):
    """显示所有边界框"""
    def _make_list(obj, default_values=None):
        # 如果obj为None,使用默认值;如果obj不是列表或元组,将其转换为列表
        if obj is None:
            obj = default_values
        elif not isinstance(obj, (list, tuple)):
            obj = [obj]
        return obj
    
    # 处理labels,确保其为列表形式
    labels = _make_list(labels) 
    # 处理colors,确保其为列表形式
#绘图颜色,不传则默认五色循环:蓝、绿、红、品红、青
    colors = _make_list(colors, ['b','g','r','m','c']) 
    # 遍历所有边界框
    for i, bbox in enumerate(bboxes): 
        # 选择颜色
        color = colors[i % len(colors)] 
        # 使用边界框和颜色生成矩形框
        rect = d2l.bbox_to_rect(bbox.detach().numpy(),color) 
        # 在图像上添加矩形框
        axes.add_patch(rect) 
        # 如果存在标签
        if labels and len(labels) > i: 
            # 根据边界框的颜色选择标签的颜色
            text_color = 'k' if color == 'w' else 'w' 
            # 在边界框上添加标签
            axes.text(rect.xy[0], rect.xy[1], labels[i], va='center',
                     ha='center', fontsize=9, color=text_color,
                     bbox=dict(facecolor=color, lw=0))
            
# 设置图像大小           
d2l.set_figsize() 
# 创建一个张量来缩放边界框的尺寸
bbox_scale = torch.tensor((w,h,w,h)) 
# 在图像上显示图像
fig = d2l.plt.imshow(img) 
print("fig.axes:",fig.axes)
# 在生成锚框的时候是0-1的值,进行缩放的话就可以省点乘法运算,因为最后输出并不需要显示所有锚框,所以可能会更快一点     
print("boxes[250,250,:,:]:",boxes[250,250,:,:])
print("bbox_scale:", bbox_scale)
print("boxes[250,250,:,:] * bbox_scale:",boxes[250,250,:,:] * bbox_scale)
show_bboxes(fig.axes, boxes[250,250,:,:] * bbox_scale, ['s=0.75, r=1','s=0.5, r=1','s=0.25, r=1','s=0.75,r=2','s=0.75,r=0.5']) # 画出以250×250像素为中心的不同高宽比的五个锚框       
复制代码
fig.axes: AxesSubplot(0.125,0.125;0.775x0.755)
boxes[250,250,:,:]: tensor([[ 0.06,  0.07,  0.63,  0.82],
        [ 0.15,  0.20,  0.54,  0.70],
        [ 0.25,  0.32,  0.44,  0.57],
        [-0.06,  0.18,  0.75,  0.71],
        [ 0.14, -0.08,  0.55,  0.98]])
bbox_scale: tensor([728, 561, 728, 561])
boxes[250,250,:,:] * bbox_scale: tensor([[ 40.13,  40.12, 460.88, 460.87],
        [110.25, 110.25, 390.75, 390.75],
        [180.38, 180.38, 320.62, 320.62],
        [-47.02, 101.74, 548.02, 399.26],
        [101.74, -47.02, 399.26, 548.02]])
python 复制代码
# 交并比(IoU)
def box_iou(boxes1,boxes2):
    """计算两个锚框或边界框列表中成对的交并比"""
    # 定义一个lambda函数,计算一个锚框或边界框的面积
    box_area = lambda boxes: ((boxes[:,2] - boxes[:,0]) *
                             (boxes[:,3] - boxes[:,1]))
    # 计算boxes1中每个框的面积
    areas1 = box_area(boxes1) 
    # 计算boxes2中每个框的面积
    areas2 = box_area(boxes2) 
    # 计算交集区域的左上角坐标(对于每对框,取其左上角坐标的最大值)
    inter_upperlefts = torch.max(boxes1[:,None,:2],boxes2[:,:2]) 
    # 计算交集区域的右下角坐标(对于每对框,取其右下角坐标的最小值)
    inter_lowerrights = torch.min(boxes1[:,None,2:],boxes2[:,2:])
    # 计算交集区域的宽和高(如果交集不存在,宽和高为0)
    inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)
    # 计算交集区域的面积
    inter_areas = inters[:,:,0] * inters[:,:,1] 
    # 计算并集区域的面积(boxes1的面积 + boxes2的面积 - 交集的面积)
    union_areas = areas1[:,None] + areas2 - inter_areas 
    # 返回交并比(交集的面积除以并集的面积)
    return inter_areas / union_areas
python 复制代码
# 将真实边界框分配给锚框
def assign_anchor_to_bbox(ground_truth,anchors,device,iou_threshold=0.5):
    """将最接近的真实边界框分配给锚框"""
    
    # 获取锚框和真实边界框的数量
    num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0] 
    
    # 计算所有的锚框和真实边缘框的IOU
    jaccard = box_iou(anchors,ground_truth) 
    
    # 创建一个长度为num_anchors的张量,用-1填充,表示锚框到真实边界框的映射(初始时没有分配)
    anchors_bbox_map = torch.full((num_anchors,), -1, dtype=torch.long, device=device)    
    
    # 对于每个锚框,找到与其IoU最大的真实边界框
    max_ious, indices = torch.max(jaccard, dim=1)
    
    # 找到IoU大于等于阈值(如0.5)的锚框,将这些锚框分配给对应的真实边界框
    anc_i = torch.nonzero(max_ious >= 0.5).reshape(-1)
    box_j = indices[max_ious >= 0.5]
    anchors_bbox_map[anc_i] = box_j
    
    # 初始化用于删除行和列的张量
    col_discard = torch.full((num_anchors,),-1)
    row_discard = torch.full((num_gt_boxes,),-1)
    
    # 通过迭代找到IoU最大的锚框,并将其分配给对应的真实边界框
    for _ in range(num_gt_boxes):
        max_idx = torch.argmax(jaccard) # 找IOU最大的锚框
        box_idx = (max_idx % num_gt_boxes).long() # 通过取余数操作,得到该元素对应的真实边界框的索引
        anc_idx = (max_idx / num_gt_boxes).long() # 通过整除操作,得到该元素对应的锚框的索引
        
        # 更新锚框到真实边界框的映射
        anchors_bbox_map[anc_idx] = box_idx
        
        # 在jaccard矩阵中删除已分配的锚框所在的行和列,以避免重复分配
        jaccard[:,box_idx] = col_discard # 把最大Iou对应的锚框在 锚框-类别 矩阵中的一列删掉
        jaccard[anc_idx,:] = row_discard # 把最大Iou对应的锚框在 锚框-类别 矩阵中的一行删掉
    
    #函数返回一个张量anchors_bbox_map,它的长度与锚框的数量相同。
    #这个张量用于存储每个锚框分配到的真实边界框的索引。
    #如果某个锚框没有分配到真实边界框,那么在这个张量中对应的位置就会是-1。
    #如果某个锚框分配到了真实边界框,那么在这个张量中对应的位置就会是分配到的真实边界框的索引。  
    #例如,如果我们有5个锚框和3个真实边界框,那么anchors_bbox_map可能会是这样的:[0, -1, 1, 2, -1]。这表示第1个锚框被分配到了第1个真实边界框,第2个锚框没有被分配到真实边界框,第3个锚框被分配到了第2个真实边界框,第4个锚框被分配到了第3个真实边界框,第5个锚框没有被分配到真实边界框。
    return anchors_bbox_map

③ 给定框A和B,中心坐标分别为和,宽度分别为和,高度分别为和。

④ 我们可以将A的偏移量标记为

python 复制代码
def offset_boxes(anchors,assigned_bb,eps=1e-6):
    """对锚框偏移量的转换"""
    # 将锚框从(左上角, 右下角)的形式转换为(中心点, 宽度, 高度)的形式
    c_anc = d2l.box_corner_to_center(anchors) 
    # 将被分配的真实边界框从(左上角, 右下角)的形式转换为(中心点, 宽度, 高度)的形式
    c_assigned_bb = d2l.box_corner_to_center(assigned_bb) 
    # 计算中心点的偏移量,并进行缩放
    offset_xy = 10 * (c_assigned_bb[:,:2] - c_anc[:,:2] / c_anc[:,2:]) 
    # 计算宽度和高度的偏移量,并进行缩放
    offset_wh = 5 * torch.log(eps + c_assigned_bb[:,2:] / c_anc[:,2:]) 
    # 将中心点和宽高的偏移量合并在一起
    offset = torch.cat([offset_xy, offset_wh], axis=1) 
    # 返回计算得到的偏移量
    return offset 
python 复制代码
# 标记锚框的类和偏移量
def multibox_target(anchors, labels):
    """使用真实边界框标记锚框"""
    # 获取批量大小和锚框
    batch_size, anchors = labels.shape[0], anchors.squeeze(0) 
    # 初始化偏移量、掩码和类别标签列表
    batch_offset, batch_mask, batch_class_labels = [], [], [] 
    # 获取设备和锚框数量
    device, num_anchors = anchors.device, anchors.shape[0] 
    # 对于每个样本
    for i in range(batch_size): 
        # 获取该样本的标签
        label = labels[i,:,:] 
        # 将最接近的真实边界框分配给锚框  
        anchors_bbox_map = assign_anchor_to_bbox(label[:,1:],anchors,device) 
        # 生成锚框掩码,用于标记哪些锚框包含目标   
        bbox_mask = ((anchors_bbox_map >= 0).float().unsqueeze(-1)).repeat(1,4) 
        # 初始化类别标签 
        class_labels = torch.zeros(num_anchors, dtype=torch.long,device=device) 
        # 初始化被分配的边界框  
        assigned_bb = torch.zeros((num_anchors,4), dtype=torch.float32,device=device) 
        # 获取包含目标的锚框的索引
        indices_true =torch.nonzero(anchors_bbox_map >= 0) 
        # 获取对应的真实边界框的索引
        bb_idx = anchors_bbox_map[indices_true] 
        # 设置包含目标的锚框的类别标签
        class_labels[indices_true] = label[bb_idx,0].long() + 1 
        # 设置被分配的边界框
        assigned_bb[indices_true] = label[bb_idx, 1:] 
        # 计算锚框的偏移量,并通过掩码进行过滤
        offset = offset_boxes(anchors, assigned_bb) * bbox_mask 
        # 将偏移量添加到列表中
        batch_offset.append(offset.reshape(-1)) 
        # 将掩码添加到列表中
        batch_mask.append(bbox_mask.reshape(-1)) 
        # 将类别标签添加到列表中
        batch_class_labels.append(class_labels) 
    # 将所有偏移量堆叠在一起
    bbox_offset = torch.stack(batch_offset) 
    # 将所有掩码堆叠在一起
    bbox_mask = torch.stack(batch_mask) 
    # 将所有类别标签堆叠在一起
    class_labels = torch.stack(batch_class_labels) 
    # 返回每一个锚框到真实标注框的offset偏移
    # bbox_mask为0表示背景锚框,就不用了,为1表示对应真实的物体
    # class_labels为锚框对应类的编号
    # 返回偏移量、掩码和类别标签
    return (bbox_offset, bbox_mask, class_labels) 
python 复制代码
# 在图像中绘制这些地面真相边界框和锚框

# 两个真实边缘框的位置信息
ground_truth = torch.tensor([[0,0.1,0.08,0.52,0.92],
                            [1,0.55,0.2,0.9,0.88]]) # 真实标注框的信息,包括类别标签(0代表狗,1代表猫)和位置信息(归一化的坐标)

# 五个锚框的位置信息
anchors = torch.tensor([[0,0.1,0.2,0.3],[0.15,0.2,0.4,0.4],
                       [0.63,0.05,0.88,0.98],[0.66,0.45,0.8,0.8],
                       [0.57,0.3,0.92,0.9]]) # 锚框的位置信息(归一化的坐标)

fig = d2l.plt.imshow(img)
# 在图像上画出真实的边界框,其中'k'代表黑色     
show_bboxes(fig.axes,ground_truth[:,1:] * bbox_scale, ['dog','cat'],'k')       
# 在图像上画出锚框,标注出锚框的索引号
show_bboxes(fig.axes,anchors * bbox_scale, ['0','1','2','3','4']) 
python 复制代码
# 根据狗和猫的真实边界框,标注这些锚框的分类和偏移量

# anchors.unsqueeze(dim=0)在0号位置加了一个批量维度,该批量维度大小为1
labels = multibox_target(anchors.unsqueeze(dim=0),ground_truth.unsqueeze(dim=0))     
# labels 对应 multibox_target 函数返回的  (bbox_offset, bbox_mask, class_labels)
print(len(labels)) 
# labels[2]有五个锚框 0表示背景、1表示狗、2表示猫 这里3号框被表示为背景是因为被2号框和四号框非极大值抑制了  
print(labels[2]) 
# 锚框是不是对应是真实物体
print(labels[1]) 
# 每一个锚框有四个值,0表示不需要预测,
print(labels[0]) 
复制代码
3
tensor([[0, 1, 2, 0, 2]])
tensor([[0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.,
         1., 1.]])
tensor([[-0.00e+00, -0.00e+00, -0.00e+00, -0.00e+00, -7.90e+00, -1.00e+01,
          2.59e+00,  7.18e+00, -2.30e+01, -1.38e-01,  1.68e+00, -1.57e+00,
         -0.00e+00, -0.00e+00, -0.00e+00, -0.00e+00, -1.40e+01, -4.60e+00,
          4.17e-06,  6.26e-01]])
python 复制代码
# 应用逆偏移变换来返回预测的边界框坐标
def offset_inverse(anchors,offset_preds):
    """根据带有预测偏移量的锚框来预测边界框"""
    # 将锚框从角点表示转换为中心-宽度表示
    anc = d2l.box_corner_to_center(anchors)
    # 利用预测的偏移量和原始锚框,计算预测边界框的中心坐标
    pred_bbox_xy = (offset_preds[:,:2] * anc[:,2:] / 10) + anc[:,:2]
    # 利用预测的偏移量和原始锚框,计算预测边界框的宽度和高度
    pred_bbox_wh = torch.exp(offset_preds[:,2:] / 5) * anc[:, 2:]
    # 将预测边界框的中心坐标和宽高组合在一起,得到预测边界框的中心-宽度表示
    pred_bbox = torch.cat((pred_bbox_xy, pred_bbox_wh), axis=1)
    # 将预测边界框从中心-宽度表示转换为角点表示
    predicted_bbox = d2l.box_center_to_corner(pred_bbox)
    # 返回预测的边界框
    return predicted_bbox # 将锚框用偏移量进行偏移,得到预测的边界框
python 复制代码
# 以下nms函数按降序对置信度进行排序并返回其索引
def nms(boxes, scores, iou_threshold):
    """对预测边界框的置信度进行排序"""
    # 按照得分降序排列预测边界框的索引
    B = torch.argsort(scores, dim = -1, descending=True)
    # 创建一个空列表,用于存储保留下来的边界框索引
    keep = []
    # 当B中还有元素时,进行循环
    while B.numel()>0: # 直到把所有框都访问过了,再退出循环
        # 取B中得分最高的边界框索引
        i = B[0] # B中的最大值,已经排好序了
        # 将这个边界框索引添加到保留列表中
        keep.append(i)
        # 如果B中只有一个元素,那么结束循环
        if B.numel() == 1: break
        # 计算剩余的边界框与当前得分最高的边界框的IoU(交并比) 
        iou = box_iou(boxes[i,:].reshape(-1,4),
                     boxes[B[1:],:].reshape(-1,4)).reshape(-1)
        # 找到所有与当前得分最高的边界框的IoU不大于阈值的边界框的索引
        inds = torch.nonzero(iou <= iou_threshold).reshape(-1)
        # 保留那些与当前得分最高的边界框的IoU不大于阈值的边界框
        B = B[inds + 1]
    # 返回保留下来的边界框索引
    return torch.tensor(keep, device=boxes.device)
python 复制代码
# 将非极大值抑制应用于预测边界框
def multibox_detection(cls_probs,offset_preds,anchors,nms_threshold=0.5,
                      pos_threshold=0.009999999):
    """使用非极大值抑制来预测边界框"""
    # 获取设备类型和批次大小
    device, batch_size = cls_probs.device, cls_probs.shape[0]
    # 将锚框数据压缩到二维
    anchors = anchors.squeeze(0)
    # 获取类别数量和锚框数量
    num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2]
    # 创建一个空列表,用于存储每个批次的预测结果
    out = []
    # 对每个批次进行循环
    for i in range(batch_size): 
        
        # 获取类别概率和预测的偏移量
        cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1,4)  
        
        # 获取最大类别概率和对应的类别id
        conf, class_id = torch.max(cls_prob[1:],0)
        
        # 根据预测的偏移量和锚框得到预测的边界框
        predicted_bb = offset_inverse(anchors,offset_pred) # 把预测框拿出来
        
        # 对预测的边界框进行非极大值抑制,获取保留下来的边界框索引
        keep = nms(predicted_bb, conf, nms_threshold)
        
        # 获取所有的边界框索引
        all_idx = torch.arange(num_anchors, dtype=torch.long, device=device)
        
        # 将保留下来的边界框索引和所有的边界框索引拼接在一起
        combined = torch.cat((keep,all_idx))
        
        # 获取唯一的索引和对应的计数
        uniques, counts = combined.unique(return_counts=True)
        
        # 获取被丢弃的边界框索引
        non_keep = uniques[counts==1]
        
        # 将保留下来的边界框索引和被丢弃的边界框索引按顺序拼接在一起
        all_id_sorted = torch.cat((keep, non_keep))
        
        # 将被丢弃的边界框的类别id设为-1
        class_id[non_keep] = -1
        class_id = class_id[all_id_sorted]
        
        # 根据索引获取对应的类别概率和预测的边界框
        conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted] 
        
        # 找到类别概率低于阈值的边界框索引
        below_min_idx = (conf < pos_threshold)
        
        # 将类别概率低于阈值的边界框的类别id设为-1
        class_id[below_min_idx] = -1
        
        # 将类别概率低于阈值的边界框的类别概率设为1减去原来的值
        conf[below_min_idx] = 1 - conf[below_min_idx]
        
        # 将类别id,类别概率和预测的边界框拼接在一起,作为预测信息
        pred_info = torch.cat((class_id.unsqueeze(1),conf.unsqueeze(1),predicted_bb),dim=1)     
        
        # 将每个批次的预测信息添加到结果列表中
        out.append(pred_info)
        
    # 将结果列表转为张量返回    
    return torch.stack(out)
python 复制代码
# 将上述算法应用到一个带有四个锚框的具体示例中

# 四个锚框的坐标
anchors = torch.tensor([[0.1,0.08,0.52,0.92],[0.08,0.2,0.56,0.95],
                       [0.15,0.3,0.62,0.91],[0.55,0.2,0.9,0.88]])

# 偏移预测值,这里假设预测值全为0,即没有预测偏移
offset_preds = torch.tensor([0] * anchors.numel())
print("offset_preds:", offset_preds) # 打印偏移预测值
print("len(offset_preds):", len(offset_preds)) # 打印偏移预测值的长度

# 类别概率,每一列对应一个锚框,每一行对应一个类别,这里有三个类别:背景、猫、狗
cls_probs = torch.tensor([[0] * 4,  # 背景类别概率
                         [0.9, 0.8, 0.7, 0.1],  # 猫类别概率
                         [0.1, 0.2, 0.3, 0.9]]) # 狗类别概率
print("cls_probs:", cls_probs) # 四个锚框对背景、猫、狗这三个类的预测值,每一列为一个锚框  
复制代码
offset_preds: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
len(offset_preds): 16
cls_probs: tensor([[0.00, 0.00, 0.00, 0.00],
        [0.90, 0.80, 0.70, 0.10],
        [0.10, 0.20, 0.30, 0.90]])
python 复制代码
# 在图像上绘制这些预测边界框和置信度
# 创建一个图像对象,并在图像上显示
fig = d2l.plt.imshow(img)
# 在图像上显示锚框,其中锚框的尺度需要进行转换以适应图像的尺度
# 每个锚框旁边的文本表示该锚框预测为某个类别的置信度
show_bboxes(fig.axes, anchors * bbox_scale, # 没有做NMS时,把四个锚框画出来
           ['dog=0.9','dog=0.8','dog=0.7','cat=0.9'])
python 复制代码
# 使用multibox_detection函数,输入类别预测概率、预测偏移量以及锚框,同时设置非极大值抑制的阈值为0.5
# 注意,这里需要先在输入数据的每个维度上添加一个维度(即批量大小的维度),然后才能传入函数
output = multibox_detection(cls_probs.unsqueeze(dim=0),
                           offset_preds.unsqueeze(dim=0),
                           anchors.unsqueeze(dim=0),nms_threshold=0.5) 
# 打印输出结果,这里的输出结果包含了每个锚框的类别预测、置信度以及经过预测偏移调整后的锚框坐标
# output[0]表示批量中的第一张图片的预测结果
print("output:",output) #output[0]为批量大小中的第一个图片
复制代码
output: tensor([[[ 0.00,  0.90,  0.10,  0.08,  0.52,  0.92],
         [ 1.00,  0.90,  0.55,  0.20,  0.90,  0.88],
         [-1.00,  0.80,  0.08,  0.20,  0.56,  0.95],
         [-1.00,  0.70,  0.15,  0.30,  0.62,  0.91]]])
python 复制代码
# 输出由非极大值抑制保存的最终预测边界框

# 在图像上绘制通过非极大值抑制筛选后的预测边界框
fig = d2l.plt.imshow(img) 

# 输出经过非极大值抑制后的预测结果
print("output[0]:", output[0])

# 遍历预测结果
for i in output[0].detach().numpy(): 
    
    # 输出当前预测结果的详细信息
    print(i)
    
    # 判断如果预测结果的类别为-1,说明这个预测结果表示的是背景或在非极大值抑制中被移除了,所以我们直接跳过这个结果
    if i[0] == -1: # 值-1表示背景或在非极大值抑制中被移除了
        continue
        
    # 打印预测的类别和置信度
    print("int(i[0]):", int(i[0]))  # i[0]=0表示狗,i[0]=1表示猫,即i的第一个元素表示框对应的类别   
    print("str(i[1]):", str(i[1]))  # i的第二元素表示该类别的置信度
    
    # 根据预测的类别和置信度生成标签
    label = ('dog=', 'cat=')[int(i[0])] + str(i[1]) # 取('dog=', 'cat=')元组的第int(i[0]位置与str(i[1])进行拼接             
    print("label:",label)
    
    # 在图像上绘制预测的边界框和标签
    show_bboxes(fig.axes, [torch.tensor(i[2:]) * bbox_scale], label)
复制代码
output[0]: tensor([[ 0.00,  0.90,  0.10,  0.08,  0.52,  0.92],
        [ 1.00,  0.90,  0.55,  0.20,  0.90,  0.88],
        [-1.00,  0.80,  0.08,  0.20,  0.56,  0.95],
        [-1.00,  0.70,  0.15,  0.30,  0.62,  0.91]])
[0.         0.9        0.10000001 0.07999998 0.52       0.92      ]
int(i[0]): 0
str(i[1]): 0.9
label: dog=0.9
[1.         0.9        0.5500001  0.20000002 0.9        0.88      ]
int(i[0]): 1
str(i[1]): 0.9
label: cat=0.9
[-1.          0.8         0.07999998  0.19999999  0.56        0.95      ]
[-1.          0.7         0.14999999  0.3         0.62        0.91      ]

7. 树叶分类竞赛技术总结

1. 总结