《动手学深度学习(PyTorch版)》笔记3.5

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过。

Chapter3 Linear Neural Networks

3.5 Image Classification Dataset

复制代码
import torch
import torchvision
import time
import matplotlib.pyplot as plt
import numpy as np 
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

def get_fashion_mnist_labels(labels):#@save
    """返回数据集的文本标签"""
    text_labels=['t-shirt','trouser','pullover','dress','coat','sandal','shirt','sneaker','bag','ankle boot']
    return [text_labels[int(i)] for i in labels]

def show_images(imgs,num_rows,num_cols,titles=None,scale=1.5):#@save
    """绘制图像列表"""
    figsize=(num_cols*scale,num_rows*scale)
    _,axes=d2l.plt.subplots(num_rows,num_cols,figsize=figsize)# The _ is a convention in Python to indicate a variable that is not going to be used. In this case, it is used to capture the first return value of subplots, which is the entire figure.
    axes=axes.flatten()
    #enumerate意为"枚举"
    for i ,(ax,img) in enumerate(zip(axes,imgs)):#The enumerate() function is used to get both the index i and the paired values.
        if torch.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

batch_size=256

def get_dataloader_workers():#@save
    """使用4个子进程来读取数据"""
    """每个子进程会预加载一批数据,并将数据放入一个共享内存区域。当主进程需要数据时,它可以直接从共享内存区域中获取,而不需要等待数据的读取和预处理。这样,主进程可以在处理当前批次的数据时,子进程已经在后台加载下一批数据,从而提高数据加载的效率。"""
    return 4

#定义一个计时器
class Timer:#@save
    def __init__(self) :
        self.times=[]
        self.start()
        
    def start(self):
        """启动计时器"""
        self.tik=time.time()
        
    def stop(self):
        """停止计时器并将时间记录在列表中"""
        self.times.append(time.time()-self.tik)
        return self.times[-1]

    def avg(self):
        """返回平均时间"""
        return sum(self.times)/len(self.times)
    
    def sum(self):
        """返回总时间"""
        return sum(self.times)
    
    def cumsum(self):
        """返回累计时间"""
        return np.array(self.times).cumsum().tolist()
    
def load_fashion_mnist(batch_size,resize=None):#@save
    """下载Fashion-MNIST数据集到内存中"""
    trans = [transforms.ToTensor()]
    #将PIL图像转为tensor格式(32位浮点数),并除以255使得所有像素的数值均为0-1
    #the "transforms" module in PyTorch's torchvision library is used to define a sequence of image transformations
    if resize:
        trans.insert(0,transforms.Resize(resize))
    trans=transforms.Compose(trans)
    #"Compose" transformation allows you to apply a sequence of transformations to an input. The resulting trans object can be applied to images or datasets.
    mnist_train=torchvision.datasets.FashionMNIST(root="./data",train=True,transform=trans,download=True)
    mnist_test=torchvision.datasets.FashionMNIST(root="./data",train=False,transform=trans,download=True)
    #"train=True"表示加载训练集,"train=False"表示加载测试集。
    #print(len(mnist_train),len(mnist_test))
    #每个输入的图像高度和宽度均为28像素,并且是灰度图像,通道数为1,下文将高度为h像素,宽为w像素的图像的的图像的形状记为(h,w)
    #print(mnist_train[0][0].shape)
    #mnist_train[0][0]指的是第一个图像数据的张量,mnist_train[0][1]指的是第一个图像的标签
    return (data.DataLoader(mnist_train,batch_size,shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test,batch_size,shuffle=False,
                            num_workers=get_dataloader_workers()))
    # The DataLoader is responsible for loading batches of data, shuffling the data (for training), and using multiple workers for data loading ("num_workers" parameter).
    # "shuffle=False" for mnist_test ensures that the test data remains in its original order during evaluation.

#X,y=next(iter(data.DataLoader(mnist_train,batch_size=18)))
#show_images(X.reshape(18,28,28),2,9,titles=get_fashion_mnist_labels(y))
#plt.show()

if __name__ == '__main__':
    train_iter, test_iter = load_fashion_mnist(32, resize=64)
    #timer = Timer()
    #for X, y in train_iter:
    #    continue
    #print(f'{timer.stop():.2f} sec')
    for X,y in train_iter:
        print(X.shape,X.dtype,y.shape,y.dtype)#dtype即data type
        break
相关推荐
cici158741 小时前
卡尔曼滤波器实现RBF神经网络训练
人工智能·深度学习·神经网络
U盘失踪了4 小时前
【笔记】Flask 用 session 对象存储用户状态
笔记
QQ2422199794 小时前
基于python+微信小程序的家教管理系统_mh3j9
开发语言·python·微信小程序
Neolnfra4 小时前
拒绝数据“裸奔”!把顶级AI装进自己的硬盘,这款神仙开源工具我粉了
人工智能·开源·蓝耘maas
code_li4 小时前
只花了几分钟,用AI开发了一个微信小程序!(附教程)
人工智能·微信小程序·小程序
飞Link4 小时前
瑞萨联姻 Irida Labs:嵌入式开发者如何玩转“端侧视觉 AI”新范式?
人工智能
RSTJ_16255 小时前
PYTHON+AI LLM DAY THREETY-SEVEN
开发语言·人工智能·python
郝学胜-神的一滴5 小时前
深度学习优化核心:梯度下降与网络训练全解析
数据结构·人工智能·python·深度学习·算法·机器学习
Aision_5 小时前
Agent 为什么需要 Checkpoint?
人工智能·python·gpt·langchain·prompt·aigc·agi
清水白石0085 小时前
《Python性能深潜:从对象分配开销到“小对象风暴”的破解之道(含实战与最佳实践)》
开发语言·python