卷积神经网络(CNN)模型 CIFAR-10 数据集 例子

使用 TensorFlow 构建一个简单的卷积神经网络(CNN)模型,完成对 CIFAR-10 数据集的图像分类任务。

使用自动编码器作为特征提取器,先通过自动编码器对图像数据进行降维,将图像从高维映射到低维特征空间,然后将提取的特征传入到 CNN 进行分类。

对比在不使用自动编码器特征提取的情况下,直接使用 CNN 进行分类的模型性能。

python 复制代码
# 导入必要的库
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Conv2D, MaxPooling2D, Flatten

# 加载CIFAR - 10数据集
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# 预处理数据
x_train = x_train / 255.0
x_test = x_test / 255.0

# 构建自动编码器
input_img = Input(shape=(32, 32, 3))
# 编码器
encoded = Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
encoded = MaxPooling2D((2, 2), padding='same')(encoded)
# 解码器
decoded = Conv2D(16, (3, 3), activation='relu', padding='same')(encoded)
decoded = tf.keras.layers.UpSampling2D((2, 2))(decoded)
decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(decoded)
autoencoder = Model(input_img, decoded)

# 编译自动编码器
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

# 训练自动编码器
autoencoder.fit(x_train, x_train,
                epochs=10,
                batch_size=128,
                validation_data=(x_test, x_test))

# 获取编码器部分
encoder = Model(input_img, encoded)

# 使用编码器提取特征
x_train_encoded = encoder.predict(x_train)
x_test_encoded = encoder.predict(x_test)

# 构建CNN分类器
input_features = Input(shape=x_train_encoded.shape[1:])
flatten = Flatten()(input_features)
dense1 = Dense(128, activation='relu')(flatten)
output = Dense(10, activation='softmax')(dense1)
classifier = Model(input_features, output)

# 编译CNN分类器
classifier.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# 训练CNN分类器
classifier.fit(x_train_encoded, y_train,
               epochs=10,
               batch_size=128,
               validation_data=(x_test_encoded, y_test))

# 预测第1000个数据的类别(假设x_test是测试数据)
prediction = classifier.predict(x_test_encoded[999:1000])
predicted_class = tf.argmax(prediction, axis=1).numpy()[0]

Epoch 1/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 16ms/step - loss: 0.6075 - val_loss: 0.5594

Epoch 2/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - loss: 0.5575 - val_loss: 0.5567

Epoch 3/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5555 - val_loss: 0.5559

Epoch 4/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5547 - val_loss: 0.5551

Epoch 5/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5544 - val_loss: 0.5547

Epoch 6/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5538 - val_loss: 0.5534

Epoch 7/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - loss: 0.5520 - val_loss: 0.5527

Epoch 8/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5519 - val_loss: 0.5525

Epoch 9/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5517 - val_loss: 0.5523

Epoch 10/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5506 - val_loss: 0.5522

[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 821us/step

Epoch 1/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.2917 - loss: 1.9827 - val_accuracy: 0.4199 - val_loss: 1.6372

Epoch 2/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4255 - loss: 1.6262 - val_accuracy: 0.4405 - val_loss: 1.5675

Epoch 3/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4576 - loss: 1.5361 - val_accuracy: 0.4749 - val_loss: 1.5088

Epoch 4/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4828 - loss: 1.4765 - val_accuracy: 0.5027 - val_loss: 1.4356

Epoch 5/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4963 - loss: 1.4287 - val_accuracy: 0.5055 - val_loss: 1.4277

Epoch 6/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5055 - loss: 1.3981 - val_accuracy: 0.5067 - val_loss: 1.4073

Epoch 7/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5203 - loss: 1.3623 - val_accuracy: 0.5194 - val_loss: 1.3617

Epoch 8/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5338 - loss: 1.3217 - val_accuracy: 0.5246 - val_loss: 1.3555

Epoch 9/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5352 - loss: 1.3199 - val_accuracy: 0.5352 - val_loss: 1.3252

Epoch 10/10

[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5379 - loss: 1.3143 - val_accuracy: 0.5144 - val_loss: 1.3900

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step

python 复制代码
predicted_class

8

以下是使用 Python 语言结合 TensorFlow 库构建卷积神经网络(CNN)对 CIFAR-10 数据集进行图像分类,并获取第 1000 个数据预测类别的示例代码:

python 复制代码
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import numpy as np

# 加载CIFAR-10数据集
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

# 归一化像素值到0-1范围
train_images, test_images = train_images / 255.0, test_images / 255.0

# 构建简单的CNN模型
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10))  # 输出层,对应10个类别(CIFAR-10有10类)

