pytorch dataloader学习

复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np 

torch.manual_seed(1)
# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self):
        # 创建一些示例数据(100个样本,每个样本包含10个特征)
        self.data = torch.randn(100, 10)
        self.labels =torch.from_numpy(np.arange(100))  # 二分类标签

    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

    def __getitem__(self, idx):
        # 根据索引 idx 返回对应的样本和标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 创建数据集的实例
dataset = CustomDataset()

# 使用DataLoader加载数据
# 设置batch_size=16,shuffle=True表示打乱数据顺序
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# 迭代DataLoader
for i in range(2):
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        print(f"Batch {batch_idx+1}")
        print(f"Inputs: {inputs.size()}")  # 显示当前batch中输入数据的维度
        print(f"Labels: {labels.size()}")  # 显示当前batch中标签的维度
        print(labels)
        # 在这里你可以对数据进行训练
        # 例如:outputs = model(inputs)

只要是shuffle=True,每次epoch结果的顺序是不一样的,如果想每一次的结果是一样的

如果shuffle=False

复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np 

torch.manual_seed(1)
# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self):
        # 创建一些示例数据(100个样本,每个样本包含10个特征)
        self.data = torch.randn(100, 10)
        self.labels =torch.from_numpy(np.arange(100))  # 二分类标签

    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

    def __getitem__(self, idx):
        # 根据索引 idx 返回对应的样本和标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 创建数据集的实例
dataset = CustomDataset()

# 使用DataLoader加载数据
# 设置batch_size=16,shuffle=True表示打乱数据顺序
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# 迭代DataLoader
for i in range(2):
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        print(f"Batch {batch_idx+1}")
        print(f"Inputs: {inputs.size()}")  # 显示当前batch中输入数据的维度
        print(f"Labels: {labels.size()}")  # 显示当前batch中标签的维度
        print(labels)
        # 在这里你可以对数据进行训练
        # 例如:outputs = model(inputs)

结果如下

相关推荐
o_insist4 分钟前
LangChain1.0 实现 PDF 文档向量检索全流程
人工智能·python·langchain
victory04316 分钟前
大模型学习阶段总结和下一阶段展望
深度学习·学习·大模型
OpenMiniServer7 分钟前
AI + GitLab + VSCode:下一代开发工作流的革命性集成
人工智能·vscode·gitlab
脑洞AI食验员9 分钟前
智能体来了:用异常与文件处理守住代码底线
人工智能·python
程序猿零零漆10 分钟前
Spring之旅 - 记录学习 Spring 框架的过程和经验(十三)SpringMVC快速入门、请求处理
java·学习·spring
摘星观月13 分钟前
【三维重建2】TCPFormer以及NeRF相关SOTA方法
人工智能·深度学习
shangjian00713 分钟前
AI大模型-机器学习-分类
人工智能·机器学习·分类
Tiny_React15 分钟前
使用 Claude Code Skills 模拟的视频生成流程
人工智能·音视频开发·vibecoding
人工小情绪18 分钟前
深度学习模型部署
人工智能·深度学习
曾浩轩19 分钟前
跟着江协科技学STM32之4-5OLED模块教程OLED显示原理
科技·stm32·单片机·嵌入式硬件·学习