tensorflow 图像分类 之二

tensorflow 图像分类 之二

图像分类代码

API 说明

tf.keras.utils.get_file

tf.keras.utils.image_dataset_from_directory

上接 《tensorflow 图像分类 之一》

tf.data.Dataset

tf.data.Dataset API 支持编写描述性强,且高效的输入管道。Dataset 的使用遵循以下常见模式:

  • 从输入数据,创建源数据集。
  • 使用数据集变换预处理数据。
  • 遍历数据集,并处理其中的元素。

遍历以流式方式进行,因此无需将整个数据集加载到内存中。

产生数据集

tf.data.Dataset.from_tensor_slices

使用Python list创建数据集

复制代码
dataset = tf.data.Dataset.from_tensor_slices([1,  2,  3, 9,10])

产生,显示数据集的python代码

复制代码
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([1,  2,  3, 9,10])
for element in dataset:
    print(element)

屏幕显示

复制代码
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(9, shape=(), dtype=int32)
tf.Tensor(10, shape=(), dtype=int32)
tf.data.TextLineDataset

使用 tf.data.TextLineDataset处理文本文件中的行

file1.csv文件内容

复制代码
stock,open price,close price,highest price,lowest price
A,100.11,101.25,110.21,99.98
B,20.1,19.09,25.91,18.01

运行下列python代码

复制代码
ds = tf.data.TextLineDataset("file1.csv")
for elem in ds:
	print(elem)

屏幕显示

复制代码
tf.Tensor(b'stock,open price,close price,highest price,lowest price', shape=(), dtype=string)
tf.Tensor(b'A,100.11,101.25,110.21,99.98', shape=(), dtype=string)
tf.Tensor(b'B,20.1,19.09,25.91,18.01', shape=(), dtype=string)
tf.data.TFRecordDataset

使用 TFRecordDataset处理TFRecord格式的记录

复制代码
dataset = tf.data.TFRecordDataset("file1.tfrecords")

下列python代码产生一个TFRecord格式文件

复制代码
import tensorflow as tf
import os
import numpy as np

cur_dir = os.getcwd()
example_path = os.path.join(cur_dir, "file1.tfrecords")
np.random.seed(0)

with tf.io.TFRecordWriter(example_path) as file_writer:
  for _ in range(4):
    x, y = np.random.random(), np.random.random()

    record_bytes = tf.train.Example(features=tf.train.Features(feature={
        "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
        "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),
    })).SerializeToString()
    file_writer.write(record_bytes)
tf.data.Dataset.list_files

使用 tf.data.Dataset.list_files创建包含所有与某个模式匹配的文件的数据集

复制代码
dataset = tf.data.Dataset.list_files("/path/*.txt")

变换

apply
复制代码
apply(
    transformation_func
) -> 'DatasetV2'

对这个数据集中的每个元素,应用转换函数。

apply 功能支持自定义数据集转换的链式调用,这些转换表示为接受一个数据集参数,并返回转换后的数据集的函数。

复制代码
import tensorflow as tf

def dataset_fn(ds):
  return ds.filter(lambda x: x < 5)

ds = tf.data.Dataset.range(10)
ds = ds.apply(dataset_fn)

l = list(ds.as_numpy_iterator())
for e in l:
	print(e, end=",")

屏幕显示

复制代码
0,1,2,3,4,
cache
复制代码
cache(
    filename='', name=None
) -> 'DatasetV2'

缓存此数据集中的元素。

首次遍历数据集时,其元素将被缓存到指定文件或内存中。后续遍历将使用缓存的数据。

复制代码
import tensorflow as tf

ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: x**2)
ds = ds.cache()
l = list(ds.as_numpy_iterator())
for e in l:
	print(e, end=" ")

屏幕·显示

复制代码
0 1 4 9 16
shuffle

TensorFlow中的 tf.data.Dataset.cache().shuffle() 模式是优化机器学习管道中数据加载和处理的常用有效方法。

shuffle(buffer_size)

  • 此操作会对数据集中的元素进行随机排序。
  • buffer_size:此参数定义了用于随机抽取元素的缓冲区的大小。buffer_size越大,排序越彻底,但需要更多内存。为了确保整个数据集的真正随机性,理想情况下,buffer_size 应等于数据集中的元素总数。

使用shuffle对·数据集中的元素进行随机排序,对训练鲁棒模型至关重要,因为它能防止模型学习与数据顺序相关的虚假模式,并确保每个批次都能提供数据集的多样化表示。

将 shuffle() 放在 cache() 之后,意味着排序操作发生在缓存的数据上。这一点很重要,因为如果将 shuffle() 放在 cache()之前,数据集只会排序一次,然后缓存该单一的排序结果,导致每个 epoch 的排序结果都相同。通过将其放在cache()之后,每个epoch都可以获得缓存数据的全新,且不同的排序结果。

prefetch

数据集预取是一种数据管道中使用的优化技术,尤其是在 TensorFlow 和 PyTorch 等机器学习框架中,它通过将数据处理与模型计算重叠来提高性能。

  • 预取通常涉及一个后台线程,该线程异步地将输入数据集中的数据元素加载到内部缓冲区中。
  • 当模型处理当前批次的数据时,预取线程已经在准备下一批数据,从而最大限度地减少处理单元(例如 CPU 或 GPU)的空闲时间。
  • 通过在需要时,提供下一批数据,预取降低了数据访问的延迟,并提高了数据管道的整体吞吐量。
  • 预取可以与缓存(例如,使用 cache() 函数)结合使用,将处理后的数据,存储在内存中,避免跨多个 epoch 的冗余操作,从而进一步提升性能。缓存通常应在预取之前执行。

TensorFlow 中的 tf.data.Dataset.prefetch(buffer_size=AUTOTUNE) 方法是输入管道的关键优化,旨在将数据预处理与模型执行重叠。

当使用 tf.data.AUTOTUNE作为 buffer_size时,TensorFlow 会在运行时,动态调整缓冲区大小。这使得框架能够根据您的计算环境和工作负载,自动确定要预取的最佳元素数量,从而最大限度地提高性能,而无需手动调整。

相关推荐
小程故事多_8018 小时前
基于LangGraph与Neo4j构建智能体级GraphRAG:打造下一代膳食规划助手
人工智能·aigc·neo4j
serve the people19 小时前
TensorFlow 中 “延迟变量创建(Deferred Variable Creation)” 机制
人工智能·python·tensorflow
serve the people20 小时前
TensorFlow 中定义模型和层
人工智能·tensorflow·neo4j
山土成旧客20 小时前
【Python学习打卡-Day17】从二分类到多分类:ROC曲线、三大平均指标与风控利器MCC/KS
python·学习·分类
技术支持者python,php20 小时前
训练分类识别器
人工智能·分类·数据挖掘
serve the people1 天前
tensorflow tf.function 的 多态性(Polymorphism)
人工智能·python·tensorflow
serve the people2 天前
tensorflow tf.function 的两种执行模式(计算图执行 vs Eager 执行)的关键差异
人工智能·python·tensorflow
serve the people2 天前
tensorflow中的计算图是什么
人工智能·python·tensorflow
serve the people2 天前
tensorflow计算图的底层原理
人工智能·tensorflow·neo4j
serve the people2 天前
TensorFlow 图执行(tf.function)的 “非严格执行(Non-strict Execution)” 特性
人工智能·python·tensorflow