显微镜图像处理(分类)-训练一个图像分类器

前言

本篇文章转载于:www.tensorflow.org/tutorials/k...

想了解更多细节可以前往tensorflow官网。

存在一个图像分类任务,我们会预先准备好四批数据:1. 训练图像数据,2. 训练图像标签数据(按照下标和训练数据对应),3. 测试图像数据,4. 预测图像标签数据(按照下标和测试数据对应)。

首先,根据上面展示出来的链接,我们知道google官网为我们准备了一批衣物数据,让我们用作图像分类,并且给我们准备好了上述说的四类数据,现在我们根据现有的四类数据,训练一个能够预测衣物种类的预测器

1. 准备开发环境

我们使用python+tensorflow进行模型训练,python版本为3.11,在这个python版本中,我们不用再去安装tensorflow-gpu版本,我们只需要安装tensorflow,并且tensorflow内部自带了keras

python 复制代码
import os

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

在上面的代码案例中,我们能够从keras自带的数据集合中提取出:

  • train_images(训练数据):用于训练模型的数据
  • train_labels(训练数据标签): 用来标注训练数据类别的标签
  • test_images(测试数据): 用于测试的数据
  • test_labels(对照标签): 训练出来的模型准不准,就使用对照数据进行对照

2. 数据预处理

在训练模型之前,必须对数据进行预处理。如果您检查训练集中的第一个图像,您会看到像素值处于 0 到 255 之间,这会我们就需要将这些图像的像素值规格化为0~1之间的浮点数

python 复制代码
train_images = train_images / 255.0

test_images = test_images / 255.0

并且keras给到我们的标签数据也都是0~9之间的整数,所以我们要对这些整数进行处理,让他变成人类可以理解的文字:

python 复制代码
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

3. 训练模型

以下是tensorflow中keras文档的原话:

构建神经网络需要先配置模型的层,然后再编译模型。

神经网络的基本组成部分是。层会从向其馈送的数据中提取表示形式。希望这些表示形式有助于解决手头上的问题。

大多数深度学习都包括将简单的层链接在一起。大多数层(如 tf.keras.layers.Dense)都具有在训练期间才会学习的参数。

python 复制代码
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
])

该网络的第一层 tf.keras.layers.Flatten 将图像格式从二维数组(28 x 28 像素)转换成一维数组(28 x 28 = 784 像素)。将该层视为图像中未堆叠的像素行并将其排列起来。该层没有要学习的参数,它只会重新格式化数据。

展平像素后,网络会包括两个 tf.keras.layers.Dense 层的序列。它们是密集连接或全连接神经层。第一个 Dense 层有 128 个节点(或神经元)。第二个(也是最后一个)层会返回一个长度为 10 的 logits 数组。每个节点都包含一个得分,用来表示当前图像属于 10 个类中的哪一类。

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数 - 测量模型在训练期间的准确程度。你希望最小化此函数,以便将模型"引导"到正确的方向上。
  • 优化器 - 决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标 - 用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
python 复制代码
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

上面两行代码加入之后,我们就自定义好了模型的层以及模型的编译设置。下面我们就可以传入训练数据和训练标签来训练模型了

python 复制代码
model.fit(train_images, train_labels, epochs=10)

test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)

在模型训练期间,会显示损失和准确率指标。此模型在训练数据上的准确率达到了 0.91(或 91%)左右。

4. 模型预测

4.1 批量预测

将test_images里面的全部图片传入predict函数,可以得到predict函数对于所有图片的预测结果

python 复制代码
probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax()])

predictions = probability_model.predict(test_images)

print(np.argmax(predictions[0]))

4.2 单个图像预测

python 复制代码
img = (np.expand_dims(img,0))

predictions_single = probability_model.predict(img)

结果

完整代码

python 复制代码
import os

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.show()

train_images = train_images / 255.0

test_images = test_images / 255.0

plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])
plt.show()


model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
])

model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

model.fit(train_images, train_labels, epochs=10)

test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)

print('\nTest accuracy:', test_acc)

probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax()])

predictions = probability_model.predict(test_images)

print(np.argmax(predictions[0]))

print(test_labels[0])

def plot_image(i, predictions_array, true_label, img):
  true_label, img = true_label[i], img[i]
  plt.grid(False)
  plt.xticks([])
  plt.yticks([])
  plt.imshow(img, cmap=plt.cm.binary)
  predicted_label = np.argmax(predictions_array)
  if predicted_label == true_label:
    color = 'blue'
  else:
    color = 'red'
  plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
                                100*np.max(predictions_array),
                                class_names[true_label]),
                                color=color)


def plot_value_array(i, predictions_array, true_label):
  true_label = true_label[i]
  plt.grid(False)
  plt.xticks(range(10))
  plt.yticks([])
  thisplot = plt.bar(range(10), predictions_array, color="#777777")
  plt.ylim([0, 1])
  predicted_label = np.argmax(predictions_array)
  thisplot[predicted_label].set_color('red')
  thisplot[true_label].set_color('blue')

i = 0
plt.figure(figsize=(6,3))
plt.subplot(1,2,1)
plot_image(i, predictions[i], test_labels, test_images)
plt.subplot(1,2,2)
plot_value_array(i, predictions[i],  test_labels)
plt.show()

i = 12
plt.figure(figsize=(6,3))
plt.subplot(1,2,1)
plot_image(i, predictions[i], test_labels, test_images)
plt.subplot(1,2,2)
plot_value_array(i, predictions[i],  test_labels)
plt.show()

num_rows = 5
num_cols = 3
num_images = num_rows*num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows))
for i in range(num_images):
  plt.subplot(num_rows, 2*num_cols, 2*i+1)
  plot_image(i, predictions[i], test_labels, test_images)
  plt.subplot(num_rows, 2*num_cols, 2*i+2)
  plot_value_array(i, predictions[i], test_labels)
plt.tight_layout()
plt.show()
相关推荐
液态不合群几秒前
Vscode 远程切换Python虚拟环境
ide·vscode·python
JavaPub-rodert2 分钟前
# Python IDE的介绍和选择 --- 《跟着小王学Python》
开发语言·ide·python·编程·开发
冷小鱼28 分钟前
Pycharm 配置 Poetry
ide·python·pycharm·poetry
m0_5945263032 分钟前
探索 TraceBoard:统计你的键盘按键使用情况
python·计算机外设·开源软件·fastapi
bo_hai39 分钟前
决策树基本 CART Python手写实现
python·算法·决策树
Y_Hungry1 小时前
ZYX地图瓦片转mbtiles文件(Python)
开发语言·jvm·python
eaglelau1 小时前
CSV 文件
数据库·python
赛丽曼1 小时前
Python中的HTTP协议
python
神奇夜光杯1 小时前
Python酷库之旅-第三方库Pandas(218)
开发语言·人工智能·python·excel·pandas·标准库及第三方库·学习与成长
深度学习lover2 小时前
<项目代码>YOLOv8 瞳孔识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·瞳孔识别