# 编译模型
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# 训练模型
model.fit(train_images, train_labels, epochs=10, 
          validation_data=(test_images, test_labels))

# 对测试集进行预测
predictions = model.predict(test_images)
# 获取预测的类别(取概率最大的类别索引作为预测类别)
predicted_classes = np.argmax(predictions, axis=1)

# 获取第1000个数据的预测类别
print(predicted_classes[999])

Epoch 1/10

[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 6ms/step - accuracy: 0.3424 - loss: 1.7848 - val_accuracy: 0.5325 - val_loss: 1.2917

Epoch 2/10

[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.5668 - loss: 1.2166 - val_accuracy: 0.6098 - val_loss: 1.1129

Epoch 3/10

[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 6ms/step - accuracy: 0.6391 - loss: 1.0275 - val_accuracy: 0.6399 - val_loss: 1.0059

Epoch 4/10

[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 6ms/step - accuracy: 0.6676 - loss: 0.9352 - val_accuracy: 0.6690 - val_loss: 0.9239

Epoch 5/10

[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7065 - loss: 0.8364 - val_accuracy: 0.6935 - val_loss: 0.8983

Epoch 6/10

[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7272 - loss: 0.7781 - val_accuracy: 0.6914 - val_loss: 0.8845

Epoch 7/10

[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7507 - loss: 0.7195 - val_accuracy: 0.6843 - val_loss: 0.9197

Epoch 8/10

[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7609 - loss: 0.6772 - val_accuracy: 0.7024 - val_loss: 0.8741

Epoch 9/10

[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7773 - loss: 0.6337 - val_accuracy: 0.7055 - val_loss: 0.8704

Epoch 10/10

[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7888 - loss: 0.5993 - val_accuracy: 0.7074 - val_loss: 0.8714

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step

8

相关推荐
彬鸿科技1 分钟前
bhSDR Studio/Matlab入门指南(十一):AI数据集采集实验界面全解析
人工智能·matlab·软件定义无线电
云烟成雨TD15 分钟前
Spring AI Alibaba 1.x 系列【63】AI Agent 长期记忆
java·人工智能·spring
武雄(小星Ai)23 分钟前
2026年AI Agent框架选型指南:LangGraph vs CrewAI vs Claude SDK vs OpenAI SDK
人工智能·aigc·agent
狒狒热知识27 分钟前
2026年AI传播新闻软文营销发布当下178软文网领衔发展路径
大数据·人工智能
黑巧克力可减脂44 分钟前
以智录声,以技留韵:AI录音,解锁声音留存的古今新范式
人工智能
智慧景区与市集主理人1 小时前
巨有科技景区智能导览告别传统讲解,打造沉浸式智慧游览体验
人工智能·科技·语音识别
keyanbanyungong1 小时前
告别杂乱病历!临床科研AI工具实测
人工智能·深度学习
出海小龙1 小时前
B2B 跟 B2C 的联盟营销有何根本区别?以及分别如何真正推动增长?
大数据·人工智能
xcLeigh1 小时前
聚合AI工具KULAAI:GPT、Claude、Gemini、DeepSeek热门模型一键使用
人工智能·gpt·claude·gemini·deepseek·聚合ai·kulaai
EnCi Zheng1 小时前
09aaac-RMSNorm是什么?
人工智能