前言
本文展示了在 MNIST 数据集上训练 Convolutional Variational AutoEncoder (VAE) 。 VAE 是自动编码器的概率模型,它会将高维输入数据压缩为维度较小的表示形式,但是实现方式与将输入映射到潜在向量的传统自动编码器不同,VAE 将输入数据映射到概率分布的参数,最经典的方式莫过于高斯分布的均值和方差。这种方法会产生一个连续的、结构化的潜在空间,这对于图像生成的多样化很有用。
模型原理
VAE 的框架图如下所示,在训练期间,输入图片数据 x 到编码器 encoder ,就像 AE 一样,编码器中一般都是多层卷积神经网络,然而与 AE 的编码器不同的是不直接输出潜在向量 latent vector ,而是输出每个潜在变量的平均值和标准差,然后从该均值和标准差中对潜在向量进行采样,然后将其发送到解码器 decoder 以重建输入图片。VAE 中的解码器的工作原理与 AE 中的解码器类似。
由于中间编码增加了 latent distribution ,所以损失函数不仅有和 AE 类似的 reconstruction loss ,还有 KL Divergence 来衡量 latent distribution 和 standard gaussian 的相似度。

数据处理
数据处理方面主要就是加载 MNIST 数据集,然后将测试集和训练集进行合并,并且将整个数据集大小调整为 【batch_size, weight,height,channel】,并且对所有的数据进行标准化。
scss
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255.0
模型搭建
这段代码定义了一个编码器 encoder ,用于将输入图像编码为潜在分布中的均值 z_mean 和对数方差 z_log_var ,并从中采样函数中得到潜在向量表示 z 。encoder 中间还是常见的卷积神经网络,进行了一系列的下采样操作。
ini
latent_dim = 2
encoder_inputs = tf.keras.Input(shape=(28, 28, 1))
x = tf.keras.layers.Conv2D(32, 3, activation='relu', strides=2, padding='same')(encoder_inputs)
x = tf.keras.layers.Conv2D(64, 3, activation='relu', strides=2, padding='same')(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(16, activation='relu')(x)
z_mean = tf.keras.layers.Dense(latent_dim, name='z_mean')(x)
z_log_var = tf.keras.layers.Dense(latent_dim, name='z_log_var')(x)
z = Sampling()([z_mean, z_log_var])
encoder = tf.keras.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')
这段代码定义了解码器 decoder 部分,输入就是采样得到的潜在向量,然后经过常见的一系列反卷积层,即针对 encoder 对应的下采样操作进行反向的上采样操作,最终得到重构的图片。
ini
latent_inputs = tf.keras.Input(shape=(latent_dim,))
x = tf.keras.layers.Dense(7 * 7 * 64, activation='relu')(latent_inputs)
x = tf.keras.layers.Reshape((7, 7, 64))(x)
x = tf.keras.layers.Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same')(x)
decoder_outputs = tf.keras.layers.Conv2DTranspose(1, 3, activation='sigmoid', strides=2, padding='same')(x)
decoder = tf.keras.Model(latent_inputs, decoder_outputs, name='decoder')
这段代码就是将上面的 encoder 和 decoder 两个部分组合起来创建 VAE 类,主要定义了训练过程中的损失函数,也就是上面提到的 reconstruction_loss 和 kl_loss 的和。reconstruction_loss 是输入图片和解码器重建图片的均方损失, kl_loss 是潜在空间分布和标准高斯分布(零均值和单位方差)之间的 KL 散度,最终损失函数就是这两个损失的总和。
python
class VAE(tf.keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super().__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.total_loss_tracker = tf.keras.metrics.Mean(name='total_loss')
self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name='reconstruction_loss')
self.kl_loss_tracker = tf.keras.metrics.Mean(name='kl_loss')
@property
def metrics(self):
return [self.total_loss_tracker, self.reconstruction_loss_tracker, self.kl_loss_tracker]
def train_step(self, data):
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z) # [B, 28, 28, 1]
tmp = keras.losses.binary_crossentropy(data, reconstruction) # [B, 28, 28]
reconstruction_loss = tf.reduce_mean(tf.reduce_sum(tmp, axis=(1, 2)))
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) #[B, 2]
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return { "loss": self.total_loss_tracker.result(), "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result() }
模型训练
- 这里主要定义了一个
采样类,将encoder得到的潜在分布的对数方差和均值进行重采样得到潜在向量 - 使用
adam作为优化器 - 加入了
EarlyStopping回调函数,当超过3次损失值没有下降,就停止训练 - 总共训练
30个 epoch ,每个 batch_size 为128
ini
class Sampling(tf.keras.layers.Layer):
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.random.normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
vae = VAE(encoder, decoder)
vae.compile(optimizer='adam')
callbacks = [tf.keras.callbacks.EarlyStopping(patience=3)]
vae.fit(mnist_digits, epochs=30, batch_size=128, callbacks=callbacks)
日志打印:
arduino
Epoch 1/30
547/547 [==============================] - 114s 2ms/step - loss: 285.0059 - reconstruction_loss: 216.4261 - kl_loss: 4.6019
Epoch 2/30
547/547 [==============================] - 1s 2ms/step - loss: 178.1703 - reconstruction_loss: 169.7719 - kl_loss: 4.4494
Epoch 3/30
547/547 [==============================] - 1s 2ms/step - loss: 168.2746 - reconstruction_loss: 162.4766 - kl_loss: 4.7840
...
Epoch 28/30
547/547 [==============================] - 2s 4ms/step - loss: 153.3385 - reconstruction_loss: 147.1945 - kl_loss: 5.8508
Epoch 29/30
547/547 [==============================] - 2s 4ms/step - loss: 153.0213 - reconstruction_loss: 147.0176 - kl_loss: 5.8779
Epoch 30/30
547/547 [==============================] - 2s 4ms/step - loss: 152.6355 - reconstruction_loss: 146.9703 - kl_loss: 5.8860
生成效果展示
这段函数代码的目的是可视化 VAE 的潜在空间。具体而言,它生成一个二维平面的网格,在这个平面上均匀地采样得到潜在向量(在这里直观展示出来的就是 x,y 坐标点),然后使用解码器将每个潜在向量解码为图像,最后将所有解码得到的图像拼接成一个大的图像展示出来。
通过观察这个图像,可以发现潜在空间中不同区域的图像表现出的特征和分布规律。对照下图直观来说就是:
- 每个数字都有自己的集中分布区域
- 相邻区域的数字具有相近的特征,比如 9 和 8 的分布区域,3 和 2 的分布区域等
- 不相近的数字则分布相距较远,比如 6 和 9 的分布区域,2 和 1 的分布区域
- 在图像相邻分布区域接壤过程中会有
图像渐变的平滑过渡过程
这种平滑的过渡在生成模型中具有重要意义:
生成图像的连续性:通过沿着潜在空间中连续的路径移动,我们可以生成具有渐变特征的图像。这使得生成的图像在视觉上连贯且具有连续性。插值和生成新图像:利用潜在空间中的平滑过渡,我们可以执行插值操作,即在两个点之间进行线性插值,生成介于这两个点之间的新图像,例如生成两个数字中间的过渡形态"新数字",或者生成又像桌子又像椅子的中间物体"椅桌"探索潜在空间:通过观察图像在潜在空间中的分布和过渡,我们可以更好地理解模型学到的表示空间。
ini
def plot_latent_space(vae, n=40, figsize=15):
digit_size = 28
scale = 1.0
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = np.linspace(-scale, scale, n)
grid_y = np.linspace(-scale, scale, n)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = vae.decoder.predict(z_sample, verbose=0)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size, j * digit_size: (j + 1) * digit_size, ] = digit
plt.figure(figsize=(figsize, figsize))
start_range = digit_size // 2
end_range = n * digit_size + start_range
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap="Greys_r")
plt.show()
