pytorch dataloader学习

复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np 

torch.manual_seed(1)
# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self):
        # 创建一些示例数据(100个样本,每个样本包含10个特征)
        self.data = torch.randn(100, 10)
        self.labels =torch.from_numpy(np.arange(100))  # 二分类标签

    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

    def __getitem__(self, idx):
        # 根据索引 idx 返回对应的样本和标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 创建数据集的实例
dataset = CustomDataset()

# 使用DataLoader加载数据
# 设置batch_size=16,shuffle=True表示打乱数据顺序
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# 迭代DataLoader
for i in range(2):
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        print(f"Batch {batch_idx+1}")
        print(f"Inputs: {inputs.size()}")  # 显示当前batch中输入数据的维度
        print(f"Labels: {labels.size()}")  # 显示当前batch中标签的维度
        print(labels)
        # 在这里你可以对数据进行训练
        # 例如:outputs = model(inputs)

只要是shuffle=True,每次epoch结果的顺序是不一样的,如果想每一次的结果是一样的

如果shuffle=False

复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np 

torch.manual_seed(1)
# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self):
        # 创建一些示例数据(100个样本,每个样本包含10个特征)
        self.data = torch.randn(100, 10)
        self.labels =torch.from_numpy(np.arange(100))  # 二分类标签

    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

    def __getitem__(self, idx):
        # 根据索引 idx 返回对应的样本和标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 创建数据集的实例
dataset = CustomDataset()

# 使用DataLoader加载数据
# 设置batch_size=16,shuffle=True表示打乱数据顺序
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# 迭代DataLoader
for i in range(2):
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        print(f"Batch {batch_idx+1}")
        print(f"Inputs: {inputs.size()}")  # 显示当前batch中输入数据的维度
        print(f"Labels: {labels.size()}")  # 显示当前batch中标签的维度
        print(labels)
        # 在这里你可以对数据进行训练
        # 例如:outputs = model(inputs)

结果如下

相关推荐
黎宇幻生25 分钟前
Java全栈学习笔记39
java·笔记·学习
ACP广源盛139246256731 小时前
(ACP广源盛)GSV1175---- MIPI/LVDS 转 Type-C/DisplayPort 1.2 转换器产品说明及功能分享
人工智能·音视频
胡耀超1 小时前
隐私计算技术全景:从联邦学习到可信执行环境的实战指南—数据安全——隐私计算 联邦学习 多方安全计算 可信执行环境 差分隐私
人工智能·安全·数据安全·tee·联邦学习·差分隐私·隐私计算
诸葛悠闲2 小时前
XCP协议在以太网上实现的配置
学习
停停的茶3 小时前
深度学习(目标检测)
人工智能·深度学习·目标检测
Y200309163 小时前
基于 CIFAR10 数据集的卷积神经网络(CNN)模型训练与集成学习
人工智能·cnn·集成学习
老兵发新帖3 小时前
主流神经网络快速应用指南
人工智能·深度学习·神经网络
AI量化投资实验室4 小时前
15年122倍,年化43.58%,回撤才20%,Optuna机器学习多目标调参backtrader,附python代码
人工智能·python·机器学习
java_logo4 小时前
vllm-openai Docker 部署手册
运维·人工智能·docker·ai·容器
倔强青铜三4 小时前
苦练Python第67天:光速读取任意行,linecache模块解锁文件处理新姿势
人工智能·python·面试