pytorch dataloader学习

复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np 

torch.manual_seed(1)
# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self):
        # 创建一些示例数据(100个样本,每个样本包含10个特征)
        self.data = torch.randn(100, 10)
        self.labels =torch.from_numpy(np.arange(100))  # 二分类标签

    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

    def __getitem__(self, idx):
        # 根据索引 idx 返回对应的样本和标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 创建数据集的实例
dataset = CustomDataset()

# 使用DataLoader加载数据
# 设置batch_size=16,shuffle=True表示打乱数据顺序
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# 迭代DataLoader
for i in range(2):
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        print(f"Batch {batch_idx+1}")
        print(f"Inputs: {inputs.size()}")  # 显示当前batch中输入数据的维度
        print(f"Labels: {labels.size()}")  # 显示当前batch中标签的维度
        print(labels)
        # 在这里你可以对数据进行训练
        # 例如:outputs = model(inputs)

只要是shuffle=True,每次epoch结果的顺序是不一样的,如果想每一次的结果是一样的

如果shuffle=False

复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np 

torch.manual_seed(1)
# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self):
        # 创建一些示例数据(100个样本,每个样本包含10个特征)
        self.data = torch.randn(100, 10)
        self.labels =torch.from_numpy(np.arange(100))  # 二分类标签

    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

    def __getitem__(self, idx):
        # 根据索引 idx 返回对应的样本和标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 创建数据集的实例
dataset = CustomDataset()

# 使用DataLoader加载数据
# 设置batch_size=16,shuffle=True表示打乱数据顺序
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# 迭代DataLoader
for i in range(2):
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        print(f"Batch {batch_idx+1}")
        print(f"Inputs: {inputs.size()}")  # 显示当前batch中输入数据的维度
        print(f"Labels: {labels.size()}")  # 显示当前batch中标签的维度
        print(labels)
        # 在这里你可以对数据进行训练
        # 例如:outputs = model(inputs)

结果如下

相关推荐
木头左7 分钟前
遗忘门参数对LSTM长期记忆保留的影响分析
人工智能·rnn·lstm
啄缘之间8 分钟前
11. UVM Test [uvm_test]
经验分享·笔记·学习·uvm·总结
serve the people12 分钟前
tensorflow 零基础吃透:RaggedTensor 的索引与切片(规则 + 示例 + 限制)
人工智能·tensorflow·neo4j
玄微云13 分钟前
选 AI 智能体开发公司?合肥玄微子科技有限公司的思路可参考
大数据·人工智能·科技·软件需求·门店管理
幂律智能14 分钟前
幂律智能CTO张惟师受邀参加山南投融汇:AI正从「工具」进化为「虚拟专家」
大数据·人工智能
javastart17 分钟前
教育行业AI落地应用:DeepSeek+智能体搭建作文批改助手
人工智能·aigc
爱笑的眼睛1118 分钟前
FastAPI 路由系统深度探索:超越基础 CRUD 的高级模式与架构实践
java·人工智能·python·ai
The Straggling Crow22 分钟前
RAGFlow 2
人工智能
RisunJan22 分钟前
【行测】类比推理-自称他称全同
学习
工藤学编程23 分钟前
零基础学AI大模型之RunnablePassthrough
人工智能