DataLoader的使用

官方网站进行查看DataLoader

batch_size 的含义

python 复制代码
import torchvision
from torch.utils.data import DataLoader

# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10('D:\Pytorch\pythonProject\Transform\dataset', train=False, transform=torchvision.transforms.ToTensor())

test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=False, num_workers=0, drop_last=False)

# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape) # torch.Size([3, 32, 32])
print(target) # 3

for data in test_loader:
    imgs, targets = data
    print(imgs.shape) # torch.Size([4, 3, 32, 32]); 4就是batch_size, 3是通道, 32×32是图片大小
    print(targets) # tensor([3, 8, 8, 0]); 4张图片的target
python 复制代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10('D:\Pytorch\pythonProject\Transform\dataset', train=False, transform=torchvision.transforms.ToTensor())

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape) # torch.Size([3, 32, 32])
print(target) # 3

writer = SummaryWriter('dataloader')
for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        # print(imgs.shape) # torch.Size([4, 3, 32, 32]); 4就是batch_size, 3是通道, 32×32是图片大小
        # print(targets) # tensor([3, 8, 8, 0]); 4张图片的target
        writer.add_images('Epoch: {}'.format(epoch), imgs, step)
        step += 1

writer.close()

shuffle=True 的话,会随机成batch

相关推荐
新缸中之脑1 分钟前
PufferLib高性能强化学习库
人工智能
FS_Marking2 分钟前
短距离网络10G SFP+光模块选型指南
网络·人工智能
行走的小派2 分钟前
本地跑模型+原生开源鸿蒙:拆解香橙派AI手机的12TOPS端侧硬核玩法
人工智能·开源·harmonyos
2501_948114242 分钟前
从 Claude Code 源码泄露看 2026 年 Agent 架构演进与工程化实践
大数据·人工智能·架构
小悟空2 分钟前
[AI生成]Iceberg 更新操作技术调研报告
人工智能
hughnz5 分钟前
断钻具的原因与预防
人工智能·钻井
疯狂成瘾者6 分钟前
增强型大模型代理
python
小李云雾7 分钟前
FastAPI 后端开发:文件上传 + 表单提交
开发语言·python·lua·postman·fastapi
Legend NO248 分钟前
数据资产评估风险识别、分析与管控体系建设
大数据·人工智能·python
llm大模型算法工程师weng13 分钟前
Python敏感词检测方案详解
开发语言·python·c#