TensorFlow2从磁盘读取图片数据集的示例(tf.data.Dataset.list_files)

python 复制代码
import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications.resnet import ResNet50
from pathlib import Path
import numpy as np

#数据所在文件夹
base_dir = './data/cats_and_dogs'
train_dir = Path(os.path.join(base_dir,'train'))
file_pattern = os.path.join(train_dir,'*/*.jpg')
image_count = len(list(train_dir.glob('*/*.jpg')))

list_ds = tf.data.Dataset.list_files(file_pattern,shuffle = False)
list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)
for f in list_ds.take(5):
  print(f.numpy())
  
class_names = np.array(sorted([item.name for item in train_dir.glob('*') ]))
print(class_names)

val_size = int(image_count * 0.2)
train_data = list_ds.skip(val_size)
validation_data = list_ds.take(val_size)
print(tf.data.experimental.cardinality(train_data).numpy())
print(tf.data.experimental.cardinality(validation_data).numpy())


def get_label(file_path):
  parts = tf.strings.split(file_path, os.path.sep)
  one_hot = parts[-2] == class_names
  return tf.argmax(one_hot)

def decode_img(img):
  img = tf.io.decode_jpeg(img, channels=3)
  return tf.image.resize(img, [64, 64])

def process_path(file_path):
  label = get_label(file_path)
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img, label

train_data = train_data.map(process_path, num_parallel_calls=tf.data.AUTOTUNE)
validation_data = validation_data.map(process_path, num_parallel_calls=tf.data.AUTOTUNE)

for image, label in train_data.take(2):
  print("Image shape: ", image.numpy().shape)
  print("Label: ", label.numpy())

def configure_for_performance(ds):
  ds = ds.cache()
  ds = ds.shuffle(buffer_size=1000)
  ds = ds.batch(4)
  ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
  return ds

train_data = configure_for_performance(train_data)
validation_data = configure_for_performance(validation_data)


save_model_cb = tf.keras.callbacks.ModelCheckpoint(filepath='model_resnet50_cats_and_dogs.h5', save_freq='epoch')

base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(64, 64, 3))
base_model.trainable = True
    
model = tf.keras.models.Sequential([
    base_model,
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation='relu',kernel_regularizer=tf.keras.regularizers.l2(l=0.01)),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(loss='binary_crossentropy',optimizer = Adam(lr=1e-3),metrics = ['acc'])

history = model.fit(train_data.repeat(),steps_per_epoch=100,epochs=50,validation_data=validation_data.repeat(),validation_steps=50,verbose=1,callbacks = [save_model_cb])
相关推荐
木卫二号Coding15 小时前
第七十二篇-V100-32G+WebUI+Flux.1-Schnell+Lora+文生图
开发语言·人工智能·python
之歆15 小时前
Spring AI入门到实战到原理源码-笔记-(上)
java·人工智能·spring
墨笔之风15 小时前
基于python 实现的小游戏
开发语言·python·pygame
多米Domi01115 小时前
0x3f 第24天 黑马web (安了半天程序 )hot100普通数组
数据结构·python·算法·leetcode
BoBoZz1915 小时前
AnatomicalOrientation 3D人体模型及三个人体标准解剖学平面展示
python·vtk·图形渲染·图形处理
love530love15 小时前
EPGF 新手教程 11在 PyCharm(中文版 GUI)中创建 uv 环境,并把 uv 做到“项目自包含”(工具本地化为必做环节)
ide·人工智能·python·pycharm·conda·uv·epgf
jackylzh15 小时前
cmd或其它终端的dos命令 & events.out.tfevents文件怎么打开
python
gis_rc15 小时前
python下shp转3dtiles
python·3d·cesium·3dtiles·数字孪生模型
廖圣平15 小时前
直播间福袋脚本,研究json格式【一】
python
Fabarta技术团队15 小时前
响应北京人工智能行动计划,枫清科技共筑AI创新高地
人工智能·科技