Pytorch自定义数据集的用法

答案来自 千问

在 PyTorch 中自定义数据集,核心是继承 torch.utils.data.Dataset 类,并实现必要的方法。这能让你灵活地加载任何格式的自有数据。

📦 核心组件

一个标准的自定义数据集类包含以下三个核心部分:

  1. __init__: 初始化方法。在这里加载数据的路径、标签,或读取包含这些信息的文件(如 txt 或 csv)。
  2. __len__: 返回数据集的总样本数。
  3. __getitem__ : 核心方法。根据索引 index 读取并返回一个样本(数据和标签)。

🚀 快速入门示例

这是一个最简单的例子,帮助你理解基本结构。

python 复制代码
import torch
from torch.utils.data import Dataset, DataLoader

# 1. 定义自己的数据集类,继承 Dataset
class MyDataset(Dataset):
    def __init__(self, data, labels):
        # 初始化:接收数据和标签
        self.data = data
        self.labels = labels

    def __len__(self):
        # 返回数据集大小
        return len(self.data)

    def __getitem__(self, index):
        # 根据索引获取单个样本
        sample = self.data[index]
        label = self.labels[index]
        return sample, label

# --- 使用示例 ---
# 模拟一些数据
data = torch.randn(100, 3)  # 100个样本, 每个3个特征
labels = torch.randint(0, 2, (100,))  # 100个标签

# 实例化数据集
dataset = MyDataset(data, labels)

# 使用 DataLoader 包装,以便批量读取
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# 遍历数据
for batch_data, batch_labels in dataloader:
    print(f"Batch shape: {batch_data.shape}, Labels: {batch_labels}")
    break

🖼️ 图像数据集实战

处理图像时,通常需要读取文件路径。这里展示一个读取文件夹中图片的常用模式。

python 复制代码
import os
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None):
        """
        参数:
            annotations_file (string): CSV文件路径,包含图片文件名和标签。
            img_dir (string): 图片存储的根目录。
            transform (callable, optional): 可选的预处理操作。
        """
        # 读取CSV文件,假设第一列是文件名,第二列是标签
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, index):
        # 1. 构建图片路径并读取
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[index, 0])
        image = Image.open(img_path).convert("RGB")  # 确保是三通道
        
        # 2. 获取标签
        label = self.img_labels.iloc[index, 1]

        # 3. 应用预处理(如调整大小、转为Tensor)
        if self.transform:
            image = self.transform(image)

        return image, label

# --- 使用示例 ---

# 定义图片预处理流程
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 假设你有一个 'data/' 文件夹和 'labels.csv' 文件
dataset = CustomImageDataset(
    annotations_file='labels.csv', 
    img_dir='data/', 
    transform=transform
)

# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

📁 简易文件夹结构示例

如果你的数据按类别分在不同文件夹里(例如 train/cat/001.jpg, train/dog/002.jpg),也可以这样写:

python 复制代码
import os
from torch.utils.data import Dataset
from PIL import Image

class SimpleFolderDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.img_paths = []
        self.labels = []
        
        # 遍历文件夹
        for label, class_name in enumerate(sorted(os.listdir(root_dir))):
            class_dir = os.path.join(root_dir, class_name)
            if os.path.isdir(class_dir):
                for img_name in os.listdir(class_dir):
                    self.img_paths.append(os.path.join(class_dir, img_name))
                    self.labels.append(label) # 文件夹名作为标签

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        image = Image.open(self.img_paths[index]).convert("RGB")
        label = self.labels[index]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

📌 总结

  • 继承 Dataset: 这是所有数据集的基类。
  • 实现 __getitem__: 这是最关键的一步,定义了如何读取单个样本。
  • 配合 DataLoader : 自定义好 Dataset 后,用 DataLoader 来实现批量加载、打乱顺序和多进程读取。
相关推荐
疯狂成瘾者10 分钟前
语义分块提升RAG检索精度
python
551只玄猫12 分钟前
【数学建模 matlab 实验报告12】聚类分析和判别分析
开发语言·数学建模·matlab·课程设计·聚类·实验报告
小陈工2 小时前
Python Web开发入门(十七):Vue.js与Python后端集成——让前后端真正“握手言和“
开发语言·前端·javascript·数据库·vue.js·人工智能·python
H Journey2 小时前
C++之 CMake、CMakeLists.txt、Makefile
开发语言·c++·makefile·cmake
A__tao6 小时前
Elasticsearch Mapping 一键生成 Java 实体类(支持嵌套 + 自动过滤注释)
java·python·elasticsearch
研究点啥好呢6 小时前
Github热门项目推荐 | 创建你的像素风格!
c++·python·node.js·github·开源软件
lly2024067 小时前
C 标准库 - `<stdio.h>`
开发语言
沫璃染墨7 小时前
C++ string 从入门到精通:构造、迭代器、容量接口全解析
c语言·开发语言·c++
jwn9997 小时前
Laravel6.x核心特性全解析
开发语言·php·laravel
迷藏4947 小时前
**发散创新:基于Rust实现的开源合规权限管理框架设计与实践**在现代软件架构中,**权限控制(RBAC)** 已成为保障
java·开发语言·python·rust·开源