使用CNN构建VAE

使用CNN构建VAE

    • [0. 前言](#0. 前言)
    • [1. 网络架构](#1. 网络架构)
    • [2. 模型构建](#2. 模型构建)
    • [3. 生成结果](#3. 生成结果)

0. 前言

在原始变分自编码器 (Variational Autoencoder, VAE)中,VAE 网络采用全连接网络实现。本节将使用卷积神经网络 (Convolutional Neural Network, CNN)提升生成数字的质量,同时将参数数量大幅减少至 134165 个。

1. 网络架构

下图展示了 CNN 编码器模型的两个输出------潜向量的均值与方差。其中的 lambda 函数实现了重参数化技巧,将随机潜编码的采样过程移至 VAE 网络外部:

下图展示了 CNN 解码器模型。其二维输入来自 lambda 函数,输出为重构的 MNIST 数字图像:

下图展示了完整的 CNN-VAE 模型,由编码器与解码器连接构成:

2. 模型构建

编码器由两层 CNN 和两层全连接层组成,用于生成潜编码。其输出结构与原始变分自编码器 (Variational Autoencoder, VAE)类似。解码器则包含一层全连接层和三层转置卷积层。

python 复制代码
#reparameterization trick
#z = z_mean + sqrt(var) * eps
def sampling(args):
    z_mean,z_log_var = args
    batch = keras.backend.shape(z_mean)[0]
    dim = keras.backend.shape(z_mean)[1]

    epsilon = keras.backend.random_normal(shape=(batch,dim))
    return z_mean + keras.backend.exp(0.5 * z_log_var) * epsilon

def plot_results(models,
        data,
        batch_size=128,
        model_name='vae_mnist'):
    encoder,decoder = models
    x_test,y_test = data
    xmin = ymin = -4
    xmax = ymax = +4
    os.makedirs(model_name,exist_ok=True)

    filename = os.path.join(model_name,'vae_mean.png')
    #display a 2D plot of the digit classes in the latent space
    z,_,_ = encoder.predict(x_test,batch_size=batch_size)
    plt.figure(figsize=(12,10))

    #axes x and y ranges
    axes = plt.gca()
    axes.set_xlim([xmin,xmax])
    axes.set_ylim([ymin,ymax])

    # subsampling to reduce density of points on the plot
    z = z[0::2]
    y_test = y_test[0::2]
    plt.scatter(z[:,0],z[:,1],marker='')
    for i,digit in enumerate(y_test):
        axes.annotate(digit,(z[i,0],z[i,1]))
    plt.xlabel('z[0]')
    plt.ylabel('z[1]')
    plt.savefig(filename)
    plt.show()

    filename = os.path.join(model_name,'digits_over_latent.png')
    #display a 30*30 2D mainfold of digits
    n = 30
    digit_size = 28
    figure = np.zeros((digit_size * n,digit_size * n))
    #linearly spaced coordinates corresponding to the 2D plot of digit classes in the latent space
    #线性间隔的坐标,对应于潜在空间中数字类的二维图
    grid_x = np.linspace(-4,4,n)
    grid_y = np.linspace(-4,4,n)[::-1]

    for i,yi in enumerate(grid_x):
        for j,xi in enumerate(grid_y):
            z_sample = np.array([[xi,yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size,digit_size)
            figure[i * digit_size:(i+1)*digit_size,j*digit_size:(j+1)*digit_size] = digit
    
    plt.figure(figsize=(10, 10))
    start_range = digit_size // 2
    end_range = (n-1) * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap='Greys_r')
    plt.savefig(filename)
    plt.show()

# MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

image_size = x_train.shape[1]
# original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, image_size,image_size,1])
x_test = np.reshape(x_test, [-1, image_size,image_size,1])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

#超参数
input_shape = (image_size,image_size,1)
batch_size = 128
kernel_size = 3
filters = 16
latent_dim = 2
epochs = 50

#VAE model
#encoder
inputs = keras.layers.Input(shape=input_shape,name='encoder_input')
x = inputs
for i in range(2):
    filters *= 2
    x = keras.layers.Conv2D(filters=filters,
            kernel_size=kernel_size,
            activation='relu',
            strides=2,
            padding='same')(x)
shape = keras.backend.int_shape(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(16,activation='relu')(x)
z_mean = keras.layers.Dense(latent_dim,name='z_mean')(x)
z_log_var = keras.layers.Dense(latent_dim,name='z_log_var')(x)

z = keras.layers.Lambda(sampling,output_shape=(latent_dim,),name='z')([z_mean,z_log_var])

encoder = keras.Model(inputs,[z_mean,z_log_var,z],name='encoder')
encoder.summary()
keras.utils.plot_model(encoder,to_file='vae_cnn_encoder.png',show_shapes=True)

#decoder
latent_inputs = keras.layers.Input(shape=(latent_dim,),name='z_sampling')
x = keras.layers.Dense(shape[1]*shape[2]*shape[3],activation='relu')(latent_inputs)
x = keras.layers.Reshape((shape[1],shape[2],shape[3]))(x)
for i in range(2):
    x = keras.layers.Conv2DTranspose(filters=filters,
            kernel_size=kernel_size,
            activation='relu',
            strides=2,
            padding='same')(x)
    filters //= 2
outputs = keras.layers.Conv2DTranspose(filters=1,
        kernel_size=kernel_size,
        activation='sigmoid',
        padding='same',
        name='decoder_output')(x)
decoder = keras.Model(latent_inputs,outputs,name='decoder')
decoder.summary()

keras.utils.plot_model(decoder,to_file='vae_cnn_decoder.png',show_shapes=True)

outputs = decoder(encoder(inputs)[2])
vae = keras.Model(inputs,outputs,name='vae_cnn')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    help_ = "Load tf model trained weights"
    parser.add_argument("-w", "--weights", help=help_)
    help_ = "Use binary cross entropy instead of mse (default)"
    parser.add_argument("--bce", help=help_, action='store_true')
    args = parser.parse_args()
    models = (encoder, decoder)
    data = (x_test, y_test)
    
    #VAE loss = mse_loss or xent_loss + kl_loss
    if args.bce:
        reconstruction_loss = keras.losses.binary_crossentropy(keras.backend.flatten(inputs),
                keras.backend.flatten(outputs))
    else:
        reconstruction_loss = keras.losses.mse(keras.backend.flatten(inputs),
                keras.backend.flatten(outputs))
    
    reconstruction_loss *= image_size * image_size
    kl_loss = 1 + z_log_var - keras.backend.square(z_mean) - keras.backend.exp(z_log_var)
    kl_loss = keras.backend.sum(kl_loss,axis=-1)
    kl_loss *= -0.5
    vae_loss = keras.backend.mean(reconstruction_loss + kl_loss)
    vae.add_loss(vae_loss)
    vae.compile(optimizer='adam')
    vae.summary()
    keras.utils.plot_model(vae,to_file='vae_cnn.png',show_shapes=True)
    save_dir = 'vae_cnn_weights'
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    if args.weights:
        filepath = os.path.join(save_dir,args.weights)
        vae = vae.load_weights(filepath)
    else:
        #train
        vae.fit(x_train,
                epochs=epochs,
                batch_size=batch_size,
                validation_data=(x_test,None))
        filepath = os.path.join(save_dir,'vae_cnn.mnist.tf')
        vae.save_weights(filepath)
    plot_results(models,data,batch_size=batch_size,model_name='vae_cnn')

3. 生成结果

VAE 模型经过 30epoch 训练后,下图展示了在连续潜空间中遍历时数字的分布情况。例如从中心区域向右移动时,数字会从 2 逐渐转变为 0

下图展示了生成模型的输出结果。从质量上看,与基于 MLP 实现 VAE相比,模糊难辨的数字数量明显减少:

相关推荐
蓝博AI6 小时前
基于卷积神经网络的香蕉成熟度识别系统,resnet50,vgg16,resnet34【pytorch框架,python代码】
人工智能·pytorch·python·神经网络·cnn
格林威7 小时前
近红外工业相机的简单介绍和场景应用
人工智能·深度学习·数码相机·计算机视觉·视觉检测·制造·工业相机
Pocker_Spades_A7 小时前
机器学习之生成对抗网络(GAN)
人工智能·深度学习·生成对抗网络
Theodore_10227 小时前
机器学习(2) 线性回归和代价函数
人工智能·深度学习·机器学习·线性回归·代价函数
技术闲聊DD8 小时前
深度学习(13)-PyTorch 数据转换
人工智能·pytorch·深度学习
星期天要睡觉8 小时前
深度学习——基于 PyTorch 的蔬菜图像分类
人工智能·pytorch·python·深度学习·分类
AI浩8 小时前
Yolo分割数据集错误数据删除
人工智能·深度学习·yolo
康kang9 小时前
Transformer神经网络模型
深度学习·神经网络·transformer