pytorch dataloader学习

复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np 

torch.manual_seed(1)
# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self):
        # 创建一些示例数据(100个样本,每个样本包含10个特征)
        self.data = torch.randn(100, 10)
        self.labels =torch.from_numpy(np.arange(100))  # 二分类标签

    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

    def __getitem__(self, idx):
        # 根据索引 idx 返回对应的样本和标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 创建数据集的实例
dataset = CustomDataset()

# 使用DataLoader加载数据
# 设置batch_size=16,shuffle=True表示打乱数据顺序
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# 迭代DataLoader
for i in range(2):
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        print(f"Batch {batch_idx+1}")
        print(f"Inputs: {inputs.size()}")  # 显示当前batch中输入数据的维度
        print(f"Labels: {labels.size()}")  # 显示当前batch中标签的维度
        print(labels)
        # 在这里你可以对数据进行训练
        # 例如:outputs = model(inputs)

只要是shuffle=True,每次epoch结果的顺序是不一样的,如果想每一次的结果是一样的

如果shuffle=False

复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np 

torch.manual_seed(1)
# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self):
        # 创建一些示例数据(100个样本,每个样本包含10个特征)
        self.data = torch.randn(100, 10)
        self.labels =torch.from_numpy(np.arange(100))  # 二分类标签

    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

    def __getitem__(self, idx):
        # 根据索引 idx 返回对应的样本和标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 创建数据集的实例
dataset = CustomDataset()

# 使用DataLoader加载数据
# 设置batch_size=16,shuffle=True表示打乱数据顺序
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# 迭代DataLoader
for i in range(2):
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        print(f"Batch {batch_idx+1}")
        print(f"Inputs: {inputs.size()}")  # 显示当前batch中输入数据的维度
        print(f"Labels: {labels.size()}")  # 显示当前batch中标签的维度
        print(labels)
        # 在这里你可以对数据进行训练
        # 例如:outputs = model(inputs)

结果如下

相关推荐
天天扭码21 分钟前
从图片到语音:我是如何用两大模型API打造沉浸式英语学习工具的
前端·人工智能·github
SH11HF22 分钟前
小菜狗的云计算之旅,学习了解rsync+sersync实现数据实时同步(详细操作步骤)
学习·云计算
Frank学习路上37 分钟前
【IOS】XCode创建firstapp并运行(成为IOS开发者)
开发语言·学习·ios·cocoa·xcode
张彦峰ZYF1 小时前
从检索到生成:RAG 如何重构大模型的知识边界?
人工智能·ai·aigc
刘海东刘海东1 小时前
结构型智能科技的关键可行性——信息型智能向结构型智能的转变(修改提纲)
人工智能·算法·机器学习
**梯度已爆炸**1 小时前
NLP文本预处理
人工智能·深度学习·nlp
uncle_ll1 小时前
李宏毅NLP-8-语音模型
人工智能·自然语言处理·语音识别·语音模型·lm
Liudef061 小时前
FLUX.1-Kontext 高效训练 LoRA:释放大语言模型定制化潜能的完整指南
人工智能·语言模型·自然语言处理·ai作画·aigc
静心问道1 小时前
大型语言模型中的自动化思维链提示
人工智能·语言模型·大模型
众链网络2 小时前
你的Prompt还有很大提升
人工智能·prompt·ai写作·ai工具·ai智能体