用Python做有趣的AI项目1:用 TensorFlow 实现图像分类(识别猫、狗、汽车等)

项目目标

通过构建卷积神经网络(CNN),让模型学会识别图片中是什么物体。我们将使用 CIFAR-10 数据集,它包含 10 类:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。

🛠️ 开发环境与依赖

安装依赖(用命令行运行):

bash 复制代码
pip install tensorflow matplotlib numpy

推荐使用 Jupyter Notebook,方便边学边运行,也可以用 VS Code、PyCharm 等编辑器。

第一步:导入库

bash 复制代码
#python
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import numpy as np

这些库的作用:

tensorflow:用于构建和训练神经网络。

matplotlib:用于可视化图片和训练过程。

numpy:用于处理数组和数据操作。

第二步:加载和预处理数据

bash 复制代码
#python
#加载 CIFAR-10 数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()


#归一化处理:将像素值从 0~255 映射到 0~1,提高模型训练效果
x_train = x_train / 255.0
x_test = x_test / 255.0

# CIFAR-10 类别名
class_names = ['飞机', '汽车', '鸟', '猫', '鹿', '狗', '青蛙', '马', '船', '卡车']

第三步:查看数据

bash 复制代码
#python

plt.figure(figsize=(10, 2))
for i in range(10):
    plt.subplot(1, 10, i + 1)
    plt.xticks([]); plt.yticks([])
    plt.imshow(x_train[i])
    plt.xlabel(class_names[y_train[i][0]])
plt.show()

这一部分可以帮你初步理解数据的样子和类别。

第四步:构建 CNN 模型

bash 复制代码
#python

model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
    layers.MaxPooling2D((2, 2)),

    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),

    layers.Conv2D(64, (3, 3), activation='relu'),
    
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(10)  # 输出层:10个类
])

model.summary()  # 查看模型结构

📌 注解:

Conv2D 是卷积层,能提取图像的边缘、角点等特征。

MaxPooling2D 是池化层,用于降维。

Flatten 把多维数据展平成一维。

Dense 是全连接层,用于分类决策。

第五步:编译和训练模型

bash 复制代码
#python

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(x_train, y_train, epochs=10, 
                    validation_data=(x_test, y_test))

💡 小贴士:

adam 是一种优化器,适合初学者使用。

SparseCategoricalCrossentropy 适合标签是整数而不是 one-hot 的分类任务。

第六步:训练过程可视化

bash 复制代码
#python

plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.xlabel('Epoch'); plt.ylabel('Accuracy')
plt.legend(); plt.grid()
plt.show()

这个图能直观看到模型是否在过拟合或欠拟合。

第七步:评估模型

bash 复制代码
#python
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f'测试准确率:{test_acc:.2f}')

第八步:预测和展示结果

bash 复制代码
#python

probability_model = models.Sequential([model, layers.Softmax()])
predictions = probability_model.predict(x_test)


#展示前5张图片及其预测结果
for i in range(5):
    plt.imshow(x_test[i])
    plt.title(f"预测:{class_names[np.argmax(predictions[i])]} / 实际:{class_names[y_test[i][0]]}")
    plt.axis('off')
    plt.show()

第九步:保存与加载模型

bash 复制代码
#python
#保存模型
model.save('cifar10_cnn_model.h5')

#加载模型
new_model = tf.keras.models.load_model('cifar10_cnn_model.h5')

🔄 扩展建议

训练猫狗二分类模型(用 Kaggle 的数据集)。

加 BatchNormalization、Dropout 提升泛化能力。

使用更强的预训练模型如 MobileNet、ResNet。

相关推荐
Codebee6 小时前
能力中心 (Agent SkillCenter):开启AI技能管理新时代
人工智能
聆风吟º7 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
uesowys7 小时前
Apache Spark算法开发指导-One-vs-Rest classifier
人工智能·算法·spark
AI_56787 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子7 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
ValhallaCoder7 小时前
hot100-二叉树I
数据结构·python·算法·二叉树
智驱力人工智能8 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
qq_160144878 小时前
亲测!2026年零基础学AI的入门干货,新手照做就能上手
人工智能
Howie Zphile8 小时前
全面预算管理难以落地的核心真相:“完美模型幻觉”的认知误区
人工智能·全面预算
人工不智能5778 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert