Tensorflow2.0笔记 - AutoEncoder做FashionMnist数据集训练

本笔记记录自编码器做FashionMnist数据集训练,关于autoencoder的原理,请自行百度。

复制代码
import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input,losses
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
from tensorflow.keras.models import Model

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__

#加载fashion mnist数据集
(x_train, _), (x_test, _) = datasets.fashion_mnist.load_data()
#图片像素数据范围限值到[0,1]
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

print (x_train.shape)
print (x_test.shape)

h_dim = 64 
class Autoencoder(Model):
  def __init__(self, h_dim):
    super(Autoencoder, self).__init__()
    self.h_dim = h_dim   
    #encoder层,[b, 28, 28] => [b, 784] => [b, h_dim]
    self.encoder = tf.keras.Sequential([
      layers.Flatten(),
      layers.Dense(256, activation='relu'),
      layers.Dense(h_dim, activation='relu'),
    ])
    #decoder层,[b, h_dim] => [b,784] => [b, 28, 28]
    self.decoder = tf.keras.Sequential([
      layers.Dense(784, activation='sigmoid'),
      #恢复成28x28的图片
      layers.Reshape((28, 28))
    ])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

model = Autoencoder(h_dim)

model.compile(optimizer='adam', loss=losses.MeanSquaredError())
model.fit(x_train, x_train,
                epochs=10,
                shuffle=True,
                validation_data=(x_test, x_test))


encoded_imgs = model.encoder(x_test).numpy()
decoded_imgs = model.decoder(encoded_imgs).numpy()
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
  #绘制原始图像
  ax = plt.subplot(2, n, i + 1)
  plt.imshow(x_test[i])
  plt.title("original")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  #绘制重建的图像
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(decoded_imgs[i])
  plt.title("reconstructed")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

运行结果:

相关推荐
四谎真好看9 小时前
Java 黑马程序员学习笔记(进阶篇19)
java·笔记·学习·学习笔记
'需尽欢'9 小时前
基于 Flask+Vue+MySQL的研学网站
python·mysql·flask
szxinmai主板定制专家10 小时前
【NI测试方案】基于ARM+FPGA的整车仿真与电池标定
arm开发·人工智能·yolo·fpga开发
新子y10 小时前
【小白笔记】最大交换 (Maximum Swap)问题
笔记·python
ygyqinghuan11 小时前
读懂目标检测
人工智能·目标检测·目标跟踪
华东数交11 小时前
企业与国有数据资产:入表全流程管理及资产化闭环理论解析
大数据·人工智能
程序员爱钓鱼11 小时前
Python编程实战 · 基础入门篇 | Python的缩进与代码块
后端·python
pr_note12 小时前
python|if判断语法对比
python
newxtc13 小时前
【昆明市不动产登记中心-注册安全分析报告】
人工智能·安全
techdashen13 小时前
圆桌讨论:Coding Agent or AI IDE 的现状和未来发展
ide·人工智能