CNN实现fashion_mnist数据集分类(tensorflow)

1、查看tensorflow版本

python 复制代码
import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

2、加载fashion_mnist数据与预处理

python 复制代码
import numpy as np
(train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.fashion_mnist.load_data()
# print(train_images.shape) # (60000, 28, 28)
# print(train_labels.shape) # (60000,)
# print(test_images.shape) # (10000, 28, 28)
# print(test_labels.shape) # (10000,)
train_images = np.expand_dims(train_images, -1)
# print(train_images.shape) # (个数, hight, width,channels)=(60000, 28, 28, 1)

3、CNN模型构建

python 复制代码
from keras.layers import Input,Dense,Dropout
from keras.layers import Conv2D,MaxPool2D,GlobalAvgPool2D

model = tf.keras.Sequential()
model.add(Input(shape=(28,28,1)))  # train_images.shape[1:]
model.add(Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same')) # 增加filter个数,增加模型拟合能力
model.add(Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'))
model.add(MaxPool2D())  # 默认2*2. 池化层扩大视野
model.add(Dropout(0.2)) # 防止过拟合
model.add(Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'))
model.add(Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'))
model.add(MaxPool2D())  # 默认2*2
model.add(Dropout(0.2)) # 防止过拟合
model.add(Conv2D(filters=256,kernel_size=(3,3),activation='relu'))
model.add(GlobalAvgPool2D()) # 全局平均池化
model.add(Dense(10,activation='softmax'))
model.summary()

4、模型配置与训练

python 复制代码
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['acc'])
              
H = model.fit(x=train_images,
              y=train_labels,
              validation_split=0.2,
              # validation_data=(X_test,y_test),
              epochs=10,
              batch_size=64,
              verbose=1)

5、损失函数和准确率分析

根据损失函数和准确率,判断模型是否过拟合或者欠拟合,不断调整网络结构,使得模型最优。

python 复制代码
import matplotlib.pyplot as plt
fig = plt.gcf()
fig.set_size_inches(12,4)
plt.subplot(1,2,1)
plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()
plt.title('loss')

plt.subplot(1,2,2)
plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()
plt.title('acc')
相关推荐
大数据魔法师3 小时前
分类与回归算法(三)- 逻辑回归
分类·回归·逻辑回归
qq_2546744112 小时前
回归、分类、聚类
分类·回归·聚类
Dfreedom.13 小时前
卷积神经网络(CNN)全面解析
人工智能·神经网络·cnn·卷积神经网络
B站_计算机毕业设计之家14 小时前
深度血虚:Django水果检测识别系统 CNN卷积神经网络算法 python语言 计算机 大数据✅
python·深度学习·计算机视觉·信息可视化·分类·cnn·django
元直数字电路验证21 小时前
感知机:乳腺癌分类实现 & K 均值聚类:从零实现
均值算法·分类·聚类
油泼辣子多加1 天前
【实战】自然语言处理--长文本分类(3)HAN算法
算法·自然语言处理·分类
大大dxy大大1 天前
机器学习实现逻辑回归-癌症分类预测
机器学习·分类·逻辑回归
听风吹等浪起1 天前
基于改进TransUNet的港口船只图像分割系统研究
人工智能·深度学习·cnn·transformer
qzhqbb1 天前
神经网络 - 卷积神经网络
神经网络·计算机视觉·cnn
~~李木子~~2 天前
Windows软件自动扫描与分类工具 - 技术文档
windows·分类·数据挖掘