tensorflow 图像分类 之二

tensorflow 图像分类 之二

图像分类代码

API 说明

tf.keras.utils.get_file

tf.keras.utils.image_dataset_from_directory

上接 《tensorflow 图像分类 之一》

tf.data.Dataset

tf.data.Dataset API 支持编写描述性强,且高效的输入管道。Dataset 的使用遵循以下常见模式:

  • 从输入数据,创建源数据集。
  • 使用数据集变换预处理数据。
  • 遍历数据集,并处理其中的元素。

遍历以流式方式进行,因此无需将整个数据集加载到内存中。

产生数据集

tf.data.Dataset.from_tensor_slices

使用Python list创建数据集

复制代码
dataset = tf.data.Dataset.from_tensor_slices([1,  2,  3, 9,10])

产生,显示数据集的python代码

复制代码
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([1,  2,  3, 9,10])
for element in dataset:
    print(element)

屏幕显示

复制代码
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(9, shape=(), dtype=int32)
tf.Tensor(10, shape=(), dtype=int32)
tf.data.TextLineDataset

使用 tf.data.TextLineDataset处理文本文件中的行

file1.csv文件内容

复制代码
stock,open price,close price,highest price,lowest price
A,100.11,101.25,110.21,99.98
B,20.1,19.09,25.91,18.01

运行下列python代码

复制代码
ds = tf.data.TextLineDataset("file1.csv")
for elem in ds:
	print(elem)

屏幕显示

复制代码
tf.Tensor(b'stock,open price,close price,highest price,lowest price', shape=(), dtype=string)
tf.Tensor(b'A,100.11,101.25,110.21,99.98', shape=(), dtype=string)
tf.Tensor(b'B,20.1,19.09,25.91,18.01', shape=(), dtype=string)
tf.data.TFRecordDataset

使用 TFRecordDataset处理TFRecord格式的记录

复制代码
dataset = tf.data.TFRecordDataset("file1.tfrecords")

下列python代码产生一个TFRecord格式文件

复制代码
import tensorflow as tf
import os
import numpy as np

cur_dir = os.getcwd()
example_path = os.path.join(cur_dir, "file1.tfrecords")
np.random.seed(0)

with tf.io.TFRecordWriter(example_path) as file_writer:
  for _ in range(4):
    x, y = np.random.random(), np.random.random()

    record_bytes = tf.train.Example(features=tf.train.Features(feature={
        "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
        "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),
    })).SerializeToString()
    file_writer.write(record_bytes)
tf.data.Dataset.list_files

使用 tf.data.Dataset.list_files创建包含所有与某个模式匹配的文件的数据集

复制代码
dataset = tf.data.Dataset.list_files("/path/*.txt")

变换

apply
复制代码
apply(
    transformation_func
) -> 'DatasetV2'

对这个数据集中的每个元素,应用转换函数。

apply 功能支持自定义数据集转换的链式调用,这些转换表示为接受一个数据集参数,并返回转换后的数据集的函数。

复制代码
import tensorflow as tf

def dataset_fn(ds):
  return ds.filter(lambda x: x < 5)

ds = tf.data.Dataset.range(10)
ds = ds.apply(dataset_fn)

l = list(ds.as_numpy_iterator())
for e in l:
	print(e, end=",")

屏幕显示

复制代码
0,1,2,3,4,
cache
复制代码
cache(
    filename='', name=None
) -> 'DatasetV2'

缓存此数据集中的元素。

首次遍历数据集时,其元素将被缓存到指定文件或内存中。后续遍历将使用缓存的数据。

复制代码
import tensorflow as tf

ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: x**2)
ds = ds.cache()
l = list(ds.as_numpy_iterator())
for e in l:
	print(e, end=" ")

屏幕·显示

复制代码
0 1 4 9 16
shuffle

TensorFlow中的 tf.data.Dataset.cache().shuffle() 模式是优化机器学习管道中数据加载和处理的常用有效方法。

