Tensorflow2.0笔记 - AutoEncoder做FashionMnist数据集训练

本笔记记录自编码器做FashionMnist数据集训练,关于autoencoder的原理,请自行百度。

复制代码
import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input,losses
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
from tensorflow.keras.models import Model

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__

#加载fashion mnist数据集
(x_train, _), (x_test, _) = datasets.fashion_mnist.load_data()
#图片像素数据范围限值到[0,1]
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

print (x_train.shape)
print (x_test.shape)

h_dim = 64 
class Autoencoder(Model):
  def __init__(self, h_dim):
    super(Autoencoder, self).__init__()
    self.h_dim = h_dim   
    #encoder层,[b, 28, 28] => [b, 784] => [b, h_dim]
    self.encoder = tf.keras.Sequential([
      layers.Flatten(),
      layers.Dense(256, activation='relu'),
      layers.Dense(h_dim, activation='relu'),
    ])
    #decoder层,[b, h_dim] => [b,784] => [b, 28, 28]
    self.decoder = tf.keras.Sequential([
      layers.Dense(784, activation='sigmoid'),
      #恢复成28x28的图片
      layers.Reshape((28, 28))
    ])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

model = Autoencoder(h_dim)

model.compile(optimizer='adam', loss=losses.MeanSquaredError())
model.fit(x_train, x_train,
                epochs=10,
                shuffle=True,
                validation_data=(x_test, x_test))


encoded_imgs = model.encoder(x_test).numpy()
decoded_imgs = model.decoder(encoded_imgs).numpy()
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
  #绘制原始图像
  ax = plt.subplot(2, n, i + 1)
  plt.imshow(x_test[i])
  plt.title("original")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  #绘制重建的图像
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(decoded_imgs[i])
  plt.title("reconstructed")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

运行结果:

相关推荐
Tezign_space7 分钟前
技术方案|构建品牌KOS内容中台:三种架构模式与AI赋能实践
人工智能·架构·数字化转型·小红书·kos·内容营销·内容科技
嵌入式-老费18 分钟前
自己动手写深度学习框架(pytorch训练第一个网络)
人工智能·pytorch·深度学习
程序员小远21 分钟前
如何搭建Appium环境?
自动化测试·软件测试·python·测试工具·职场和发展·appium·测试用例
小刘摸鱼中23 分钟前
高频电子电路-振荡器的频率稳定度
网络·人工智能
用户3255491305625 分钟前
AI辅助神器Cursor –从0到1实战《仿小红书小程序》(已完结)
人工智能
烟袅27 分钟前
使用 OpenAI SDK 调用 Tools 实现外部工具集成
python·openai·agent
青瓷程序设计40 分钟前
果蔬识别系统【最新版】Python+TensorFlow+Vue3+Django+人工智能+深度学习+卷积神经网络算法
人工智能·python·深度学习
川石课堂软件测试1 小时前
自动化过程中验证码的解决思路
数据库·python·功能测试·测试工具·单元测试·tomcat·自动化
沫儿笙1 小时前
镀锌板焊接中库卡机器人是如何省气的
网络·人工智能·机器人
2301_764441331 小时前
新能源汽车电磁辐射高级预测
python·算法·数学建模·汽车