使用TensorFlow实现简化版 GoogLeNet 模型进行 MNIST 图像分类

在本文中,我们将使用 TensorFlow 和 Keras 实现一个简化版的 GoogLeNet 模型来进行 MNIST 数据集的手写数字分类任务。GoogLeNet 采用了 Inception 模块,这使得它在处理图像数据时能更高效地提取特征。本教程将详细介绍如何在 MNIST 数据集上训练和测试这个模型。

项目结构

我们的代码将分为两个部分:

  1. 训练部分 (train.py): 包含模型定义、数据加载、模型训练等。
  2. 测试部分 (test.py): 用于加载训练好的模型,并在测试集上评估其性能。

训练部分:train.py

1. 数据加载与预处理

首先,我们需要加载 MNIST 数据集并进行预处理。预处理包括调整图像形状、归一化以及 One-Hot 编码标签。

python 复制代码
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

def load_and_preprocess_data():
    # 加载 MNIST 数据集
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()

    # 数据预处理:将图像形状调整为 [28, 28, 1],并归一化到 [0, 1] 范围
    train_images = train_images.reshape((train_images.shape[0], 28, 28, 1)) / 255.0
    test_images = test_images.reshape((test_images.shape[0], 28, 28, 1)) / 255.0

    # One-Hot 编码标签
    train_labels = to_categorical(train_labels, 10)
    test_labels = to_categorical(test_labels, 10)

    return train_images, train_labels, test_images, test_labels

2. 创建简化版 GoogLeNet 模型

接下来,我们定义一个简化版的 GoogLeNet 模型。该模型包括卷积层、Inception 模块和全连接层。

python 复制代码
from tensorflow.keras import layers, models

def googlenet(input_shape=(28, 28, 1), num_classes=10):
    inputs = layers.Input(shape=input_shape)

    # 第一卷积层 + 池化层
    x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(inputs)
    x = layers.MaxPooling2D((2, 2))(x)

    # 第二卷积层 + 池化层
    x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling2D((2, 2))(x)

    # 第三卷积层 + 池化层
    x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling2D((2, 2))(x)

    # Inception 模块
    inception1 = layers.Conv2D(64, (1, 1), activation='relu', padding='same')(x)
    inception2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    inception3 = layers.Conv2D(32, (5, 5), activation='relu', padding='same')(x)

    # 拼接 Inception 模块的输出
    x = layers.concatenate([inception1, inception2, inception3], axis=-1)

    # 全局平均池化层
    x = layers.GlobalAveragePooling2D()(x)

    # 全连接层
    x = layers.Dense(1024, activation='relu')(x)
    x = layers.Dropout(0.5)(x)  # Dropout 层减少过拟合
    outputs = layers.Dense(num_classes, activation='softmax')(x)  # 输出层,使用 softmax 激活函数进行多分类

    model = models.Model(inputs=inputs, outputs=outputs)
    return model

3. 模型训练

定义好模型之后,我们使用 Adam 优化器和交叉熵损失函数来训练模型,并保存训练好的模型。

python 复制代码
def train_model(model, train_images, train_labels, epochs=5, batch_size=64):
    # 训练模型
    history = model.fit(train_images, train_labels,
                        epochs=epochs,
                        batch_size=batch_size)

    return history

def save_model(model, filename='googlenet_mnist.h5'):
    model.save(filename)
    print(f"Model saved to {filename}")

4. 主程序

最后,在主程序中,我们加载数据、创建模型并开始训练。

python 复制代码
def main():
    train_images, train_labels, test_images, test_labels = load_and_preprocess_data()

    model = googlenet(input_shape=(28, 28, 1), num_classes=10)

    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    train_model(model, train_images, train_labels, epochs=5, batch_size=64)

    save_model(model)

if __name__ == '__main__':
    main()

测试部分:test.py

1. 加载训练好的模型

在测试部分,我们将加载训练好的模型,并在测试集上进行评估。

python 复制代码
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

def load_and_preprocess_data():
    (_, _), (test_images, test_labels) = mnist.load_data()
    test_images = test_images.reshape((test_images.shape[0], 28, 28, 1)) / 255.0
    test_labels = to_categorical(test_labels, 10)
    return test_images, test_labels

