TensorDataset简介

TensorDataset 是 PyTorch 中一个非常实用的工具类,它的核心作用是将多个张量 (Tensor)打包成一个数据集。

简单来说,它就像是一个"数据拉链",把特征 X 和标签 y 一一对应地组合在一起。

🎯 核心作用:打包与索引

当你有一堆特征数据 X 和对应的标签数据 y 时,TensorDataset 会按照以下方式工作:

  1. 打包 :将 Xy 中相同位置的数据项组合成一个新的数据单元。
  2. 索引 :当你对这个数据集使用索引 dataset[i] 时,它会返回一个元组 (X[i], y[i])

这种设计完美契合了 PyTorch 的 DataLoader,使得批量迭代数据变得非常简单。

💻 基本用法

下面的代码清晰地展示了 TensorDataset 的基本用法:

python 复制代码
import torch
from torch.utils.data import TensorDataset

# 1. 准备数据 (特征和标签)
# 假设有4个样本,每个样本有3个特征
X = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float32)
# 假设有4个样本,每个样本对应一个标签
y = torch.tensor([0, 1, 0, 1], dtype=torch.float32)

# 2. 打包成 TensorDataset
dataset = TensorDataset(X, y)

# 3. 查看打包结果
print(f"数据集长度: {len(dataset)}")  # 输出: 数据集长度: 4
print(f"索引0的数据: {dataset[0]}")   # 输出: (tensor([1., 2., 3.]), tensor(0.))
print(f"索引1的数据: {dataset[1]}")   # 输出: (tensor([4., 5., 6.]), tensor(1.))

# 可以看到,dataset[0] 返回了 (X[0], y[0])

🚀 与 DataLoader 结合使用(标准流程)

TensorDataset 最常见的搭档就是 DataLoaderDataLoader 负责批量加载和打乱数据,而 TensorDataset 负责管理数据项之间的映射关系。

python 复制代码
from torch.utils.data import DataLoader

# 1. 使用上面的 dataset
dataset = TensorDataset(X, y)

# 2. 创建 DataLoader,设置批量大小为2,并打乱数据
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 3. 迭代使用
print("--- 使用 DataLoader 迭代 ---")
for batch_X, batch_y in dataloader:
    print(f"批量特征 X:\n{batch_X}")
    print(f"批量标签 y: {batch_y}")
    print("-" * 20)

# 可能的输出 (因为 shuffle=True,顺序可能不同):
# --- 使用 DataLoader 迭代 ---
# 批量特征 X:
# tensor([[10., 11., 12.],
#         [ 4.,  5.,  6.]])
# 批量标签 y: tensor([1., 1.])
# --------------------
# 批量特征 X:
# tensor([[7., 8., 9.],
#         [1., 2., 3.]])
# 批量标签 y: tensor([0., 0.])
# --------------------

📝 关键注意事项

  1. 数据类型与形状 :所有传入 TensorDataset 的张量,在第一维(维度0)的大小必须相同。这个大小就是数据集的总样本数。上例中所有张量的第一维大小都是4。
  2. 支持多参数 :你不止可以传入 (X, y),可以传入任意数量的张量,比如 TensorDataset(X, y, w, z)。索引时也会返回对应数量的元素元组。
  3. 主要用途 :它是一个便捷的封装器,主要用于简化从张量到 DataLoader 的过程。对于更复杂的场景(如从文件读取的图片、文本数据),通常需要自定义 Dataset 类。

💎 总结

  • 是什么:一个将多个张量"拉链式"打包的工具。
  • 解决什么问题 :让你无需手动编写索引逻辑,就能方便地获取 (X[i], y[i]) 这样的配对数据。
  • 核心用法TensorDataset(X, y) -> DataLoader(dataset, batch_size=...)
相关推荐
青春不败 177-3266-052015 分钟前
MATLAB 2024b深度学习新特性全面解析与DeepSeek大模型集成开发
人工智能·深度学习·机器学习·matlab·卷积神经网络·自编码器·deepseek
初心未改HD1 小时前
深度学习之激活函数详解
人工智能·深度学习
爱写代码的小朋友1 小时前
人工智能背景下深度学习在高中信息技术教育中的应用研究
人工智能·深度学习
knight_9___1 小时前
大模型project面试5
人工智能·python·深度学习·面试·agent·rag·mcp
东方佑1 小时前
OpenASH 85M 参数,不用 Softmax,也能通过中文考试
人工智能·深度学习
大江东去浪淘尽千古风流人物2 小时前
【SANA-WM】分钟级世界模型:混合线性扩散Transformer与双分支相机控制深度解析
人工智能·深度学习·架构·spark·机器人·transformer·wm
cskywit4 小时前
【BIBM2025】 MedMamba-YOLO:医疗目标检测,当 YOLO 遇见轻量级 Mamba
深度学习·yolo·目标检测
AI技术控4 小时前
Prompt Engineering 在企业大模型应用中的实践:从提示词模板到可控输出
人工智能·python·深度学习·语言模型·自然语言处理·prompt
手写码匠4 小时前
手写 AI Prompt Injection 防护系统:从零实现 LLM 安全边界
人工智能·深度学习·算法·aigc
小何code5 小时前
人工智能【第31篇】生成对抗网络GAN入门:AI的创造力之源
深度学习·生成对抗网络·gan·图像生成