Tensorflow2.0笔记 - AutoEncoder做FashionMnist数据集训练

本笔记记录自编码器做FashionMnist数据集训练,关于autoencoder的原理,请自行百度。

复制代码
import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input,losses
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
from tensorflow.keras.models import Model

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__

#加载fashion mnist数据集
(x_train, _), (x_test, _) = datasets.fashion_mnist.load_data()
#图片像素数据范围限值到[0,1]
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

print (x_train.shape)
print (x_test.shape)

h_dim = 64 
class Autoencoder(Model):
  def __init__(self, h_dim):
    super(Autoencoder, self).__init__()
    self.h_dim = h_dim   
    #encoder层,[b, 28, 28] => [b, 784] => [b, h_dim]
    self.encoder = tf.keras.Sequential([
      layers.Flatten(),
      layers.Dense(256, activation='relu'),
      layers.Dense(h_dim, activation='relu'),
    ])
    #decoder层,[b, h_dim] => [b,784] => [b, 28, 28]
    self.decoder = tf.keras.Sequential([
      layers.Dense(784, activation='sigmoid'),
      #恢复成28x28的图片
      layers.Reshape((28, 28))
    ])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

model = Autoencoder(h_dim)

model.compile(optimizer='adam', loss=losses.MeanSquaredError())
model.fit(x_train, x_train,
                epochs=10,
                shuffle=True,
                validation_data=(x_test, x_test))


encoded_imgs = model.encoder(x_test).numpy()
decoded_imgs = model.decoder(encoded_imgs).numpy()
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
  #绘制原始图像
  ax = plt.subplot(2, n, i + 1)
  plt.imshow(x_test[i])
  plt.title("original")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  #绘制重建的图像
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(decoded_imgs[i])
  plt.title("reconstructed")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

运行结果:

相关推荐
聚客AI32 分钟前
Embedding进化论:从Word2Vec到OpenAI三代模型技术跃迁
人工智能·llm·掘金·日新计划
碎叶城李白34 分钟前
若依学习笔记1-validated
java·笔记·学习·validated
weixin_387545641 小时前
深入解析 AI Gateway:新一代智能流量控制中枢
人工智能·gateway
im_AMBER1 小时前
学习日志05 python
python·学习
大虫小呓1 小时前
Python 处理 Excel 数据 pandas 和 openpyxl 哪家强?
python·pandas
聽雨2371 小时前
03每日简报20250705
人工智能·社交电子·娱乐·传媒·媒体
哪 吒1 小时前
2025B卷 - 华为OD机试七日集训第5期 - 按算法分类,由易到难,循序渐进,玩转OD(Python/JS/C/C++)
python·算法·华为od·华为od机试·2025b卷
二川bro1 小时前
飞算智造JavaAI:智能编程革命——AI重构Java开发新范式
java·人工智能·重构
acstdm2 小时前
DAY 48 CBAM注意力
人工智能·深度学习·机器学习
澪-sl2 小时前
基于CNN的人脸关键点检测
人工智能·深度学习·神经网络·计算机视觉·cnn·视觉检测·卷积神经网络