shuffle(buffer_size)

  • 此操作会对数据集中的元素进行随机排序。
  • buffer_size:此参数定义了用于随机抽取元素的缓冲区的大小。buffer_size越大,排序越彻底,但需要更多内存。为了确保整个数据集的真正随机性,理想情况下,buffer_size 应等于数据集中的元素总数。

使用shuffle对·数据集中的元素进行随机排序,对训练鲁棒模型至关重要,因为它能防止模型学习与数据顺序相关的虚假模式,并确保每个批次都能提供数据集的多样化表示。

将 shuffle() 放在 cache() 之后,意味着排序操作发生在缓存的数据上。这一点很重要,因为如果将 shuffle() 放在 cache()之前,数据集只会排序一次,然后缓存该单一的排序结果,导致每个 epoch 的排序结果都相同。通过将其放在cache()之后,每个epoch都可以获得缓存数据的全新,且不同的排序结果。

prefetch

数据集预取是一种数据管道中使用的优化技术,尤其是在 TensorFlow 和 PyTorch 等机器学习框架中,它通过将数据处理与模型计算重叠来提高性能。

  • 预取通常涉及一个后台线程,该线程异步地将输入数据集中的数据元素加载到内部缓冲区中。
  • 当模型处理当前批次的数据时,预取线程已经在准备下一批数据,从而最大限度地减少处理单元(例如 CPU 或 GPU)的空闲时间。
  • 通过在需要时,提供下一批数据,预取降低了数据访问的延迟,并提高了数据管道的整体吞吐量。
  • 预取可以与缓存(例如,使用 cache() 函数)结合使用,将处理后的数据,存储在内存中,避免跨多个 epoch 的冗余操作,从而进一步提升性能。缓存通常应在预取之前执行。

TensorFlow 中的 tf.data.Dataset.prefetch(buffer_size=AUTOTUNE) 方法是输入管道的关键优化,旨在将数据预处理与模型执行重叠。

当使用 tf.data.AUTOTUNE作为 buffer_size时,TensorFlow 会在运行时,动态调整缓冲区大小。这使得框架能够根据您的计算环境和工作负载,自动确定要预取的最佳元素数量,从而最大限度地提高性能,而无需手动调整。

相关推荐
大鹏的NLP博客19 小时前
大模型中为什么 CoT 对分类有效?
人工智能·分类·数据挖掘
后端小张20 小时前
【AI 学习】LangChain框架深度解析:从核心组件到企业级应用实战
java·人工智能·学习·langchain·tensorflow·gpt-3·ai编程
算法与编程之美20 小时前
不同的优化器对分类精度的影响以及损失函数对分类精度的影响.
人工智能·算法·机器学习·分类·数据挖掘
麦麦大数据20 小时前
F054-基于Vue+Flask+Neo4j构建的移民知识图谱可视化分析系统
vue.js·flask·知识图谱·neo4j·移民分析
sali-tec20 小时前
C# 基于halcon的视觉工具VisionTool Halcon发布
人工智能·深度学习·算法·计算机视觉·分类
MarkHD21 小时前
智能体在车联网中的应用:第14天 卷积神经网络(CNN)专精:从卷积原理到LeNet-5实战车辆图像分类
人工智能·分类·cnn
新鲜势力呀21 小时前
TensorFlow 中 tf.placeholder 适用版本解析|附 PHP 调用 TF 模型实战(兼容低版本)
tensorflow·php·neo4j
麦麦大数据21 小时前
F055 vue+neo4j船舶知识问答系统|知识图谱|问答系统
vue.js·flask·问答系统·知识图谱·neo4j·可视化
啊阿狸不会拉杆21 小时前
《数字图像处理》第 12 章 - 图像模式分类
图像处理·人工智能·算法·机器学习·计算机视觉·分类·数据挖掘
A尘埃1 天前
PyTorch的分布式训练策略:DDP + DeepSpeed + TensorFlow的分布式训练策略:MirroredStrategy
pytorch·分布式·tensorflow