Python中的卷积神经网络(CNN)入门

卷积神经网络(Convolutional Neural Networks, CNN)是一类特别适用于处理图像数据的深度学习模型。在Python中,我们可以使用流行的深度学习库TensorFlow和Keras来创建和训练一个CNN模型。在本文中,我们将介绍如何使用Keras创建一个简单的CNN模型,并用它对手写数字进行分类。

1. 准备数据集

我们将使用MNIST数据集,这是一个常用的手写数字数据集。Keras库提供了一个方便的函数来加载MNIST数据集。数据集包含60000个训练样本和10000个测试样本,每个样本是一个28x28的灰度图像。

复制代码
python
复制代码
from tensorflow.keras.datasets import mnist

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

接下来,我们需要对数据进行预处理。我们将图像数据归一化到0-1之间,并将标签数据进行one-hot编码:

复制代码
python
复制代码
train_images = train_images.reshape((60000, 28, 28, 1))
train_images = train_images.astype("float32") / 255

test_images = test_images.reshape((10000, 28, 28, 1))
test_images = test_images.astype("float32") / 255

from tensorflow.keras.utils import to_categorical

train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
2. 创建CNN模型

我们将使用Keras创建一个简单的CNN模型,包括卷积层、池化层、全连接层等。模型的结构如下:

  • 卷积层:使用32个3x3的卷积核,激活函数为ReLU;

  • 池化层:使用2x2的最大池化;

  • 卷积层:使用64个3x3的卷积核,激活函数为ReLU;

  • 池化层:使用2x2的最大池化;

  • 全连接层:包含128个神经元,激活函数为ReLU;

  • 输出层:包含10个神经元,激活函数为softmax。

    python
    复制代码
    from tensorflow.keras import layers
    from tensorflow.keras import models

    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation="relu", input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation="relu"))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Flatten())
    model.add(layers.Dense(128, activation="relu"))
    model.add(layers.Dense(10, activation="softmax"))

3. 训练CNN模型

我们将使用训练数据集训练CNN模型,并在测试数据集上评估模型性能。我们将使用交叉熵损失函数和Adam优化器,训练10个epoch。

复制代码
python
复制代码
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

model.fit(train_images, train_labels, epochs=10, batch_size=64)

test_loss, test_acc = model.evaluate(test_images, test_labels)
print("Test accuracy: {:.2f}%".format(test_acc * 100))
4. 使用CNN模型进行预测

训练好CNN模型后,我们可以用它对新的图像数据进行预测。下面我们将随机选择一个测试图像,并使用模型进行预测。

复制代码
python
复制代码
import numpy as np
import matplotlib.pyplot as plt

index = np.random.randint(0, len(test_images))
image = test_images[index]

plt.imshow(image.reshape(28, 28), cmap="gray")
plt.show()

predictions = model.predict(np.expand_dims(image, axis=0))
predicted_label = np.argmax(predictions)

print("Predicted label:", predicted_label)

上述代码将展示一个随机选择的手写数字图像,并输出模型预测的结果。

这就是如何在Python中使用Keras创建和训练一个简单的CNN模型进行手写数字分类。在实际应用中,可以根据需求调整CNN模型的结构和参数以优化性能。

相关推荐
程序员敲代码吗31 分钟前
用Python生成艺术:分形与算法绘图
jvm·数据库·python
琅琊榜首202033 分钟前
AI生成脑洞付费短篇小说:从灵感触发到内容落地
大数据·人工智能
EndingCoder34 分钟前
案例研究:从 JavaScript 迁移到 TypeScript
开发语言·前端·javascript·性能优化·typescript
Yyyyy123jsjs34 分钟前
如何通过免费的外汇API轻松获取实时汇率数据
开发语言·python
imbackneverdie41 分钟前
近年来,我一直在用的科研工具
人工智能·自然语言处理·aigc·论文·ai写作·学术·ai工具
白露与泡影42 分钟前
2026版Java架构师面试题及答案整理汇总
java·开发语言
喵手43 分钟前
Python爬虫实战:GovDataMiner —— 开放数据门户数据集元数据采集器(附 CSV 导出)!
爬虫·python·爬虫实战·python爬虫工程化实战·零基础python爬虫教学·open data·开放数据门户数据集列表
历程里程碑1 小时前
滑动窗口---- 无重复字符的最长子串
java·数据结构·c++·python·算法·leetcode·django
roman_日积跬步-终至千里1 小时前
【计算机视觉-作业1】从图像到向量:kNN数据预处理完整流程
人工智能·计算机视觉