tensorflow 图像分类 之二

tensorflow 图像分类 之二

图像分类代码

API 说明

tf.keras.utils.get_file

tf.keras.utils.image_dataset_from_directory

上接 《tensorflow 图像分类 之一》

tf.data.Dataset

tf.data.Dataset API 支持编写描述性强,且高效的输入管道。Dataset 的使用遵循以下常见模式:

  • 从输入数据,创建源数据集。
  • 使用数据集变换预处理数据。
  • 遍历数据集,并处理其中的元素。

遍历以流式方式进行,因此无需将整个数据集加载到内存中。

产生数据集

tf.data.Dataset.from_tensor_slices

使用Python list创建数据集

复制代码
dataset = tf.data.Dataset.from_tensor_slices([1,  2,  3, 9,10])

产生,显示数据集的python代码

复制代码
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([1,  2,  3, 9,10])
for element in dataset:
    print(element)

屏幕显示

复制代码
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(9, shape=(), dtype=int32)
tf.Tensor(10, shape=(), dtype=int32)
tf.data.TextLineDataset

使用 tf.data.TextLineDataset处理文本文件中的行

file1.csv文件内容

复制代码
stock,open price,close price,highest price,lowest price
A,100.11,101.25,110.21,99.98
B,20.1,19.09,25.91,18.01

运行下列python代码

复制代码
ds = tf.data.TextLineDataset("file1.csv")
for elem in ds:
	print(elem)

屏幕显示

复制代码
tf.Tensor(b'stock,open price,close price,highest price,lowest price', shape=(), dtype=string)
tf.Tensor(b'A,100.11,101.25,110.21,99.98', shape=(), dtype=string)
tf.Tensor(b'B,20.1,19.09,25.91,18.01', shape=(), dtype=string)
tf.data.TFRecordDataset

使用 TFRecordDataset处理TFRecord格式的记录

复制代码
dataset = tf.data.TFRecordDataset("file1.tfrecords")

下列python代码产生一个TFRecord格式文件

复制代码
import tensorflow as tf
import os
import numpy as np

cur_dir = os.getcwd()
example_path = os.path.join(cur_dir, "file1.tfrecords")
np.random.seed(0)

with tf.io.TFRecordWriter(example_path) as file_writer:
  for _ in range(4):
    x, y = np.random.random(), np.random.random()

    record_bytes = tf.train.Example(features=tf.train.Features(feature={
        "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
        "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),
    })).SerializeToString()
    file_writer.write(record_bytes)
tf.data.Dataset.list_files

使用 tf.data.Dataset.list_files创建包含所有与某个模式匹配的文件的数据集

复制代码
dataset = tf.data.Dataset.list_files("/path/*.txt")

变换

apply
复制代码
apply(
    transformation_func
) -> 'DatasetV2'

对这个数据集中的每个元素,应用转换函数。

apply 功能支持自定义数据集转换的链式调用,这些转换表示为接受一个数据集参数,并返回转换后的数据集的函数。

复制代码
import tensorflow as tf

def dataset_fn(ds):
  return ds.filter(lambda x: x < 5)

ds = tf.data.Dataset.range(10)
ds = ds.apply(dataset_fn)

l = list(ds.as_numpy_iterator())
for e in l:
	print(e, end=",")

屏幕显示

复制代码
0,1,2,3,4,
cache
复制代码
cache(
    filename='', name=None
) -> 'DatasetV2'

缓存此数据集中的元素。

首次遍历数据集时,其元素将被缓存到指定文件或内存中。后续遍历将使用缓存的数据。

复制代码
import tensorflow as tf

ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: x**2)
ds = ds.cache()
l = list(ds.as_numpy_iterator())
for e in l:
	print(e, end=" ")

屏幕·显示

复制代码
0 1 4 9 16
shuffle

TensorFlow中的 tf.data.Dataset.cache().shuffle() 模式是优化机器学习管道中数据加载和处理的常用有效方法。

shuffle(buffer_size)

  • 此操作会对数据集中的元素进行随机排序。
  • buffer_size:此参数定义了用于随机抽取元素的缓冲区的大小。buffer_size越大,排序越彻底,但需要更多内存。为了确保整个数据集的真正随机性,理想情况下,buffer_size 应等于数据集中的元素总数。

使用shuffle对·数据集中的元素进行随机排序,对训练鲁棒模型至关重要,因为它能防止模型学习与数据顺序相关的虚假模式,并确保每个批次都能提供数据集的多样化表示。

将 shuffle() 放在 cache() 之后,意味着排序操作发生在缓存的数据上。这一点很重要,因为如果将 shuffle() 放在 cache()之前,数据集只会排序一次,然后缓存该单一的排序结果,导致每个 epoch 的排序结果都相同。通过将其放在cache()之后,每个epoch都可以获得缓存数据的全新,且不同的排序结果。

prefetch

数据集预取是一种数据管道中使用的优化技术,尤其是在 TensorFlow 和 PyTorch 等机器学习框架中,它通过将数据处理与模型计算重叠来提高性能。

  • 预取通常涉及一个后台线程,该线程异步地将输入数据集中的数据元素加载到内部缓冲区中。
  • 当模型处理当前批次的数据时,预取线程已经在准备下一批数据,从而最大限度地减少处理单元(例如 CPU 或 GPU)的空闲时间。
  • 通过在需要时,提供下一批数据,预取降低了数据访问的延迟,并提高了数据管道的整体吞吐量。
  • 预取可以与缓存(例如,使用 cache() 函数)结合使用,将处理后的数据,存储在内存中,避免跨多个 epoch 的冗余操作,从而进一步提升性能。缓存通常应在预取之前执行。

TensorFlow 中的 tf.data.Dataset.prefetch(buffer_size=AUTOTUNE) 方法是输入管道的关键优化,旨在将数据预处理与模型执行重叠。

当使用 tf.data.AUTOTUNE作为 buffer_size时,TensorFlow 会在运行时,动态调整缓冲区大小。这使得框架能够根据您的计算环境和工作负载,自动确定要预取的最佳元素数量,从而最大限度地提高性能,而无需手动调整。

相关推荐
是小蟹呀^2 分钟前
从稀疏到自适应:人脸识别中稀疏表示的核心演进
人工智能·分类
AAD5558889911 小时前
YOLO11-EfficientRepBiPAN载重汽车轮胎热成像检测与分类_3
人工智能·分类·数据挖掘
小徐xxx19 小时前
Softmax回归(分类问题)学习记录
深度学习·分类·回归·softmax·学习记录
AAD5558889919 小时前
YOLOv8-MAN-Faster电容器缺陷检测:七类组件识别与分类系统
yolo·分类·数据挖掘
JicasdC123asd19 小时前
【工业检测】基于YOLO13-C3k2-EIEM的铸造缺陷检测与分类系统_1
人工智能·算法·分类
子夜江寒21 小时前
基于 LSTM 的中文情感分类项目解析
人工智能·分类·lstm
是小蟹呀^1 天前
Focal Loss:解决长尾图像分类中“多数类太强势”的损失函数
人工智能·机器学习·分类
2501_941329721 天前
基于Centernet的甜菜幼苗生长状态识别与分类系统
人工智能·分类·数据挖掘
Daydream.V1 天前
决策树三中分类标准
算法·决策树·分类
ZCXZ12385296a1 天前
【实战案例】基于YOLOv8的亚洲107种鸟类图像分类与目标检测系统_2
yolo·目标检测·分类