直观感受卷积 VAE 模型的潜在分布空间

前言

本文展示了在 MNIST 数据集上训练 Convolutional Variational AutoEncoder (VAE)VAE 是自动编码器的概率模型,它会将高维输入数据压缩为维度较小的表示形式,但是实现方式与将输入映射到潜在向量的传统自动编码器不同,VAE 将输入数据映射到概率分布的参数,最经典的方式莫过于高斯分布的均值和方差。这种方法会产生一个连续的、结构化的潜在空间,这对于图像生成的多样化很有用。

模型原理

VAE 的框架图如下所示,在训练期间,输入图片数据 x 到编码器 encoder ,就像 AE 一样,编码器中一般都是多层卷积神经网络,然而与 AE 的编码器不同的是不直接输出潜在向量 latent vector ,而是输出每个潜在变量的平均值和标准差,然后从该均值和标准差中对潜在向量进行采样,然后将其发送到解码器 decoder 以重建输入图片。VAE 中的解码器的工作原理与 AE 中的解码器类似。

由于中间编码增加了 latent distribution ,所以损失函数不仅有和 AE 类似的 reconstruction loss ,还有 KL Divergence 来衡量 latent distributionstandard gaussian 的相似度。

数据处理

数据处理方面主要就是加载 MNIST 数据集,然后将测试集和训练集进行合并,并且将整个数据集大小调整为 【batch_size, weight,height,channel】,并且对所有的数据进行标准化。

scss 复制代码
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255.0

模型搭建

这段代码定义了一个编码器 encoder ,用于将输入图像编码为潜在分布中的均值 z_mean对数方差 z_log_var ,并从中采样函数中得到潜在向量表示 z 。encoder 中间还是常见的卷积神经网络,进行了一系列的下采样操作。

ini 复制代码
latent_dim = 2
encoder_inputs = tf.keras.Input(shape=(28, 28, 1))   
x = tf.keras.layers.Conv2D(32, 3, activation='relu', strides=2, padding='same')(encoder_inputs) 
x = tf.keras.layers.Conv2D(64, 3, activation='relu', strides=2, padding='same')(x)   
x = tf.keras.layers.Flatten()(x)   
x = tf.keras.layers.Dense(16, activation='relu')(x)  
z_mean = tf.keras.layers.Dense(latent_dim, name='z_mean')(x)   
z_log_var = tf.keras.layers.Dense(latent_dim, name='z_log_var')(x)   
z = Sampling()([z_mean, z_log_var])   
encoder = tf.keras.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')

这段代码定义了解码器 decoder 部分,输入就是采样得到的潜在向量,然后经过常见的一系列反卷积层,即针对 encoder 对应的下采样操作进行反向的上采样操作,最终得到重构的图片。

ini 复制代码
latent_inputs = tf.keras.Input(shape=(latent_dim,))  
x = tf.keras.layers.Dense(7 * 7 * 64, activation='relu')(latent_inputs)   
x = tf.keras.layers.Reshape((7, 7, 64))(x)   
x = tf.keras.layers.Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same')(x)   
decoder_outputs = tf.keras.layers.Conv2DTranspose(1, 3, activation='sigmoid', strides=2, padding='same')(x)   
decoder = tf.keras.Model(latent_inputs, decoder_outputs, name='decoder')

这段代码就是将上面的 encoderdecoder 两个部分组合起来创建 VAE 类,主要定义了训练过程中的损失函数,也就是上面提到的 reconstruction_losskl_loss 的和。reconstruction_loss 是输入图片和解码器重建图片的均方损失, kl_loss 是潜在空间分布和标准高斯分布(零均值和单位方差)之间的 KL 散度,最终损失函数就是这两个损失的总和。

