用 TensorFlow和Keras 搭建CNN的经典案例解析~

使用 TensorFlow 和 Keras 搭建 卷积神经网络(CNN) 进行 MNIST 手写数字分类 的示例

一、代码示例

python 复制代码
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# 1. 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# 2. 预处理数据:归一化 & 维度扩展 (28x28 -> 28x28x1)
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = x_train[..., tf.newaxis]  # 添加通道维度
x_test = x_test[..., tf.newaxis]

# 3. 构建 CNN 模型
model = keras.Sequential([
    layers.Conv2D(32, kernel_size=(3, 3), activation="relu", input_shape=(28, 28, 1)),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Flatten(),
    layers.Dense(128, activation="relu"),
    layers.Dense(10, activation="softmax")  # 10 个类别(0-9)
])

# 4. 编译模型
model.compile(optimizer="adam",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

# 5. 训练模型
model.fit(x_train, y_train, epochs=5, batch_size=64, validation_split=0.1)

# 6. 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"测试集准确率: {test_acc:.4f}")

二、代码解析

步骤 作用
1. 加载数据集 keras.datasets.mnist 下载 MNIST 数据,分成训练集和测试集
2. 预处理数据 归一化(0-255 -> 0-1),并调整形状,使其符合 CNN 输入格式(28x28x1)
3. 构建 CNN 模型 - 第一层: 32 个 3x3 卷积核 + ReLU 激活 + 最大池化(2x2) - 第二层: 64 个 3x3 卷积核 + ReLU + 最大池化(2x2) - 展平(Flatten):将特征图变为一维向量 - 全连接层(Dense):128 个神经元 + ReLU - 输出层:10 个类别(0-9),使用 Softmax 归一化概率
4. 编译模型 使用 Adam 优化器,sparse_categorical_crossentropy 作为损失函数
5. 训练模型 用 90% 训练数据训练 5 轮,每次 64 个样本
6. 评估模型 在测试集上计算准确率

三、运行结果

如果一切顺利,你会看到类似这样的输出:

python 复制代码
Epoch 1/5
844/844 [==============================] - 5s 6ms/step - loss: 0.1403 - accuracy: 0.9583 - val_loss: 0.0478 - val_accuracy: 0.9853
Epoch 2/5
844/844 [==============================] - 4s 5ms/step - loss: 0.0474 - accuracy: 0.9853 - val_loss: 0.0371 - val_accuracy: 0.9883
...
测试集准确率: 0.9871

四、你可以尝试改进...

  1. 增加卷积层数:提高模型复杂度,可能提升识别率。

  2. 加入 Dropout 层 :防止过拟合,例如:

    python 复制代码
    layers.Dropout(0.5)
  3. 使用更多 Epoch 训练epochs=10 可能会提高最终表现。

  4. 换更强的优化器 :尝试 SGDAdamW

五、在这段代码中tensorflow起到了什么作用?

在这段代码中,TensorFlow 主要提供了 深度学习计算框架 ,负责神经网络的 构建、训练和推理。具体来说,它在以下几个方面发挥了作用:


1. 提供深度学习基础组件

TensorFlow 作为 深度学习框架 ,提供了 张量(Tensor)计算自动求导优化器神经网络层 等基本工具,使得搭建和训练神经网络变得简单。例如:

  • tensorflow.keras.datasets.mnist.load_data() 👉 加载数据
  • tensorflow.keras.layers.Conv2D() 👉 卷积层
  • tensorflow.keras.layers.Dense() 👉 全连接层
  • tensorflow.keras.optimizers.Adam() 👉 优化器
  • model.fit() 👉 训练
  • model.evaluate() 👉 测试

2. 负责数据的处理和转换

TensorFlow 允许使用 张量(Tensor) 进行高效的 矩阵计算,加速数据预处理:

python 复制代码
x_train = x_train.astype("float32") / 255.0  # 归一化
x_train = x_train[..., tf.newaxis]  # 添加通道维度 (28, 28) -> (28, 28, 1)

👉 这里用 TensorFlow 处理数据,使其适应 CNN 训练。


3. 提供神经网络构建工具

TensorFlow 的 Keras API 提供了丰富的神经网络层,例如:

python 复制代码
layers.Conv2D(32, kernel_size=(3, 3), activation="relu", input_shape=(28, 28, 1))

👉 这行代码使用 TensorFlow 定义一个 CNN 卷积层 ,会自动进行 权重初始化梯度计算


4. 训练和优化神经网络

TensorFlow 负责:

  1. 前向传播(Forward Pass):输入数据经过 CNN,计算损失值。

  2. 反向传播(Backpropagation):自动计算梯度,更新参数。

  3. 优化器(Optimizer) :比如 Adam 负责调整网络参数,使其更快收敛:

    python 复制代码
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
  4. 批量训练(Mini-batch Training)

    python 复制代码
    model.fit(x_train, y_train, epochs=5, batch_size=64, validation_split=0.1)

    这里 TensorFlow 负责 自动并行计算,加速训练过程。


5. 评估和推理

训练完成后,TensorFlow 还能 自动计算测试集的准确率

python 复制代码
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"测试集准确率: {test_acc:.4f}")

👉 这部分利用 TensorFlow 的计算图(Computational Graph)执行高效的矩阵运算。


总结

TensorFlow 主要起到 计算引擎 的作用,它提供:

  1. 数据处理能力(张量计算、归一化)
  2. 神经网络构建能力(卷积层、全连接层等)
  3. 自动求导和优化(反向传播、Adam 优化器)
  4. 高效计算能力(GPU 加速、大规模并行训练)
  5. 推理与评估(在测试集上运行模型)

没有 TensorFlow,手写这些梯度计算和矩阵运算会非常复杂......

相关推荐
一个会的不多的人7 分钟前
人工智能基础篇:概念性名词浅谈(第二十六讲)
人工智能·制造·量子计算·数字化转型
liu****8 分钟前
能源之星案例
人工智能·python·算法·机器学习·能源
摆烂咸鱼~12 分钟前
机器学习(13-2)
人工智能·机器学习
人工智能AI技术12 分钟前
从零复现马斯克开源X推荐算法
人工智能
Mixtral22 分钟前
4款录音转文字工具深度评测:钉钉闪记、Otter、随身鹿、讯飞听见...AI后处理能力谁更强?
人工智能
智驱力人工智能26 分钟前
守护矿山动脉 矿山皮带跑偏AI识别系统的工程化实践与价值 皮带偏离检测 皮带状态异常检测 多模态皮带偏离监测系统
大数据·人工智能·opencv·算法·安全·yolo·边缘计算
大模型真好玩26 分钟前
大模型训练全流程实战指南基础篇(二)——大模型文件结构解读与原理解析
人工智能·pytorch·langchain
周博洋K28 分钟前
Deepseek的新论文Engram
人工智能
e***985730 分钟前
2024技术趋势:AI领跑,云端边缘共舞
人工智能
智驱力人工智能35 分钟前
构筑安全红线 发电站旋转设备停机合规监测的视觉分析技术与应用 旋转设备停机检测 旋转设备异常检测 设备停机AI行为建模
人工智能·opencv·算法·安全·目标检测·计算机视觉·边缘计算