CNN实现fashion_mnist数据集分类(tensorflow)

1、查看tensorflow版本

python 复制代码
import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

2、加载fashion_mnist数据与预处理

python 复制代码
import numpy as np
(train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.fashion_mnist.load_data()
# print(train_images.shape) # (60000, 28, 28)
# print(train_labels.shape) # (60000,)
# print(test_images.shape) # (10000, 28, 28)
# print(test_labels.shape) # (10000,)
train_images = np.expand_dims(train_images, -1)
# print(train_images.shape) # (个数, hight, width,channels)=(60000, 28, 28, 1)

3、CNN模型构建

python 复制代码
from keras.layers import Input,Dense,Dropout
from keras.layers import Conv2D,MaxPool2D,GlobalAvgPool2D

model = tf.keras.Sequential()
model.add(Input(shape=(28,28,1)))  # train_images.shape[1:]
model.add(Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same')) # 增加filter个数,增加模型拟合能力
model.add(Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'))
model.add(MaxPool2D())  # 默认2*2. 池化层扩大视野
model.add(Dropout(0.2)) # 防止过拟合
model.add(Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'))
model.add(Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'))
model.add(MaxPool2D())  # 默认2*2
model.add(Dropout(0.2)) # 防止过拟合
model.add(Conv2D(filters=256,kernel_size=(3,3),activation='relu'))
model.add(GlobalAvgPool2D()) # 全局平均池化
model.add(Dense(10,activation='softmax'))
model.summary()

4、模型配置与训练

python 复制代码
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['acc'])
              
H = model.fit(x=train_images,
              y=train_labels,
              validation_split=0.2,
              # validation_data=(X_test,y_test),
              epochs=10,
              batch_size=64,
              verbose=1)

5、损失函数和准确率分析

根据损失函数和准确率,判断模型是否过拟合或者欠拟合,不断调整网络结构,使得模型最优。

python 复制代码
import matplotlib.pyplot as plt
fig = plt.gcf()
fig.set_size_inches(12,4)
plt.subplot(1,2,1)
plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()
plt.title('loss')

plt.subplot(1,2,2)
plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()
plt.title('acc')
相关推荐
三之又三12 小时前
卷积神经网络CNN-part5-NiN
人工智能·神经网络·cnn
rit843249917 小时前
人工鱼群算法AFSA优化支持向量机SVM,提高故障分类精度
算法·支持向量机·分类
小龙18 小时前
图卷积神经网络(GCN)学习笔记
笔记·学习·cnn·gcn·图卷积神经网络·理论知识
先做个垃圾出来………19 小时前
传统模型RNN与CNN介绍
人工智能·rnn·cnn
guidovans21 小时前
Crawl4AI精准提取结构化数据
人工智能·python·tensorflow
虚拟现实旅人2 天前
【机器学习】通过tensorflow实现猫狗识别的深度学习进阶之路
深度学习·机器学习·tensorflow
colus_SEU2 天前
【卷积神经网络详解与实例】4——感受野
人工智能·深度学习·计算机视觉·cnn
索迪迈科技2 天前
机器学习投票分类
人工智能·机器学习·分类
盼小辉丶2 天前
Transformer实战(17)——微调Transformer语言模型进行多标签文本分类
深度学习·分类·transformer
君名余曰正则2 天前
机器学习实操项目03——Scikit-learn介绍及简单分类案例
机器学习·分类·scikit-learn