前言
本文展示了在 MNIST
数据集上训练 Convolutional Variational AutoEncoder (VAE)
。 VAE
是自动编码器的概率模型,它会将高维输入数据压缩为维度较小的表示形式,但是实现方式与将输入映射到潜在向量的传统自动编码器不同,VAE 将输入数据映射到概率分布的参数
,最经典的方式莫过于高斯分布的均值和方差
。这种方法会产生一个连续的、结构化的潜在空间,这对于图像生成的多样化很有用。
模型原理
VAE
的框架图如下所示,在训练期间,输入图片数据 x
到编码器 encoder
,就像 AE 一样,编码器中一般都是多层卷积神经网络
,然而与 AE 的编码器不同的是不直接输出潜在向量 latent vector ,而是输出每个潜在变量的平均值和标准差
,然后从该均值和标准差中对潜在向量进行采样,然后将其发送到解码器 decoder
以重建输入图片。VAE 中的解码器的工作原理与 AE 中的解码器类似。
由于中间编码增加了 latent distribution
,所以损失函数不仅有和 AE 类似的 reconstruction loss
,还有 KL Divergence
来衡量 latent distribution
和 standard gaussian
的相似度。
数据处理
数据处理方面主要就是加载 MNIST 数据集,然后将测试集和训练集进行合并,并且将整个数据集大小调整为 【batch_size, weight,height,channel】,并且对所有的数据进行标准化。
scss
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255.0
模型搭建
这段代码定义了一个编码器 encoder
,用于将输入图像编码为潜在分布中的均值 z_mean
和对数方差 z_log_var
,并从中采样函数中得到潜在向量表示 z
。encoder 中间还是常见的卷积神经网络,进行了一系列的下采样操作。
ini
latent_dim = 2
encoder_inputs = tf.keras.Input(shape=(28, 28, 1))
x = tf.keras.layers.Conv2D(32, 3, activation='relu', strides=2, padding='same')(encoder_inputs)
x = tf.keras.layers.Conv2D(64, 3, activation='relu', strides=2, padding='same')(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(16, activation='relu')(x)
z_mean = tf.keras.layers.Dense(latent_dim, name='z_mean')(x)
z_log_var = tf.keras.layers.Dense(latent_dim, name='z_log_var')(x)
z = Sampling()([z_mean, z_log_var])
encoder = tf.keras.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')
这段代码定义了解码器 decoder
部分,输入就是采样得到的潜在向量
,然后经过常见的一系列反卷积层,即针对 encoder 对应的下采样操作进行反向的上采样
操作,最终得到重构的图片。
ini
latent_inputs = tf.keras.Input(shape=(latent_dim,))
x = tf.keras.layers.Dense(7 * 7 * 64, activation='relu')(latent_inputs)
x = tf.keras.layers.Reshape((7, 7, 64))(x)
x = tf.keras.layers.Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same')(x)
decoder_outputs = tf.keras.layers.Conv2DTranspose(1, 3, activation='sigmoid', strides=2, padding='same')(x)
decoder = tf.keras.Model(latent_inputs, decoder_outputs, name='decoder')
这段代码就是将上面的 encoder
和 decoder
两个部分组合起来创建 VAE 类
,主要定义了训练过程中的损失函数,也就是上面提到的 reconstruction_loss
和 kl_loss
的和。reconstruction_loss
是输入图片和解码器重建图片的均方损失, kl_loss
是潜在空间分布和标准高斯分布(零均值和单位方差)之间的 KL 散度,最终损失函数就是这两个损失的总和。
python
class VAE(tf.keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super().__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.total_loss_tracker = tf.keras.metrics.Mean(name='total_loss')
self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name='reconstruction_loss')
self.kl_loss_tracker = tf.keras.metrics.Mean(name='kl_loss')
@property
def metrics(self):
return [self.total_loss_tracker, self.reconstruction_loss_tracker, self.kl_loss_tracker]
def train_step(self, data):
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z) # [B, 28, 28, 1]
tmp = keras.losses.binary_crossentropy(data, reconstruction) # [B, 28, 28]
reconstruction_loss = tf.reduce_mean(tf.reduce_sum(tmp, axis=(1, 2)))
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) #[B, 2]
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return { "loss": self.total_loss_tracker.result(), "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result() }
模型训练
- 这里主要定义了一个
采样类
,将encoder
得到的潜在分布的对数方差和均值
进行重采样得到潜在向量
- 使用
adam
作为优化器 - 加入了
EarlyStopping
回调函数,当超过3
次损失值没有下降,就停止训练 - 总共训练
30
个 epoch ,每个 batch_size 为128
ini
class Sampling(tf.keras.layers.Layer):
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.random.normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
vae = VAE(encoder, decoder)
vae.compile(optimizer='adam')
callbacks = [tf.keras.callbacks.EarlyStopping(patience=3)]
vae.fit(mnist_digits, epochs=30, batch_size=128, callbacks=callbacks)
日志打印:
arduino
Epoch 1/30
547/547 [==============================] - 114s 2ms/step - loss: 285.0059 - reconstruction_loss: 216.4261 - kl_loss: 4.6019
Epoch 2/30
547/547 [==============================] - 1s 2ms/step - loss: 178.1703 - reconstruction_loss: 169.7719 - kl_loss: 4.4494
Epoch 3/30
547/547 [==============================] - 1s 2ms/step - loss: 168.2746 - reconstruction_loss: 162.4766 - kl_loss: 4.7840
...
Epoch 28/30
547/547 [==============================] - 2s 4ms/step - loss: 153.3385 - reconstruction_loss: 147.1945 - kl_loss: 5.8508
Epoch 29/30
547/547 [==============================] - 2s 4ms/step - loss: 153.0213 - reconstruction_loss: 147.0176 - kl_loss: 5.8779
Epoch 30/30
547/547 [==============================] - 2s 4ms/step - loss: 152.6355 - reconstruction_loss: 146.9703 - kl_loss: 5.8860
生成效果展示
这段函数代码的目的是可视化 VAE 的潜在空间
。具体而言,它生成一个二维平面的网格,在这个平面上均匀地采样得到潜在向量
(在这里直观展示出来的就是 x,y 坐标点
),然后使用解码器将每个潜在向量解码为图像,最后将所有解码得到的图像拼接成一个大的图像展示出来。
通过观察这个图像,可以发现潜在空间中不同区域的图像表现出的特征和分布规律。对照下图直观来说就是:
- 每个数字都有自己的集中分布区域
- 相邻区域的数字具有相近的特征,比如 9 和 8 的分布区域,3 和 2 的分布区域等
- 不相近的数字则分布相距较远,比如 6 和 9 的分布区域,2 和 1 的分布区域
- 在图像相邻分布区域接壤过程中会有
图像渐变的平滑过渡过程
这种平滑的过渡在生成模型中具有重要意义:
生成图像的连续性
:通过沿着潜在空间中连续的路径移动,我们可以生成具有渐变特征的图像。这使得生成的图像在视觉上连贯且具有连续性。插值和生成新图像
:利用潜在空间中的平滑过渡,我们可以执行插值操作,即在两个点之间进行线性插值,生成介于这两个点之间的新图像,例如生成两个数字中间的过渡形态"新数字",或者生成又像桌子又像椅子的中间物体"椅桌"探索潜在空间
:通过观察图像在潜在空间中的分布和过渡,我们可以更好地理解模型学到的表示空间。
ini
def plot_latent_space(vae, n=40, figsize=15):
digit_size = 28
scale = 1.0
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = np.linspace(-scale, scale, n)
grid_y = np.linspace(-scale, scale, n)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = vae.decoder.predict(z_sample, verbose=0)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size, j * digit_size: (j + 1) * digit_size, ] = digit
plt.figure(figsize=(figsize, figsize))
start_range = digit_size // 2
end_range = n * digit_size + start_range
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap="Greys_r")
plt.show()