CNN实现fashion_mnist数据集分类(tensorflow)

1、查看tensorflow版本

python 复制代码
import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

2、加载fashion_mnist数据与预处理

python 复制代码
import numpy as np
(train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.fashion_mnist.load_data()
# print(train_images.shape) # (60000, 28, 28)
# print(train_labels.shape) # (60000,)
# print(test_images.shape) # (10000, 28, 28)
# print(test_labels.shape) # (10000,)
train_images = np.expand_dims(train_images, -1)
# print(train_images.shape) # (个数, hight, width,channels)=(60000, 28, 28, 1)

3、CNN模型构建

python 复制代码
from keras.layers import Input,Dense,Dropout
from keras.layers import Conv2D,MaxPool2D,GlobalAvgPool2D

model = tf.keras.Sequential()
model.add(Input(shape=(28,28,1)))  # train_images.shape[1:]
model.add(Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same')) # 增加filter个数,增加模型拟合能力
model.add(Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'))
model.add(MaxPool2D())  # 默认2*2. 池化层扩大视野
model.add(Dropout(0.2)) # 防止过拟合
model.add(Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'))
model.add(Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'))
model.add(MaxPool2D())  # 默认2*2
model.add(Dropout(0.2)) # 防止过拟合
model.add(Conv2D(filters=256,kernel_size=(3,3),activation='relu'))
model.add(GlobalAvgPool2D()) # 全局平均池化
model.add(Dense(10,activation='softmax'))
model.summary()

4、模型配置与训练

python 复制代码
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['acc'])
              
H = model.fit(x=train_images,
              y=train_labels,
              validation_split=0.2,
              # validation_data=(X_test,y_test),
              epochs=10,
              batch_size=64,
              verbose=1)

5、损失函数和准确率分析

根据损失函数和准确率,判断模型是否过拟合或者欠拟合,不断调整网络结构,使得模型最优。

python 复制代码
import matplotlib.pyplot as plt
fig = plt.gcf()
fig.set_size_inches(12,4)
plt.subplot(1,2,1)
plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()
plt.title('loss')

plt.subplot(1,2,2)
plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()
plt.title('acc')
相关推荐
Elastic 中国社区官方博客17 小时前
Elasticsearch:使用机器学习生成筛选器和分类标签
大数据·人工智能·elasticsearch·机器学习·搜索引擎·ai·分类
IT猿手2 天前
基于CNN-LSTM的深度Q网络(Deep Q-Network,DQN)求解移动机器人路径规划,MATLAB代码
网络·cnn·lstm
MPCTHU2 天前
预测分析(三):基于机器学习的分类预测
人工智能·机器学习·分类
荷包蛋蛋怪2 天前
【北京化工大学】 神经网络与深度学习 实验6 MATAR图像分类
人工智能·深度学习·神经网络·opencv·机器学习·计算机视觉·分类
卧式纯绿3 天前
每日文献(八)——Part one
人工智能·yolo·目标检测·计算机视觉·目标跟踪·cnn
带娃的IT创业者3 天前
《Python实战进阶》No39:模型部署——TensorFlow Serving 与 ONNX
pytorch·python·tensorflow·持续部署
浊酒南街3 天前
TensorFlow实现逻辑回归
人工智能·tensorflow·逻辑回归
西柚小萌新3 天前
【深度学习:进阶篇】--2.1.多分类与TensorFlow
分类·数据挖掘·tensorflow
简简单单做算法3 天前
基于mediapipe深度学习和限定半径最近邻分类树算法的人体摔倒检测系统python源码
人工智能·python·深度学习·算法·分类·mediapipe·限定半径最近邻分类树
小白的高手之路3 天前
torch.nn.Conv2d介绍——Pytorch中的二维卷积层
人工智能·pytorch·python·深度学习·神经网络·机器学习·cnn