直观感受卷积 VAE 模型的潜在分布空间

前言

本文展示了在 MNIST 数据集上训练 Convolutional Variational AutoEncoder (VAE)VAE 是自动编码器的概率模型,它会将高维输入数据压缩为维度较小的表示形式,但是实现方式与将输入映射到潜在向量的传统自动编码器不同,VAE 将输入数据映射到概率分布的参数,最经典的方式莫过于高斯分布的均值和方差。这种方法会产生一个连续的、结构化的潜在空间,这对于图像生成的多样化很有用。

模型原理

VAE 的框架图如下所示,在训练期间,输入图片数据 x 到编码器 encoder ,就像 AE 一样,编码器中一般都是多层卷积神经网络,然而与 AE 的编码器不同的是不直接输出潜在向量 latent vector ,而是输出每个潜在变量的平均值和标准差,然后从该均值和标准差中对潜在向量进行采样,然后将其发送到解码器 decoder 以重建输入图片。VAE 中的解码器的工作原理与 AE 中的解码器类似。

由于中间编码增加了 latent distribution ,所以损失函数不仅有和 AE 类似的 reconstruction loss ,还有 KL Divergence 来衡量 latent distributionstandard gaussian 的相似度。

数据处理

数据处理方面主要就是加载 MNIST 数据集,然后将测试集和训练集进行合并,并且将整个数据集大小调整为 【batch_size, weight,height,channel】,并且对所有的数据进行标准化。

scss 复制代码
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255.0

模型搭建

这段代码定义了一个编码器 encoder ,用于将输入图像编码为潜在分布中的均值 z_mean对数方差 z_log_var ,并从中采样函数中得到潜在向量表示 z 。encoder 中间还是常见的卷积神经网络,进行了一系列的下采样操作。

ini 复制代码
latent_dim = 2
encoder_inputs = tf.keras.Input(shape=(28, 28, 1))   
x = tf.keras.layers.Conv2D(32, 3, activation='relu', strides=2, padding='same')(encoder_inputs) 
x = tf.keras.layers.Conv2D(64, 3, activation='relu', strides=2, padding='same')(x)   
x = tf.keras.layers.Flatten()(x)   
x = tf.keras.layers.Dense(16, activation='relu')(x)  
z_mean = tf.keras.layers.Dense(latent_dim, name='z_mean')(x)   
z_log_var = tf.keras.layers.Dense(latent_dim, name='z_log_var')(x)   
z = Sampling()([z_mean, z_log_var])   
encoder = tf.keras.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')

这段代码定义了解码器 decoder 部分,输入就是采样得到的潜在向量,然后经过常见的一系列反卷积层,即针对 encoder 对应的下采样操作进行反向的上采样操作,最终得到重构的图片。

ini 复制代码
latent_inputs = tf.keras.Input(shape=(latent_dim,))  
x = tf.keras.layers.Dense(7 * 7 * 64, activation='relu')(latent_inputs)   
x = tf.keras.layers.Reshape((7, 7, 64))(x)   
x = tf.keras.layers.Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same')(x)   
decoder_outputs = tf.keras.layers.Conv2DTranspose(1, 3, activation='sigmoid', strides=2, padding='same')(x)   
decoder = tf.keras.Model(latent_inputs, decoder_outputs, name='decoder')

这段代码就是将上面的 encoderdecoder 两个部分组合起来创建 VAE 类,主要定义了训练过程中的损失函数,也就是上面提到的 reconstruction_losskl_loss 的和。reconstruction_loss 是输入图片和解码器重建图片的均方损失, kl_loss 是潜在空间分布和标准高斯分布(零均值和单位方差)之间的 KL 散度,最终损失函数就是这两个损失的总和。

