任务2.1.1 深度学习实战 - 数据加载

文章目录

  • [1. 实战概述](#1. 实战概述)
  • [2. 实战步骤](#2. 实战步骤)
    • [2.1 加载自带数据](#2.1 加载自带数据)
      • [2.1.1 查看MNIST数据集](#2.1.1 查看MNIST数据集)
      • [2.1.2 加载MNIST数据集](#2.1.2 加载MNIST数据集)
      • [2.1.3 转换成Dataset对象](#2.1.3 转换成Dataset对象)
    • [2.2 加载外部文件](#2.2 加载外部文件)
      • [2.2.1 加载CSV文件](#2.2.1 加载CSV文件)
      • [2.2.2 加载TFRecord文件](#2.2.2 加载TFRecord文件)
      • [2.2.3 加载文本文件](#2.2.3 加载文本文件)
      • [2.2.4 加载文件集](#2.2.4 加载文件集)
  • [3. 实战总结](#3. 实战总结)

1. 实战概述

  • 本次实战系统演示了 TensorFlow 中多种数据集的加载与处理方法。内容涵盖从 Keras 内置的 MNIST 数据集,到本地的 CSV、TFRecord、文本文件,以及复杂的图片文件夹结构。重点讲解了如何利用 tf.data.Dataset API 将原始数据(如路径、标签)转化为高效的流式输入管道,为模型训练奠定了数据基础。

2. 实战步骤

2.1 加载自带数据

  • MNIST数据集:MNIST(改进版美国国家标准与技术研究院数据库)是机器学习领域的"Hello World"级基准数据集,包含70,000张28×28像素的手写数字灰度图像(0-9)。其中60,000张为训练集,10,000张为测试集。它广泛用于验证算法性能及深度学习入门教学。

2.1.1 查看MNIST数据集

  • C:\Users\Administrator\.keras\datasets\mnist.npz
  • 展示了Keras框架的本地数据集存储目录,路径位于 C:\Users\Administrator\.keras\datasets。文件夹中包含一个名为 mnist.npz 的文件(约11MB),这是著名的MNIST手写数字数据集的压缩格式。当代码调用 keras.datasets.mnist.load_data() 时,系统会优先在此处查找,若文件存在则直接加载,无需重新从网络下载,从而提升效率。

2.1.2 加载MNIST数据集

  • 执行代码

    python 复制代码
    import tensorflow as tf
    import pandas as pd
    from tensorflow.keras import datasets # 导入经典数据集
    		
    # 加载MNIST数据集
    (x, y), (x_test, y_test) = datasets.mnist.load_data()
    print(f'x: {x.shape}, y: {y.shape}, x_test: {x_test.shape}, y_test: {y_test.shape}')
  • 代码说明 :代码首先导入 TensorFlow、Pandas 及数据集模块,随即调用 datasets.mnist.load_data() 加载 MNIST 数据集并自动划分为训练集与测试集。打印结果显示,成功获取了 60,000 张训练图像与 10,000 张测试图像(尺寸均为 28x28),完成了数据准备。

2.1.3 转换成Dataset对象

  • 执行代码

    python 复制代码
    train = tf.data.Dataset.from_tensor_slices((x, y))
    test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
  • 代码说明 :这两行代码利用 tf.data.Dataset.from_tensor_slices 将 NumPy 数组形式的训练数据 (x, y) 和测试数据 (x_test, y_test) 封装成了 TensorFlow 的 Dataset 对象。这一步构建了高效的输入管道,将数据切片并转化为模型可直接迭代读取的格式,为后续的批量训练与评估做好了准备,而下方的日志仅为 CPU 指令集优化的提示信息。

2.2 加载外部文件

2.2.1 加载CSV文件

  • 泰坦尼克号乘客数据集[titanic_file.csv]:泰坦尼克号乘客数据集记录了1912年沉船事故中乘客的生存情况及相关信息,包含乘客ID、船舱等级、姓名、性别、年龄、同行家属数、船票价格、登船港口等特征。该数据集常用于机器学习分类任务,预测乘客是否生还,是数据科学入门的经典案例。
  1. 使用Pandas库中的read_csv函数加载数据

    • 执行代码

      python 复制代码
      # 1. 使用 pandas 读取 CSV 文件
      # 注意:请确保文件路径 'e:/csvfiles/titanic_file.csv' 在你的电脑上真实存在
      titanic_file = pd.read_csv('e:/csvfiles/titanic_file.csv')
      
      # 2. 将 DataFrame 转换为 TensorFlow Dataset
      # dict(titanic_file) 将 pandas DataFrame 转换为字典形式,键为列名,值为列数据
      titanic_slices = tf.data.Dataset.from_tensor_slices(dict(titanic_file))
      
      # 3. 遍历数据集
      # take(1) 表示只取出一个批次(即第一行数据)进行演示
      for feature_batch in titanic_slices.take(1):
          # feature_batch 是一个字典,包含当前行所有特征的键值对
          for key, value in feature_batch.items():
              # {!r:20s} 表示将键转换为 repr 字符串并占用 20 个字符宽度
              # {} 表示直接打印值
              print('{!r:20s}: {}'.format(key, value))
    • 代码说明 :这段代码演示了如何用 Pandas 读取 Titanic CSV 数据,并通过 tf.data.Dataset.from_tensor_slices 转换为 TensorFlow 数据集。随后,利用 take(1) 取出首个样本,遍历并打印其所有特征(如性别、年龄、票价等)及生存标签,验证了数据从 Pandas DataFrame 到 TF Dataset 的转换与单样本访问流程。

  2. 使用make_csv_dataset函数加载数据

    • 执行代码

      python 复制代码
      # 1. 构建 CSV 数据集
      # file_pattern: CSV 文件路径
      # batch_size: 每个批次包含的样本数量 (这里设为 4)
      # label_name: 指定哪一列作为标签 (这里指定 'survived' 列)
      # 该函数会自动推断列的数据类型,并返回一个包含 (特征, 标签) 的数据集
      titanic_batches = tf.data.experimental.make_csv_dataset(
          'e:/csvfiles/titanic_file.csv', 
          batch_size=4, 
          label_name='survived'
      )
      
      # 2. 遍历数据集
      # take(1) 表示只从数据集中取出一个批次(batch)进行演示
      for feature_batch, label_batch in titanic_batches.take(1):    
          # 打印标签数据 (即 survived 列的值,形状为 [4])
          print('survived: {}'.format(label_batch))    
          print('features:')    
          # 3. 遍历特征字典
          # feature_batch 是一个 OrderedDict,键是列名,值是该列的数据张量
          for key, value in feature_batch.items():  # 修正了原代码中的拼写错误 feature_batches
              # 打印列名(键)和对应的数据(值)
              # {!r:20s} 格式化输出:使用 repr 形式,占位 20 字符宽度
              print('{!r:20s}: {}'.format(key, value))
    • 代码说明 :这段代码演示了如何使用 TensorFlow 的 make_csv_dataset 从 CSV 文件构建数据集。它成功读取了一个批次(4 个样本)的泰坦尼克号数据,将 survived 列作为标签,其余列为特征。输出显示了每个特征列的名称及其对应的张量值,包括字符串和数值类型,验证了数据加载和预处理流程的正确性。

2.2.2 加载TFRecord文件

  • FSNS概述:FSNS(Google Street View 法语签名数据集)是用于文本识别的公开数据集,包含 100 万张法国街道名称图像。每张图像通常包含四个视角,标注了对应的路名文本(如"Rue Perreyon")。该数据集广泛用于训练 OCR 模型,解决街景图片中的序列识别问题。
  1. 加载数据

    • 执行代码

      python 复制代码
      dataset = tf.data.TFRecordDataset(filenames =['d:/data/fsns.tfrec'])
      print(dataset)
    • 代码说明 :这段代码使用 tf.data.TFRecordDataset 从指定路径(d:/data/fsns.tfrec)加载 TFRecord 文件,创建了一个 TensorFlow 数据集对象。打印结果显示该数据集的元素类型为标量字符串(TensorSpec(shape=(), dtype=tf.string)),表明数据以序列化的字符串形式存储,后续需要解析才能获取具体特征。

  2. 解码数据

    • 执行代码

      python 复制代码
      raw_example = next(iter(dataset))
      parsed = tf.train.Example.FromString(raw_example.numpy())
      print(parsed.features.feature['image/text'])
    • 代码说明 :这段代码从 dataset 中提取了一条原始 TFRecord 样本,并将其从序列化字符串反序列化为 Example 对象。随后,它访问并打印了其中名为 'image/text' 的特征。输出结果显示该样本的文本标签为 "Rue Perreyon",这通常对应于 FSNS 数据集中街景图像上的路名文本。

2.2.3 加载文本文件

  1. 查看cowper.txt文件

  2. 加载cowper.txt文件

    • 执行代码

      python 复制代码
      cowper = tf.data.TextLineDataset('d:/data/cowper.txt')
      for line in cowper.take(5):
          print(line.numpy())
    • 代码说明 :这段代码利用 TensorFlow 的 TextLineDataset 接口读取指定路径下的文本文件(cowper.txt),构建了一个按行存储的数据集。随后,代码调用 take(5) 方法从中截取前 5 行数据,并通过循环将字节流解码打印。输出结果展示了《伊利亚特》译本的前五行诗句,最后的日志提示数据序列已读取完毕。

2.2.4 加载文件集

  1. 花卉数据集

    • flower_photos 是由 Google 发布的经典花卉图像数据集,包含约 3,670 张高分辨率图片,分为雏菊、蒲公英、玫瑰、向日葵和郁金香 5 个类别。该数据集结构规范,广泛用于图像分类、迁移学习及深度学习入门教学。
    • 641张玫瑰图像
    • 699张向日葵图像
  2. 获取图片路径并查看

    • 执行代码

      python 复制代码
      import random
      import pathlib
      
      data_path = pathlib.Path('d:/data/flower_photos')
      all_image_paths = list(data_path.glob('*/*'))
      all_image_paths = [str(path) for path in all_image_paths]
      random.shuffle(all_image_paths) # 打散数据
      
      image_count = len(all_image_paths)
      print(f'图片总数:{image_count}')
      print('10张图片路径:')
      for path in all_image_paths[:10]:
          print(path)
    • 代码说明 :这段代码利用 pathlib 库读取 d:/data/flower_photos 目录下的所有图像文件路径,并将其转换为字符串列表。随后,代码使用 random.shuffle 对数据进行随机打乱,以消除原始数据的顺序偏差。最后,它统计并打印了数据集的图片总数(3670张),并展示了打乱后的前10张图片的具体路径,验证了数据加载与洗牌的效果。

  3. 图片与标签对应

    • 执行代码

      python 复制代码
      # 提取分类名
      label_names = sorted(item.name for item in data_path.glob('*/') if item.is_dir())
      print(f'分类名:{label_names}')
      
      # 创建标签
      label_to_index = dict((name, index) for index, name in enumerate(label_names))
      print(f'标签:{label_to_index}')
      
      # 将图片与标签对应
      all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]
      for image, label in zip(all_image_paths[:10], all_image_labels[:10]):
          label_name = label_names[label]
          print(f'{image} ---> {label} : {label_name}')
    • 代码说明 :这段代码主要完成了图像数据的标签构建与验证。首先从数据目录中提取出5个花卉分类名称(如 daisy、roses 等),并按字母顺序排序后建立名称到数字索引的映射字典。接着遍历前10张图片路径,通过读取图片所在的父文件夹名称,从字典中查找对应的数字标签。最终打印图片路径、数字ID及分类名称,验证了数据与标签的准确对应关系。

  4. 将加载后的图片转换为Dataset对象

    • 执行代码:ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))
    • 代码说明 :这行代码使用 tf.data.Dataset.from_tensor_slices 创建了一个 TensorFlow 数据集对象 ds。它将图片路径列表 all_image_paths 和对应的标签列表 all_image_labels 打包在一起,构建出"(图片路径, 标签)"的数据对。这一步将原始数据转换为 TensorFlow 可高效处理的流式数据集格式,为后续的批量读取、打乱顺序及模型训练做好了准备。

3. 实战总结

  • 本次实战深入剖析了 TensorFlow 的数据处理全流程。首先,我们掌握了加载内置数据(MNIST)的便捷方式;其次,针对结构化数据(CSV),学会了使用 Pandas 预处理或直接构建批次;对于非结构化数据,我们实践了读取文本行(TextLineDataset)和解析二进制记录(TFRecord)的技术。最后,实战的核心在于处理图像分类任务。我们通过 pathlib 遍历文件夹获取路径,构建了路径与标签(数字索引)的映射关系,并最终利用 tf.data.Dataset.from_tensor_slices 将离散的路径和标签整合为统一的 Dataset 对象。这一过程不仅实现了数据的规范化,更构建了高效、可迭代的输入流水线,完美适配了后续的深度学习模型训练需求。
相关推荐
Narrastory2 小时前
Note:强化学习(五)
人工智能·深度学习·强化学习
测绘第一深情2 小时前
租用GPU云服务器进行深度学习(AutoDL,超保姆级,适用新手)
数据结构·人工智能·经验分享·python·深度学习·算法·计算机视觉
蔡俊锋2 小时前
AI代理落地指南:从Demo到生产级的实战攻略
人工智能·深度学习·hermes·ai团队知识沉淀
MF_AI2 小时前
YOLO无人机视角检测识别数据集汇总
人工智能·深度学习·yolo·目标检测·视觉检测·无人机
白云千载尽3 小时前
神经网络的闭环运行链条
人工智能·深度学习·神经网络
逻辑驱动的ken3 小时前
Java高频面试考点场景题15
java·开发语言·深度学习·面试·职场和发展·高效学习
学弟3 小时前
【内涵】transformer之位置编码
人工智能·深度学习·transformer
数智工坊3 小时前
【DAB-DETR论文阅读】:动态锚框作为更优查询,彻底解决DETR训练收敛慢难题
网络·论文阅读·人工智能·深度学习·cnn