TensorFlow学习:使用官方模型进行图像分类并对模型进行微调

本文是对文章 TensorFlow学习:使用官方模型进行图像分类、使用自己的数据对模型进行微调

的补充说明。因为版本兼容的原因,原文有多处代码无法成功运行。这里把调整后的两处完整代码贴了出来,同时附上对应的模型文件(里面的文件或目录和Python文件放在相同目录下),以作对比。

运行环境:Mac 14.2、Python 3.12.2

使用成熟模型的完整代码:

python 复制代码
# 导入tensorflow 和科学计算库
import tensorflow as tf
import numpy as np
# tensorflow-hub是一个TensorFlow库的扩展,它提供了一个简单的接口,用于重用已经训练好的机器学习模型的部分
import tensorflow_hub as hub
# 字体属性
from matplotlib.font_manager import FontProperties
# matplotlib是用于绘制图表和可视化数据的库
import matplotlib.pylab as plt
# 用于加载json文件
import json
import tf_keras
import ssl
import certifi

# 导入模型
# 不能直接加载模型文件,需要加载器目录
# 加载mobilenet_v2模型,这里要加载文件夹不要直接加载pb文件
# 模型如何加载要看文档,原来使用tf.keras.models.load_model加载一直失败
model = tf_keras.Sequential([
    hub.KerasLayer('mobilenet-v2-classification')
])

# 假设输入为224x224 RGB图像
#input_shape = (224, 224, 3)
#input_layer = tf.keras.Input(shape=(input_shape))
#hub_layer = hub.KerasLayer("mobilenet_v2", trainable=True)
#x = hub_layer(input_layer)
#output_layer = tf.keras.layers.Dense(units, activation='activation_type')(x)
#
#model = tf.keras.Model(inputs=input_layer, outputs=output_layer)

#model = tf.keras.applications.mobilenet_v2.MobileNetV2()
print("模型信息:",model)


# 预处理输入数据
# 1、mobilenet需要的图片尺寸是 224 * 224
image = tf.keras.preprocessing.image.load_img('pics/dog.jpg',target_size=(224,224))

# 设置SSL上下文
#image =tf.keras.utils.get_file('bird.jpg','https://scpic.chinaz.net/files/default/imgs/2023-08-29/7dc085b6d3291303.jpg')

# 2、将图片转为数组,既是只有一张图片
image = tf.keras.preprocessing.image.img_to_array(image)
# 3、扩展数组维度,使其符合模型的输入
image = np.expand_dims(image, axis=0)
# 4、使用mobilenet_v2提供的预处理函数对图像处理,包括图像归一化、颜色通道顺序调整、像素值标准化等操作
image = tf.keras.applications.mobilenet_v2.preprocess_input(image)


# 预测
predictions = model.predict(image)
# 获取最高概率对应的类别索引
predicted_index = np.argmax(predictions)
# 概率值
confidence = np.max(predictions)
print("索引和概率值是:",predicted_index,confidence)

# 初始化一个空列表来存储文件的行
#labels_dict = []
# 加载映射文件
with open('mobilenet_v2/ImageNetLabels.txt','r') as f:
#    labels_dict  = json.load(f)
    labels_dict = f.readlines()
# 类别的索引是字符串,这里要简单处理一下,这里-1是因为官方提供的多了一个0(背景),我找到的标签没有这个,因此要-1
class_name = labels_dict[predicted_index]
print(class_name)

# 可视化显示
font = FontProperties()
plt.figure() # 创建图像窗口
plt.xticks([])
plt.yticks([])
plt.grid(False) # 取消网格线
plt.imshow(image[0]) # 显示图片
plt.xlabel(class_name,fontproperties=font)
plt.show() # 显示图形窗口

对模型进行微调的完整代码

python 复制代码
# 导入tensorflow 和科学计算库
import tensorflow as tf
import numpy as np
# tensorflow-hub是一个TensorFlow库的扩展,它提供了一个简单的接口,用于重用已经训练好的机器学习模型的部分
import tensorflow_hub as hub
# 字体属性
from matplotlib.font_manager import FontProperties
# matplotlib是用于绘制图表和可视化数据的库
import matplotlib.pylab as plt

import datetime
import tf_keras

# 导入模型
# 不能直接加载模型文件,需要加载器目录
model = tf_keras.Sequential([
    hub.KerasLayer('mobilenet-v2-classification')
])


# 32张图片为一个批次,尺寸设置为224*224
batch_size = 32
img_height = 224
img_width = 224

# 加载图像数据集,并将其分割为训练集和验证集,验证集比例为20%
train_ds = tf.keras.utils.image_dataset_from_directory(
    'flower_photos',  # 目录
    validation_split=0.2, # 验证集占20%
    subset="training", # 将数据集划分为训练集
    seed= 123, # 随机种子,用于数据集随机划分
    image_size= (img_width,img_height) , # 调整图像大小
    batch_size= batch_size  # 每个批次中包含的图像数量
)
# 验证集
val_ds = tf.keras.utils.image_dataset_from_directory(
    'flower_photos',  # 目录
    validation_split=0.2, # 验证集占20%
    subset="validation", # 将数据集划分为验证集
    seed= 123, # 随机种子,用于数据集随机划分
    image_size= (img_width,img_height) , # 调整图像大小
    batch_size= batch_size  # 每个批次中包含的图像数量
)

