TensorDataset 是 PyTorch 中一个非常实用的工具类,它的核心作用是将多个张量 (Tensor)打包成一个数据集。
简单来说,它就像是一个"数据拉链",把特征 X 和标签 y 一一对应地组合在一起。
🎯 核心作用:打包与索引
当你有一堆特征数据 X 和对应的标签数据 y 时,TensorDataset 会按照以下方式工作:
- 打包 :将
X和y中相同位置的数据项组合成一个新的数据单元。 - 索引 :当你对这个数据集使用索引
dataset[i]时,它会返回一个元组(X[i], y[i])。
这种设计完美契合了 PyTorch 的 DataLoader,使得批量迭代数据变得非常简单。
💻 基本用法
下面的代码清晰地展示了 TensorDataset 的基本用法:
python
import torch
from torch.utils.data import TensorDataset
# 1. 准备数据 (特征和标签)
# 假设有4个样本,每个样本有3个特征
X = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float32)
# 假设有4个样本,每个样本对应一个标签
y = torch.tensor([0, 1, 0, 1], dtype=torch.float32)
# 2. 打包成 TensorDataset
dataset = TensorDataset(X, y)
# 3. 查看打包结果
print(f"数据集长度: {len(dataset)}") # 输出: 数据集长度: 4
print(f"索引0的数据: {dataset[0]}") # 输出: (tensor([1., 2., 3.]), tensor(0.))
print(f"索引1的数据: {dataset[1]}") # 输出: (tensor([4., 5., 6.]), tensor(1.))
# 可以看到,dataset[0] 返回了 (X[0], y[0])
🚀 与 DataLoader 结合使用(标准流程)
TensorDataset 最常见的搭档就是 DataLoader。DataLoader 负责批量加载和打乱数据,而 TensorDataset 负责管理数据项之间的映射关系。
python
from torch.utils.data import DataLoader
# 1. 使用上面的 dataset
dataset = TensorDataset(X, y)
# 2. 创建 DataLoader,设置批量大小为2,并打乱数据
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 3. 迭代使用
print("--- 使用 DataLoader 迭代 ---")
for batch_X, batch_y in dataloader:
print(f"批量特征 X:\n{batch_X}")
print(f"批量标签 y: {batch_y}")
print("-" * 20)
# 可能的输出 (因为 shuffle=True,顺序可能不同):
# --- 使用 DataLoader 迭代 ---
# 批量特征 X:
# tensor([[10., 11., 12.],
# [ 4., 5., 6.]])
# 批量标签 y: tensor([1., 1.])
# --------------------
# 批量特征 X:
# tensor([[7., 8., 9.],
# [1., 2., 3.]])
# 批量标签 y: tensor([0., 0.])
# --------------------
📝 关键注意事项
- 数据类型与形状 :所有传入
TensorDataset的张量,在第一维(维度0)的大小必须相同。这个大小就是数据集的总样本数。上例中所有张量的第一维大小都是4。 - 支持多参数 :你不止可以传入
(X, y),可以传入任意数量的张量,比如TensorDataset(X, y, w, z)。索引时也会返回对应数量的元素元组。 - 主要用途 :它是一个便捷的封装器,主要用于简化从张量到
DataLoader的过程。对于更复杂的场景(如从文件读取的图片、文本数据),通常需要自定义Dataset类。
💎 总结
- 是什么:一个将多个张量"拉链式"打包的工具。
- 解决什么问题 :让你无需手动编写索引逻辑,就能方便地获取
(X[i], y[i])这样的配对数据。 - 核心用法 :
TensorDataset(X, y)->DataLoader(dataset, batch_size=...)。