Tensorflow2.0笔记 - AutoEncoder做FashionMnist数据集训练

本笔记记录自编码器做FashionMnist数据集训练,关于autoencoder的原理,请自行百度。

复制代码
import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input,losses
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
from tensorflow.keras.models import Model

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__

#加载fashion mnist数据集
(x_train, _), (x_test, _) = datasets.fashion_mnist.load_data()
#图片像素数据范围限值到[0,1]
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

print (x_train.shape)
print (x_test.shape)

h_dim = 64 
class Autoencoder(Model):
  def __init__(self, h_dim):
    super(Autoencoder, self).__init__()
    self.h_dim = h_dim   
    #encoder层,[b, 28, 28] => [b, 784] => [b, h_dim]
    self.encoder = tf.keras.Sequential([
      layers.Flatten(),
      layers.Dense(256, activation='relu'),
      layers.Dense(h_dim, activation='relu'),
    ])
    #decoder层,[b, h_dim] => [b,784] => [b, 28, 28]
    self.decoder = tf.keras.Sequential([
      layers.Dense(784, activation='sigmoid'),
      #恢复成28x28的图片
      layers.Reshape((28, 28))
    ])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

model = Autoencoder(h_dim)

model.compile(optimizer='adam', loss=losses.MeanSquaredError())
model.fit(x_train, x_train,
                epochs=10,
                shuffle=True,
                validation_data=(x_test, x_test))


encoded_imgs = model.encoder(x_test).numpy()
decoded_imgs = model.decoder(encoded_imgs).numpy()
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
  #绘制原始图像
  ax = plt.subplot(2, n, i + 1)
  plt.imshow(x_test[i])
  plt.title("original")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  #绘制重建的图像
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(decoded_imgs[i])
  plt.title("reconstructed")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

运行结果:

相关推荐
晚烛几秒前
CANN 数据增强 on NPU:训练数据增强的 NPU 加速实战
人工智能·python·深度学习·缓存·数据挖掘
FunTester1 分钟前
当 SDD 遇见 BDD:AI 时代 QA 范式的彻底重构
人工智能·重构·大语言模型·sdd·ai时代qa范式重构
英辰朗迪AI获客6 分钟前
WordPress 7.0 新手极速部署与实战指南
人工智能
ujainu6 分钟前
CANN pto-isa:为什么 AI 编译需要一层虚拟指令集
人工智能·ascend
SEO_juper6 分钟前
高转化英文产品页:SEO 友好 + GEO 易引用
人工智能·seo·跨境电商·外贸·geo·2026·谷歌算法更新
迁旭7 分钟前
Claude Code /status 功能技术文档
前端·javascript·人工智能·react.js·机器学习·gpt-3·文心一言
2601_957786778 分钟前
2026年企业级AI矩阵系统技术演进:从“群控分发“到“智能增长中台“的架构跃迁
人工智能·ai矩阵系统
南屹川8 分钟前
【架构设计】微服务架构设计模式:从理论到实践
人工智能
心中有国也有家9 分钟前
CANN 学习新范式:cann-learning-hub 如何让昇腾入门不再「劝退」
人工智能·经验分享·笔记·学习·算法
bboyHan9 分钟前
AI重构工程质量检测:从多模态感知到全流程闭环的技术实践
大数据·人工智能