【VAE】From Pixels to Faces: Building a VAE from Scratch


文章目录

  • [1. 什么是 VAE?](#1. 什么是 VAE?)
  • [2. 整体架构概览](#2. 整体架构概览)
  • [3. 编码器:从图像到概率分布](#3. 编码器:从图像到概率分布)
  • [4. 重参数化技巧 (Reparameterization Trick)](#4. 重参数化技巧 (Reparameterization Trick))
  • [5. 解码器:从隐向量重建图像](#5. 解码器:从隐向量重建图像)
  • [6. 从隐空间采样生成新图像](#6. 从隐空间采样生成新图像)
  • [7. 损失函数:重构 + 正则化](#7. 损失函数:重构 + 正则化)
  • [8. 训练流程](#8. 训练流程)
  • [9. 实验结果:重建与生成](#9. 实验结果:重建与生成)
  • [10. 数据加载:CelebA 数据集](#10. 数据加载:CelebA 数据集)
  • [11. 完整的前向传播流程](#11. 完整的前向传播流程)
  • [12. 关键超参数与调优建议](#12. 关键超参数与调优建议)
  • [13. 扩展方向](#13. 扩展方向)
  • [14. 完整代码清单](#14. 完整代码清单)
  • [15. 完整代码](#15. 完整代码)
  • 参考

从零实现 VAE:用变分自编码器生成人脸图像

本文基于 PyTorch 实现一个完整的 Variational Autoencoder (VAE),在 CelebA 数据集上训练,实现人脸重建与生成。代码已添加详尽注释,适合作为理解生成模型的入门实践。


1. 什么是 VAE?

Variational Autoencoder (VAE) 是由 Kingma 和 Welling 在 2014 年提出的深度生成模型。与普通自编码器不同,VAE 不直接学习一个确定的隐向量,而是学习隐空间的概率分布。这使得 VAE 能够:

  • 生成新样本:从学到的分布中采样,通过解码器生成新数据
  • 平滑的隐空间:相似的隐向量对应相似的输出,支持插值和语义操作

VAE 的核心思想可以用 ELBO(证据下界)概括:

L = E z ∼ q ϕ ( z ∣ x ) log ⁡ p θ ( x ∣ z ) ⏟ 重构损失 − D K L ( q ϕ ( z ∣ x ) ∥ p ( z ) ) ⏟ KL 散度 \mathcal{L} = \underbrace{\mathbb{E}{z \sim q\phi(z|x)}\\log p_\\theta(x\|z)}{\text{重构损失}} - \underbrace{D{KL}(q_\phi(z|x) \| p(z))}_{\text{KL 散度}} L=重构损失 Ez∼qϕ(z∣x)logpθ(x∣z)−KL 散度 DKL(qϕ(z∣x)∥p(z))

其中第一项鼓励解码器准确重建输入,第二项 (KL 散度) 鼓励编码器输出的分布逼近标准正态分布 N ( 0 , I ) \mathcal{N}(0, I) N(0,I)。


2. 整体架构概览

本项目包含三个核心文件:

文件 职责
model.py VAE 模型定义(编码器 + 解码器 + 重参数化)
main.py 训练循环、损失函数、重建与生成逻辑
load_celebA.py CelebA 数据集加载与预处理

输入为 64×64×3 的 RGB 人脸图像,经过编码器压缩为 128 维的隐向量,再通过解码器重建回原始尺寸。


3. 编码器:从图像到概率分布

编码器由 5 层 stride=2 的卷积堆叠而成,逐步将 64×64 的输入压缩为 2×2 的特征图,最后通过两个独立的全连接头分别输出 μ \mu μ 和 log ⁡ σ 2 \log\sigma^2 logσ2。

python 复制代码
class VAE(nn.Module):
    """VAE for 64x64 face generation.

    VAE (Variational Autoencoder) 变分自编码器,用于 64x64 人脸图像的生成。
    与普通自编码器不同,VAE 的编码器输出隐变量的概率分布参数 (μ, σ²),
    而非一个确定的隐向量,从而使得模型具备生成新样本的能力。

    The hidden dimensions can be tuned.
    中间的隐藏层通道数可以根据显存和效果需求进行调整。
    """

    def __init__(self, hiddens=[16, 32, 64, 128, 256], latent_dim=128) -> None:
        """
        参数:
            hiddens: 编码器每层卷积的输出通道数,同时也是解码器转置卷积的输入通道数(逆序使用)
                     默认 [16, 32, 64, 128, 256],共 5 层,每次下采样 2 倍
            latent_dim: 隐空间维度,即潜变量 z 的维度,默认 128
        """
        super().__init__()

        # ==================== 编码器 (Encoder) ====================
        # 编码器将输入图像 [B, 3, 64, 64] 逐步压缩为特征图 [B, 256, 2, 2]
        # 然后通过两个全连接层分别输出隐变量的均值和对数方差
        prev_channels = 3          # 输入为 3 通道 RGB 图像
        modules = []               # 存储各层卷积模块
        img_length = 64            # 输入图像的空间尺寸 (高/宽)
        for cur_channels in hiddens:
            # 每层:Conv2d(stride=2) 将空间尺寸减半,通道数翻倍
            # kernel_size=3, stride=2, padding=1 实现 exactly 减半的效果
            # 以 64 → 32 为例: output = (64 - 3 + 2*1) / 2 + 1 = 32 ✓
            modules.append(
                nn.Sequential(
                    nn.Conv2d(prev_channels,      # 输入通道数
                              cur_channels,       # 输出通道数
                              kernel_size=3,      # 3×3 卷积核
                              stride=2,           # 步长为 2,替代池化实现下采样
                              padding=1),         # 填充 1,保持边界信息
                    nn.BatchNorm2d(cur_channels), # 批归一化,加速收敛、稳定训练
                    nn.ReLU()))                   # ReLU 激活,引入非线性
            prev_channels = cur_channels
            img_length //= 2                      # 每层空间尺寸减半: 64→32→16→8→4→2
        self.encoder = nn.Sequential(*modules)

        # 编码器最终输出特征图尺寸: [B, 256, 2, 2],展平后为 [B, 1024]
        # 两个独立的全连接头:一个预测均值 μ,一个预测对数方差 log(σ²)
        # 使用 log(σ²) 而非 σ² 的原因:
        #   1. 数值稳定性更好,避免方差趋近于 0
        #   2. 无需约束输出为正(exp 后自然为正)
        self.mean_linear = nn.Linear(prev_channels * img_length * img_length,
                                     latent_dim)  # 1024 → 128,输出均值 μ
        self.var_linear = nn.Linear(prev_channels * img_length * img_length,
                                    latent_dim)   # 1024 → 128,输出对数方差 log(σ²)
        self.latent_dim = latent_dim

编码器维度变化

复制代码
输入: [3, 64, 64]
  ↓ Conv2d(3→16, k3, s2, p1) + BN + ReLU   → [16, 32, 32]
  ↓ Conv2d(16→32, k3, s2, p1) + BN + ReLU  → [32, 16, 16]
  ↓ Conv2d(32→64, k3, s2, p1) + BN + ReLU  → [64,  8,  8]
  ↓ Conv2d(64→128, k3, s2, p1) + BN + ReLU → [128, 4,  4]
  ↓ Conv2d(128→256, k3, s2, p1) + BN + ReLU → [256, 2,  2]
  ↓ Flatten → 1024 维向量
  ↓ mean_linear: 1024 → 128  (均值 μ)
  ↓ var_linear:  1024 → 128  (对数方差 log σ²)

关键设计点:

  1. stride=2 替代池化:使用 stride=2 的卷积代替 MaxPooling,让网络自己学习下采样方式
  2. BatchNorm:加速收敛并稳定训练,防止某一层激活值过大或过小
  3. 两个独立的全连接头 :分别输出均值 μ \mu μ 和对数方差 log ⁡ σ 2 \log \sigma^2 logσ2------使用 log 形式保证数值稳定,且无需约束输出为正


4. 重参数化技巧 (Reparameterization Trick)

这是 VAE 实现中最精妙的部分。我们需要从 N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2) N(μ,σ2) 中采样 z z z,但采样操作是不可导 的,会阻断梯度回传到 μ \mu μ 和 σ \sigma σ。

解决方案 :从标准正态分布采样噪声 ε ∼ N ( 0 , I ) \varepsilon \sim \mathcal{N}(0, I) ε∼N(0,I),然后做线性变换:

z = μ + σ ⋅ ε = μ + e log ⁡ σ 2 2 ⋅ ε z = \mu + \sigma \cdot \varepsilon = \mu + e^{\frac{\log \sigma^2}{2}} \cdot \varepsilon z=μ+σ⋅ε=μ+e2logσ2⋅ε

python 复制代码
    def forward(self, x):
        """
        前向传播:输入图像 → 编码 → 重参数化采样 → 解码 → 重建图像

        参数:
            x: 输入图像 [B, 3, 64, 64]

        返回:
            decoded: 重建图像 [B, 3, 64, 64]
            mean: 隐变量分布的均值 μ [B, latent_dim]
            logvar: 隐变量分布的对数方差 log(σ²) [B, latent_dim]
        """
        # ---- 编码阶段 ----
        encoded = self.encoder(x)                         # [B, 3, 64, 64] → [B, 256, 2, 2]
        encoded = torch.flatten(encoded, 1)               # 展平: [B, 256, 2, 2] → [B, 1024]
        mean = self.mean_linear(encoded)                  # 均值 μ: [B, 128]
        logvar = self.var_linear(encoded)                 # 对数方差 log(σ²): [B, 128]

        # ---- 重参数化技巧 (Reparameterization Trick) ----
        # 目标:从 N(μ, σ²) 中采样 z,同时保持梯度可回传
        # 做法:先从 N(0,1) 采样 ε,再计算 z = μ + σ · ε
        # 这样随机性被隔离在 ε 中,μ 和 σ 仍可接收梯度
        eps = torch.randn_like(logvar)                    # ε ~ N(0, I),与 logvar 同形状
        std = torch.exp(logvar / 2)                       # σ = exp(log σ² / 2) = exp(log σ)
        z = eps * std + mean                              # z = μ + σ · ε,重参数化后的隐变量

        # ---- 解码阶段 ----
        x = self.decoder_projection(z)                    # [B, 128] → [B, 1024]
        x = torch.reshape(x, (-1, *self.decoder_input_chw))  # [B, 1024] → [B, 256, 2, 2]
        decoded = self.decoder(x)                         # [B, 256, 2, 2] → [B, 3, 64, 64]

        return decoded, mean, logvar

这样随机性被隔离在 ε \varepsilon ε 中, μ \mu μ 和 log ⁡ σ 2 \log \sigma^2 logσ2 都是确定性的,梯度可以正常回传。


注意,实现中学的是 μ \mu μ 和 l o g σ 2 log\sigma^2 logσ2

为什么网络输出 log(σ²),而不是直接输出 σ²?

答案其实涉及 数值稳定性 + 优化难度 + KL公式简化 三个方面。


(1)首先,σ² 必须大于 0

VAE 假设:

q ( z ∣ x ) = N ( μ , σ 2 ) q(z|x)=N(\mu,\sigma^2) q(z∣x)=N(μ,σ2)

这里:

  • μ 可以任意
  • σ² 必须 > 0

因为方差不可能为负数。


如果直接预测:

python 复制代码
sigma2 = Linear(...)

网络可能输出:

text 复制代码
-3.2
-0.8
-100

这些都是非法值。


所以必须加约束:

例如:

python 复制代码
sigma2 = softplus(raw)

或者:

python 复制代码
sigma2 = exp(raw)

(2)用 log(σ²) 可以天然保证正数

令:

l = log ⁡ ( σ 2 ) l = \log(\sigma^2) l=log(σ2)

网络预测: l l l

而不是:

σ 2 \sigma^2 σ2


恢复方差时:

σ 2 = e l \sigma^2 = e^l σ2=el

无论:

text 复制代码
l = -100
l = 0
l = 100

都有:

e l > 0 e^l > 0 el>0

因此:

不需要额外约束,天然合法。


(3)KL Loss 会变得特别漂亮

VAE的 KL 项:

K L ( q ( z ∣ x ) ∥ p ( z ) ) KL(q(z|x)\parallel p(z)) KL(q(z∣x)∥p(z))

其中:

q ( z ∣ x ) = N ( μ , σ 2 ) q(z|x)=N(\mu,\sigma^2) q(z∣x)=N(μ,σ2)

p ( z ) = N ( 0 , 1 ) p(z)=N(0,1) p(z)=N(0,1)

最终可以推导成(过程这里省略,严格推导出来的):

K L = 1 2 ∑ ( μ 2 + σ 2 − log ⁡ ( σ 2 ) − 1 ) KL = \frac{1}{2} \sum \left( \mu^2 + \sigma^2 - \log(\sigma^2) - 1 \right) KL=21∑(μ2+σ2−log(σ2)−1)


这里直接出现:

log ⁡ ( σ 2 ) \log(\sigma^2) log(σ2)


所以如果网络输出:

python 复制代码
logvar

KL直接写:

python 复制代码
kl = -0.5 * torch.sum(
    1 + logvar - mu.pow(2) - logvar.exp()
)

非常方便。


如果输出的是:

python 复制代码
sigma2

那每次都要:

python 复制代码
torch.log(sigma2)

更麻烦。


(4)数值稳定性更好

这是最重要的原因。


假设真实方差:

σ 2 = 0.000001 \sigma^2=0.000001 σ2=0.000001

如果直接优化:

σ 2 \sigma^2 σ2

梯度会非常奇怪。


再假设:

σ 2 = 1000000 \sigma^2=1000000 σ2=1000000

梯度又会巨大。


动态范围:

text 复制代码
0.000001
~
1000000

跨度:

10 12 10^{12} 1012


优化器很难处理。


而取对数后:

log ⁡ ( σ 2 ) \log(\sigma^2) log(σ2)

变成:

text 复制代码
-13.8
~
13.8

范围小得多。

优化容易很多。


(5)从高斯分布角度看

很多统计模型里:

真正优化的其实不是:

σ \sigma σ

也不是:

σ 2 \sigma^2 σ2

而是:

log ⁡ ( σ ) \log(\sigma) log(σ)

或者:

log ⁡ ( σ 2 ) \log(\sigma^2) log(σ2)

因为:

高斯分布对数似然本身就是 log-space 的。

例如负对数似然中会出现:

log ⁡ ( σ 2 ) \log(\sigma^2) log(σ2)

因此在 log 空间优化更加自然。


(6)Reparameterization Trick 更方便

VAE采样:

z = μ + σ ϵ z=\mu+\sigma\epsilon z=μ+σϵ

其中:

ϵ ∼ N ( 0 , 1 ) \epsilon \sim N(0,1) ϵ∼N(0,1)


网络输出:

python 复制代码
logvar

之后:

python 复制代码
std = exp(0.5 * logvar)

因为:

σ = σ 2 = e log ⁡ ( σ 2 ) = e 0.5 ⋅ log ⁡ ( σ 2 ) \sigma = \sqrt{\sigma^2} = \sqrt{e^{\log(\sigma^2)}} = e^{0.5 \cdot \log(\sigma^2)} σ=σ2 =elog(σ2) =e0.5⋅log(σ2)

于是:

python 复制代码
z = mu + std * eps

直接完成。


方差必须为正,而 log(σ²) 既能保证方差合法,又能让 KL Loss 更简洁、梯度更稳定、训练更容易收敛。


5. 解码器:从隐向量重建图像

解码器先将 128 维隐向量通过全连接层投影回 1024 维,reshape 为 2×2 的特征图,再通过 5 层转置卷积(ConvTranspose2d)逐步上采样回 64×64。

python 复制代码
        # ==================== 解码器 (Decoder) ====================
        # 解码器将隐向量 z [B, 128] 映射回原始图像空间 [B, 3, 64, 64]
        modules = []

        # 首先通过全连接层将隐向量投影到与编码器输出相同的维度
        # [B, 128] → [B, 256*2*2=1024] → reshape → [B, 256, 2, 2]
        self.decoder_projection = nn.Linear(
            latent_dim, prev_channels * img_length * img_length)
        self.decoder_input_chw = (prev_channels, img_length, img_length)  # (256, 2, 2)

        # 转置卷积 (ConvTranspose2d) 实现上采样:每次将空间尺寸翻倍
        # 通道数从 256 逐步减半至 16(与编码器对称但逆序)
        # 注意:range(len(hiddens)-1, 0, -1) 即 [4, 3, 2, 1]
        # 对应 hiddens[4]→hiddens[3]: 256→128, 128→64, 64→32, 32→16
        for i in range(len(hiddens) - 1, 0, -1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hiddens[i],       # 输入通道 (深层)
                                       hiddens[i - 1],   # 输出通道 (浅层)
                                       kernel_size=3,    # 3×3 卷积核
                                       stride=2,         # 步长 2,实现 2 倍上采样
                                       padding=1,        # 填充 1
                                       output_padding=1),# 输出填充,确保尺寸精确翻倍
                    nn.BatchNorm2d(hiddens[i - 1]),
                    nn.ReLU()))

        # 最后一层解码:将 16 通道特征图转换为 3 通道 RGB 图像
        # 分两步:先用转置卷积上采样 (16→16, 32→64),再用普通卷积映射通道 (16→3)
        modules.append(
            nn.Sequential(
                nn.ConvTranspose2d(hiddens[0],           # 输入: 16 通道, 32×32
                                   hiddens[0],           # 输出: 16 通道, 64×64
                                   kernel_size=3,
                                   stride=2,
                                   padding=1,
                                   output_padding=1),
                nn.BatchNorm2d(hiddens[0]),
                nn.ReLU(),
                # 普通卷积将通道数从 16 映射到 3 (RGB),保持 64×64 尺寸不变
                nn.Conv2d(hiddens[0], 3, kernel_size=3, stride=1, padding=1),
                nn.ReLU()))                              # ReLU 保证输出像素值非负
        self.decoder = nn.Sequential(*modules)

解码器维度变化

text 复制代码
隐向量 z: [128]
  ↓ Linear(128 → 1024)
  ↓ Reshape → [256, 2, 2]
  ↓ ConvTranspose2d(256→128, k3, s2, p1) + BN + ReLU → [128, 4, 4]
  ↓ ConvTranspose2d(128→64, k3, s2, p1) + BN + ReLU  → [64,  8,  8]
  ↓ ConvTranspose2d(64→32, k3, s2, p1) + BN + ReLU   → [32, 16, 16]
  ↓ ConvTranspose2d(32→16, k3, s2, p1) + BN + ReLU   → [16, 32, 32]
  ↓ ConvTranspose2d(16→16, k3, s2, p1) + BN + ReLU   → [16, 64, 64]
  ↓ Conv2d(16→3, k3, s1, p1) + ReLU → [3, 64, 64]

设计细节:

  • output_padding=1 :stride=2 时,output_padding=1 确保输出尺寸精确翻倍(避免 32→63 的情况)。
  • 最后一层分两步:先转置卷积上采样(16 通道 32×32 → 16 通道 64×64),再用普通卷积映射通道(16→3 RGB)。将"上采样"和"通道映射"解耦,让网络更容易优化。


6. 从隐空间采样生成新图像

python 复制代码
    def sample(self, device='cuda'):
        """
        从标准正态分布 N(0, I) 采样一个隐向量,通过解码器生成一张新的人脸图像。

        这是 VAE 作为生成模型的核心功能:无需任何输入图像,
        直接从先验分布采样即可生成新样本。

        参数:
            device: 计算设备,默认 'cuda'

        返回:
            decoded: 生成的图像 [1, 3, 64, 64]
        """
        z = torch.randn(1, self.latent_dim).to(device)    # z ~ N(0, I),batch_size=1
        x = self.decoder_projection(z)
        x = torch.reshape(x, (-1, *self.decoder_input_chw))
        decoded = self.decoder(x)
        return decoded

一旦训练完成,生成新图像只需从 N ( 0 , I ) \mathcal{N}(0, I) N(0,I) 采样一个 128 维向量,传入解码器即可。这正是 VAE 作为生成模型的核心能力------无需任何输入,凭空"想象"出新的人脸。


7. 损失函数:重构 + 正则化

python 复制代码
def loss_fn(y, y_hat, mean, logvar):
    """
    VAE 的损失函数 = 重构损失 + KL 散度正则项

    VAE 的优化目标(ELBO):
        L = E[log p(x|z)] - D_KL(q(z|x) || p(z))

    即最大化证据下界等价于:
        最小化 重构误差 + KL 散度

    参数:
        y:      原始输入图像 [B, 3, 64, 64]
        y_hat:  重建图像 [B, 3, 64, 64]
        mean:   编码器输出的均值 μ [B, latent_dim]
        logvar: 编码器输出的对数方差 log(σ²) [B, latent_dim]

    返回:
        loss: 总损失(标量)
    """
    # ---- 重构损失 (Reconstruction Loss) ----
    # 使用 MSE 衡量原始图像与重建图像的像素级差异
    # 等价于假设 p(x|z) 为高斯分布时的负对数似然
    recons_loss = F.mse_loss(y_hat, y)

    # ---- KL 散度 (Kullback-Leibler Divergence) ----
    # D_KL( N(μ, σ²) || N(0, I) ) 的解析形式:
    #   = -0.5 * Σ( 1 + log(σ²) - μ² - σ² )
    #
    # 直观理解:KL 散度约束编码器输出的分布接近标准正态分布
    # - 当 μ=0, σ²=1 时 KL=0(完全匹配先验)
    # - 当 μ 偏离 0 或 σ² 偏离 1 时 KL>0(惩罚偏离先验)
    kl_loss = torch.mean(
        -0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar), 1), 0)

    # ---- 总损失 ----
    # kl_weight < 1 意味着放松对 KL 散度的约束
    # 允许编码器学习更灵活的分布,换取更好的重建质量
    loss = recons_loss + kl_loss * kl_weight
    return loss

KL 散度的解析形式

当先验 p ( z ) = N ( 0 , I ) p(z) = \mathcal{N}(0, I) p(z)=N(0,I) 且后验 q ( z ∣ x ) = N ( μ , σ 2 ) q(z|x) = \mathcal{N}(\mu, \sigma^2) q(z∣x)=N(μ,σ2) 时,KL 散度有闭式解:

D K L ( N ( μ , σ 2 ) ∥ N ( 0 , I ) ) = − 1 2 ∑ j = 1 d ( 1 + log ⁡ σ j 2 − μ j 2 − σ j 2 ) D_{KL}(\mathcal{N}(\mu, \sigma^2) \| \mathcal{N}(0, I)) = -\frac{1}{2}\sum_{j=1}^{d}\left(1 + \log\sigma_j^2 - \mu_j^2 - \sigma_j^2\right) DKL(N(μ,σ2)∥N(0,I))=−21j=1∑d(1+logσj2−μj2−σj2)

代码中 kl_weight = 0.00025 是一个关键超参数。标准 VAE 的 KL 散度系数为 1,但在这里被设为远小于 1,主要目的有两个:

  1. 防止 posterior collapse(后验坍缩) :当 KL 项过强时,编码器可能直接输出先验 N ( 0 , I ) \mathcal{N}(0,I) N(0,I) 而忽略输入图像,导致解码器学会无视隐变量 z z z,VAE 退化为普通自编码器
  2. 平衡损失量级:对于 64×64×3 的高维图像,MSE 重构损失的数值量级远大于 128 维潜变量的 KL 散度,减小 KL 权重有助于两项损失的平衡优化

8. 训练流程

python 复制代码
"""
VAE 训练与推理脚本

功能:
    1. train():   在 CelebA 数据集上训练 VAE 模型
    2. reconstruct(): 输入真实图像,经过编码-解码后重建,保存对比图
    3. generate(): 从 N(0,I) 随机采样,生成全新的人脸图像并保存

使用方式:
    1. 下载 CelebA Align&Cropped Images 数据集
    2. 修改 load_celebA.py 中 get_dataloader 的 root 路径指向你的数据目录
    3. 运行 python main.py(在 main() 中选择要执行的功能)
"""

# ==================== 超参数 ====================
n_epochs = 10            # 训练轮数(CelebA 约 20 万张图,10 个 epoch 可得到基本可用的结果)
kl_weight = 0.00025      # KL 散度的权重系数
                         # 标准 VAE 中该系数为 1,这里设为远小于 1 是为了防止
                         # posterior collapse(后验坍缩)------即 KL 项过强时,编码器
                         # 直接输出先验 N(0,I),解码器学会忽略隐变量 z,导致 VAE
                         # 退化为普通自编码器。同时,对于高维图像数据,MSE 重构
                         # 损失的数值量级远大于 KL 散度,减小 KL 权重有助于平衡两项
lr = 0.005               # Adam 优化器的学习率


def train(device, dataloader, model):
    """
    训练 VAE 模型

    训练流程:
        1. 从 dataloader 获取一个 batch 的图像
        2. 前向传播:图像 → 编码 → 重参数化 → 解码 → 重建图像
        3. 计算损失:MSE(重建, 原图) + kl_weight * KL( N(μ,σ²) || N(0,I) )
        4. 反向传播、更新参数
        5. 每 epoch 结束后打印损失并保存模型权重
    """
    optimizer = torch.optim.Adam(model.parameters(), lr)
    dataset_len = len(dataloader.dataset)                 # 数据集总样本数

    begin_time = time()
    for i in range(n_epochs):
        loss_sum = 0                                      # 累计损失,用于计算平均
        for x in dataloader:
            x = x.to(device)                              # 将图像移到 GPU
            y_hat, mean, logvar = model(x)                # 前向传播:得到重建图、μ、log(σ²)
            loss = loss_fn(x, y_hat, mean, logvar)        # 计算损失
            optimizer.zero_grad()                         # 清空梯度缓存
            loss.backward()                               # 反向传播计算梯度
            optimizer.step()                              # 更新参数
            loss_sum += loss
        loss_sum /= dataset_len                           # 计算每个样本的平均损失
        training_time = time() - begin_time
        minute = int(training_time // 60)
        second = int(training_time % 60)
        print(f'epoch {i}: loss {loss_sum} {minute}:{second}')
        torch.save(model.state_dict(), 'model.pth')       # 每轮保存一次模型

训练配置精简但有效:10 个 epoch、Adam 优化器、batch size 16。得益于 VAE 的稳定训练特性(相比 GAN 需要小心的 minmax 博弈),无需复杂的调参即可收敛。


9. 实验结果:重建与生成

训练完成后,代码提供了两种评估方式:

重建 (Reconstruct)

python 复制代码
def reconstruct(device, dataloader, model):
    """
    图像重建演示:取一张真实图像,经过 VAE 编码再解码,观察重建质量。

    模型需要先训练好,加载 model.pth 权重后使用。
    结果保存为 reconstruct.jpg(左半:重建,右半:原图)。
    """
    model.eval()                                          # 切换到评估模式(关闭 BN 的运行时统计)
    batch = next(iter(dataloader))                        # 取一个 batch
    x = batch[0:1, ...].to(device)                        # 只取第一张图 [1, 3, 64, 64]
    output = model(x)[0]                                  # 重建图像(丢弃返回的 mean, logvar)
    output = output[0].detach().cpu()                     # 去除 batch 维度,移回 CPU
    input = batch[0].detach().cpu()                       # 原始图像
    combined = torch.cat((output, input), 1)              # 水平拼接:左重建 | 右原图
    img = ToPILImage()(combined)                          # Tensor → PIL Image
    img.save('reconstruct.jpg')

生成 (Generate)

python 复制代码
def generate(device, model):
    """
    随机生成新图像:从标准正态分布 N(0,I) 采样隐向量,通过解码器生成人脸。

    这是 VAE 区别于普通自编码器的关键能力------无需输入图像即可生成新样本。
    结果保存为 generate.jpg。
    """
    model.eval()                                          # 切换到评估模式
    output = model.sample(device)                         # 从 N(0,I) 采样 → 解码 → 生成图像
    output = output[0].detach().cpu()                     # 去除 batch 维度,移回 CPU
    img = ToPILImage()(output)                            # Tensor → PIL Image
    img.save('generate.jpg')

10. 数据加载:CelebA 数据集

CelebA (CelebFaces Attributes Dataset) 包含超过 20 万张名人面部图像。本项目使用 Align&Cropped 版本(已人脸对齐并裁剪),原始尺寸 178×218。

python 复制代码
"""
CelebA 数据集加载模块

CelebA (CelebFaces Attributes Dataset) 是一个大规模人脸属性数据集,
包含超过 20 万张名人面部图像,每张标注了 40 种属性。

本模块使用 Align&Cropped 版本,即已经过人脸对齐和裁剪的图像,
原始尺寸为 178×218,通过预处理转换为 64×64 的 RGB 图像供 VAE 训练使用。
"""

class CelebADataset(Dataset):
    """
    CelebA 自定义数据集类

    继承 torch.utils.data.Dataset,实现 __len__ 和 __getitem__ 方法,
    使其可以被 PyTorch 的 DataLoader 加载。

    预处理流程:
        1. 读取原始图像 (178×218)
        2. CenterCrop(168): 中心裁剪为 168×168 的正方形,去除不均匀的背景
        3. Resize(64×64):   缩放到目标尺寸
        4. ToTensor():      将 PIL Image 转换为 [0, 1] 范围的 torch.Tensor
    """

    def __init__(self, root, img_shape=(64, 64)) -> None:
        """
        参数:
            root: 数据集根目录,包含所有 .jpg 图像文件
            img_shape: 目标图像尺寸 (高, 宽),默认 (64, 64)
        """
        super().__init__()
        self.root = root
        self.img_shape = img_shape
        self.filenames = sorted(os.listdir(root))         # 按文件名排序,保证每次加载顺序一致

    def __len__(self) -> int:
        """返回数据集样本总数"""
        return len(self.filenames)

    def __getitem__(self, index: int):
        """
        根据索引返回预处理后的图像张量

        参数:
            index: 样本索引 (0 ~ len(dataset)-1)

        返回:
            torch.Tensor: 形状为 [3, H, W] 的 RGB 图像,值域 [0, 1]
        """
        path = os.path.join(self.root, self.filenames[index])
        img = Image.open(path).convert('RGB')             # 打开图像并确保为 RGB 三通道

        # 预处理 pipeline
        # CenterCrop(168): CelebA 原始尺寸 178×218,中心裁出 168×168 正方形
        #                 人脸位于图像中央,裁剪后去除大部分背景区域
        # Resize(64,64):  将 168×168 缩放至 64×64,大幅降低计算量
        # ToTensor():     PIL Image [0, 255] → torch.Tensor [0.0, 1.0]
        #                 同时自动将通道维度从 HWC 转为 CHW
        pipeline = transforms.Compose([
            transforms.CenterCrop(168),
            transforms.Resize(self.img_shape),
            transforms.ToTensor()
        ])
        return pipeline(img)


def get_dataloader(root='/data/ym/datasets-face/img_align_celeba', **kwargs):
    """
    创建 CelebA 数据集的 DataLoader

    参数:
        root: 数据集路径,请根据实际存放位置修改
        **kwargs: 传递给 CelebADataset 的额外参数(如 img_shape)

    返回:
        DataLoader: batch_size=16, shuffle=True
                   每次迭代返回形状为 [16, 3, 64, 64] 的图像 batch
    """
    dataset = CelebADataset(root, **kwargs)
    return DataLoader(dataset, 16, shuffle=True)          # batch_size=16,每个 epoch 随机打乱

预处理流程:CenterCrop(168)Resize(64,64)ToTensor()。CelebA 原始对齐图像为 178×218,中心裁剪 168×168 后缩放到 64×64,去除了大部分背景,保留了人脸核心区域。


11. 完整的前向传播流程

将整个 pipeline 串联起来,一个 batch 的完整前向传播如下:

text 复制代码
输入 x: [B, 3, 64, 64]
    │
    ▼
编码器 (5层 stride-2 卷积) → 特征图 [B, 256, 2, 2]
    │
    ▼
Flatten → [B, 1024]
    │
    ├── mean_linear  → μ  [B, 128]
    └── var_linear   → log σ² [B, 128]
    │
    ▼
重参数化: z = μ + exp(log σ²/2) · ε    (ε ~ N(0,I))
    │
    ▼
Decoder Projection: Linear(128 → 1024) → Reshape → [B, 256, 2, 2]
    │
    ▼
解码器 (5层转置卷积) → 重建图像 [B, 3, 64, 64]
    │
    ▼
损失: MSE(x, x̂) + kl_weight · KL( N(μ,σ²) || N(0,I) )

12. 关键超参数与调优建议

超参数 取值 影响
latent_dim 128 越大表达力越强,但训练越慢,过大可能过拟合
kl_weight 0.00025 最关键的超参数。太大会导致 posterior collapse(解码器忽略 z),太小则隐空间不够正则化
hiddens 16,32,64,128,256 控制模型容量,可根据显存调整
lr 0.005 Adam 下略高的学习率,加速收敛
batch_size 16 小 batch 带来的噪声有助于隐空间泛化

13. 扩展方向

基于这个基础实现,可以进一步探索:

  1. 隐空间插值:在两个人脸的隐向量之间线性插值,观察生成的渐变效果
  2. 条件 VAE (CVAE):加入属性标签(性别、发色等),控制生成特定特征的人脸
  3. VQ-VAE:将连续隐空间替换为离散的 codebook,通常能获得更清晰的生成结果
  4. 与 GAN 对比:在同一数据集上训练 DCGAN,对比两种生成范式的差异

14. 完整代码清单

所有代码可在 dldemos/VAE/ 目录下找到:

  • model.py --- VAE 模型定义(含详尽注释)
  • main.py --- 训练与推理(含详尽注释)
  • load_celebA.py --- 数据加载(含详尽注释)

运行方式:

bash 复制代码
# 1. 下载 CelebA Align&Cropped Images
# 2. 修改 main.py 中 get_dataloader 的数据路径
# 3. 运行
python main.py

15. 完整代码

load_celebA.py --- 数据加载

py 复制代码
"""
CelebA 数据集加载模块

CelebA (CelebFaces Attributes Dataset) 是一个大规模人脸属性数据集,
包含超过 20 万张名人面部图像,每张标注了 40 种属性。

本模块使用 Align&Cropped 版本,即已经过人脸对齐和裁剪的图像,
原始尺寸为 178×218,通过预处理转换为 64×64 的 RGB 图像供 VAE 训练使用。
"""

import os

import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms


class CelebADataset(Dataset):
    """
    CelebA 自定义数据集类

    继承 torch.utils.data.Dataset,实现 __len__ 和 __getitem__ 方法,
    使其可以被 PyTorch 的 DataLoader 加载。

    预处理流程:
        1. 读取原始图像 (178×218)
        2. CenterCrop(168): 中心裁剪为 168×168 的正方形,去除不均匀的背景
        3. Resize(64×64):   缩放到目标尺寸
        4. ToTensor():      将 PIL Image 转换为 [0, 1] 范围的 torch.Tensor
    """

    def __init__(self, root, img_shape=(64, 64)) -> None:
        """
        参数:
            root: 数据集根目录,包含所有 .jpg 图像文件
            img_shape: 目标图像尺寸 (高, 宽),默认 (64, 64)
        """
        super().__init__()
        self.root = root
        self.img_shape = img_shape
        self.filenames = sorted(os.listdir(root))         # 按文件名排序,保证每次加载顺序一致

    def __len__(self) -> int:
        """返回数据集样本总数"""
        return len(self.filenames)

    def __getitem__(self, index: int):
        """
        根据索引返回预处理后的图像张量

        参数:
            index: 样本索引 (0 ~ len(dataset)-1)

        返回:
            torch.Tensor: 形状为 [3, H, W] 的 RGB 图像,值域 [0, 1]
        """
        path = os.path.join(self.root, self.filenames[index])
        img = Image.open(path).convert('RGB')             # 打开图像并确保为 RGB 三通道

        # 预处理 pipeline
        # CenterCrop(168): CelebA 原始尺寸 178×218,中心裁出 168×168 正方形
        #                 人脸位于图像中央,裁剪后去除大部分背景区域
        # Resize(64,64):  将 168×168 缩放至 64×64,大幅降低计算量
        # ToTensor():     PIL Image [0, 255] → torch.Tensor [0.0, 1.0]
        #                 同时自动将通道维度从 HWC 转为 CHW
        pipeline = transforms.Compose([
            transforms.CenterCrop(168),
            transforms.Resize(self.img_shape),
            transforms.ToTensor()
        ])
        return pipeline(img)


def get_dataloader(root='/data/bryant/datasets-face/img_align_celeba', **kwargs):
    """
    创建 CelebA 数据集的 DataLoader

    参数:
        root: 数据集路径,请根据实际存放位置修改
        **kwargs: 传递给 CelebADataset 的额外参数(如 img_shape)

    返回:
        DataLoader: batch_size=16, shuffle=True
                   每次迭代返回形状为 [16, 3, 64, 64] 的图像 batch
    """
    dataset = CelebADataset(root, **kwargs)
    return DataLoader(dataset, 16, shuffle=True)          # batch_size=16,每个 epoch 随机打乱


if __name__ == '__main__':
    # 测试数据加载:取一个 batch 的图像,拼接成 4×4 的网格图保存
    # 用于检查数据预处理效果和图像质量
    dataloader = get_dataloader()
    img = next(iter(dataloader))                          # 获取一个 batch: [16, 3, 64, 64]
    print(img.shape)                                      # 打印形状以确认: torch.Size([16, 3, 64, 64])

    # 将 16 张图拼接成 4×4 的网格
    N, C, H, W = img.shape
    assert N == 16                                        # 确保 batch_size=16
    img = torch.permute(img, (1, 0, 2, 3))               # [3, 16, 64, 64]
    img = torch.reshape(img, (C, 4, 4 * H, W))           # [3, 4, 256, 64] --- 4行,每行4张图拼接
    img = torch.permute(img, (0, 2, 1, 3))               # [3, 256, 4, 64]
    img = torch.reshape(img, (C, 4 * H, 4 * W))           # [3, 256, 256] --- 最终 4×4 网格
    img = transforms.ToPILImage()(img)                    # Tensor → PIL Image
    img.save('tmp.jpg')                                   # 保存预览图
    print('Preview grid saved as tmp.jpg')

py 复制代码
import torch
import torch.nn as nn


class VAE(nn.Module):
    """VAE for 64x64 face generation.

    VAE (Variational Autoencoder) 变分自编码器,用于 64x64 人脸图像的生成。
    与普通自编码器不同,VAE 的编码器输出隐变量的概率分布参数 (μ, σ²),
    而非一个确定的隐向量,从而使得模型具备生成新样本的能力。

    The hidden dimensions can be tuned.
    中间的隐藏层通道数可以根据显存和效果需求进行调整。
    """

    def __init__(self, hiddens=[16, 32, 64, 128, 256], latent_dim=128) -> None:
        """
        参数:
            hiddens: 编码器每层卷积的输出通道数,同时也是解码器转置卷积的输入通道数(逆序使用)
                     默认 [16, 32, 64, 128, 256],共 5 层,每次下采样 2 倍
            latent_dim: 隐空间维度,即潜变量 z 的维度,默认 128
        """
        super().__init__()

        # ==================== 编码器 (Encoder) ====================
        # 编码器将输入图像 [B, 3, 64, 64] 逐步压缩为特征图 [B, 256, 2, 2]
        # 然后通过两个全连接层分别输出隐变量的均值和对数方差
        prev_channels = 3          # 输入为 3 通道 RGB 图像
        modules = []               # 存储各层卷积模块
        img_length = 64            # 输入图像的空间尺寸 (高/宽)
        for cur_channels in hiddens:
            # 每层:Conv2d(stride=2) 将空间尺寸减半,通道数翻倍
            # kernel_size=3, stride=2, padding=1 实现 exactly 减半的效果
            # 以 64 → 32 为例: output = (64 - 3 + 2*1) / 2 + 1 = 32 ✓
            modules.append(
                nn.Sequential(
                    nn.Conv2d(prev_channels,      # 输入通道数
                              cur_channels,       # 输出通道数
                              kernel_size=3,      # 3×3 卷积核
                              stride=2,           # 步长为 2,替代池化实现下采样
                              padding=1),         # 填充 1,保持边界信息
                    nn.BatchNorm2d(cur_channels), # 批归一化,加速收敛、稳定训练
                    nn.ReLU()))                   # ReLU 激活,引入非线性
            prev_channels = cur_channels
            img_length //= 2                      # 每层空间尺寸减半: 64→32→16→8→4→2
        self.encoder = nn.Sequential(*modules)

        # 编码器最终输出特征图尺寸: [B, 256, 2, 2],展平后为 [B, 1024]
        # 两个独立的全连接头:一个预测均值 μ,一个预测对数方差 log(σ²)
        # 使用 log(σ²) 而非 σ² 的原因:
        #   1. 数值稳定性更好,避免方差趋近于 0
        #   2. 无需约束输出为正(exp 后自然为正)
        self.mean_linear = nn.Linear(prev_channels * img_length * img_length,
                                     latent_dim)  # 1024 → 128,输出均值 μ
        self.var_linear = nn.Linear(prev_channels * img_length * img_length,
                                    latent_dim)   # 1024 → 128,输出对数方差 log(σ²)
        self.latent_dim = latent_dim

        # ==================== 解码器 (Decoder) ====================
        # 解码器将隐向量 z [B, 128] 映射回原始图像空间 [B, 3, 64, 64]
        modules = []

        # 首先通过全连接层将隐向量投影到与编码器输出相同的维度
        # [B, 128] → [B, 256*2*2=1024] → reshape → [B, 256, 2, 2]
        self.decoder_projection = nn.Linear(
            latent_dim, prev_channels * img_length * img_length)
        self.decoder_input_chw = (prev_channels, img_length, img_length)  # (256, 2, 2)

        # 转置卷积 (ConvTranspose2d) 实现上采样:每次将空间尺寸翻倍
        # 通道数从 256 逐步减半至 16(与编码器对称但逆序)
        # 注意:range(len(hiddens)-1, 0, -1) 即 [4, 3, 2, 1]
        # 对应 hiddens[4]→hiddens[3]: 256→128, 128→64, 64→32, 32→16
        for i in range(len(hiddens) - 1, 0, -1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hiddens[i],       # 输入通道 (深层)
                                       hiddens[i - 1],   # 输出通道 (浅层)
                                       kernel_size=3,    # 3×3 卷积核
                                       stride=2,         # 步长 2,实现 2 倍上采样
                                       padding=1,        # 填充 1
                                       output_padding=1),# 输出填充,确保尺寸精确翻倍
                    nn.BatchNorm2d(hiddens[i - 1]),
                    nn.ReLU()))

        # 最后一层解码:将 16 通道特征图转换为 3 通道 RGB 图像
        # 分两步:先用转置卷积上采样 (16→16, 32→64),再用普通卷积映射通道 (16→3)
        modules.append(
            nn.Sequential(
                nn.ConvTranspose2d(hiddens[0],           # 输入: 16 通道, 32×32
                                   hiddens[0],           # 输出: 16 通道, 64×64
                                   kernel_size=3,
                                   stride=2,
                                   padding=1,
                                   output_padding=1),
                nn.BatchNorm2d(hiddens[0]),
                nn.ReLU(),
                # 普通卷积将通道数从 16 映射到 3 (RGB),保持 64×64 尺寸不变
                nn.Conv2d(hiddens[0], 3, kernel_size=3, stride=1, padding=1),
                nn.ReLU()))                              # ReLU 保证输出像素值非负
        self.decoder = nn.Sequential(*modules)

    def forward(self, x):
        """
        前向传播:输入图像 → 编码 → 重参数化采样 → 解码 → 重建图像

        参数:
            x: 输入图像 [B, 3, 64, 64]

        返回:
            decoded: 重建图像 [B, 3, 64, 64]
            mean: 隐变量分布的均值 μ [B, latent_dim]
            logvar: 隐变量分布的对数方差 log(σ²) [B, latent_dim]
        """
        # ---- 编码阶段 ----
        encoded = self.encoder(x)                         # [B, 3, 64, 64] → [B, 256, 2, 2]
        encoded = torch.flatten(encoded, 1)               # 展平: [B, 256, 2, 2] → [B, 1024]
        mean = self.mean_linear(encoded)                  # 均值 μ: [B, 128]
        logvar = self.var_linear(encoded)                 # 对数方差 log(σ²): [B, 128]

        # ---- 重参数化技巧 (Reparameterization Trick) ----
        # 目标:从 N(μ, σ²) 中采样 z,同时保持梯度可回传
        # 做法:先从 N(0,1) 采样 ε,再计算 z = μ + σ · ε
        # 这样随机性被隔离在 ε 中,μ 和 σ 仍可接收梯度
        eps = torch.randn_like(logvar)                    # ε ~ N(0, I),与 logvar 同形状
        std = torch.exp(logvar / 2)                       # σ = exp(log σ² / 2) = exp(log σ)
        z = eps * std + mean                              # z = μ + σ · ε,重参数化后的隐变量

        # ---- 解码阶段 ----
        x = self.decoder_projection(z)                    # [B, 128] → [B, 1024]
        x = torch.reshape(x, (-1, *self.decoder_input_chw))  # [B, 1024] → [B, 256, 2, 2]
        decoded = self.decoder(x)                         # [B, 256, 2, 2] → [B, 3, 64, 64]

        return decoded, mean, logvar

    def sample(self, device='cuda'):
        """
        从标准正态分布 N(0, I) 采样一个隐向量,通过解码器生成一张新的人脸图像。

        这是 VAE 作为生成模型的核心功能:无需任何输入图像,
        直接从先验分布采样即可生成新样本。

        参数:
            device: 计算设备,默认 'cuda'

        返回:
            decoded: 生成的图像 [1, 3, 64, 64]
        """
        z = torch.randn(1, self.latent_dim).to(device)    # z ~ N(0, I),batch_size=1
        x = self.decoder_projection(z)
        x = torch.reshape(x, (-1, *self.decoder_input_chw))
        decoded = self.decoder(x)
        return decoded

py 复制代码
"""
VAE 训练与推理脚本

功能:
    1. train():   在 CelebA 数据集上训练 VAE 模型
    2. reconstruct(): 输入真实图像,经过编码-解码后重建,保存对比图
    3. generate(): 从 N(0,I) 随机采样,生成全新的人脸图像并保存

使用方式:
    1. 下载 CelebA Align&Cropped Images 数据集
    2. 修改 load_celebA.py 中 get_dataloader 的 root 路径指向你的数据目录
    3. 运行 python main.py(在 main() 中选择要执行的功能)
"""

from time import time

import torch
import torch.nn.functional as F
from torchvision.transforms import ToPILImage

# 从本地模块导入数据加载器和模型
# 注:如果作为包运行,可改为 from dldemos.VAE.load_celebA import get_dataloader
from load_celebA import get_dataloader
from model import VAE

# ==================== 超参数 ====================
n_epochs = 10            # 训练轮数(CelebA 约 20 万张图,10 个 epoch 可得到基本可用的结果)
kl_weight = 0.00025      # KL 散度的权重(β-VAE 中的 β)
                         # 设为远小于 1 的值,让模型更注重重建质量
                         # 若设为 1 则等价于标准 VAE,但重建效果通常较差
lr = 0.005               # Adam 优化器的学习率


def loss_fn(y, y_hat, mean, logvar):
    """
    VAE 的损失函数 = 重构损失 + KL 散度正则项

    VAE 的优化目标(ELBO):
        L = E[log p(x|z)] - D_KL(q(z|x) || p(z))

    即最大化证据下界等价于:
        最小化 重构误差 + KL 散度

    参数:
        y:      原始输入图像 [B, 3, 64, 64]
        y_hat:  重建图像 [B, 3, 64, 64]
        mean:   编码器输出的均值 μ [B, latent_dim]
        logvar: 编码器输出的对数方差 log(σ²) [B, latent_dim]

    返回:
        loss: 总损失(标量)
    """
    # ---- 重构损失 (Reconstruction Loss) ----
    # 使用 MSE 衡量原始图像与重建图像的像素级差异
    # 等价于假设 p(x|z) 为高斯分布时的负对数似然
    recons_loss = F.mse_loss(y_hat, y)

    # ---- KL 散度 (Kullback-Leibler Divergence) ----
    # D_KL( N(μ, σ²) || N(0, I) ) 的解析形式:
    #   = -0.5 * Σ( 1 + log(σ²) - μ² - σ² )
    #
    # 直观理解:KL 散度约束编码器输出的分布接近标准正态分布
    # - 当 μ=0, σ²=1 时 KL=0(完全匹配先验)
    # - 当 μ 偏离 0 或 σ² 偏离 1 时 KL>0(惩罚偏离先验)
    kl_loss = torch.mean(
        -0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar), 1), 0)

    # ---- 总损失(β-VAE 形式)----
    # kl_weight < 1 意味着放松对 KL 散度的约束
    # 允许编码器学习更灵活的分布,换取更好的重建质量
    loss = recons_loss + kl_loss * kl_weight
    return loss


def train(device, dataloader, model):
    """
    训练 VAE 模型

    训练流程:
        1. 从 dataloader 获取一个 batch 的图像
        2. 前向传播:图像 → 编码 → 重参数化 → 解码 → 重建图像
        3. 计算损失:MSE(重建, 原图) + β * KL( N(μ,σ²) || N(0,I) )
        4. 反向传播、更新参数
        5. 每 epoch 结束后打印损失并保存模型权重
    """
    optimizer = torch.optim.Adam(model.parameters(), lr)
    dataset_len = len(dataloader.dataset)                 # 数据集总样本数

    begin_time = time()
    # train
    for i in range(n_epochs):
        loss_sum = 0                                      # 累计损失,用于计算平均
        for x in dataloader:
            x = x.to(device)                              # 将图像移到 GPU
            y_hat, mean, logvar = model(x)                # 前向传播:得到重建图、μ、log(σ²)
            loss = loss_fn(x, y_hat, mean, logvar)        # 计算损失
            optimizer.zero_grad()                         # 清空梯度缓存
            loss.backward()                               # 反向传播计算梯度
            optimizer.step()                              # 更新参数
            loss_sum += loss
        loss_sum /= dataset_len                           # 计算每个样本的平均损失
        training_time = time() - begin_time
        minute = int(training_time // 60)
        second = int(training_time % 60)
        print(f'epoch {i}: loss {loss_sum} {minute}:{second}')
        torch.save(model.state_dict(), 'model.pth')       # 每轮保存一次模型


def reconstruct(device, dataloader, model):
    """
    图像重建演示:取一张真实图像,经过 VAE 编码再解码,观察重建质量。

    模型需要先训练好,加载 model.pth 权重后使用。
    结果保存为 reconstruct.jpg(左半:重建,右半:原图)。
    """
    model.eval()                                          # 切换到评估模式(关闭 BN 的运行时统计)
    batch = next(iter(dataloader))                        # 取一个 batch
    x = batch[0:1, ...].to(device)                        # 只取第一张图 [1, 3, 64, 64]
    output = model(x)[0]                                  # 重建图像(丢弃返回的 mean, logvar)
    output = output[0].detach().cpu()                     # 去除 batch 维度,移回 CPU
    input = batch[0].detach().cpu()                       # 原始图像
    combined = torch.cat((output, input), 1)              # 水平拼接:左重建 | 右原图
    img = ToPILImage()(combined)                          # Tensor → PIL Image
    img.save('reconstruct.jpg')


def generate(device, model):
    """
    随机生成新图像:从标准正态分布 N(0,I) 采样隐向量,通过解码器生成人脸。

    这是 VAE 区别于普通自编码器的关键能力------无需输入图像即可生成新样本。
    结果保存为 generate.jpg。
    """
    model.eval()                                          # 切换到评估模式
    output = model.sample(device)                         # 从 N(0,I) 采样 → 解码 → 生成图像
    output = output[0].detach().cpu()                     # 去除 batch 维度,移回 CPU
    img = ToPILImage()(output)                            # Tensor → PIL Image
    img.save('generate.jpg')


def main():
    """
    主函数:根据需求选择执行 train / reconstruct / generate。

    使用方法:
        1. 首次运行:取消 train() 的注释,训练模型
        2. 训练完成后:取消 reconstruct() 和 generate() 的注释,查看效果
    """
    device = 'cuda:2'                                     # 指定使用的 GPU(可根据实际环境修改)
    dataloader = get_dataloader()                         # 加载 CelebA 数据集

    model = VAE().to(device)                              # 实例化模型并移至 GPU

    # If you obtain the ckpt, load it
    model.load_state_dict(torch.load('model.pth', 'cuda:0'))  # 加载预训练权重

    # Choose the function
    # train(device, dataloader, model)
    reconstruct(device, dataloader, model)                # 重建演示
    generate(device, model)                               # 生成演示


if __name__ == '__main__':
    main()

参考


本文是对 DL-Demos 项目中 VAE 实现的详细解读,旨在帮助初学者理解变分自编码器的原理与 PyTorch 实现细节。

相关推荐
冷小鱼2 小时前
TensorFlow 2.21 进阶实战:从训练优化到生产部署的完整指南
人工智能·pytorch·python·tensorflow
冷小鱼3 小时前
PyTorch 2.12 完全指南:从动态图到编译优化的深度学习框架演进
人工智能·pytorch·深度学习
盼小辉丶3 小时前
PyTorch强化学习实战(14)——优先经验回放机制
pytorch·python·深度学习·强化学习
装不满的克莱因瓶3 小时前
【工业领域】了解目标检测评估指标——从mAP到IoU的完整评价体系解析
人工智能·pytorch·python·深度学习·目标检测·计算机视觉·目标跟踪
闵孚龙12 小时前
动态图机制:为什么 PyTorch 调试起来更舒服
人工智能·pytorch·python
Kobebryant-Manba13 小时前
RNN从0实现
pytorch·rnn·深度学习
闵孚龙17 小时前
PyTorch 系列 之 nn.Module:所有模型的骨架
人工智能·pytorch·python
去伪存真21 小时前
如何将没有字幕的英文视频转换成中文视频?
前端·pytorch·llm
装不满的克莱因瓶1 天前
了解3D卷积原理——从空间感知到时空建模的深度学习核心算子
人工智能·pytorch·python·深度学习·机器学习·3d·ai