TensorDataset简介

TensorDataset 是 PyTorch 中一个非常实用的工具类,它的核心作用是将多个张量 (Tensor)打包成一个数据集。

简单来说,它就像是一个"数据拉链",把特征 X 和标签 y 一一对应地组合在一起。

🎯 核心作用:打包与索引

当你有一堆特征数据 X 和对应的标签数据 y 时,TensorDataset 会按照以下方式工作:

  1. 打包 :将 Xy 中相同位置的数据项组合成一个新的数据单元。
  2. 索引 :当你对这个数据集使用索引 dataset[i] 时,它会返回一个元组 (X[i], y[i])

这种设计完美契合了 PyTorch 的 DataLoader,使得批量迭代数据变得非常简单。

💻 基本用法

下面的代码清晰地展示了 TensorDataset 的基本用法:

python 复制代码
import torch
from torch.utils.data import TensorDataset

# 1. 准备数据 (特征和标签)
# 假设有4个样本,每个样本有3个特征
X = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float32)
# 假设有4个样本,每个样本对应一个标签
y = torch.tensor([0, 1, 0, 1], dtype=torch.float32)

# 2. 打包成 TensorDataset
dataset = TensorDataset(X, y)

# 3. 查看打包结果
print(f"数据集长度: {len(dataset)}")  # 输出: 数据集长度: 4
print(f"索引0的数据: {dataset[0]}")   # 输出: (tensor([1., 2., 3.]), tensor(0.))
print(f"索引1的数据: {dataset[1]}")   # 输出: (tensor([4., 5., 6.]), tensor(1.))

# 可以看到,dataset[0] 返回了 (X[0], y[0])

🚀 与 DataLoader 结合使用(标准流程)

TensorDataset 最常见的搭档就是 DataLoaderDataLoader 负责批量加载和打乱数据,而 TensorDataset 负责管理数据项之间的映射关系。

python 复制代码
from torch.utils.data import DataLoader

# 1. 使用上面的 dataset
dataset = TensorDataset(X, y)

# 2. 创建 DataLoader,设置批量大小为2,并打乱数据
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 3. 迭代使用
print("--- 使用 DataLoader 迭代 ---")
for batch_X, batch_y in dataloader:
    print(f"批量特征 X:\n{batch_X}")
    print(f"批量标签 y: {batch_y}")
    print("-" * 20)

# 可能的输出 (因为 shuffle=True,顺序可能不同):
# --- 使用 DataLoader 迭代 ---
# 批量特征 X:
# tensor([[10., 11., 12.],
#         [ 4.,  5.,  6.]])
# 批量标签 y: tensor([1., 1.])
# --------------------
# 批量特征 X:
# tensor([[7., 8., 9.],
#         [1., 2., 3.]])
# 批量标签 y: tensor([0., 0.])
# --------------------

📝 关键注意事项

  1. 数据类型与形状 :所有传入 TensorDataset 的张量,在第一维(维度0)的大小必须相同。这个大小就是数据集的总样本数。上例中所有张量的第一维大小都是4。
  2. 支持多参数 :你不止可以传入 (X, y),可以传入任意数量的张量,比如 TensorDataset(X, y, w, z)。索引时也会返回对应数量的元素元组。
  3. 主要用途 :它是一个便捷的封装器,主要用于简化从张量到 DataLoader 的过程。对于更复杂的场景(如从文件读取的图片、文本数据),通常需要自定义 Dataset 类。

💎 总结

  • 是什么:一个将多个张量"拉链式"打包的工具。
  • 解决什么问题 :让你无需手动编写索引逻辑,就能方便地获取 (X[i], y[i]) 这样的配对数据。
  • 核心用法TensorDataset(X, y) -> DataLoader(dataset, batch_size=...)
相关推荐
毕胜客源码3 小时前
卷积神经网络的农作物识别系统(有技术文档)深度学习 图像识别 卷积神经网络 Django python 人工智能
人工智能·python·深度学习·cnn·django
小鱼~~4 小时前
GRU模型简介
人工智能·深度学习
小鱼~~4 小时前
DataLoader简介
人工智能·深度学习
多年小白4 小时前
谷歌第八代 TPU 来了:性能提升 124%
网络·人工智能·科技·深度学习·ai
AI木马人5 小时前
1.【AI系统架构设计】如何设计一个高效、安全的人性化AI工具系统?(从0到1完整方案)
人工智能·深度学习·神经网络·计算机视觉·自然语言处理
多年小白7 小时前
AI 日报 - 2026年4月25日(周六)
网络·人工智能·科技·深度学习·ai
深度学习lover7 小时前
<数据集>yolo 垃圾识别<目标检测>
人工智能·深度学习·yolo·目标检测·计算机视觉·垃圾识别
LaughingZhu7 小时前
Product Hunt 每日热榜 | 2026-04-25
人工智能·经验分享·深度学习·神经网络·产品运营
AI木马人8 小时前
2.【多模型接入架构】如何同时接入GPT、Gemini、Claude并统一管理?(完整实现方案)
人工智能·gpt·深度学习·神经网络·自然语言处理