答案来自 千问
在 PyTorch 中自定义数据集,核心是继承 torch.utils.data.Dataset 类,并实现必要的方法。这能让你灵活地加载任何格式的自有数据。
📦 核心组件
一个标准的自定义数据集类包含以下三个核心部分:
__init__: 初始化方法。在这里加载数据的路径、标签,或读取包含这些信息的文件(如 txt 或 csv)。__len__: 返回数据集的总样本数。__getitem__: 核心方法。根据索引index读取并返回一个样本(数据和标签)。
🚀 快速入门示例
这是一个最简单的例子,帮助你理解基本结构。
python
import torch
from torch.utils.data import Dataset, DataLoader
# 1. 定义自己的数据集类,继承 Dataset
class MyDataset(Dataset):
def __init__(self, data, labels):
# 初始化:接收数据和标签
self.data = data
self.labels = labels
def __len__(self):
# 返回数据集大小
return len(self.data)
def __getitem__(self, index):
# 根据索引获取单个样本
sample = self.data[index]
label = self.labels[index]
return sample, label
# --- 使用示例 ---
# 模拟一些数据
data = torch.randn(100, 3) # 100个样本, 每个3个特征
labels = torch.randint(0, 2, (100,)) # 100个标签
# 实例化数据集
dataset = MyDataset(data, labels)
# 使用 DataLoader 包装,以便批量读取
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
# 遍历数据
for batch_data, batch_labels in dataloader:
print(f"Batch shape: {batch_data.shape}, Labels: {batch_labels}")
break
🖼️ 图像数据集实战
处理图像时,通常需要读取文件路径。这里展示一个读取文件夹中图片的常用模式。
python
import os
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None):
"""
参数:
annotations_file (string): CSV文件路径,包含图片文件名和标签。
img_dir (string): 图片存储的根目录。
transform (callable, optional): 可选的预处理操作。
"""
# 读取CSV文件,假设第一列是文件名,第二列是标签
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, index):
# 1. 构建图片路径并读取
img_path = os.path.join(self.img_dir, self.img_labels.iloc[index, 0])
image = Image.open(img_path).convert("RGB") # 确保是三通道
# 2. 获取标签
label = self.img_labels.iloc[index, 1]
# 3. 应用预处理(如调整大小、转为Tensor)
if self.transform:
image = self.transform(image)
return image, label
# --- 使用示例 ---
# 定义图片预处理流程
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# 假设你有一个 'data/' 文件夹和 'labels.csv' 文件
dataset = CustomImageDataset(
annotations_file='labels.csv',
img_dir='data/',
transform=transform
)
# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
📁 简易文件夹结构示例
如果你的数据按类别分在不同文件夹里(例如 train/cat/001.jpg, train/dog/002.jpg),也可以这样写:
python
import os
from torch.utils.data import Dataset
from PIL import Image
class SimpleFolderDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.img_paths = []
self.labels = []
# 遍历文件夹
for label, class_name in enumerate(sorted(os.listdir(root_dir))):
class_dir = os.path.join(root_dir, class_name)
if os.path.isdir(class_dir):
for img_name in os.listdir(class_dir):
self.img_paths.append(os.path.join(class_dir, img_name))
self.labels.append(label) # 文件夹名作为标签
def __len__(self):
return len(self.img_paths)
def __getitem__(self, index):
image = Image.open(self.img_paths[index]).convert("RGB")
label = self.labels[index]
if self.transform:
image = self.transform(image)
return image, label
📌 总结
- 继承
Dataset: 这是所有数据集的基类。 - 实现
__getitem__: 这是最关键的一步,定义了如何读取单个样本。 - 配合
DataLoader: 自定义好Dataset后,用DataLoader来实现批量加载、打乱顺序和多进程读取。