torch.utils.data
torch.utils.data是PyTorch数据处理核心模块,包含Dataset、DataLoader等关键组件。
torch.utils.data模块为PyTorch提供了完整的数据处理流水线,主要包含以下核心组件
核心组件
Dataset - 数据容器抽象类
- 必须实现 len() 返回数据集大小
- 必须实现 getitem() 按索引获取样本
- 支持自定义数据源和预处理逻辑
DataLoader - 数据加载迭代器
- batch_size: 批次大小
- shuffle: 是否随机打乱
- num_workers: 并行加载进程数
- drop_last: 是否舍弃最后不足批次的数据
Sampler类 - 数据采样策略
SequentialSampler: 顺序采样
RandomSampler: 随机采样
WeightedRandomSampler: 加权随机采样
实用工具函数
随机分割 - random_split()
clike
# 将数据集按比例分割
train_set, val_set = random_split(dataset, [800, 200])
数据拼接 - ConcatDataset()
clike
# 合并多个数据集
combined = ConcatDataset([set1, set2])
📊 实战示例
clike
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
class CustomDataset(Dataset):
def __init__(self, csv_file, transform=None):
self.df = pd.read_csv(csv_file)
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
sample = self.df.iloc[idx].SalePrice
if self.transform:
sample = self.transform(sample)
return sample
多进程加速
clike
# 启用4个worker进程并行加载
dataloader = DataLoader(dataset, batch_size=32,
shuffle=True, num_workers=4)
内存优化
- 使用惰性加载处理大数据集
- 启用 pin_memory=True 加速GPU数据传输
- 合理设置 prefetch_factor 预取数据
数据预处理

如何实现自定义数据集的预处理?
PyTorch中自定义数据集的预处理主要通过继承Dataset类和组合transforms模块来实现。
核心实现框架
自定义数据集预处理需要实现三个关键部分:Dataset类定义、transforms预处理管道、DataLoader数据加载。
基础Dataset类实现
clike
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
class CustomDataset(Dataset):
def __init__(self, csv_file, feature_cols, label_col, transform=None):
self.data = pd.read_csv(csv_file)
self.features = self.data[feature_cols]
self.labels = self.data[label_col]
self.transform = transform
self._preprocess_data()
def _preprocess_data(self):
"""数据预处理:处理缺失值和类型转换"""
# 数值型特征:填充均值
numeric_features = self.features.select_dtypes(include=[np.number])
self.features[numeric_features.columns] = numeric_features.fillna(numeric_features.mean())
# 类别型特征:转为one-hot编码
categorical_features = self.features.select_dtypes(include=['object'])
if not categorical_features.empty:
self.features = pd.get_dummies(self.features, columns=categorical_features.columns)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = {
'features': torch.tensor(self.features.iloc[idx].values, dtype=torch.float32),
'label': torch.tensor(self.labels.iloc[idx], dtype=torch.long)
}
if self.transform:
sample = self.transform(sample)
return sample
自定义预处理转换类
clike
import torchvision.transforms as transforms
class Normalization:
"""数据归一化处理"""
def __call__(self, sample):
features = sample['features']
# 最大最小归一化
features = (features - features.min()) / (features.max() - features.min())
sample['features'] = features
return sample
class FeatureScaling:
"""特征标准化"""
def __call__(self, sample):
features = sample['features']
features = (features - features.mean()) / features.std()
sample['features'] = features
return sample
class ToTensor:
"""确保数据转为Tensor格式"""
def __call__(self, sample):
if not isinstance(sample['features'], torch.Tensor):
sample['features'] = torch.tensor(sample['features'], dtype=torch.float32)
return sample
完整的预处理管道
clike
# 构建预处理管道
from torchvision.transforms import Compose
# 训练集预处理(包含数据增强)
train_transform = Compose([
Normalization(),
FeatureScaling(),
ToTensor()
])
# 测试集预处理(不包含数据增强)
test_transform = Compose([
Normalization(),
ToTensor()
])
# 创建数据集实例
train_dataset = CustomDataset(
csv_file='train.csv',
feature_cols=['feature1', 'feature2', 'feature3'],
label_col='label',
transform=train_transform
)
# 创建DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
图像数据预处理示例
对于图像数据,PyTorch提供了更丰富的预处理工具:
clike
from torchvision import transforms
# 图像预处理管道
image_transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整尺寸
transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转
transforms.RandomRotation(10), # 随机旋转
transforms.ToTensor(), # 转为Tensor并归一化到[0,1]
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet标准化
])
常用的数据集
常用数据集分类与介绍
-
图像分类
- CIFAR-10:6万张32×32彩色图像,涵盖飞机、汽车等10类,每类6000张。
- ImageNet:超1400万标注图像,含2.2万类别,是计算机视觉基准数据集。
- COCO:30万+图像,含200万+对象实例,支持80类目标检测任务。
-
自然语言处理
- IMDB评论:5万条电影评论,用于情感分析。
- Wikipedia:开放文本数据,支持NLP研究。
-
其他领域
- COVID-19 X-Ray Dataset:6500张胸部X光片,含COVID-19病例标注。
- MovieLens:2000万条电影评分数据,用于推荐系统。
MNIST数据集介绍
MNIST是包含6万张训练图像和1万张测试图像的手写数字数据集,每张为28×28像素的灰度图,标签为0-9的整数。数据需归一化到[0,1]范围,标签转为one-hot向量格式。作为深度学习入门基准,其简单性使其成为验证模型效果的理想选择。