pytorch dataloader学习

复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np 

torch.manual_seed(1)
# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self):
        # 创建一些示例数据(100个样本,每个样本包含10个特征)
        self.data = torch.randn(100, 10)
        self.labels =torch.from_numpy(np.arange(100))  # 二分类标签

    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

    def __getitem__(self, idx):
        # 根据索引 idx 返回对应的样本和标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 创建数据集的实例
dataset = CustomDataset()

# 使用DataLoader加载数据
# 设置batch_size=16,shuffle=True表示打乱数据顺序
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# 迭代DataLoader
for i in range(2):
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        print(f"Batch {batch_idx+1}")
        print(f"Inputs: {inputs.size()}")  # 显示当前batch中输入数据的维度
        print(f"Labels: {labels.size()}")  # 显示当前batch中标签的维度
        print(labels)
        # 在这里你可以对数据进行训练
        # 例如:outputs = model(inputs)

只要是shuffle=True,每次epoch结果的顺序是不一样的,如果想每一次的结果是一样的

如果shuffle=False

复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np 

torch.manual_seed(1)
# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self):
        # 创建一些示例数据(100个样本,每个样本包含10个特征)
        self.data = torch.randn(100, 10)
        self.labels =torch.from_numpy(np.arange(100))  # 二分类标签

    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

    def __getitem__(self, idx):
        # 根据索引 idx 返回对应的样本和标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 创建数据集的实例
dataset = CustomDataset()

# 使用DataLoader加载数据
# 设置batch_size=16,shuffle=True表示打乱数据顺序
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# 迭代DataLoader
for i in range(2):
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        print(f"Batch {batch_idx+1}")
        print(f"Inputs: {inputs.size()}")  # 显示当前batch中输入数据的维度
        print(f"Labels: {labels.size()}")  # 显示当前batch中标签的维度
        print(labels)
        # 在这里你可以对数据进行训练
        # 例如:outputs = model(inputs)

结果如下

相关推荐
我们从未走散1 小时前
JVM学习笔记-----StringTable
jvm·笔记·学习
The Open Group2 小时前
英特尔公司Darren Pulsipher 博士:以架构之力推动政府数字化转型
大数据·人工智能·架构
胡萝卜3.02 小时前
数据结构初阶:排序算法(一)插入排序、选择排序
数据结构·笔记·学习·算法·排序算法·学习方法
Ronin-Lotus2 小时前
深度学习篇---卷积核的权重
人工智能·深度学习
.银河系.2 小时前
8.18 机器学习-决策树(1)
人工智能·决策树·机器学习
敬往事一杯酒哈2 小时前
第7节 神经网络
人工智能·深度学习·神经网络
三掌柜6662 小时前
NVIDIA 技术沙龙探秘:聚焦 Physical AI 专场前沿技术
大数据·人工智能
2502_927161282 小时前
DAY 42 Grad-CAM与Hook函数
人工智能
Hello123网站3 小时前
Flowith-节点式GPT-4 驱动的AI生产力工具
人工智能·ai工具
xinzheng新政3 小时前
纸板制造制胶工艺学习记录4
学习·制造