python 复制代码
class VAE(tf.keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = tf.keras.metrics.Mean(name='total_loss')
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name='reconstruction_loss')
        self.kl_loss_tracker = tf.keras.metrics.Mean(name='kl_loss')

    @property
    def metrics(self):
        return [self.total_loss_tracker, self.reconstruction_loss_tracker, self.kl_loss_tracker]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)  # [B, 28, 28, 1]
            tmp = keras.losses.binary_crossentropy(data, reconstruction)  # [B, 28, 28]
            reconstruction_loss = tf.reduce_mean(tf.reduce_sum(tmp, axis=(1, 2)))
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))  #[B, 2]
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return { "loss": self.total_loss_tracker.result(),  "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result()  }

模型训练

  • 这里主要定义了一个采样类,将 encoder 得到的潜在分布的对数方差和均值进行重采样得到潜在向量
  • 使用 adam 作为优化器
  • 加入了 EarlyStopping 回调函数,当超过 3 次损失值没有下降,就停止训练
  • 总共训练 30 个 epoch ,每个 batch_size 为 128
ini 复制代码
class Sampling(tf.keras.layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

vae = VAE(encoder, decoder)
vae.compile(optimizer='adam')
callbacks = [tf.keras.callbacks.EarlyStopping(patience=3)]
vae.fit(mnist_digits, epochs=30, batch_size=128, callbacks=callbacks)

日志打印:

arduino 复制代码
Epoch 1/30
547/547 [==============================] - 114s 2ms/step - loss: 285.0059 - reconstruction_loss: 216.4261 - kl_loss: 4.6019
Epoch 2/30
547/547 [==============================] - 1s 2ms/step - loss: 178.1703 - reconstruction_loss: 169.7719 - kl_loss: 4.4494
Epoch 3/30
547/547 [==============================] - 1s 2ms/step - loss: 168.2746 - reconstruction_loss: 162.4766 - kl_loss: 4.7840
...
Epoch 28/30
547/547 [==============================] - 2s 4ms/step - loss: 153.3385 - reconstruction_loss: 147.1945 - kl_loss: 5.8508
Epoch 29/30
547/547 [==============================] - 2s 4ms/step - loss: 153.0213 - reconstruction_loss: 147.0176 - kl_loss: 5.8779
Epoch 30/30
547/547 [==============================] - 2s 4ms/step - loss: 152.6355 - reconstruction_loss: 146.9703 - kl_loss: 5.8860

生成效果展示

这段函数代码的目的是可视化 VAE 的潜在空间。具体而言,它生成一个二维平面的网格,在这个平面上均匀地采样得到潜在向量(在这里直观展示出来的就是 x,y 坐标点),然后使用解码器将每个潜在向量解码为图像,最后将所有解码得到的图像拼接成一个大的图像展示出来。

通过观察这个图像,可以发现潜在空间中不同区域的图像表现出的特征和分布规律。对照下图直观来说就是:

  • 每个数字都有自己的集中分布区域
  • 相邻区域的数字具有相近的特征,比如 9 和 8 的分布区域,3 和 2 的分布区域等
  • 不相近的数字则分布相距较远,比如 6 和 9 的分布区域,2 和 1 的分布区域
  • 在图像相邻分布区域接壤过程中会有图像渐变的平滑过渡过程

这种平滑的过渡在生成模型中具有重要意义:

  1. 生成图像的连续性:通过沿着潜在空间中连续的路径移动,我们可以生成具有渐变特征的图像。这使得生成的图像在视觉上连贯且具有连续性。
  2. 插值和生成新图像:利用潜在空间中的平滑过渡,我们可以执行插值操作,即在两个点之间进行线性插值,生成介于这两个点之间的新图像,例如生成两个数字中间的过渡形态"新数字",或者生成又像桌子又像椅子的中间物体"椅桌"
  3. 探索潜在空间:通过观察图像在潜在空间中的分布和过渡,我们可以更好地理解模型学到的表示空间。
ini 复制代码
def plot_latent_space(vae, n=40, figsize=15):
    digit_size = 28
    scale = 1.0
    figure = np.zeros((digit_size * n, digit_size * n))
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]
    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder.predict(z_sample, verbose=0)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[i * digit_size: (i + 1) * digit_size, j * digit_size: (j + 1) * digit_size, ] = digit
    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()

参考

相关推荐
余炜yw9 分钟前
【LSTM实战】跨越千年,赋诗成文:用LSTM重现唐诗的韵律与情感
人工智能·rnn·深度学习
莫叫石榴姐25 分钟前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
967727 分钟前
对抗样本存在的原因
深度学习
如若1231 小时前
利用 `OpenCV` 和 `Matplotlib` 库进行图像读取、颜色空间转换、掩膜创建、颜色替换
人工智能·opencv·matplotlib
YRr YRr1 小时前
深度学习:神经网络中的损失函数的使用
人工智能·深度学习·神经网络
ChaseDreamRunner1 小时前
迁移学习理论与应用
人工智能·机器学习·迁移学习
Guofu_Liao1 小时前
大语言模型---梯度的简单介绍;梯度的定义;梯度计算的方法
人工智能·语言模型·矩阵·llama
我爱学Python!1 小时前
大语言模型与图结构的融合: 推荐系统中的新兴范式
人工智能·语言模型·自然语言处理·langchain·llm·大语言模型·推荐系统
果冻人工智能1 小时前
OpenAI 是怎么“压力测试”大型语言模型的?
人工智能·语言模型·压力测试
日出等日落1 小时前
Windows电脑本地部署llamafile并接入Qwen大语言模型远程AI对话实战
人工智能·语言模型·自然语言处理