Tensorflow2.0笔记 - AutoEncoder做FashionMnist数据集训练

本笔记记录自编码器做FashionMnist数据集训练,关于autoencoder的原理,请自行百度。

复制代码
import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input,losses
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
from tensorflow.keras.models import Model

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__

#加载fashion mnist数据集
(x_train, _), (x_test, _) = datasets.fashion_mnist.load_data()
#图片像素数据范围限值到[0,1]
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

print (x_train.shape)
print (x_test.shape)

h_dim = 64 
class Autoencoder(Model):
  def __init__(self, h_dim):
    super(Autoencoder, self).__init__()
    self.h_dim = h_dim   
    #encoder层,[b, 28, 28] => [b, 784] => [b, h_dim]
    self.encoder = tf.keras.Sequential([
      layers.Flatten(),
      layers.Dense(256, activation='relu'),
      layers.Dense(h_dim, activation='relu'),
    ])
    #decoder层,[b, h_dim] => [b,784] => [b, 28, 28]
    self.decoder = tf.keras.Sequential([
      layers.Dense(784, activation='sigmoid'),
      #恢复成28x28的图片
      layers.Reshape((28, 28))
    ])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

model = Autoencoder(h_dim)

model.compile(optimizer='adam', loss=losses.MeanSquaredError())
model.fit(x_train, x_train,
                epochs=10,
                shuffle=True,
                validation_data=(x_test, x_test))


encoded_imgs = model.encoder(x_test).numpy()
decoded_imgs = model.decoder(encoded_imgs).numpy()
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
  #绘制原始图像
  ax = plt.subplot(2, n, i + 1)
  plt.imshow(x_test[i])
  plt.title("original")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  #绘制重建的图像
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(decoded_imgs[i])
  plt.title("reconstructed")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

运行结果:

相关推荐
垂葛酒肝汤3 分钟前
Unity Sprite Rect 越界问题笔记
笔记·unity·游戏引擎
云青黛5 分钟前
ReAct(推理与行动)框架
python·算法
朗迹 - 张伟7 分钟前
UE5 UMG学习笔记
笔记·学习·ue5
jinanwuhuaguo12 分钟前
AI应用开发与自动化工具全景解析:Coze、Dify、FastGPT、n8n、MCP、Manus、Claude Code、OpenClaw
人工智能·学习·重构·新人首发·openclaw
人工智能AI技术12 分钟前
Claude 3.7 企业版私有化部署技术验证:与 .NET 实战方案
人工智能·c#
布局呆星13 分钟前
Python 文件操作教程
开发语言·python
数字护盾(和中)15 分钟前
AI 赋能安全:重构数字防御新范式
人工智能·安全·重构
左左右右左右摇晃17 分钟前
JVM 整理(五) 垃圾回收(GC)
jvm·笔记
大傻^17 分钟前
LangChain4j Agent 模式:ReAct、Plan-and-Solve 与自主决策
人工智能·agent·langchain4j·自主决策
跨境海王哥17 分钟前
ChatGPT降智怎么恢复?GPT-5.4降智原因与恢复方法
人工智能·chatgpt