显微镜图像处理(分类)-训练一个图像分类器

前言

本篇文章转载于:www.tensorflow.org/tutorials/k...

想了解更多细节可以前往tensorflow官网。

存在一个图像分类任务,我们会预先准备好四批数据:1. 训练图像数据,2. 训练图像标签数据(按照下标和训练数据对应),3. 测试图像数据,4. 预测图像标签数据(按照下标和测试数据对应)。

首先,根据上面展示出来的链接,我们知道google官网为我们准备了一批衣物数据,让我们用作图像分类,并且给我们准备好了上述说的四类数据,现在我们根据现有的四类数据,训练一个能够预测衣物种类的预测器

1. 准备开发环境

我们使用python+tensorflow进行模型训练,python版本为3.11,在这个python版本中,我们不用再去安装tensorflow-gpu版本,我们只需要安装tensorflow,并且tensorflow内部自带了keras

python 复制代码
import os

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

在上面的代码案例中,我们能够从keras自带的数据集合中提取出:

  • train_images(训练数据):用于训练模型的数据
  • train_labels(训练数据标签): 用来标注训练数据类别的标签
  • test_images(测试数据): 用于测试的数据
  • test_labels(对照标签): 训练出来的模型准不准,就使用对照数据进行对照

2. 数据预处理

在训练模型之前,必须对数据进行预处理。如果您检查训练集中的第一个图像,您会看到像素值处于 0 到 255 之间,这会我们就需要将这些图像的像素值规格化为0~1之间的浮点数

python 复制代码
train_images = train_images / 255.0

test_images = test_images / 255.0

并且keras给到我们的标签数据也都是0~9之间的整数,所以我们要对这些整数进行处理,让他变成人类可以理解的文字:

python 复制代码
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

3. 训练模型

以下是tensorflow中keras文档的原话:

构建神经网络需要先配置模型的层,然后再编译模型。

神经网络的基本组成部分是。层会从向其馈送的数据中提取表示形式。希望这些表示形式有助于解决手头上的问题。

大多数深度学习都包括将简单的层链接在一起。大多数层(如 tf.keras.layers.Dense)都具有在训练期间才会学习的参数。

python 复制代码
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
])

该网络的第一层 tf.keras.layers.Flatten 将图像格式从二维数组(28 x 28 像素)转换成一维数组(28 x 28 = 784 像素)。将该层视为图像中未堆叠的像素行并将其排列起来。该层没有要学习的参数,它只会重新格式化数据。

展平像素后,网络会包括两个 tf.keras.layers.Dense 层的序列。它们是密集连接或全连接神经层。第一个 Dense 层有 128 个节点(或神经元)。第二个(也是最后一个)层会返回一个长度为 10 的 logits 数组。每个节点都包含一个得分,用来表示当前图像属于 10 个类中的哪一类。

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数 - 测量模型在训练期间的准确程度。你希望最小化此函数,以便将模型"引导"到正确的方向上。
  • 优化器 - 决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标 - 用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
python 复制代码
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

上面两行代码加入之后,我们就自定义好了模型的层以及模型的编译设置。下面我们就可以传入训练数据和训练标签来训练模型了

python 复制代码
model.fit(train_images, train_labels, epochs=10)

test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)

在模型训练期间,会显示损失和准确率指标。此模型在训练数据上的准确率达到了 0.91(或 91%)左右。

4. 模型预测

4.1 批量预测

将test_images里面的全部图片传入predict函数,可以得到predict函数对于所有图片的预测结果

python 复制代码
probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax()])

predictions = probability_model.predict(test_images)

print(np.argmax(predictions[0]))

4.2 单个图像预测

python 复制代码
img = (np.expand_dims(img,0))

predictions_single = probability_model.predict(img)

结果

完整代码

python 复制代码
import os

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.show()

train_images = train_images / 255.0

test_images = test_images / 255.0

plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])
plt.show()


model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
])

model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

model.fit(train_images, train_labels, epochs=10)

test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)

print('\nTest accuracy:', test_acc)

probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax()])

predictions = probability_model.predict(test_images)

print(np.argmax(predictions[0]))

print(test_labels[0])

def plot_image(i, predictions_array, true_label, img):
  true_label, img = true_label[i], img[i]
  plt.grid(False)
  plt.xticks([])
  plt.yticks([])
  plt.imshow(img, cmap=plt.cm.binary)
  predicted_label = np.argmax(predictions_array)
  if predicted_label == true_label:
    color = 'blue'
  else:
    color = 'red'
  plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
                                100*np.max(predictions_array),
                                class_names[true_label]),
                                color=color)


def plot_value_array(i, predictions_array, true_label):
  true_label = true_label[i]
  plt.grid(False)
  plt.xticks(range(10))
  plt.yticks([])
  thisplot = plt.bar(range(10), predictions_array, color="#777777")
  plt.ylim([0, 1])
  predicted_label = np.argmax(predictions_array)
  thisplot[predicted_label].set_color('red')
  thisplot[true_label].set_color('blue')

i = 0
plt.figure(figsize=(6,3))
plt.subplot(1,2,1)
plot_image(i, predictions[i], test_labels, test_images)
plt.subplot(1,2,2)
plot_value_array(i, predictions[i],  test_labels)
plt.show()

i = 12
plt.figure(figsize=(6,3))
plt.subplot(1,2,1)
plot_image(i, predictions[i], test_labels, test_images)
plt.subplot(1,2,2)
plot_value_array(i, predictions[i],  test_labels)
plt.show()

num_rows = 5
num_cols = 3
num_images = num_rows*num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows))
for i in range(num_images):
  plt.subplot(num_rows, 2*num_cols, 2*i+1)
  plot_image(i, predictions[i], test_labels, test_images)
  plt.subplot(num_rows, 2*num_cols, 2*i+2)
  plot_value_array(i, predictions[i], test_labels)
plt.tight_layout()
plt.show()
相关推荐
数据智能老司机42 分钟前
精通 Python 设计模式——并发与异步模式
python·设计模式·编程语言
数据智能老司机42 分钟前
精通 Python 设计模式——测试模式
python·设计模式·架构
数据智能老司机42 分钟前
精通 Python 设计模式——性能模式
python·设计模式·架构
c8i1 小时前
drf初步梳理
python·django
每日AI新事件1 小时前
python的异步函数
python
这里有鱼汤2 小时前
miniQMT下载历史行情数据太慢怎么办?一招提速10倍!
前端·python
databook11 小时前
Manim实现脉冲闪烁特效
后端·python·动效
程序设计实验室12 小时前
2025年了,在 Django 之外,Python Web 框架还能怎么选?
python
倔强青铜三13 小时前
苦练Python第46天:文件写入与上下文管理器
人工智能·python·面试
用户25191624271116 小时前
Python之语言特点
python