def load_model(model_path='googlenet_mnist.h5'):
    model = tf.keras.models.load_model(model_path)
    return model

2. 评估模型

我们通过 evaluate 方法评估模型的损失和准确率。

python 复制代码
def evaluate_model(model, test_images, test_labels):
    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print(f"Test accuracy: {test_acc * 100:.2f}%")
    return test_loss, test_acc

3. 显示预测结果

使用 Matplotlib 可视化前几张图片的预测结果。

python 复制代码
import matplotlib.pyplot as plt

def display_predictions(model, test_images, test_labels, num_images=6):
    predictions = model.predict(test_images[:num_images])

    fig, axes = plt.subplots(2, 3, figsize=(10, 6))
    axes = axes.flatten()

    for i in range(num_images):
        ax = axes[i]
        ax.imshow(test_images[i].reshape(28, 28), cmap='gray')
        ax.set_title(f"Pred: {tf.argmax(predictions[i]).numpy()} \n True: {tf.argmax(test_labels[i]).numpy()}")
        ax.axis('off')

    plt.tight_layout()
    plt.show()

4. 主程序

在主程序中,我们加载模型,评估其性能,并显示预测结果。

python 复制代码
def main():
    test_images, test_labels = load_and_preprocess_data()
    model = load_model('googlenet_mnist.h5')

    evaluate_model(model, test_images, test_labels)
    display_predictions(model, test_images, test_labels)

if __name__ == '__main__':
    main()

总结

本文介绍了如何使用 TensorFlow 实现简化版 GoogLeNet,并在 MNIST 数据集上进行训练和测试。我们将代码分为训练和测试两部分,分别处理数据预处理、模型训练与评估、结果展示等工作。

通过使用 GoogLeNet 进行图像分类,我们不仅能够提高分类性能,还能了解 Inception 模块在图像处理中的强大能力。希望这篇博客能够帮助你更好地理解深度学习模型的训练与测试过程。

完整项目:GoogLeNet-TensorFlow: 使用TensorFlow实现简化版 GoogLeNet 进行 MNIST 图像分类https://gitee.com/qxdlll/goog-le-net-tensor-flow

qxd-ljy/GoogLeNet-TensorFlow: 使用 TensorFlow实现简化版 GoogLeNet 进行 MNIST 图像分类https://github.com/qxd-ljy/GoogLeNet-TensorFlow

相关推荐
盼小辉丶9 小时前
TensorFlow深度学习实战——情感分析模型
深度学习·神经网络·tensorflow
paradoxjun16 小时前
落地级分类模型训练框架搭建(1):resnet18/50和mobilenetv2在CIFAR10上测试结果
人工智能·深度学习·算法·计算机视觉·分类
Hugh&1 天前
(开源)基于Django+Yolov8+Tensorflow的智能鸟类识别平台
python·yolo·django·tensorflow
诸神缄默不语3 天前
用sklearn运行分类模型,选择AUC最高的模型保存模型权重并绘制AUCROC曲线(以逻辑回归、随机森林、梯度提升、MLP为例)
分类·逻辑回归·sklearn
jieshenai3 天前
企业分类相似度筛选实战:基于规则与向量方法的对比分析
人工智能·自然语言处理·分类
一只码代码的章鱼3 天前
分类问题(二元,多元逻辑回归,费歇尔判别分析)spss实操
大数据·数学建模·分类·数据挖掘·逻辑回归
丶21363 天前
【分类】【损失函数】处理类别不平衡:CEFL 和 CEFL2 损失函数的实现与应用
人工智能·分类·损失函数
机器学习之心3 天前
SCSSA-BiLSTM基于改进麻雀搜索算法优化双向长短期记忆网络多特征分类预测Matlab实现
matlab·分类·数据挖掘
Zda天天爱打卡4 天前
【机器学习实战入门】基于深度学习的乳腺癌分类
大数据·人工智能·深度学习·机器学习·分类·数据挖掘
听风吹等浪起5 天前
第9章:基于Vision Transformer(ViT)网络实现的迁移学习图像分类任务:早期秧苗图像识别
分类·transformer·迁移学习