DataLoader的使用

官方网站进行查看DataLoader

batch_size 的含义

python 复制代码
import torchvision
from torch.utils.data import DataLoader

# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10('D:\Pytorch\pythonProject\Transform\dataset', train=False, transform=torchvision.transforms.ToTensor())

test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=False, num_workers=0, drop_last=False)

# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape) # torch.Size([3, 32, 32])
print(target) # 3

for data in test_loader:
    imgs, targets = data
    print(imgs.shape) # torch.Size([4, 3, 32, 32]); 4就是batch_size, 3是通道, 32×32是图片大小
    print(targets) # tensor([3, 8, 8, 0]); 4张图片的target
python 复制代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10('D:\Pytorch\pythonProject\Transform\dataset', train=False, transform=torchvision.transforms.ToTensor())

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape) # torch.Size([3, 32, 32])
print(target) # 3

writer = SummaryWriter('dataloader')
for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        # print(imgs.shape) # torch.Size([4, 3, 32, 32]); 4就是batch_size, 3是通道, 32×32是图片大小
        # print(targets) # tensor([3, 8, 8, 0]); 4张图片的target
        writer.add_images('Epoch: {}'.format(epoch), imgs, step)
        step += 1

writer.close()

shuffle=True 的话,会随机成batch

相关推荐
Dxy12393102163 分钟前
Python如何使用DrissionPage做自动化:简单入门指南
开发语言·python·自动化
珂朵莉MM5 分钟前
2025年睿抗机器人开发者大赛CAIP-编程技能赛-高职组(国赛)解题报告 | 珂学家
java·开发语言·人工智能·算法·机器人
石去皿5 分钟前
从本地知识库到“活”知识——RAG 落地全景指南
c++·python·大模型·rag
hui函数9 分钟前
Python系列Bug修复PyCharm控制台pip install报错:如何解决 pip install 网络报错 企业网关拦截 User-Agent 问题
python·pycharm·bug
猫头虎9 分钟前
Claude Code 永动机:ralph-loop 无限循环迭代插件详解(安装 / 原理 / 最佳实践 / 避坑)
ide·人工智能·langchain·开源·编辑器·aigc·编程技术
a努力。12 分钟前
虾皮Java面试被问:JVM Native Memory Tracking追踪堆外内存泄漏
java·开发语言·jvm·后端·python·面试
Kratzdisteln12 分钟前
【Python】Flask
开发语言·python·flask
aigcapi13 分钟前
如何让AI推广我的品牌?成长期企业GEO优化的“降本增效”实战指南
人工智能
百***243720 分钟前
GPT-5.2国内调用+API中转+成本管控
大数据·人工智能·深度学习
min18112345627 分钟前
金融风控中的实时行为建模
大数据·人工智能