(动手学习深度学习)第13章 实战kaggle竞赛:CIFAR-10

  1. 导入相关库
python 复制代码
import collections
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
  1. 下载数据集
python 复制代码
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',
                                '2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')

# 如果使用完整的Kaggle竞赛的数据集,设置demo为False
demo = True

if demo:
    data_dir = d2l.download_extract('cifar10_tiny')
else:
    data_dir = '../data/kaggle/cifar-10/'
  1. 整理数据集
python 复制代码
# 查看数据集
def read_csv_labels(fname):
    """读取'fname'来给标签字典返回一个文件名"""
    with open(fname, 'r') as f:
        lines = f.readlines()[1:]  # readlines(): 每次读文档的一行,以后还需要逐步循环
        tokens = [l.rstrip().split(',') for l in lines]  # rstrip(): 删除字符串后面(右面)的空格或特殊字符, 还有lstrip(左面)、strip(两面)
        return dict((name, label) for name, label in tokens)

labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
print('训练样本:', len(labels))
print('类别:', len(set(labels.values())))  # set(): 集合,里面不能包含重复的元素,接受一个list作为参数

将验证集从原始的训练集钟拆分出来

python 复制代码
# 拆分数据集:训练集、验证集
def copyfile(filename, target_dir):
    """将文件复制到目标目录"""
    os.makedirs(target_dir, exist_ok=True)  # 创建多层目录,exist_ok为True:在目标目录已存在的情况下不会触发FileExistsError异常。
    shutil.copy(filename, target_dir)  #拷贝文件,filename:要拷贝的文件;target_dir:目标文件夹

def reorg_train_valid(data_dir, labels, valid_ratio):
    """将验证集从原始训练集钟拆分出来"""
    # 训练数据集中样本数量最少的类别中的样本数
    # Counter: 计数器,返回一个字典,键为元素,值为元素个数;
    # .most_common(): 返回一个列表, 列表元素为(元素,出现次数),默认按出现频率排序
    # [-1]: 样本数量最少的类别(类别, 样本数),[-1][1]: 样本数数量最少的类别中的样本数
    n = collections.Counter(labels.values()).most_common()[-1][1]
    # 验证集中每个类别的样本数
    n_valid_per_label= max(1, math.floor((n * valid_ratio)))  # math.floor(): 向下取整  math.ceil(): 向上取整
    label_count = {}

    # 遍历原始训练集中的每个样本
    for train_file in os.listdir(os.path.join(data_dir, 'train')):
        label = labels[train_file.split('.')[0]]  # 从文件名中提取标签
        fname = os.path.join(data_dir, 'train', train_file)
        copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train_valid', label))
        # 如果该类别的样本数还未达到在验证集中的设定数量,则将样本复制到验证集中
        if label not in label_count or label_count[label] < n_valid_per_label:
            copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'valid', label))
            label_count[label] = label_count.get(label, 0) + 1
        else:
            copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train', label))

    return n_valid_per_label

# reorg_test函数用来在预测期间整理测试集,以方便读取
def reorg_test(data_dir):
    """在预测期间整理测试集,以方便读取"""
    # 遍历测试集中的每个样本
    for test_file in os.listdir(os.path.join(data_dir, 'test')):
        # 将测试集中的样本复制到新的目录结构中的 'test' 子目录下,标签为 'unknown'
        copyfile(os.path.join(data_dir, 'test', test_file),
                 os.path.join(data_dir, 'train_valid_test', 'test', 'unknown'))
python 复制代码
# 整个处理数据集函数
def reorg_cifar10_data(data_dir, valid_ratio):
    labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
    reorg_train_valid(data_dir, labels, valid_ratio)
    reorg_test(data_dir)
  • 这个小规模数据集的批量大小是32,在实际的cifar-10数据集中,可以设为128
  • 将10%的训练样本作为调整超参数的验证集
python 复制代码
batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_cifar10_data(data_dir, valid_ratio)
python 复制代码
结果会生成一个train_valid_test的文件夹,里面有:
- test文件夹---unknow文件夹:5张没有标签的测试照片
- train_valid文件夹---10个类被的文件夹:每个文件夹包含所属类别的全部照片
- train文件夹--10个类别的文件夹:每个文件夹下包含90%的照片用于训练
- valid文件夹--10个类别的文件夹:每个文件夹下包含10%的照片用于验证
  1. 图像增广