python 复制代码
class VAE(tf.keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = tf.keras.metrics.Mean(name='total_loss')
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name='reconstruction_loss')
        self.kl_loss_tracker = tf.keras.metrics.Mean(name='kl_loss')

    @property
    def metrics(self):
        return [self.total_loss_tracker, self.reconstruction_loss_tracker, self.kl_loss_tracker]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)  # [B, 28, 28, 1]
            tmp = keras.losses.binary_crossentropy(data, reconstruction)  # [B, 28, 28]
            reconstruction_loss = tf.reduce_mean(tf.reduce_sum(tmp, axis=(1, 2)))
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))  #[B, 2]
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return { "loss": self.total_loss_tracker.result(),  "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result()  }

模型训练

  • 这里主要定义了一个采样类,将 encoder 得到的潜在分布的对数方差和均值进行重采样得到潜在向量
  • 使用 adam 作为优化器
  • 加入了 EarlyStopping 回调函数,当超过 3 次损失值没有下降,就停止训练
  • 总共训练 30 个 epoch ,每个 batch_size 为 128
ini 复制代码
class Sampling(tf.keras.layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

vae = VAE(encoder, decoder)
vae.compile(optimizer='adam')
callbacks = [tf.keras.callbacks.EarlyStopping(patience=3)]
vae.fit(mnist_digits, epochs=30, batch_size=128, callbacks=callbacks)

日志打印:

arduino 复制代码
Epoch 1/30
547/547 [==============================] - 114s 2ms/step - loss: 285.0059 - reconstruction_loss: 216.4261 - kl_loss: 4.6019
Epoch 2/30
547/547 [==============================] - 1s 2ms/step - loss: 178.1703 - reconstruction_loss: 169.7719 - kl_loss: 4.4494
Epoch 3/30
547/547 [==============================] - 1s 2ms/step - loss: 168.2746 - reconstruction_loss: 162.4766 - kl_loss: 4.7840
...
Epoch 28/30
547/547 [==============================] - 2s 4ms/step - loss: 153.3385 - reconstruction_loss: 147.1945 - kl_loss: 5.8508
Epoch 29/30
547/547 [==============================] - 2s 4ms/step - loss: 153.0213 - reconstruction_loss: 147.0176 - kl_loss: 5.8779
Epoch 30/30
547/547 [==============================] - 2s 4ms/step - loss: 152.6355 - reconstruction_loss: 146.9703 - kl_loss: 5.8860

生成效果展示

这段函数代码的目的是可视化 VAE 的潜在空间。具体而言,它生成一个二维平面的网格,在这个平面上均匀地采样得到潜在向量(在这里直观展示出来的就是 x,y 坐标点),然后使用解码器将每个潜在向量解码为图像,最后将所有解码得到的图像拼接成一个大的图像展示出来。

通过观察这个图像,可以发现潜在空间中不同区域的图像表现出的特征和分布规律。对照下图直观来说就是:

  • 每个数字都有自己的集中分布区域
  • 相邻区域的数字具有相近的特征,比如 9 和 8 的分布区域,3 和 2 的分布区域等
  • 不相近的数字则分布相距较远,比如 6 和 9 的分布区域,2 和 1 的分布区域
  • 在图像相邻分布区域接壤过程中会有图像渐变的平滑过渡过程

这种平滑的过渡在生成模型中具有重要意义:

  1. 生成图像的连续性:通过沿着潜在空间中连续的路径移动,我们可以生成具有渐变特征的图像。这使得生成的图像在视觉上连贯且具有连续性。
  2. 插值和生成新图像:利用潜在空间中的平滑过渡,我们可以执行插值操作,即在两个点之间进行线性插值,生成介于这两个点之间的新图像,例如生成两个数字中间的过渡形态"新数字",或者生成又像桌子又像椅子的中间物体"椅桌"
  3. 探索潜在空间:通过观察图像在潜在空间中的分布和过渡,我们可以更好地理解模型学到的表示空间。
ini 复制代码
def plot_latent_space(vae, n=40, figsize=15):
    digit_size = 28
    scale = 1.0
    figure = np.zeros((digit_size * n, digit_size * n))
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]
    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder.predict(z_sample, verbose=0)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[i * digit_size: (i + 1) * digit_size, j * digit_size: (j + 1) * digit_size, ] = digit
    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()

参考

相关推荐
陈苏同学6 分钟前
4. 将pycharm本地项目同步到(Linux)服务器上——深度学习·科研实践·从0到1
linux·服务器·ide·人工智能·python·深度学习·pycharm
吾名招财24 分钟前
yolov5-7.0模型DNN加载函数及参数详解(重要)
c++·人工智能·yolo·dnn
FL162386312935 分钟前
[深度学习][python]yolov11+bytetrack+pyqt5实现目标追踪
深度学习·qt·yolo
羊小猪~~41 分钟前
深度学习项目----用LSTM模型预测股价(包含LSTM网络简介,代码数据均可下载)
pytorch·python·rnn·深度学习·机器学习·数据分析·lstm
鼠鼠龙年发大财1 小时前
【鼠鼠学AI代码合集#7】概率
人工智能
龙的爹23331 小时前
论文 | Model-tuning Via Prompts Makes NLP Models Adversarially Robust
人工智能·gpt·深度学习·语言模型·自然语言处理·prompt
工业机器视觉设计和实现1 小时前
cnn突破四(生成卷积核与固定核对比)
人工智能·深度学习·cnn
醒了就刷牙1 小时前
58 深层循环神经网络_by《李沐:动手学深度学习v2》pytorch版
pytorch·rnn·深度学习
985小水博一枚呀1 小时前
【对于Python爬虫的理解】数据挖掘、信息聚合、价格监控、新闻爬取等,附代码。
爬虫·python·深度学习·数据挖掘
想要打 Acm 的小周同学呀2 小时前
实现mnist手写数字识别
深度学习·tensorflow·实现mnist手写数字识别