Tensorflow2.0笔记 - AutoEncoder做FashionMnist数据集训练

本笔记记录自编码器做FashionMnist数据集训练,关于autoencoder的原理,请自行百度。

复制代码
import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input,losses
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
from tensorflow.keras.models import Model

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__

#加载fashion mnist数据集
(x_train, _), (x_test, _) = datasets.fashion_mnist.load_data()
#图片像素数据范围限值到[0,1]
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

print (x_train.shape)
print (x_test.shape)

h_dim = 64 
class Autoencoder(Model):
  def __init__(self, h_dim):
    super(Autoencoder, self).__init__()
    self.h_dim = h_dim   
    #encoder层,[b, 28, 28] => [b, 784] => [b, h_dim]
    self.encoder = tf.keras.Sequential([
      layers.Flatten(),
      layers.Dense(256, activation='relu'),
      layers.Dense(h_dim, activation='relu'),
    ])
    #decoder层,[b, h_dim] => [b,784] => [b, 28, 28]
    self.decoder = tf.keras.Sequential([
      layers.Dense(784, activation='sigmoid'),
      #恢复成28x28的图片
      layers.Reshape((28, 28))
    ])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

model = Autoencoder(h_dim)

model.compile(optimizer='adam', loss=losses.MeanSquaredError())
model.fit(x_train, x_train,
                epochs=10,
                shuffle=True,
                validation_data=(x_test, x_test))


encoded_imgs = model.encoder(x_test).numpy()
decoded_imgs = model.decoder(encoded_imgs).numpy()
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
  #绘制原始图像
  ax = plt.subplot(2, n, i + 1)
  plt.imshow(x_test[i])
  plt.title("original")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  #绘制重建的图像
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(decoded_imgs[i])
  plt.title("reconstructed")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

运行结果:

相关推荐
一个java开发10 分钟前
distributed.client.Client 用户可调用函数分析
大数据·python
eqwaak021 分钟前
Matplotlib 动态显示详解:技术深度与创新思考
网络·python·网络协议·tcp/ip·语言模型·matplotlib
六月的可乐25 分钟前
【干货推荐】AI助理前端UI组件-悬浮球组件
前端·人工智能·ui
蔡俊锋31 分钟前
【无标题】
人工智能·chatgpt
007php00733 分钟前
某大厂MySQL面试之SQL注入触点发现与SQLMap测试
数据库·python·sql·mysql·面试·职场和发展·golang
CodeCraft Studio35 分钟前
Excel处理控件Aspose.Cells教程:使用 Python 将 Pandas DataFrame 转换为 Excel
python·json·excel·pandas·csv·aspose·dataframe
说私域40 分钟前
基于开源AI大模型AI智能名片S2B2C商城小程序的参与感构建研究
人工智能·小程序·开源
flashlight_hi1 小时前
LeetCode 分类刷题:2563. 统计公平数对的数目
python·算法·leetcode
java1234_小锋1 小时前
Scikit-learn Python机器学习 - 特征预处理 - 归一化 (Normalization):MinMaxScaler
python·机器学习·scikit-learn
码蛊仙尊1 小时前
2025计算机视觉新技术
人工智能·计算机视觉