Pytorch自定义数据集的用法

答案来自 千问

在 PyTorch 中自定义数据集,核心是继承 torch.utils.data.Dataset 类,并实现必要的方法。这能让你灵活地加载任何格式的自有数据。

📦 核心组件

一个标准的自定义数据集类包含以下三个核心部分:

  1. __init__: 初始化方法。在这里加载数据的路径、标签,或读取包含这些信息的文件(如 txt 或 csv)。
  2. __len__: 返回数据集的总样本数。
  3. __getitem__ : 核心方法。根据索引 index 读取并返回一个样本(数据和标签)。

🚀 快速入门示例

这是一个最简单的例子,帮助你理解基本结构。

python 复制代码
import torch
from torch.utils.data import Dataset, DataLoader

# 1. 定义自己的数据集类,继承 Dataset
class MyDataset(Dataset):
    def __init__(self, data, labels):
        # 初始化:接收数据和标签
        self.data = data
        self.labels = labels

    def __len__(self):
        # 返回数据集大小
        return len(self.data)

    def __getitem__(self, index):
        # 根据索引获取单个样本
        sample = self.data[index]
        label = self.labels[index]
        return sample, label

# --- 使用示例 ---
# 模拟一些数据
data = torch.randn(100, 3)  # 100个样本, 每个3个特征
labels = torch.randint(0, 2, (100,))  # 100个标签

# 实例化数据集
dataset = MyDataset(data, labels)

# 使用 DataLoader 包装,以便批量读取
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# 遍历数据
for batch_data, batch_labels in dataloader:
    print(f"Batch shape: {batch_data.shape}, Labels: {batch_labels}")
    break

🖼️ 图像数据集实战

处理图像时,通常需要读取文件路径。这里展示一个读取文件夹中图片的常用模式。

python 复制代码
import os
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None):
        """
        参数:
            annotations_file (string): CSV文件路径,包含图片文件名和标签。
            img_dir (string): 图片存储的根目录。
            transform (callable, optional): 可选的预处理操作。
        """
        # 读取CSV文件,假设第一列是文件名,第二列是标签
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, index):
        # 1. 构建图片路径并读取
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[index, 0])
        image = Image.open(img_path).convert("RGB")  # 确保是三通道
        
        # 2. 获取标签
        label = self.img_labels.iloc[index, 1]

        # 3. 应用预处理(如调整大小、转为Tensor)
        if self.transform:
            image = self.transform(image)

        return image, label

# --- 使用示例 ---

# 定义图片预处理流程
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 假设你有一个 'data/' 文件夹和 'labels.csv' 文件
dataset = CustomImageDataset(
    annotations_file='labels.csv', 
    img_dir='data/', 
    transform=transform
)

# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

📁 简易文件夹结构示例

如果你的数据按类别分在不同文件夹里(例如 train/cat/001.jpg, train/dog/002.jpg),也可以这样写:

python 复制代码
import os
from torch.utils.data import Dataset
from PIL import Image

class SimpleFolderDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.img_paths = []
        self.labels = []
        
        # 遍历文件夹
        for label, class_name in enumerate(sorted(os.listdir(root_dir))):
            class_dir = os.path.join(root_dir, class_name)
            if os.path.isdir(class_dir):
                for img_name in os.listdir(class_dir):
                    self.img_paths.append(os.path.join(class_dir, img_name))
                    self.labels.append(label) # 文件夹名作为标签

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        image = Image.open(self.img_paths[index]).convert("RGB")
        label = self.labels[index]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

📌 总结

  • 继承 Dataset: 这是所有数据集的基类。
  • 实现 __getitem__: 这是最关键的一步,定义了如何读取单个样本。
  • 配合 DataLoader : 自定义好 Dataset 后,用 DataLoader 来实现批量加载、打乱顺序和多进程读取。
相关推荐
OnYoung1 小时前
设计模式在C++中的实现
开发语言·c++·算法
曹牧1 小时前
Java:代理转发配置Nginx
java·开发语言·nginx
foundbug9991 小时前
利用MATLAB计算梁单元刚度矩阵并组装成总体刚度矩阵
开发语言·matlab·矩阵
Aurora@Hui1 小时前
GSAP (GreenSock Animation Platform)
人工智能·python
码农水水2 小时前
小红书Java面试被问:mTLS(双向TLS)的证书验证和握手过程
java·开发语言·数据库·redis·python·面试·开源
zmzb01032 小时前
C++课后习题训练记录Day85
开发语言·c++·算法
梵刹古音2 小时前
【C语言】 整型变量
c语言·开发语言
工程师老罗2 小时前
Python中__call__和__init__的区别
开发语言·pytorch·python
dyyx1112 小时前
Python GUI开发:Tkinter入门教程
jvm·数据库·python