Tensorflow2.0笔记 - AutoEncoder做FashionMnist数据集训练

本笔记记录自编码器做FashionMnist数据集训练,关于autoencoder的原理,请自行百度。

复制代码
import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input,losses
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
from tensorflow.keras.models import Model

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__

#加载fashion mnist数据集
(x_train, _), (x_test, _) = datasets.fashion_mnist.load_data()
#图片像素数据范围限值到[0,1]
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

print (x_train.shape)
print (x_test.shape)

h_dim = 64 
class Autoencoder(Model):
  def __init__(self, h_dim):
    super(Autoencoder, self).__init__()
    self.h_dim = h_dim   
    #encoder层,[b, 28, 28] => [b, 784] => [b, h_dim]
    self.encoder = tf.keras.Sequential([
      layers.Flatten(),
      layers.Dense(256, activation='relu'),
      layers.Dense(h_dim, activation='relu'),
    ])
    #decoder层,[b, h_dim] => [b,784] => [b, 28, 28]
    self.decoder = tf.keras.Sequential([
      layers.Dense(784, activation='sigmoid'),
      #恢复成28x28的图片
      layers.Reshape((28, 28))
    ])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

model = Autoencoder(h_dim)

model.compile(optimizer='adam', loss=losses.MeanSquaredError())
model.fit(x_train, x_train,
                epochs=10,
                shuffle=True,
                validation_data=(x_test, x_test))


encoded_imgs = model.encoder(x_test).numpy()
decoded_imgs = model.decoder(encoded_imgs).numpy()
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
  #绘制原始图像
  ax = plt.subplot(2, n, i + 1)
  plt.imshow(x_test[i])
  plt.title("original")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  #绘制重建的图像
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(decoded_imgs[i])
  plt.title("reconstructed")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

运行结果:

相关推荐
ZhengEnCi8 分钟前
M3-markconv库找不到wkhtmltopdf问题
python
华农DrLai3 小时前
什么是LLM做推荐的三种范式?Prompt-based、Embedding-based、Fine-tuning深度解析
人工智能·深度学习·prompt·transformer·知识图谱·embedding
2301_764441333 小时前
LISA时空跃迁分析,地理时空分析
数据结构·python·算法
东北洗浴王子讲AI3 小时前
GPT-5.4辅助算法设计与优化:从理论到实践的系统方法
人工智能·gpt·算法·chatgpt
超低空3 小时前
OpenClaw Windows 安装详细教程
人工智能·程序员·ai编程
恋猫de小郭3 小时前
你的代理归我了:AI 大模型恶意中间人攻击,钱包都被转走了
前端·人工智能·ai编程
yongyoudayee3 小时前
2026 AI CRM选型大比拼:四大架构路线实测对比
人工智能·架构
大连好光景4 小时前
PYG从入门到放弃
笔记·学习
chushiyunen4 小时前
python rest请求、requests
开发语言·python
cTz6FE7gA4 小时前
Python异步编程:从协程到Asyncio的底层揭秘
python