MLP实现fashion_mnist数据集分类(1)-模型构建、训练、保存与加载(tensorflow)

1、查看tensorflow版本

python 复制代码
import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

2、fashion_mnist数据集下载与展示

python 复制代码
(train_image,train_label),(test_image,test_label) = tf.keras.datasets.fashion_mnist.load_data()
print(train_image.shape)
print(train_label.shape)
print(test_image.shape)
print(test_label.shape)
python 复制代码
import matplotlib.pyplot as plt
# plt.imshow(train_image[0])  # 此处为啥是彩色的?

def plot_images_lables(images,labels,start_idx,num=5):
    fig = plt.gcf()
    fig.set_size_inches(12,14)
    for i in range(num):
        ax = plt.subplot(1,num,1+i)
        ax.imshow(images[start_idx+i],cmap='binary')
        title = 'label=' + str(labels[start_idx+i])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()
plot_images_lables(train_image,train_label,0,5)
# plot_images_lables(test_image,test_label,0,5)

3、数据预处理

python 复制代码
X_train,X_test = tf.cast(train_image/255.0,tf.float32),tf.cast(test_image/255.0,tf.float32) # 归一化
y_train,y_test = train_label,test_label # 此处对y没有做onehot处理,需要使用稀疏交叉损失函数

4、模型构建

python 复制代码
from keras import Sequential
from keras.layers import Flatten,Dense,Dropout
from keras import Input

model = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(units=256,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=64,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))
model.summary()

5、模型配置

python 复制代码
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['acc'])

6、模型训练

python 复制代码
H = model.fit(x=X_train,
              y=y_train,
              validation_split=0.2,
              # validation_data=(X_test,y_test),
              epochs=10,
              batch_size=128,
              verbose=1)
python 复制代码
plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()
python 复制代码
plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()

7、模型评估

python 复制代码
model.evaluate(X_test,y_test)

8、模型预测

python 复制代码
import numpy as np
import matplotlib.pyplot as plt

def pred_plot_images_lables(images,labels,start_idx,num=5):
    # 预测
    res = model.predict(images[start_idx:start_idx+num])
    res = np.argmax(res,axis=1)

    # 画图
    fig = plt.gcf()
    fig.set_size_inches(12,14)
    for i in range(num):
        ax = plt.subplot(1,num,1+i)
        ax.imshow(images[start_idx+i],cmap='binary')
        title = 'label=' + str(labels[start_idx+i]) + ', pred=' + str(res[i])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()
pred_plot_images_lables(X_test,y_test,0,5)

9、模型保存与加载

python 复制代码
import numpy as np

tf.keras.models.save_model(model,"model.keras")
loaded_model = tf.keras.models.load_model("model.keras")
# assert np.allclose(model.predict(X_test[:5]), loaded_model.predict(X_test[:5]))
print(np.argmax(model.predict(X_test[:5]),axis=1))
print(np.argmax(loaded_model.predict(X_test[:5]),axis=1))
相关推荐
xiangzhihong839 分钟前
Amodal3R ,南洋理工推出的 3D 生成模型
人工智能·深度学习·计算机视觉
狂奔solar1 小时前
diffusion-vas 提升遮挡区域的分割精度
人工智能·深度学习
fantasy_arch5 小时前
深度学习--softmax回归
人工智能·深度学习·回归
Blossom.1185 小时前
量子计算与经典计算的融合与未来
人工智能·深度学习·机器学习·计算机视觉·量子计算
硅谷秋水5 小时前
MoLe-VLA:通过混合层实现的动态跳层视觉-语言-动作模型实现高效机器人操作
人工智能·深度学习·机器学习·计算机视觉·语言模型·机器人
2301_764441336 小时前
基于神经网络的肾脏疾病预测模型
人工智能·深度学习·神经网络
HABuo6 小时前
【YOLOv8】YOLOv8改进系列(12)----替换主干网络之StarNet
人工智能·深度学习·yolo·目标检测·计算机视觉
Dovis(誓平步青云)7 小时前
深挖 DeepSeek 隐藏玩法·智能炼金术2.0版本
人工智能·深度学习·机器学习·数据挖掘·服务发现·智慧城市
赵钰老师7 小时前
【Deepseek、ChatGPT】智能气候前沿:AI Agent结合机器学习与深度学习在全球气候变化驱动因素预测中的应用
人工智能·python·深度学习·机器学习·数据分析
Start_Present9 小时前
Pytorch 第十三回:神经网络编码器——自动编解码器
pytorch·python·深度学习·神经网络