VAE的原理及MNIST数据生成

⭐️ 变分自编码器(VAE)

变分自编码器(Variational Autoencoder,VAE)是一种生成模型,通过学习训练数据的概率分布,可以生成与训练数据分布相似的新样本。与经典的自编码器不同,VAE 的目标不仅仅是学习压缩和重构数据,而是通过学习数据的潜在分布来进行概率建模,适合图像生成、异常检测、数据缺失填补等任务。本文将深入探讨 VAE 的原理和实现步骤。

⭐️ VAE 的基本原理

VAE 是一种概率生成模型。与传统的自编码器类似,VAE 也包含一个编码器和一个解码器:

  • 编码器:将输入数据(如图像)映射到一个潜在空间的概率分布中。
  • 解码器:从潜在空间中采样的点生成新的数据。

VAE 的关键在于引入了 概率分布的思想,即假设潜在空间中的数据服从某种分布(通常是高斯分布),并通过学习这个分布来对输入数据进行生成建模。

在生成模型的框架下,VAE 的目标是找到一组参数,使得模型生成的样本分布尽可能接近训练数据的分布。这需要解决的问题是如何从潜在空间中采样数据并计算样本的重构误差。


⭐️ 编码器与解码器结构

VAE 的编码器和解码器结构如下:

  • 编码器 :将输入 x x x 映射到潜在空间的均值 μ ( x ) \mu(x) μ(x) 和标准差 σ ( x ) \sigma(x) σ(x)。这样,我们就可以从正态分布中采样出一个潜在变量 z z z。
    z ∼ N ( μ ( x ) , σ ( x ) 2 ) z \sim \mathcal{N}(\mu(x), \sigma(x)^2) z∼N(μ(x),σ(x)2)

  • 解码器 :从潜在变量 z z z 生成重构的样本 x ′ x' x′,即 p ( x ∣ z ) p(x|z) p(x∣z)。解码器的目标是让生成的 x ′ x' x′ 尽量接近原始输入 x x x。


⭐️ VAE 的损失函数

VAE 的损失函数由两部分组成:

  1. 重构误差 :度量重构数据与原始数据之间的相似度。通常使用二元交叉熵或均方误差来计算。
    Reconstruction Loss = − E q ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] \text{Reconstruction Loss} = -\mathbb{E}_{q(z|x)}[\log p(x|z)] Reconstruction Loss=−Eq(z∣x)[logp(x∣z)]

  2. KL 散度 :表示生成分布和标准正态分布之间的差异。它将潜在变量的分布约束为标准正态分布,从而使采样出的点在潜在空间上形成连续分布,能够生成平滑的图像。
    KL Divergence = D K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) \text{KL Divergence} = D_{KL}(q(z|x) || p(z)) KL Divergence=DKL(q(z∣x)∣∣p(z))

    其中 D K L D_{KL} DKL 是 KL 散度。

综上,VAE 的总损失可以表示为:
L = Reconstruction Loss + KL Divergence \mathcal{L} = \text{Reconstruction Loss} + \text{KL Divergence} L=Reconstruction Loss+KL Divergence


⭐️ 重参数化技巧

VAE 的训练难点在于采样过程不具有可微性。为了解决这个问题,引入了 重参数化技巧

  1. 通过编码器输出均值 μ \mu μ 和方差 σ \sigma σ,得到一个标准正态分布的噪声变量 ϵ \epsilon ϵ。
  2. 使用公式 z = μ + σ ⋅ ϵ z = \mu + \sigma \cdot \epsilon z=μ+σ⋅ϵ 得到潜在变量 z z z。

这样,我们可以通过反向传播来训练模型,因为这个过程可以微分。


⭐️ 训练VAE并随机生成MNIST数据

代码如下

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt


