tensorflow 图像分类 之二
- 图像分类代码
- [API 说明](#API 说明)
图像分类代码
API 说明
tf.keras.utils.get_file
tf.keras.utils.image_dataset_from_directory
上接 《tensorflow 图像分类 之一》
tf.data.Dataset
tf.data.Dataset API 支持编写描述性强,且高效的输入管道。Dataset 的使用遵循以下常见模式:
- 从输入数据,创建源数据集。
- 使用数据集变换预处理数据。
- 遍历数据集,并处理其中的元素。
遍历以流式方式进行,因此无需将整个数据集加载到内存中。
产生数据集
tf.data.Dataset.from_tensor_slices
使用Python list创建数据集
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 9,10])
产生,显示数据集的python代码
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 9,10])
for element in dataset:
print(element)
屏幕显示
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(9, shape=(), dtype=int32)
tf.Tensor(10, shape=(), dtype=int32)
tf.data.TextLineDataset
使用 tf.data.TextLineDataset处理文本文件中的行
file1.csv文件内容
stock,open price,close price,highest price,lowest price
A,100.11,101.25,110.21,99.98
B,20.1,19.09,25.91,18.01
运行下列python代码
ds = tf.data.TextLineDataset("file1.csv")
for elem in ds:
print(elem)
屏幕显示
tf.Tensor(b'stock,open price,close price,highest price,lowest price', shape=(), dtype=string)
tf.Tensor(b'A,100.11,101.25,110.21,99.98', shape=(), dtype=string)
tf.Tensor(b'B,20.1,19.09,25.91,18.01', shape=(), dtype=string)
tf.data.TFRecordDataset
使用 TFRecordDataset处理TFRecord格式的记录
dataset = tf.data.TFRecordDataset("file1.tfrecords")
下列python代码产生一个TFRecord格式文件
import tensorflow as tf
import os
import numpy as np
cur_dir = os.getcwd()
example_path = os.path.join(cur_dir, "file1.tfrecords")
np.random.seed(0)
with tf.io.TFRecordWriter(example_path) as file_writer:
for _ in range(4):
x, y = np.random.random(), np.random.random()
record_bytes = tf.train.Example(features=tf.train.Features(feature={
"x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
"y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),
})).SerializeToString()
file_writer.write(record_bytes)
tf.data.Dataset.list_files
使用 tf.data.Dataset.list_files创建包含所有与某个模式匹配的文件的数据集
dataset = tf.data.Dataset.list_files("/path/*.txt")
变换
apply
apply(
transformation_func
) -> 'DatasetV2'
对这个数据集中的每个元素,应用转换函数。
apply 功能支持自定义数据集转换的链式调用,这些转换表示为接受一个数据集参数,并返回转换后的数据集的函数。
import tensorflow as tf
def dataset_fn(ds):
return ds.filter(lambda x: x < 5)
ds = tf.data.Dataset.range(10)
ds = ds.apply(dataset_fn)
l = list(ds.as_numpy_iterator())
for e in l:
print(e, end=",")
屏幕显示
0,1,2,3,4,
cache
cache(
filename='', name=None
) -> 'DatasetV2'
缓存此数据集中的元素。
首次遍历数据集时,其元素将被缓存到指定文件或内存中。后续遍历将使用缓存的数据。
import tensorflow as tf
ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: x**2)
ds = ds.cache()
l = list(ds.as_numpy_iterator())
for e in l:
print(e, end=" ")
屏幕·显示
0 1 4 9 16
shuffle
TensorFlow中的 tf.data.Dataset.cache().shuffle() 模式是优化机器学习管道中数据加载和处理的常用有效方法。
shuffle(buffer_size)
- 此操作会对数据集中的元素进行随机排序。
- buffer_size:此参数定义了用于随机抽取元素的缓冲区的大小。buffer_size越大,排序越彻底,但需要更多内存。为了确保整个数据集的真正随机性,理想情况下,buffer_size 应等于数据集中的元素总数。
使用shuffle对·数据集中的元素进行随机排序,对训练鲁棒模型至关重要,因为它能防止模型学习与数据顺序相关的虚假模式,并确保每个批次都能提供数据集的多样化表示。
将 shuffle() 放在 cache() 之后,意味着排序操作发生在缓存的数据上。这一点很重要,因为如果将 shuffle() 放在 cache()之前,数据集只会排序一次,然后缓存该单一的排序结果,导致每个 epoch 的排序结果都相同。通过将其放在cache()之后,每个epoch都可以获得缓存数据的全新,且不同的排序结果。
prefetch
数据集预取是一种数据管道中使用的优化技术,尤其是在 TensorFlow 和 PyTorch 等机器学习框架中,它通过将数据处理与模型计算重叠来提高性能。
- 预取通常涉及一个后台线程,该线程异步地将输入数据集中的数据元素加载到内部缓冲区中。
- 当模型处理当前批次的数据时,预取线程已经在准备下一批数据,从而最大限度地减少处理单元(例如 CPU 或 GPU)的空闲时间。
- 通过在需要时,提供下一批数据,预取降低了数据访问的延迟,并提高了数据管道的整体吞吐量。
- 预取可以与缓存(例如,使用 cache() 函数)结合使用,将处理后的数据,存储在内存中,避免跨多个 epoch 的冗余操作,从而进一步提升性能。缓存通常应在预取之前执行。
TensorFlow 中的 tf.data.Dataset.prefetch(buffer_size=AUTOTUNE) 方法是输入管道的关键优化,旨在将数据预处理与模型执行重叠。
当使用 tf.data.AUTOTUNE作为 buffer_size时,TensorFlow 会在运行时,动态调整缓冲区大小。这使得框架能够根据您的计算环境和工作负载,自动确定要预取的最佳元素数量,从而最大限度地提高性能,而无需手动调整。