(动手学习深度学习)第13章 实战kaggle竞赛:CIFAR-10

  1. 导入相关库
python 复制代码
import collections
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
  1. 下载数据集
python 复制代码
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',
                                '2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')

# 如果使用完整的Kaggle竞赛的数据集,设置demo为False
demo = True

if demo:
    data_dir = d2l.download_extract('cifar10_tiny')
else:
    data_dir = '../data/kaggle/cifar-10/'
  1. 整理数据集
python 复制代码
# 查看数据集
def read_csv_labels(fname):
    """读取'fname'来给标签字典返回一个文件名"""
    with open(fname, 'r') as f:
        lines = f.readlines()[1:]  # readlines(): 每次读文档的一行,以后还需要逐步循环
        tokens = [l.rstrip().split(',') for l in lines]  # rstrip(): 删除字符串后面(右面)的空格或特殊字符, 还有lstrip(左面)、strip(两面)
        return dict((name, label) for name, label in tokens)

labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
print('训练样本:', len(labels))
print('类别:', len(set(labels.values())))  # set(): 集合,里面不能包含重复的元素,接受一个list作为参数

将验证集从原始的训练集钟拆分出来

python 复制代码
# 拆分数据集:训练集、验证集
def copyfile(filename, target_dir):
    """将文件复制到目标目录"""
    os.makedirs(target_dir, exist_ok=True)  # 创建多层目录,exist_ok为True:在目标目录已存在的情况下不会触发FileExistsError异常。
    shutil.copy(filename, target_dir)  #拷贝文件,filename:要拷贝的文件;target_dir:目标文件夹

def reorg_train_valid(data_dir, labels, valid_ratio):
    """将验证集从原始训练集钟拆分出来"""
    # 训练数据集中样本数量最少的类别中的样本数
    # Counter: 计数器,返回一个字典,键为元素,值为元素个数;
    # .most_common(): 返回一个列表, 列表元素为(元素,出现次数),默认按出现频率排序
    # [-1]: 样本数量最少的类别(类别, 样本数),[-1][1]: 样本数数量最少的类别中的样本数
    n = collections.Counter(labels.values()).most_common()[-1][1]
    # 验证集中每个类别的样本数
    n_valid_per_label= max(1, math.floor((n * valid_ratio)))  # math.floor(): 向下取整  math.ceil(): 向上取整
    label_count = {}

    # 遍历原始训练集中的每个样本
    for train_file in os.listdir(os.path.join(data_dir, 'train')):
        label = labels[train_file.split('.')[0]]  # 从文件名中提取标签
        fname = os.path.join(data_dir, 'train', train_file)
        copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train_valid', label))
        # 如果该类别的样本数还未达到在验证集中的设定数量,则将样本复制到验证集中
        if label not in label_count or label_count[label] < n_valid_per_label:
            copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'valid', label))
            label_count[label] = label_count.get(label, 0) + 1
        else:
            copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train', label))

    return n_valid_per_label

# reorg_test函数用来在预测期间整理测试集,以方便读取
def reorg_test(data_dir):
    """在预测期间整理测试集,以方便读取"""
    # 遍历测试集中的每个样本
    for test_file in os.listdir(os.path.join(data_dir, 'test')):
        # 将测试集中的样本复制到新的目录结构中的 'test' 子目录下,标签为 'unknown'
        copyfile(os.path.join(data_dir, 'test', test_file),
                 os.path.join(data_dir, 'train_valid_test', 'test', 'unknown'))
python 复制代码
# 整个处理数据集函数
def reorg_cifar10_data(data_dir, valid_ratio):
    labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
    reorg_train_valid(data_dir, labels, valid_ratio)
    reorg_test(data_dir)
  • 这个小规模数据集的批量大小是32,在实际的cifar-10数据集中,可以设为128
  • 将10%的训练样本作为调整超参数的验证集
python 复制代码
batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_cifar10_data(data_dir, valid_ratio)
python 复制代码
结果会生成一个train_valid_test的文件夹,里面有:
- test文件夹---unknow文件夹:5张没有标签的测试照片
- train_valid文件夹---10个类被的文件夹:每个文件夹包含所属类别的全部照片
- train文件夹--10个类别的文件夹:每个文件夹下包含90%的照片用于训练
- valid文件夹--10个类别的文件夹:每个文件夹下包含10%的照片用于验证
  1. 图像增广
