tensorflow 图像分类 之二

tensorflow 图像分类 之二

图像分类代码

API 说明

tf.keras.utils.get_file

tf.keras.utils.image_dataset_from_directory

上接 《tensorflow 图像分类 之一》

tf.data.Dataset

tf.data.Dataset API 支持编写描述性强,且高效的输入管道。Dataset 的使用遵循以下常见模式:

  • 从输入数据,创建源数据集。
  • 使用数据集变换预处理数据。
  • 遍历数据集,并处理其中的元素。

遍历以流式方式进行,因此无需将整个数据集加载到内存中。

产生数据集

tf.data.Dataset.from_tensor_slices

使用Python list创建数据集

复制代码
dataset = tf.data.Dataset.from_tensor_slices([1,  2,  3, 9,10])

产生,显示数据集的python代码

复制代码
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([1,  2,  3, 9,10])
for element in dataset:
    print(element)

屏幕显示

复制代码
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(9, shape=(), dtype=int32)
tf.Tensor(10, shape=(), dtype=int32)
tf.data.TextLineDataset

使用 tf.data.TextLineDataset处理文本文件中的行

file1.csv文件内容

复制代码
stock,open price,close price,highest price,lowest price
A,100.11,101.25,110.21,99.98
B,20.1,19.09,25.91,18.01

运行下列python代码

复制代码
ds = tf.data.TextLineDataset("file1.csv")
for elem in ds:
	print(elem)

屏幕显示

复制代码
tf.Tensor(b'stock,open price,close price,highest price,lowest price', shape=(), dtype=string)
tf.Tensor(b'A,100.11,101.25,110.21,99.98', shape=(), dtype=string)
tf.Tensor(b'B,20.1,19.09,25.91,18.01', shape=(), dtype=string)
tf.data.TFRecordDataset

使用 TFRecordDataset处理TFRecord格式的记录

复制代码
dataset = tf.data.TFRecordDataset("file1.tfrecords")

下列python代码产生一个TFRecord格式文件

复制代码
import tensorflow as tf
import os
import numpy as np

cur_dir = os.getcwd()
example_path = os.path.join(cur_dir, "file1.tfrecords")
np.random.seed(0)

with tf.io.TFRecordWriter(example_path) as file_writer:
  for _ in range(4):
    x, y = np.random.random(), np.random.random()

    record_bytes = tf.train.Example(features=tf.train.Features(feature={
        "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
        "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),
    })).SerializeToString()
    file_writer.write(record_bytes)
tf.data.Dataset.list_files

使用 tf.data.Dataset.list_files创建包含所有与某个模式匹配的文件的数据集

复制代码
dataset = tf.data.Dataset.list_files("/path/*.txt")

变换

apply
复制代码
apply(
    transformation_func
) -> 'DatasetV2'

对这个数据集中的每个元素,应用转换函数。

apply 功能支持自定义数据集转换的链式调用,这些转换表示为接受一个数据集参数,并返回转换后的数据集的函数。

复制代码
import tensorflow as tf

def dataset_fn(ds):
  return ds.filter(lambda x: x < 5)

ds = tf.data.Dataset.range(10)
ds = ds.apply(dataset_fn)

l = list(ds.as_numpy_iterator())
for e in l:
	print(e, end=",")

屏幕显示

复制代码
0,1,2,3,4,
cache
复制代码
cache(
    filename='', name=None
) -> 'DatasetV2'

缓存此数据集中的元素。

首次遍历数据集时,其元素将被缓存到指定文件或内存中。后续遍历将使用缓存的数据。

复制代码
import tensorflow as tf

ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: x**2)
ds = ds.cache()
l = list(ds.as_numpy_iterator())
for e in l:
	print(e, end=" ")

屏幕·显示

复制代码
0 1 4 9 16
shuffle

TensorFlow中的 tf.data.Dataset.cache().shuffle() 模式是优化机器学习管道中数据加载和处理的常用有效方法。

shuffle(buffer_size)

  • 此操作会对数据集中的元素进行随机排序。
  • buffer_size:此参数定义了用于随机抽取元素的缓冲区的大小。buffer_size越大,排序越彻底,但需要更多内存。为了确保整个数据集的真正随机性,理想情况下,buffer_size 应等于数据集中的元素总数。

使用shuffle对·数据集中的元素进行随机排序,对训练鲁棒模型至关重要,因为它能防止模型学习与数据顺序相关的虚假模式,并确保每个批次都能提供数据集的多样化表示。

将 shuffle() 放在 cache() 之后,意味着排序操作发生在缓存的数据上。这一点很重要,因为如果将 shuffle() 放在 cache()之前,数据集只会排序一次,然后缓存该单一的排序结果,导致每个 epoch 的排序结果都相同。通过将其放在cache()之后,每个epoch都可以获得缓存数据的全新,且不同的排序结果。

prefetch

数据集预取是一种数据管道中使用的优化技术,尤其是在 TensorFlow 和 PyTorch 等机器学习框架中,它通过将数据处理与模型计算重叠来提高性能。

  • 预取通常涉及一个后台线程,该线程异步地将输入数据集中的数据元素加载到内部缓冲区中。
  • 当模型处理当前批次的数据时,预取线程已经在准备下一批数据,从而最大限度地减少处理单元(例如 CPU 或 GPU)的空闲时间。
  • 通过在需要时,提供下一批数据,预取降低了数据访问的延迟,并提高了数据管道的整体吞吐量。
  • 预取可以与缓存(例如,使用 cache() 函数)结合使用,将处理后的数据,存储在内存中,避免跨多个 epoch 的冗余操作,从而进一步提升性能。缓存通常应在预取之前执行。

TensorFlow 中的 tf.data.Dataset.prefetch(buffer_size=AUTOTUNE) 方法是输入管道的关键优化,旨在将数据预处理与模型执行重叠。

当使用 tf.data.AUTOTUNE作为 buffer_size时,TensorFlow 会在运行时,动态调整缓冲区大小。这使得框架能够根据您的计算环境和工作负载,自动确定要预取的最佳元素数量,从而最大限度地提高性能,而无需手动调整。

相关推荐
IT_Beijing_BIT8 小时前
tensorflow 图像分类 之四
人工智能·分类·tensorflow
武子康21 小时前
Java-167 Neo4j CQL 实战:CREATE/MATCH 与关系建模速通 案例实测
java·开发语言·数据库·python·sql·nosql·neo4j
Learn Beyond Limits1 天前
Clustering vs Classification|聚类vs分类
人工智能·算法·机器学习·ai·分类·数据挖掘·聚类
chao1898441 天前
遗传算法与粒子群算法优化BP提高分类效果
算法·分类·数据挖掘
诸葛务农1 天前
光电对抗分类及外场静爆试验操作规程
人工智能·嵌入式硬件·分类·数据挖掘
ScilogyHunter1 天前
卫星姿态控制模式全解析:从基准到任务的体系化分类
算法·分类
盼小辉丶1 天前
TensorFlow深度学习实战——胶囊网络
深度学习·tensorflow·keras
傻啦嘿哟2 天前
Python高效实现Word转HTML:从基础到进阶的全流程方案
人工智能·python·tensorflow
AI_56782 天前
AI开发革命:PyCharm科学计算模式重塑TensorFlow调试体验
人工智能·ai·neo4j