变分自编码器(VAE, Variational Autoencoder)

代码说明

VAE 模型结构:

编码器将输入数据(如 MNIST 图像)映射到潜在空间,生成均值 (mu) 和对数方差 (logvar)。

通过重新参数化技巧 (reparameterize) 从正态分布中采样潜在向量 z。

解码器将潜在向量 z 映射回原始空间,生成重构数据。

损失函数:

重构误差(BCE):衡量重构数据和原始数据的差异。

KL 散度(KLD):衡量潜在向量分布与标准正态分布的接近程度。

数据加载:

MNIST 数据集被用作示例,图像被标准化为 [0, 1] 范围。

生成结果:

测试阶段通过潜在空间随机采样生成新样本,并用 Matplotlib 可视化。

代码

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# 超参数
input_dim = 784  # 输入维度 (28x28 图像展开为向量)
hidden_dim = 400  # 隐藏层维度
latent_dim = 20   # 潜在空间维度
batch_size = 128  # 批量大小
num_epochs = 20   # 训练轮数
learning_rate = 1e-3  # 学习率

# 数据加载
# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
transform = transforms.Compose([
    transforms.ToTensor()  # 将像素值直接归一化到 [0, 1]
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# VAE 模型定义
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        # 编码器
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        # 解码器
        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h1 = torch.relu(self.fc1(x))
        mu = self.fc_mu(h1)
        logvar = self.fc_logvar(h1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h2 = torch.relu(self.fc2(z))
        return torch.sigmoid(self.fc3(h2))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# 构造模型、损失函数和优化器
model = VAE(input_dim, hidden_dim, latent_dim)
criterion = nn.BCELoss(reduction='sum')  # 二元交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练
def loss_function(recon_x, x, mu, logvar):
    BCE = criterion(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

model.train()
for epoch in range(num_epochs):
    train_loss = 0
    for data, _ in train_loader:
        data = data.view(-1, input_dim)  # 展平输入图像
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {train_loss / len(train_loader.dataset):.4f}")

# 测试(生成样本)
model.eval()
with torch.no_grad():
    z = torch.randn(16, latent_dim)  # 随机采样潜在向量
    samples = model.decode(z).view(-1, 1, 28, 28)  # 生成样本

# 可视化生成结果
plt.figure(figsize=(8, 8))
for i in range(16):
    plt.subplot(4, 4, i + 1)
    plt.imshow(samples[i][0].numpy(), cmap='gray')
    plt.axis('off')
plt.suptitle('Generated Samples from VAE')
plt.show()
相关推荐
lihuayong4 分钟前
OpenClaw 系统提示词
人工智能·prompt·提示词·openclaw
黑客说17 分钟前
AI驱动剧情,解锁无限可能——AI游戏发展解析
人工智能·游戏
踩着两条虫23 分钟前
AI驱动的Vue3应用开发平台深入探究(十):物料系统之内置组件库
android·前端·vue.js·人工智能·低代码·系统架构·rxjava
小仙女的小稀罕28 分钟前
听不清重要会议录音急疯?这款常见AI工具听脑AI精准转译
开发语言·人工智能·python
reesn35 分钟前
qwen3.5 0.8B纠正任务实践
人工智能·语言模型
实在智能RPA37 分钟前
实在Agent 制造业落地案例:探寻工业大模型从实验室走向车间的实战路径
人工智能·ai
阿酷tony1 小时前
Nano Banna 提示词:创意超逼真的3D商业风格产品图
人工智能·3d·gemini·图片生成
披着羊皮不是狼1 小时前
MSE、MAE、Binary/Categorical Cross-Entropy、HingeLoss五种损失函数的典型应用场景
人工智能·损失函数
guslegend1 小时前
大模型RAG进阶多格式文档解析
人工智能·大模型
独角鲸网络安全实验室1 小时前
惊魂零点击!OpenClaw漏洞(ClawJacked)突袭,开发者AI Agent遭无声劫持
人工智能·网络安全·数据安全·漏洞·openclaw·clawjacked·cve-2026-25253