动手学深度学习(pytorch土堆)-04torchvision中数据集的使用

CIFAR10

CIFAR-10 数据集由 10 个类的 60000 张 32x32 彩色图像组成,每个类有 6000 张图像。有 50000 张训练图像和 10000 张测试图像。

数据集分为 5 个训练批次和 1 个测试批次,每个批次有 10000 张图像。测试批次包含每个类中随机选择的 1000 张图像。训练批次包含按随机顺序排列的剩余图像,但某些训练批次可能包含来自一个类的图像多于另一个类的图像。在它们之间,训练批次包含来自每个类的 5000 张图像。

c 复制代码
import torchvision
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)
print(test_set[0])

(<PIL.Image.Image image mode=RGB size=32x32 at 0x1F5B55DD5E0>, 3)

test_set[]存放两个数据,一个是图像本身,一个是标签

图片显示

c 复制代码
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]
                                                 )#将图片都转为tensor数据类型
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)
print(test_set[0])
writer=SummaryWriter("p10")
for i in range(10):
    img,target=test_set[i]
    writer.add_image("test_set",img,i)
writer.close()

dataloader

参数

dataset (Dataset) -- 从中加载数据的数据集。

batch_size (int, optional) -- 每批要加载的样本数 (默认值:)。1

shuffle (bool, optional) -- 设置为重新洗牌数据 在每个 epoch (默认值: )。TrueFalse

sampler (Sampler 或 Iterable,可选) -- 定义绘制的策略 数据集中的样本。可以是任何已实施的。如果指定,则不得指定。Iterable__len__shuffle

batch_sampler (Sampler 或 Iterable,可选) -- 类似于 ,但 一次返回一批索引。与 、 、 互斥 和。batch_sizeshuffledrop_last

num_workers (int, optional) -- 用于数据的子进程数 装载。 表示数据将在主进程中加载。 (默认:00)

collate_fn (Callable, optional) -- 合并样本列表以形成 小批量的 Tensor 中。当使用 batch loading from 地图样式数据集。

pin_memory (bool, optional) -- 如果 ,数据加载器将复制 Tensor 放入 device/CUDA 固定内存中。如果您的数据元素 是自定义类型,或者您返回的批次是自定义类型, 请参阅下面的示例。Truecollate_fn

drop_last (bool, optional) -- 设置为 以删除最后一个未完成的批次, 如果数据集大小不能被批量大小整除。If 和 数据集的大小不能被批次大小整除,然后是最后一个批次 会更小。(默认:TrueFalseFalse)

timeout (numeric, optional) -- 如果为正数,则为收集批次的超时值 从工人。应始终为非负数。(默认:0)

worker_init_fn (Callable, optional) -- 如果不是 ,则将在每个 worker 子进程,其中 worker id ( int in ) 为 input、seeding 之后和 data loading 之前。(默认:None[0, num_workers - 1]None)

multiprocessing_context (str 或 multiprocessing.context.BaseContext,可选) -- 如果 ,则操作系统的默认多处理上下文将 被使用。(默认:NoneNone)

发电机(Torch.生成器,可选) -- 如果没有,将使用此 RNG 通过 RandomSampler 生成随机索引,并通过 multiprocessing 为 worker 生成。(默认:Nonebase_seedNone)

prefetch_factor (int, optional, keyword-only arg) -- 加载的批次数 由每个 worker 提前完成。 表示总共会有 2 * num_workers 个批次,在所有工作程序中预取。(默认值取决于 在 num_workers 的 Set 值上。如果值 num_workers=0,则默认值为 。 否则,如果 default 的值为 )。2Nonenum_workers > 02

persistent_workers (bool, optional) -- 如果 ,则数据加载器不会关闭 工作程序在 dataset 被使用一次后进行处理。这允许 保持 worker Dataset 实例处于活动状态。(默认:TrueFalse)

pin_memory_device (str, optional) -- 如果设备为 。pin_memorypin_memoryTrue

c 复制代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
#测试数据集第一张图片
img,target=test_data[0]
#

writer=SummaryWriter("dataloader")
i=0
for data in test_loader:
    imgs,target=data
    writer.add_images("testdata",imgs,i)
    i=i+1
writer.close()
相关推荐
qq_283720052 分钟前
本地大模型部署全教程:Python 低成本调用开源 AI 模型
人工智能·python·开源
金融小师妹2 分钟前
AI多因子定价模型:美元强化与能源约束下 黄金反弹受限弹性解析
深度学习·svn·逻辑回归·能源
胡利光2 分钟前
AI Agent 实战避坑 05|AI 版 TDD:Eval-Driven Development 完全指南
人工智能
米奇妙啊妙11 分钟前
agent 学习 -模拟AI调用工具
人工智能·学习
试剂界的爱马仕19 分钟前
AI学习实现:如何给基金实时估值?
大数据·人工智能·科技·学习·机器学习
笑不语23 分钟前
从共病网络到可解释 AI:同济医院 10 分 SCI 全流程复现(R 语言)
开发语言·人工智能·r语言
xiangzhihong827 分钟前
Claude Code系列教程之Claude Code 基础用法基础用法
人工智能
deephub29 分钟前
2026年的 ReAct Agent架构解析:原生 Tool Calling 与 LangGraph 状态机
人工智能·大语言模型·agent·langgraph
淡海水42 分钟前
【AI模型】概念-Token
人工智能·算法
数智化精益手记局1 小时前
什么是安全生产?解读安全生产的基本方针与核心要求
大数据·运维·人工智能·安全·信息可视化·自动化·精益工程