CNN实现fashion_mnist数据集分类(tensorflow)

1、查看tensorflow版本

python 复制代码
import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

2、加载fashion_mnist数据与预处理

python 复制代码
import numpy as np
(train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.fashion_mnist.load_data()
# print(train_images.shape) # (60000, 28, 28)
# print(train_labels.shape) # (60000,)
# print(test_images.shape) # (10000, 28, 28)
# print(test_labels.shape) # (10000,)
train_images = np.expand_dims(train_images, -1)
# print(train_images.shape) # (个数, hight, width,channels)=(60000, 28, 28, 1)

3、CNN模型构建

python 复制代码
from keras.layers import Input,Dense,Dropout
from keras.layers import Conv2D,MaxPool2D,GlobalAvgPool2D

model = tf.keras.Sequential()
model.add(Input(shape=(28,28,1)))  # train_images.shape[1:]
model.add(Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same')) # 增加filter个数,增加模型拟合能力
model.add(Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'))
model.add(MaxPool2D())  # 默认2*2. 池化层扩大视野
model.add(Dropout(0.2)) # 防止过拟合
model.add(Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'))
model.add(Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'))
model.add(MaxPool2D())  # 默认2*2
model.add(Dropout(0.2)) # 防止过拟合
model.add(Conv2D(filters=256,kernel_size=(3,3),activation='relu'))
model.add(GlobalAvgPool2D()) # 全局平均池化
model.add(Dense(10,activation='softmax'))
model.summary()

4、模型配置与训练

python 复制代码
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['acc'])
              
H = model.fit(x=train_images,
              y=train_labels,
              validation_split=0.2,
              # validation_data=(X_test,y_test),
              epochs=10,
              batch_size=64,
              verbose=1)

5、损失函数和准确率分析

根据损失函数和准确率,判断模型是否过拟合或者欠拟合,不断调整网络结构,使得模型最优。

python 复制代码
import matplotlib.pyplot as plt
fig = plt.gcf()
fig.set_size_inches(12,4)
plt.subplot(1,2,1)
plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()
plt.title('loss')

plt.subplot(1,2,2)
plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()
plt.title('acc')
相关推荐
β添砖java3 小时前
深度学习(20)深度卷积神经网络AlexNet
人工智能·深度学习·cnn
2zcode9 小时前
基于MATLAB深度学习的非酒精性脂肪性肝病超声图像分类研究( GUI界面+数据集+训练代码)
深度学习·matlab·分类
艾派森12 小时前
深度学习实战-基于EfficientNetB5的家禽鸡病图像分类识别模型
人工智能·python·深度学习·神经网络·分类
ZGi.ai1 天前
智能客服系统设计:从工单分类到自动派单的工程实现
大数据·人工智能·分类
江南十四行1 天前
CNN进阶:Batch Normalization与Layer Normalization对比 + 网络结构设计与PyTorch实现
pytorch·cnn·batch
2zcode1 天前
基于低光照增强与轻量型CNN道路实时识别算法研究(UI界面+数据集+训练代码)
人工智能·算法·cnn·低光照增强·自动驾驶技术
xrui581 天前
2026实测:让 Gemini 3.1镜像站抓取邮箱并智能分类,GTD 效率提升 3 倍
人工智能·分类·数据挖掘
2zcode1 天前
基于MATLAB的深度学习工业表面缺陷多分类检测系统设计与实现(GUI界面+数据集+训练代码)
深度学习·matlab·分类
2zcode1 天前
基于深度学习的糖尿病眼底图像分类识别系统(含UI界面+多模型对比+数据集+训练代码)
人工智能·深度学习·分类
katheta2 天前
时间序列模型总体分类
人工智能·分类·数据挖掘·时间序列·时序模型