动手学深度学习(pytorch土堆)-04torchvision中数据集的使用

CIFAR10

CIFAR-10 数据集由 10 个类的 60000 张 32x32 彩色图像组成,每个类有 6000 张图像。有 50000 张训练图像和 10000 张测试图像。

数据集分为 5 个训练批次和 1 个测试批次,每个批次有 10000 张图像。测试批次包含每个类中随机选择的 1000 张图像。训练批次包含按随机顺序排列的剩余图像,但某些训练批次可能包含来自一个类的图像多于另一个类的图像。在它们之间,训练批次包含来自每个类的 5000 张图像。

c 复制代码
import torchvision
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)
print(test_set[0])

(<PIL.Image.Image image mode=RGB size=32x32 at 0x1F5B55DD5E0>, 3)

test_set[]存放两个数据,一个是图像本身,一个是标签

图片显示

c 复制代码
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]
                                                 )#将图片都转为tensor数据类型
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)
print(test_set[0])
writer=SummaryWriter("p10")
for i in range(10):
    img,target=test_set[i]
    writer.add_image("test_set",img,i)
writer.close()

dataloader

参数

dataset (Dataset) -- 从中加载数据的数据集。

batch_size (int, optional) -- 每批要加载的样本数 (默认值:)。1

shuffle (bool, optional) -- 设置为重新洗牌数据 在每个 epoch (默认值: )。TrueFalse

sampler (Sampler 或 Iterable,可选) -- 定义绘制的策略 数据集中的样本。可以是任何已实施的。如果指定,则不得指定。Iterable__len__shuffle

batch_sampler (Sampler 或 Iterable,可选) -- 类似于 ,但 一次返回一批索引。与 、 、 互斥 和。batch_sizeshuffledrop_last

num_workers (int, optional) -- 用于数据的子进程数 装载。 表示数据将在主进程中加载。 (默认:00)

collate_fn (Callable, optional) -- 合并样本列表以形成 小批量的 Tensor 中。当使用 batch loading from 地图样式数据集。

pin_memory (bool, optional) -- 如果 ,数据加载器将复制 Tensor 放入 device/CUDA 固定内存中。如果您的数据元素 是自定义类型,或者您返回的批次是自定义类型, 请参阅下面的示例。Truecollate_fn

drop_last (bool, optional) -- 设置为 以删除最后一个未完成的批次, 如果数据集大小不能被批量大小整除。If 和 数据集的大小不能被批次大小整除,然后是最后一个批次 会更小。(默认:TrueFalseFalse)

timeout (numeric, optional) -- 如果为正数,则为收集批次的超时值 从工人。应始终为非负数。(默认:0)

worker_init_fn (Callable, optional) -- 如果不是 ,则将在每个 worker 子进程,其中 worker id ( int in ) 为 input、seeding 之后和 data loading 之前。(默认:None[0, num_workers - 1]None)

multiprocessing_context (str 或 multiprocessing.context.BaseContext,可选) -- 如果 ,则操作系统的默认多处理上下文将 被使用。(默认:NoneNone)

发电机(Torch.生成器,可选) -- 如果没有,将使用此 RNG 通过 RandomSampler 生成随机索引,并通过 multiprocessing 为 worker 生成。(默认:Nonebase_seedNone)

prefetch_factor (int, optional, keyword-only arg) -- 加载的批次数 由每个 worker 提前完成。 表示总共会有 2 * num_workers 个批次,在所有工作程序中预取。(默认值取决于 在 num_workers 的 Set 值上。如果值 num_workers=0,则默认值为 。 否则,如果 default 的值为 )。2Nonenum_workers > 02

persistent_workers (bool, optional) -- 如果 ,则数据加载器不会关闭 工作程序在 dataset 被使用一次后进行处理。这允许 保持 worker Dataset 实例处于活动状态。(默认:TrueFalse)

pin_memory_device (str, optional) -- 如果设备为 。pin_memorypin_memoryTrue

c 复制代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
#测试数据集第一张图片
img,target=test_data[0]
#

writer=SummaryWriter("dataloader")
i=0
for data in test_loader:
    imgs,target=data
    writer.add_images("testdata",imgs,i)
    i=i+1
writer.close()
相关推荐
张人玉2 小时前
人工智能——猴子摘香蕉问题
人工智能
草莓屁屁我不吃2 小时前
Siri因ChatGPT-4o升级:我们的个人信息还安全吗?
人工智能·安全·chatgpt·chatgpt-4o
小言从不摸鱼2 小时前
【AI大模型】ChatGPT模型原理介绍(下)
人工智能·python·深度学习·机器学习·自然语言处理·chatgpt
AI科研视界2 小时前
ChatGPT+2:修订初始AI安全性和超级智能假设
人工智能·chatgpt
霍格沃兹测试开发学社测试人社区2 小时前
人工智能 | 基于ChatGPT开发人工智能服务平台
软件测试·人工智能·测试开发·chatgpt
小R资源3 小时前
3款免费的GPT类工具
人工智能·gpt·chatgpt·ai作画·ai模型·国内免费
artificiali5 小时前
Anaconda配置pytorch的基本操作
人工智能·pytorch·python
酱香编程,风雨兼程6 小时前
深度学习——基础知识
人工智能·深度学习
Lossya6 小时前
【机器学习】参数学习的基本概念以及贝叶斯网络的参数学习和马尔可夫随机场的参数学习
人工智能·学习·机器学习·贝叶斯网络·马尔科夫随机场·参数学习