python 复制代码
transform_train = torchvision.transforms.Compose([
    # 原本图像是32*32,先放大成40*40, 在随机裁剪为32*32,实现训练数据的增强
    torchvision.transforms.Resize(40),
    torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        [0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010]
    )
])
transform_test = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    # 标准化图像的每个通道 : 消除评估结果中的随机性
    torchvision.transforms.Normalize(
        [0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010]
    )
])
  1. 加载数据集
python 复制代码
train_ds, train_valid_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train
    ) for folder in ['train', 'train_valid']
]
valid_ds, test_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train_valid_test', folder), transform=transform_test
    ) for folder in ['valid', 'test']
]
  1. 定义迭代器,方便快速迭代数据
python 复制代码
train_iter, train_valid_iter = [
    torch.utils.data.DataLoader(
        dataset, batch_size, shuffle=True, drop_last=True
    ) for dataset in (train_ds, train_valid_ds)
]
valid_iter = torch.utils.data.DataLoader(
    valid_ds, batch_size, shuffle=False, drop_last=True
)
test_iter = torch.utils.data.DataLoader(
    test_ds, batch_size, shuffle=False, drop_last=False
)
  1. 定义模型与损失函数
python 复制代码
# 对resnet18做微调,输入通道数为3, 输出类别数为10
def get_net():
    num_classes = 10
    net = d2l.resnet18(num_classes, in_channels=3)
    return net
python 复制代码
# 查看网络模型
get_net()
python 复制代码
# 使用交叉熵损失函数作为损失函数: 直接返回n分样本的loss
loss = nn.CrossEntropyLoss(reduction='none')
  1. 定义训练函数
python 复制代码
# 定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):
    trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
    num_batches, timer = len(train_iter), d2l.Timer()
    legend = ['train loss', 'train acc']
    if valid_iter is not None:
        legend.append('valid acc')
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        net.train()
        metric = d2l.Accumulator(3)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = d2l.train_batch_ch13(net, features, labels, loss, trainer, devices)
            metric.add(l, acc, labels.shape[0])
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0]/ metric[2], metric[1] / metric[2], None))
        if valid_iter is not None:
            valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)
            animator.add(epoch+1, (None, None, valid_acc))
        scheduler.step()
    measures = (f'train loss {metric[0] / metric[2]:.3f},'
                f'train acc{metric[1] / metric[2]:.3f}')
    if valid_iter is not None:
        measures += f', valid acc {valid_acc:.3f}'
    print(measures + f'\n{metric[2] * num_epochs /timer.sum():.1f}'
                     f'example/sec on {str(devices)}')
  1. 训练模型
    • (数据集太小,导致精度不高)
python 复制代码
import time

# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以

# 训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4
lr_period, lr_decay, net = 4, 0.9, get_net()
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)

# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以

# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')
  1. 对测试集进行分类并提交结果
python 复制代码
net, preds = get_net(), []
train(net ,train_valid_iter, None, num_epochs, lr, wd, devices, lr_period, lr_decay)
for X, _ in test_iter:
    y_hat = net(X.to(devices[0]))
    preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())
sorted_ids = list(range(1, len(test_ds) + 1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id' : sorted_ids, 'label': preds})
df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])
df.to_csv('submission.csv', index=False)
相关推荐
励志不掉头发的内向程序员5 分钟前
从零开始的python学习——文件
开发语言·python·学习
l1t12 分钟前
张泽鹏先生手搓的纯ANSI处理UTF-8与美团龙猫调用expat库读取Excel xml对比测试
xml·人工智能·excel·utf8·expat
THMAIL14 分钟前
量化基金从小白到大师 - 金融数据获取大全:从免费API到Tick级数据实战指南
人工智能·python·深度学习·算法·机器学习·金融·kafka
zzywxc78714 分钟前
AI在金融、医疗、教育、制造业等领域的落地案例(含代码、流程图、Prompt示例与图表)
人工智能·spring·机器学习·金融·数据挖掘·prompt·流程图
周末程序猿1 小时前
谈谈Vibe编程(氛围编程)
人工智能
悠哉悠哉愿意1 小时前
【数学建模学习笔记】无监督聚类模型:分层聚类
笔记·python·学习·数学建模
Tiger Z1 小时前
《动手学深度学习v2》学习笔记 | 2.4 微积分 & 2.5 自动微分
pytorch·深度学习·ai
软件算法开发1 小时前
基于LSTM深度学习的网络流量测量算法matlab仿真
深度学习·matlab·lstm·网络流量测量
北冥电磁电子智能2 小时前
江协科技STM32学习笔记补充之004
笔记·科技·学习
水印云2 小时前
AI配音工具哪个好用?7款热门配音软件推荐指南!
人工智能·语音识别