TensorFlow学习系列02 | 实现彩色图片分类

一、前置知识

1、CNN知识扫盲

1.1. 卷积(Convolution)

  • 怎么做 ?用一个叫 卷积核(Kernel)滤波器(Filter) 的小矩阵(如 3×3)在图像上滑动。
  • 每滑一步 ,和底下像素做"点乘 + 求和",得到一个新值 → 构成 特征图(Feature Map)
  • 不同卷积核 = 不同"滤镜"
    • 一个核专门检测竖直边缘
    • 一个核专门检测45度斜线
    • 一个核专门检测亮斑

🌟 关键优势

  • 局部连接 :每个神经元只看图像一小块(感受野),符合生物视觉
  • 参数共享 :同一个卷积核在整个图像上复用 → 参数量暴减!(普通神经网络需百万参数,CNN 只需几千)

1.2. 池化(Pooling)

  • 作用:降低特征图尺寸,减少计算量,同时让网络对微小位移"不敏感"(鲁棒性)。
  • 最常用:最大池化(Max Pooling)
    • 把特征图切成 2×2 小块
    • 每块只保留最大值
    • 结果:尺寸减半,但最显著的特征(如最强边缘)被保留

🎯 举例:即使猫头稍微偏左/偏右,最大池化后"耳朵区域"依然高亮 → 不影响识别

1.3. 堆叠与全连接(Stacking + FC)

  • 浅层卷积:检测简单模式(边缘、角点)
  • 中层卷积:组合简单模式 → 复杂模式(眼睛、耳朵)
  • 深层卷积:组合部件 → 完整物体(猫、狗、车)
  • 最后接全连接层(FC):把高级特征"拍平",送入分类器输出概率

🔁 整个过程: 原始图像 → 卷积 → 激活(ReLU)→ 池化 → 卷积 → ... → 全连接 → 分类结果

2、图片分类流程

二、代码实现

1、准备工作

1.1.设置GPU

复制代码
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
    
print(gpus)

2026-01-14 13:03:21.672196: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

1.2.导入数据

复制代码
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 947s 6us/step

1.3.归一化

复制代码
# 将像素的值标准化至0到1的区间内。(对于灰度图片来说,每个像素最大值是255,每个像素最小值是0,也就是直接除以255就可以完成归一化。)
train_images, test_images = train_images / 255.0, test_images / 255.0
# 查看数据维数信息
train_images.shape,test_images.shape,train_labels.shape,test_labels.shape

((50000, 32, 32, 3), (10000, 32, 32, 3), (50000, 1), (10000, 1))

1.4.可视化图片

复制代码
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']

plt.figure(figsize=(20,10))
for i in range(20):
    plt.subplot(5,10,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i][0]])
plt.show()

2、训练模型

2.1.构建CNN网络

复制代码
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)), #卷积层1,卷积核3*3
    layers.MaxPooling2D((2, 2)),                   #池化层1,2*2采样
    layers.Conv2D(64, (3, 3), activation='relu'),  #卷积层2,卷积核3*3
    layers.MaxPooling2D((2, 2)),                   #池化层2,2*2采样
    layers.Conv2D(64, (3, 3), activation='relu'),  #卷积层3,卷积核3*3
    
    layers.Flatten(),                      #Flatten层,连接卷积层与全连接层
    layers.Dense(64, activation='relu'),   #全连接层,特征进一步提取
    layers.Dense(10)                       #输出层,输出预期结果
])

model.summary()  # 打印网络结构

2026-01-14 13:25:51.622783: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-01-14 13:25:52.805568: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 10099 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3080 Ti, pci bus id: 0000:5e:00.0, compute capability: 8.6


Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 30, 30, 32)        896       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 15, 15, 32)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 13, 13, 64)        18496     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 6, 6, 64)         0         
 2D)                                                             
                                                                 
 conv2d_2 (Conv2D)           (None, 4, 4, 64)          36928     
                                                                 
 flatten (Flatten)           (None, 1024)              0         
                                                                 
 dense (Dense)               (None, 64)                65600     
                                                                 
 dense_1 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 122,570