python 复制代码
transform_train = torchvision.transforms.Compose([
    # 原本图像是32*32,先放大成40*40, 在随机裁剪为32*32,实现训练数据的增强
    torchvision.transforms.Resize(40),
    torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        [0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010]
    )
])
transform_test = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    # 标准化图像的每个通道 : 消除评估结果中的随机性
    torchvision.transforms.Normalize(
        [0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010]
    )
])
  1. 加载数据集
python 复制代码
train_ds, train_valid_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train
    ) for folder in ['train', 'train_valid']
]
valid_ds, test_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train_valid_test', folder), transform=transform_test
    ) for folder in ['valid', 'test']
]
  1. 定义迭代器,方便快速迭代数据
python 复制代码
train_iter, train_valid_iter = [
    torch.utils.data.DataLoader(
        dataset, batch_size, shuffle=True, drop_last=True
    ) for dataset in (train_ds, train_valid_ds)
]
valid_iter = torch.utils.data.DataLoader(
    valid_ds, batch_size, shuffle=False, drop_last=True
)
test_iter = torch.utils.data.DataLoader(
    test_ds, batch_size, shuffle=False, drop_last=False
)
  1. 定义模型与损失函数
python 复制代码
# 对resnet18做微调,输入通道数为3, 输出类别数为10
def get_net():
    num_classes = 10
    net = d2l.resnet18(num_classes, in_channels=3)
    return net
python 复制代码
# 查看网络模型
get_net()
python 复制代码
# 使用交叉熵损失函数作为损失函数: 直接返回n分样本的loss
loss = nn.CrossEntropyLoss(reduction='none')
  1. 定义训练函数
python 复制代码
# 定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):
    trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
    num_batches, timer = len(train_iter), d2l.Timer()
    legend = ['train loss', 'train acc']
    if valid_iter is not None:
        legend.append('valid acc')
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        net.train()
        metric = d2l.Accumulator(3)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = d2l.train_batch_ch13(net, features, labels, loss, trainer, devices)
            metric.add(l, acc, labels.shape[0])
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0]/ metric[2], metric[1] / metric[2], None))
        if valid_iter is not None:
            valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)
            animator.add(epoch+1, (None, None, valid_acc))
        scheduler.step()
    measures = (f'train loss {metric[0] / metric[2]:.3f},'
                f'train acc{metric[1] / metric[2]:.3f}')
    if valid_iter is not None:
        measures += f', valid acc {valid_acc:.3f}'
    print(measures + f'\n{metric[2] * num_epochs /timer.sum():.1f}'
                     f'example/sec on {str(devices)}')
  1. 训练模型
    • (数据集太小,导致精度不高)
python 复制代码
import time

# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以

# 训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4
lr_period, lr_decay, net = 4, 0.9, get_net()
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)

# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以

# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')
  1. 对测试集进行分类并提交结果
python 复制代码
net, preds = get_net(), []
train(net ,train_valid_iter, None, num_epochs, lr, wd, devices, lr_period, lr_decay)
for X, _ in test_iter:
    y_hat = net(X.to(devices[0]))
    preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())
sorted_ids = list(range(1, len(test_ds) + 1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id' : sorted_ids, 'label': preds})
df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])
df.to_csv('submission.csv', index=False)
相关推荐
FL16238631294 分钟前
[数据集][目标检测]车油口挡板开关闭合检测数据集VOC+YOLO格式138张2类别
人工智能·yolo·目标检测
YesPMP平台官方6 分钟前
AI+教育|拥抱AI智能科技,让课堂更生动高效
人工智能·科技·ai·数据分析·软件开发·教育
李小星同志11 分钟前
高级算法设计与分析 学习笔记6 B树
笔记·学习
霜晨月c23 分钟前
MFC 使用细节
笔记·学习·mfc
FL162386312932 分钟前
AI健身体能测试之基于paddlehub实现引体向上计数个数统计
人工智能
黑客-雨35 分钟前
构建你的AI职业生涯:从基础知识到专业实践的路线图
人工智能·产品经理·ai大模型·ai产品经理·大模型学习·大模型入门·大模型教程
小江湖199436 分钟前
元数据保护者,Caesium压缩不丢重要信息
运维·学习·软件需求·改行学it
子午37 分钟前
动物识别系统Python+卷积神经网络算法+TensorFlow+人工智能+图像识别+计算机毕业设计项目
人工智能·python·cnn
大耳朵爱学习1 小时前
掌握Transformer之注意力为什么有效
人工智能·深度学习·自然语言处理·大模型·llm·transformer·大语言模型
TAICHIFEI1 小时前
目标检测-数据集
人工智能·目标检测·目标跟踪