动手学深度学习(pytorch)学习记录9-图像分类数据集之Fashion-MNIST[学习记录]

注:本代码在jupyter notebook上运行
封面图片来源

Fashion-MNIST是一个广泛使用的图像数据集,主要用于机器学习算法的基准测试,特别是图像分类和识别任务。Fashion-MNIST由德国的时尚科技公司Zalando旗下的研究部门提供。作为MNIST手写数字集的一个直接替代品,旨在提供更具挑战性且更现代的机器学习基准测试数据集。数据集的图像结构简单,但分类难度相比MNIST有所提升,要求模型具备更强的特征提取和模式识别能力。

数据集总共包含70,000张灰度图像,分为60,000张训练图像和10,000张测试图像。其中每张图像都是28x28像素的灰度图像。涵盖了10种不同的衣物类型,包括T恤、裤子、套衫、裙子、外套、凉鞋、汗衫、运动鞋、包和踝靴。

python 复制代码
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
# from d2l import torch as d2l

# d2l.use_svg_display()

Fashion-MNIST数据集下载

python 复制代码
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="data", train=False, transform=trans, download=True)

Fashion-MNIST由10个类别的图像组成, 每个类别由训练数据集(train dataset)中的6000张图像和测试数据集(test dataset)中的1000张图像组成。 因此,训练集和测试集分别包含60000和10000张图像。 测试数据集不会用于训练,只用于评估模型性能。

python 复制代码
# 获取数据集长度
len(mnist_train), len(mnist_test)

(60000, 10000)

每个输入的灰度图像的高度和宽度均为28像素,通道数为1。

python 复制代码
mnist_train[0][0].shape

torch.Size([1, 28, 28])

Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。 以下函数用于在数字标签索引及其文本名称之间进行转换。

python 复制代码
def get_fashion_mnist_labels(labels):  
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

创建一个函数来可视化这些样本

python 复制代码
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): 
    """绘制图像列表""" # 图片、行数、列数
    figsize = (num_cols * scale, num_rows * scale) # 画布尺寸
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy()) # 转化成张量
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False) # 设置x,y轴不可见
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

以下是训练数据集中前几个样本的图像及其相应的标签。

python 复制代码
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

小批量读取

python 复制代码
batch_size = 256

def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

读取训练数据所需的时间。

python 复制代码
# 定义计时器
import time
import numpy as np
class Timer:
    """记录多次运行时间"""
    def __init__(self):
        self.times = []
        self.start()

    def start(self):
        """启动计时器"""
        self.tik = time.time()

    def stop(self):
        """停止计时器并将时间记录在列表中"""
        self.times.append(time.time() - self.tik)
        return self.times[-1]# 返回列表最后记录的时间

    def avg(self):
        """返回平均时间"""
        return sum(self.times) / len(self.times)

    def sum(self):
        """返回时间总和"""
        return sum(self.times)

    def cumsum(self):
        """返回累计时间"""
        return np.array(self.times).cumsum().tolist()
python 复制代码
timer = Timer()
for X, y in train_iter:
    continue
f'{timer.stop():.2f} sec'

欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/

恳请大佬批评指正。

相关推荐
天水幼麟2 分钟前
python学习笔记(深度学习)
笔记·python·学习
巴里巴气4 分钟前
安装GPU版本的Pytorch
人工智能·pytorch·python
wt_cs25 分钟前
银行回单ocr api集成解析-图像文字识别-文字识别技术
开发语言·python
you45801 小时前
小程序学习笔记:使用 MobX 实现全局数据共享,实例创建、计算属性与 Actions 方法
笔记·学习·小程序
_WndProc1 小时前
【Python】Flask网页
开发语言·python·flask
互联网搬砖老肖1 小时前
Python 中如何使用 Conda 管理版本和创建 Django 项目
python·django·conda
测试者家园1 小时前
基于DeepSeek和crewAI构建测试用例脚本生成器
人工智能·python·测试用例·智能体·智能化测试·crewai
大模型真好玩1 小时前
准确率飙升!Graph RAG如何利用知识图谱提升RAG答案质量(四)——微软GraphRAG代码实战
人工智能·python·mcp
Brookty1 小时前
【MySQL】JDBC编程
java·数据库·后端·学习·mysql·jdbc
前端付豪1 小时前
11、打造自己的 CLI 工具:从命令行到桌面效率神器
后端·python