Trainable params: 122,570
Non-trainable params: 0
_________________________________________________________________

2.2.编译模型

复制代码
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

2.3.训练模型

复制代码
history = model.fit(train_images, train_labels, epochs=10, 
                    validation_data=(test_images, test_labels))

Epoch 1/10
2026-01-14 13:26:28.595628: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8101
2026-01-14 13:26:31.971672: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
1563/1563 [==============================] - 15s 6ms/step - loss: 1.5071 - accuracy: 0.4519 - val_loss: 1.2221 - val_accuracy: 0.5586
Epoch 2/10
1563/1563 [==============================] - 9s 6ms/step - loss: 1.1403 - accuracy: 0.5940 - val_loss: 1.0596 - val_accuracy: 0.6186
Epoch 3/10
1563/1563 [==============================] - 9s 6ms/step - loss: 0.9880 - accuracy: 0.6537 - val_loss: 0.9577 - val_accuracy: 0.6648
Epoch 4/10
1563/1563 [==============================] - 9s 6ms/step - loss: 0.8843 - accuracy: 0.6883 - val_loss: 0.9154 - val_accuracy: 0.6849
Epoch 5/10
1563/1563 [==============================] - 9s 6ms/step - loss: 0.8077 - accuracy: 0.7143 - val_loss: 0.8642 - val_accuracy: 0.6990
Epoch 6/10
1563/1563 [==============================] - 9s 6ms/step - loss: 0.7459 - accuracy: 0.7385 - val_loss: 0.9030 - val_accuracy: 0.6912
Epoch 7/10
1563/1563 [==============================] - 9s 6ms/step - loss: 0.6906 - accuracy: 0.7584 - val_loss: 0.8798 - val_accuracy: 0.7117
Epoch 8/10
1563/1563 [==============================] - 9s 6ms/step - loss: 0.6560 - accuracy: 0.7673 - val_loss: 0.8750 - val_accuracy: 0.7042
Epoch 9/10
1563/1563 [==============================] - 9s 6ms/step - loss: 0.6054 - accuracy: 0.7874 - val_loss: 0.9065 - val_accuracy: 0.7002
Epoch 10/10
1563/1563 [==============================] - 9s 6ms/step - loss: 0.5748 - accuracy: 0.7987 - val_loss: 0.9055 - val_accuracy: 0.7049

3、模型预测

复制代码
plt.imshow(test_images[1])

<matplotlib.image.AxesImage at 0x7f3e28442b80>
复制代码
import numpy as np

pre = model.predict(test_images)
print(class_names[np.argmax(pre[1])])

313/313 [==============================] - 1s 2ms/step
ship

4、模型评估

复制代码
import matplotlib.pyplot as plt

plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show()

test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)
复制代码
313/313 - 1s - loss: 0.9055 - accuracy: 0.7049 - 874ms/epoch - 3ms/step

print(test_acc)

0.7049000263214111
相关推荐
独行soc2 小时前
2026年渗透测试面试题总结-2(题目+回答)
android·java·网络·python·安全·web安全·渗透测试
写代码的【黑咖啡】2 小时前
Python中的BeautifulSoup:强大的HTML/XML解析库
python·html·beautifulsoup
李守聪2 小时前
小程序定制,我的实践复盘分享
python
小二·3 小时前
Python Web 开发进阶实战:无障碍深度集成 —— 构建真正包容的 Flask + Vue 应用
前端·python·flask
中草药z11 小时前
【嵌入模型】概念、应用与两大 AI 开源社区(Hugging Face / 魔塔)
人工智能·算法·机器学习·数据集·向量·嵌入模型
web3.088899912 小时前
微店商品详情API实用
python·json·时序数据库
知乎的哥廷根数学学派12 小时前
基于数据驱动的自适应正交小波基优化算法(Python)
开发语言·网络·人工智能·pytorch·python·深度学习·算法
sunfove12 小时前
将 Python 仿真工具部署并嵌入个人博客
开发语言·数据库·python
Learner12 小时前
Python类
开发语言·python