# 定义VAE模型
class VAE(nn.Module):

    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        # 编码器部分
        self.encoder = nn.Sequential(
            nn.Linear(784, 400),
            nn.ReLU(),
            nn.Linear(400, 2 * latent_dim)  # 输出均值和对数方差
        )
        # 解码器部分
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 400),
            nn.ReLU(),
            nn.Linear(400, 784),
            nn.Sigmoid()  # 将输出值约束到0-1之间
        )
        self.latent_dim = latent_dim

    def encode(self, x):
        h = self.encoder(x)
        mu, log_var = h.chunk(2, dim=-1)  # 分割成均值和对数方差
        return mu, log_var

    # 重参数化
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    # 前向传播
    def forward(self, x):
        mu, log_var = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var


# 定义VAE的损失函数
def vae_loss(recon_x, x, mu, log_var):
    # 重构损失:二元交叉熵
    recon_loss = nn.functional.binary_cross_entropy(recon_x,
                                                    x.view(-1, 784),
                                                    reduction='sum')
    # KL 散度
    kl_divergence = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss + kl_divergence


# 数据加载
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data',
                               train=True,
                               transform=transform,
                               download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# 初始化模型和优化器
latent_dim = 20
vae = VAE(latent_dim=latent_dim).to('cuda')
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

# 训练VAE模型,这里训练100轮
epochs = 100
vae.train()
for epoch in range(epochs):
    train_loss = 0
    for x, _ in train_loader:
        x = x.to('cuda')
        optimizer.zero_grad()
        recon_x, mu, log_var = vae(x)
        loss = vae_loss(recon_x, x, mu, log_var)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(f"Epoch {epoch + 1}, Loss: {train_loss / len(train_loader.dataset):.4f}")


# 保存模型
torch.save(vae.state_dict(), 'vae_gen_mnist_image.pth')

# 生成图像
vae.eval()
with torch.no_grad():
    z = torch.randn(64, latent_dim).to('cuda')  # 生成64个随机latent向量
    generated_images = vae.decode(z).cpu()
    generated_images = generated_images.view(-1, 1, 28, 28)  # 调整维度适应MNIST格式

    # 可视化生成的图像
    grid = make_grid(generated_images, nrow=8, padding=2, normalize=True)
    plt.figure(figsize=(8, 8))
    plt.imshow(grid.permute(1, 2, 0), cmap='gray')
    plt.axis('off')
    # plt.show()
    plt.savefig('gen_mnist.png', dpi=800)

训练100轮后的结果


⭐️ 总结与展望

VAE 通过概率建模在潜在空间中进行有效采样,生成数据的能力优于经典自编码器。这种方法使得 VAE 在图像生成、数据增强、异常检测和生成对抗网络的预训练等任务中表现出色。通过不断调整网络结构和损失函数,VAE 还可以扩展到其他复杂任务,如自然语言生成、音频生成等。

尽管 VAE 的生成效果可能比 GAN 略逊色,但其稳定的训练过程和概率模型框架使得 VAE 在多个领域得到了广泛应用。

相关推荐
古希腊掌管学习的神1 小时前
[机器学习]XGBoost(3)——确定树的结构
人工智能·机器学习
ZHOU_WUYI1 小时前
4.metagpt中的软件公司智能体 (ProjectManager 角色)
人工智能·metagpt
靴子学长2 小时前
基于字节大模型的论文翻译(含免费源码)
人工智能·深度学习·nlp
AI_NEW_COME3 小时前
知识库管理系统可扩展性深度测评
人工智能
海棠AI实验室3 小时前
AI的进阶之路:从机器学习到深度学习的演变(一)
人工智能·深度学习·机器学习
hunteritself3 小时前
AI Weekly『12月16-22日』:OpenAI公布o3,谷歌发布首个推理模型,GitHub Copilot免费版上线!
人工智能·gpt·chatgpt·github·openai·copilot
IT古董4 小时前
【机器学习】机器学习的基本分类-强化学习-策略梯度(Policy Gradient,PG)
人工智能·机器学习·分类
centurysee4 小时前
【最佳实践】Anthropic:Agentic系统实践案例
人工智能
mahuifa4 小时前
混合开发环境---使用编程AI辅助开发Qt
人工智能·vscode·qt·qtcreator·编程ai
四口鲸鱼爱吃盐4 小时前
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
人工智能·pytorch·分类