# 花卉种类
class_names = np.array(train_ds.class_names)
print("花卉种类:",class_names)

# 归一化
normalization_layer = tf.keras.layers.Rescaling(1./255) # 创建了一个Rescaling层,将像素值缩放到0到1之间 。 1./255是 1/255保留小数,差点没看懂
train_ds = train_ds.map(lambda x,y:(normalization_layer(x),y))
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))

# 使用缓冲预取,避免产生I/O阻塞
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# 验证数据是否成功加载和处理
for image_batch, labels_batch in train_ds:
  print(image_batch.shape)
  print(labels_batch.shape)
  break

# 对一批图片运行分类器,进行预测
result_batch = model.predict(train_ds)
# 加载映射文件,这里我将其下载到了本地
imagenet_labels = np.array(open('mobilenet-v2-feature-vector/ImageNetLabels.txt').read().splitlines())
# 在给定的张量中找到沿指定轴的最大值的索引
predict_class_names = imagenet_labels[tf.math.argmax(result_batch, axis=-1)]
print("预测类别:",predict_class_names)

# 绘制出预测与图片
# plt.figure(figsize=(10,9))
# plt.subplots_adjust(hspace=0.5)
# for n in range(30):
#   plt.subplot(6,5,n+1)
#   plt.imshow(image_batch[n])
#   plt.title(predict_class_names[n])
#   plt.axis('off')
# _ = plt.suptitle("ImageNet predictions")
# plt.show()

# 加载特征提取器
feature_extractor_layer = hub.KerasLayer(
  'mobilenet-v2-feature-vector', # 预训练模型
  input_shape=(224,224,3), # 指定图像输入的高度、宽度和通道数
  trainable=False  #训练过程中不更新特征提取器的权重
)
# 特征提取器为每个图像返回一个 1280 长的向量(在此示例中,图像批大小仍为 32)
feature_batch = feature_extractor_layer(image_batch)
print("特征批次形状:",feature_batch.shape)

# 附加分类头
new_model = tf_keras.Sequential([
  feature_extractor_layer,
  tf_keras.layers.Dense(len(class_names),activation='softmax') # 指定输出分类,这里的花是5类
])


# 训练模型
new_model.compile(
  optimizer=tf_keras.optimizers.Adam(),  # 使用Adam优化器作为优化算法
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # 使用SparseCategoricalCrossentropy作为损失函数
  metrics=['acc'] # 使用准确率作为评估指标
)
# 训练日志
log_dir = "logs/fit/" + datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
# 用于在训练过程中收集模型指标和摘要数据,并将其写入TensorBoard日志文件中
tensorboard_callback = tf_keras.callbacks.TensorBoard(
  log_dir= log_dir,
  histogram_freq=1
)

# 开始训练,暂时只训练10轮。history记录了训练过程中的各项指标,便于后续分析和可视化
history = new_model.fit(
  train_ds, # 训练数据集
  validation_data=val_ds, # 验证数据集,用于在训练过程中监控模型的性能
  epochs=10, # 训练的总轮次
  callbacks=tensorboard_callback # 回调函数,用于在训练过程中执行特定操作,比如记录日志
)

# 简单预测
# predicted_batch = new_model.predict(image_batch)
# predicted_id = tf.math.argmax(predicted_batch, axis=-1)
# predicted_label_batch = class_names[predicted_id]
# print("花卉种类:",predicted_label_batch)

# plt.figure(figsize=(10,9))
# plt.subplots_adjust(hspace=0.5)

# for n in range(30):
#   plt.subplot(6,5,n+1)
#   plt.imshow(image_batch[n])
#   plt.title(predicted_label_batch[n].title())
#   plt.axis('off')
# _ = plt.suptitle("Model predictions")
# plt.show()

# 导出训练好的模型
export_path = 'tmp/saved_models/flower_model'
new_model.save(export_path)
相关推荐
进击的六角龙1 小时前
Python中处理Excel的基本概念(如工作簿、工作表等)
开发语言·python·excel
一只爱好编程的程序猿1 小时前
Java后台生成指定路径下创建指定名称的文件
java·python·数据下载
Aniay_ivy1 小时前
深入探索 Java 8 Stream 流:高效操作与应用场景
java·开发语言·python
gonghw4031 小时前
DearPyGui学习
python·gui
向阳12182 小时前
Bert快速入门
人工智能·python·自然语言处理·bert
engchina2 小时前
Neo4j 和 Python 初学者指南:如何使用可选关系匹配优化 Cypher 查询
数据库·python·neo4j
兆。2 小时前
掌握 PyQt5:从零开始的桌面应用开发
开发语言·爬虫·python·qt
南宫理的日知录2 小时前
99、Python并发编程:多线程的问题、临界资源以及同步机制
开发语言·python·学习·编程学习
coberup2 小时前
django Forbidden (403)错误解决方法
python·django·403错误
龙哥说跨境3 小时前
如何利用指纹浏览器爬虫绕过Cloudflare的防护?
服务器·网络·python·网络爬虫