
文章目录
- [1. 什么是 VAE?](#1. 什么是 VAE?)
- [2. 整体架构概览](#2. 整体架构概览)
- [3. 编码器:从图像到概率分布](#3. 编码器:从图像到概率分布)
- [4. 重参数化技巧 (Reparameterization Trick)](#4. 重参数化技巧 (Reparameterization Trick))
- [5. 解码器:从隐向量重建图像](#5. 解码器:从隐向量重建图像)
- [6. 从隐空间采样生成新图像](#6. 从隐空间采样生成新图像)
- [7. 损失函数:重构 + 正则化](#7. 损失函数:重构 + 正则化)
- [8. 训练流程](#8. 训练流程)
- [9. 实验结果:重建与生成](#9. 实验结果:重建与生成)
- [10. 数据加载:CelebA 数据集](#10. 数据加载:CelebA 数据集)
- [11. 完整的前向传播流程](#11. 完整的前向传播流程)
- [12. 关键超参数与调优建议](#12. 关键超参数与调优建议)
- [13. 扩展方向](#13. 扩展方向)
- [14. 完整代码清单](#14. 完整代码清单)
- [15. 完整代码](#15. 完整代码)
- 参考
从零实现 VAE:用变分自编码器生成人脸图像
本文基于 PyTorch 实现一个完整的 Variational Autoencoder (VAE),在 CelebA 数据集上训练,实现人脸重建与生成。代码已添加详尽注释,适合作为理解生成模型的入门实践。
1. 什么是 VAE?
Variational Autoencoder (VAE) 是由 Kingma 和 Welling 在 2014 年提出的深度生成模型。与普通自编码器不同,VAE 不直接学习一个确定的隐向量,而是学习隐空间的概率分布。这使得 VAE 能够:
- 生成新样本:从学到的分布中采样,通过解码器生成新数据
- 平滑的隐空间:相似的隐向量对应相似的输出,支持插值和语义操作
VAE 的核心思想可以用 ELBO(证据下界)概括:
L = E z ∼ q ϕ ( z ∣ x ) log p θ ( x ∣ z ) ⏟ 重构损失 − D K L ( q ϕ ( z ∣ x ) ∥ p ( z ) ) ⏟ KL 散度 \mathcal{L} = \underbrace{\mathbb{E}{z \sim q\phi(z|x)}\\log p_\\theta(x\|z)}{\text{重构损失}} - \underbrace{D{KL}(q_\phi(z|x) \| p(z))}_{\text{KL 散度}} L=重构损失 Ez∼qϕ(z∣x)logpθ(x∣z)−KL 散度 DKL(qϕ(z∣x)∥p(z))
其中第一项鼓励解码器准确重建输入,第二项 (KL 散度) 鼓励编码器输出的分布逼近标准正态分布 N ( 0 , I ) \mathcal{N}(0, I) N(0,I)。
2. 整体架构概览
本项目包含三个核心文件:
| 文件 | 职责 |
|---|---|
model.py |
VAE 模型定义(编码器 + 解码器 + 重参数化) |
main.py |
训练循环、损失函数、重建与生成逻辑 |
load_celebA.py |
CelebA 数据集加载与预处理 |
输入为 64×64×3 的 RGB 人脸图像,经过编码器压缩为 128 维的隐向量,再通过解码器重建回原始尺寸。
3. 编码器:从图像到概率分布
编码器由 5 层 stride=2 的卷积堆叠而成,逐步将 64×64 的输入压缩为 2×2 的特征图,最后通过两个独立的全连接头分别输出 μ \mu μ 和 log σ 2 \log\sigma^2 logσ2。
python
class VAE(nn.Module):
"""VAE for 64x64 face generation.
VAE (Variational Autoencoder) 变分自编码器,用于 64x64 人脸图像的生成。
与普通自编码器不同,VAE 的编码器输出隐变量的概率分布参数 (μ, σ²),
而非一个确定的隐向量,从而使得模型具备生成新样本的能力。
The hidden dimensions can be tuned.
中间的隐藏层通道数可以根据显存和效果需求进行调整。
"""
def __init__(self, hiddens=[16, 32, 64, 128, 256], latent_dim=128) -> None:
"""
参数:
hiddens: 编码器每层卷积的输出通道数,同时也是解码器转置卷积的输入通道数(逆序使用)
默认 [16, 32, 64, 128, 256],共 5 层,每次下采样 2 倍
latent_dim: 隐空间维度,即潜变量 z 的维度,默认 128
"""
super().__init__()
# ==================== 编码器 (Encoder) ====================
# 编码器将输入图像 [B, 3, 64, 64] 逐步压缩为特征图 [B, 256, 2, 2]
# 然后通过两个全连接层分别输出隐变量的均值和对数方差
prev_channels = 3 # 输入为 3 通道 RGB 图像
modules = [] # 存储各层卷积模块
img_length = 64 # 输入图像的空间尺寸 (高/宽)
for cur_channels in hiddens:
# 每层:Conv2d(stride=2) 将空间尺寸减半,通道数翻倍
# kernel_size=3, stride=2, padding=1 实现 exactly 减半的效果
# 以 64 → 32 为例: output = (64 - 3 + 2*1) / 2 + 1 = 32 ✓
modules.append(
nn.Sequential(
nn.Conv2d(prev_channels, # 输入通道数
cur_channels, # 输出通道数
kernel_size=3, # 3×3 卷积核
stride=2, # 步长为 2,替代池化实现下采样
padding=1), # 填充 1,保持边界信息
nn.BatchNorm2d(cur_channels), # 批归一化,加速收敛、稳定训练
nn.ReLU())) # ReLU 激活,引入非线性
prev_channels = cur_channels
img_length //= 2 # 每层空间尺寸减半: 64→32→16→8→4→2
self.encoder = nn.Sequential(*modules)
# 编码器最终输出特征图尺寸: [B, 256, 2, 2],展平后为 [B, 1024]
# 两个独立的全连接头:一个预测均值 μ,一个预测对数方差 log(σ²)
# 使用 log(σ²) 而非 σ² 的原因:
# 1. 数值稳定性更好,避免方差趋近于 0
# 2. 无需约束输出为正(exp 后自然为正)
self.mean_linear = nn.Linear(prev_channels * img_length * img_length,
latent_dim) # 1024 → 128,输出均值 μ
self.var_linear = nn.Linear(prev_channels * img_length * img_length,
latent_dim) # 1024 → 128,输出对数方差 log(σ²)
self.latent_dim = latent_dim
编码器维度变化
输入: [3, 64, 64]
↓ Conv2d(3→16, k3, s2, p1) + BN + ReLU → [16, 32, 32]
↓ Conv2d(16→32, k3, s2, p1) + BN + ReLU → [32, 16, 16]
↓ Conv2d(32→64, k3, s2, p1) + BN + ReLU → [64, 8, 8]
↓ Conv2d(64→128, k3, s2, p1) + BN + ReLU → [128, 4, 4]
↓ Conv2d(128→256, k3, s2, p1) + BN + ReLU → [256, 2, 2]
↓ Flatten → 1024 维向量
↓ mean_linear: 1024 → 128 (均值 μ)
↓ var_linear: 1024 → 128 (对数方差 log σ²)
关键设计点:
- stride=2 替代池化:使用 stride=2 的卷积代替 MaxPooling,让网络自己学习下采样方式
- BatchNorm:加速收敛并稳定训练,防止某一层激活值过大或过小
- 两个独立的全连接头 :分别输出均值 μ \mu μ 和对数方差 log σ 2 \log \sigma^2 logσ2------使用 log 形式保证数值稳定,且无需约束输出为正

4. 重参数化技巧 (Reparameterization Trick)
这是 VAE 实现中最精妙的部分。我们需要从 N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2) N(μ,σ2) 中采样 z z z,但采样操作是不可导 的,会阻断梯度回传到 μ \mu μ 和 σ \sigma σ。
解决方案 :从标准正态分布采样噪声 ε ∼ N ( 0 , I ) \varepsilon \sim \mathcal{N}(0, I) ε∼N(0,I),然后做线性变换:
z = μ + σ ⋅ ε = μ + e log σ 2 2 ⋅ ε z = \mu + \sigma \cdot \varepsilon = \mu + e^{\frac{\log \sigma^2}{2}} \cdot \varepsilon z=μ+σ⋅ε=μ+e2logσ2⋅ε
python
def forward(self, x):
"""
前向传播:输入图像 → 编码 → 重参数化采样 → 解码 → 重建图像
参数:
x: 输入图像 [B, 3, 64, 64]
返回:
decoded: 重建图像 [B, 3, 64, 64]
mean: 隐变量分布的均值 μ [B, latent_dim]
logvar: 隐变量分布的对数方差 log(σ²) [B, latent_dim]
"""
# ---- 编码阶段 ----
encoded = self.encoder(x) # [B, 3, 64, 64] → [B, 256, 2, 2]
encoded = torch.flatten(encoded, 1) # 展平: [B, 256, 2, 2] → [B, 1024]
mean = self.mean_linear(encoded) # 均值 μ: [B, 128]
logvar = self.var_linear(encoded) # 对数方差 log(σ²): [B, 128]
# ---- 重参数化技巧 (Reparameterization Trick) ----
# 目标:从 N(μ, σ²) 中采样 z,同时保持梯度可回传
# 做法:先从 N(0,1) 采样 ε,再计算 z = μ + σ · ε
# 这样随机性被隔离在 ε 中,μ 和 σ 仍可接收梯度
eps = torch.randn_like(logvar) # ε ~ N(0, I),与 logvar 同形状
std = torch.exp(logvar / 2) # σ = exp(log σ² / 2) = exp(log σ)
z = eps * std + mean # z = μ + σ · ε,重参数化后的隐变量
# ---- 解码阶段 ----
x = self.decoder_projection(z) # [B, 128] → [B, 1024]
x = torch.reshape(x, (-1, *self.decoder_input_chw)) # [B, 1024] → [B, 256, 2, 2]
decoded = self.decoder(x) # [B, 256, 2, 2] → [B, 3, 64, 64]
return decoded, mean, logvar
这样随机性被隔离在 ε \varepsilon ε 中, μ \mu μ 和 log σ 2 \log \sigma^2 logσ2 都是确定性的,梯度可以正常回传。
注意,实现中学的是 μ \mu μ 和 l o g σ 2 log\sigma^2 logσ2
为什么网络输出 log(σ²),而不是直接输出 σ²?
答案其实涉及 数值稳定性 + 优化难度 + KL公式简化 三个方面。
(1)首先,σ² 必须大于 0
VAE 假设:
q ( z ∣ x ) = N ( μ , σ 2 ) q(z|x)=N(\mu,\sigma^2) q(z∣x)=N(μ,σ2)
这里:
- μ 可以任意
- σ² 必须 > 0
因为方差不可能为负数。
如果直接预测:
python
sigma2 = Linear(...)
网络可能输出:
text
-3.2
-0.8
-100
这些都是非法值。
所以必须加约束:
例如:
python
sigma2 = softplus(raw)
或者:
python
sigma2 = exp(raw)
(2)用 log(σ²) 可以天然保证正数
令:
l = log ( σ 2 ) l = \log(\sigma^2) l=log(σ2)
网络预测: l l l
而不是:
σ 2 \sigma^2 σ2
恢复方差时:
σ 2 = e l \sigma^2 = e^l σ2=el
无论:
text
l = -100
l = 0
l = 100
都有:
e l > 0 e^l > 0 el>0
因此:
不需要额外约束,天然合法。
(3)KL Loss 会变得特别漂亮
VAE的 KL 项:
K L ( q ( z ∣ x ) ∥ p ( z ) ) KL(q(z|x)\parallel p(z)) KL(q(z∣x)∥p(z))
其中:
q ( z ∣ x ) = N ( μ , σ 2 ) q(z|x)=N(\mu,\sigma^2) q(z∣x)=N(μ,σ2)
p ( z ) = N ( 0 , 1 ) p(z)=N(0,1) p(z)=N(0,1)
最终可以推导成(过程这里省略,严格推导出来的):
K L = 1 2 ∑ ( μ 2 + σ 2 − log ( σ 2 ) − 1 ) KL = \frac{1}{2} \sum \left( \mu^2 + \sigma^2 - \log(\sigma^2) - 1 \right) KL=21∑(μ2+σ2−log(σ2)−1)
这里直接出现:
log ( σ 2 ) \log(\sigma^2) log(σ2)
所以如果网络输出:
python
logvar
KL直接写:
python
kl = -0.5 * torch.sum(
1 + logvar - mu.pow(2) - logvar.exp()
)
非常方便。
如果输出的是:
python
sigma2
那每次都要:
python
torch.log(sigma2)
更麻烦。
(4)数值稳定性更好
这是最重要的原因。
假设真实方差:
σ 2 = 0.000001 \sigma^2=0.000001 σ2=0.000001
如果直接优化:
σ 2 \sigma^2 σ2
梯度会非常奇怪。
再假设:
σ 2 = 1000000 \sigma^2=1000000 σ2=1000000
梯度又会巨大。
动态范围:
text
0.000001
~
1000000
跨度:
10 12 10^{12} 1012
优化器很难处理。
而取对数后:
log ( σ 2 ) \log(\sigma^2) log(σ2)
变成:
text
-13.8
~
13.8
范围小得多。
优化容易很多。
(5)从高斯分布角度看
很多统计模型里:
真正优化的其实不是:
σ \sigma σ
也不是:
σ 2 \sigma^2 σ2
而是:
log ( σ ) \log(\sigma) log(σ)
或者:
log ( σ 2 ) \log(\sigma^2) log(σ2)
因为:
高斯分布对数似然本身就是 log-space 的。
例如负对数似然中会出现:
log ( σ 2 ) \log(\sigma^2) log(σ2)
因此在 log 空间优化更加自然。
(6)Reparameterization Trick 更方便
VAE采样:
z = μ + σ ϵ z=\mu+\sigma\epsilon z=μ+σϵ
其中:
ϵ ∼ N ( 0 , 1 ) \epsilon \sim N(0,1) ϵ∼N(0,1)
网络输出:
python
logvar
之后:
python
std = exp(0.5 * logvar)
因为:
σ = σ 2 = e log ( σ 2 ) = e 0.5 ⋅ log ( σ 2 ) \sigma = \sqrt{\sigma^2} = \sqrt{e^{\log(\sigma^2)}} = e^{0.5 \cdot \log(\sigma^2)} σ=σ2 =elog(σ2) =e0.5⋅log(σ2)
于是:
python
z = mu + std * eps
直接完成。
方差必须为正,而 log(σ²) 既能保证方差合法,又能让 KL Loss 更简洁、梯度更稳定、训练更容易收敛。
5. 解码器:从隐向量重建图像
解码器先将 128 维隐向量通过全连接层投影回 1024 维,reshape 为 2×2 的特征图,再通过 5 层转置卷积(ConvTranspose2d)逐步上采样回 64×64。
python
# ==================== 解码器 (Decoder) ====================
# 解码器将隐向量 z [B, 128] 映射回原始图像空间 [B, 3, 64, 64]
modules = []
# 首先通过全连接层将隐向量投影到与编码器输出相同的维度
# [B, 128] → [B, 256*2*2=1024] → reshape → [B, 256, 2, 2]
self.decoder_projection = nn.Linear(
latent_dim, prev_channels * img_length * img_length)
self.decoder_input_chw = (prev_channels, img_length, img_length) # (256, 2, 2)
# 转置卷积 (ConvTranspose2d) 实现上采样:每次将空间尺寸翻倍
# 通道数从 256 逐步减半至 16(与编码器对称但逆序)
# 注意:range(len(hiddens)-1, 0, -1) 即 [4, 3, 2, 1]
# 对应 hiddens[4]→hiddens[3]: 256→128, 128→64, 64→32, 32→16
for i in range(len(hiddens) - 1, 0, -1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hiddens[i], # 输入通道 (深层)
hiddens[i - 1], # 输出通道 (浅层)
kernel_size=3, # 3×3 卷积核
stride=2, # 步长 2,实现 2 倍上采样
padding=1, # 填充 1
output_padding=1),# 输出填充,确保尺寸精确翻倍
nn.BatchNorm2d(hiddens[i - 1]),
nn.ReLU()))
# 最后一层解码:将 16 通道特征图转换为 3 通道 RGB 图像
# 分两步:先用转置卷积上采样 (16→16, 32→64),再用普通卷积映射通道 (16→3)
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hiddens[0], # 输入: 16 通道, 32×32
hiddens[0], # 输出: 16 通道, 64×64
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hiddens[0]),
nn.ReLU(),
# 普通卷积将通道数从 16 映射到 3 (RGB),保持 64×64 尺寸不变
nn.Conv2d(hiddens[0], 3, kernel_size=3, stride=1, padding=1),
nn.ReLU())) # ReLU 保证输出像素值非负
self.decoder = nn.Sequential(*modules)
解码器维度变化
text
隐向量 z: [128]
↓ Linear(128 → 1024)
↓ Reshape → [256, 2, 2]
↓ ConvTranspose2d(256→128, k3, s2, p1) + BN + ReLU → [128, 4, 4]
↓ ConvTranspose2d(128→64, k3, s2, p1) + BN + ReLU → [64, 8, 8]
↓ ConvTranspose2d(64→32, k3, s2, p1) + BN + ReLU → [32, 16, 16]
↓ ConvTranspose2d(32→16, k3, s2, p1) + BN + ReLU → [16, 32, 32]
↓ ConvTranspose2d(16→16, k3, s2, p1) + BN + ReLU → [16, 64, 64]
↓ Conv2d(16→3, k3, s1, p1) + ReLU → [3, 64, 64]
设计细节:
- output_padding=1 :stride=2 时,
output_padding=1确保输出尺寸精确翻倍(避免 32→63 的情况)。 - 最后一层分两步:先转置卷积上采样(16 通道 32×32 → 16 通道 64×64),再用普通卷积映射通道(16→3 RGB)。将"上采样"和"通道映射"解耦,让网络更容易优化。

6. 从隐空间采样生成新图像
python
def sample(self, device='cuda'):
"""
从标准正态分布 N(0, I) 采样一个隐向量,通过解码器生成一张新的人脸图像。
这是 VAE 作为生成模型的核心功能:无需任何输入图像,
直接从先验分布采样即可生成新样本。
参数:
device: 计算设备,默认 'cuda'
返回:
decoded: 生成的图像 [1, 3, 64, 64]
"""
z = torch.randn(1, self.latent_dim).to(device) # z ~ N(0, I),batch_size=1
x = self.decoder_projection(z)
x = torch.reshape(x, (-1, *self.decoder_input_chw))
decoded = self.decoder(x)
return decoded
一旦训练完成,生成新图像只需从 N ( 0 , I ) \mathcal{N}(0, I) N(0,I) 采样一个 128 维向量,传入解码器即可。这正是 VAE 作为生成模型的核心能力------无需任何输入,凭空"想象"出新的人脸。
7. 损失函数:重构 + 正则化
python
def loss_fn(y, y_hat, mean, logvar):
"""
VAE 的损失函数 = 重构损失 + KL 散度正则项
VAE 的优化目标(ELBO):
L = E[log p(x|z)] - D_KL(q(z|x) || p(z))
即最大化证据下界等价于:
最小化 重构误差 + KL 散度
参数:
y: 原始输入图像 [B, 3, 64, 64]
y_hat: 重建图像 [B, 3, 64, 64]
mean: 编码器输出的均值 μ [B, latent_dim]
logvar: 编码器输出的对数方差 log(σ²) [B, latent_dim]
返回:
loss: 总损失(标量)
"""
# ---- 重构损失 (Reconstruction Loss) ----
# 使用 MSE 衡量原始图像与重建图像的像素级差异
# 等价于假设 p(x|z) 为高斯分布时的负对数似然
recons_loss = F.mse_loss(y_hat, y)
# ---- KL 散度 (Kullback-Leibler Divergence) ----
# D_KL( N(μ, σ²) || N(0, I) ) 的解析形式:
# = -0.5 * Σ( 1 + log(σ²) - μ² - σ² )
#
# 直观理解:KL 散度约束编码器输出的分布接近标准正态分布
# - 当 μ=0, σ²=1 时 KL=0(完全匹配先验)
# - 当 μ 偏离 0 或 σ² 偏离 1 时 KL>0(惩罚偏离先验)
kl_loss = torch.mean(
-0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar), 1), 0)
# ---- 总损失 ----
# kl_weight < 1 意味着放松对 KL 散度的约束
# 允许编码器学习更灵活的分布,换取更好的重建质量
loss = recons_loss + kl_loss * kl_weight
return loss
KL 散度的解析形式
当先验 p ( z ) = N ( 0 , I ) p(z) = \mathcal{N}(0, I) p(z)=N(0,I) 且后验 q ( z ∣ x ) = N ( μ , σ 2 ) q(z|x) = \mathcal{N}(\mu, \sigma^2) q(z∣x)=N(μ,σ2) 时,KL 散度有闭式解:
D K L ( N ( μ , σ 2 ) ∥ N ( 0 , I ) ) = − 1 2 ∑ j = 1 d ( 1 + log σ j 2 − μ j 2 − σ j 2 ) D_{KL}(\mathcal{N}(\mu, \sigma^2) \| \mathcal{N}(0, I)) = -\frac{1}{2}\sum_{j=1}^{d}\left(1 + \log\sigma_j^2 - \mu_j^2 - \sigma_j^2\right) DKL(N(μ,σ2)∥N(0,I))=−21j=1∑d(1+logσj2−μj2−σj2)
代码中 kl_weight = 0.00025 是一个关键超参数。标准 VAE 的 KL 散度系数为 1,但在这里被设为远小于 1,主要目的有两个:
- 防止 posterior collapse(后验坍缩) :当 KL 项过强时,编码器可能直接输出先验 N ( 0 , I ) \mathcal{N}(0,I) N(0,I) 而忽略输入图像,导致解码器学会无视隐变量 z z z,VAE 退化为普通自编码器
- 平衡损失量级:对于 64×64×3 的高维图像,MSE 重构损失的数值量级远大于 128 维潜变量的 KL 散度,减小 KL 权重有助于两项损失的平衡优化
8. 训练流程
python
"""
VAE 训练与推理脚本
功能:
1. train(): 在 CelebA 数据集上训练 VAE 模型
2. reconstruct(): 输入真实图像,经过编码-解码后重建,保存对比图
3. generate(): 从 N(0,I) 随机采样,生成全新的人脸图像并保存
使用方式:
1. 下载 CelebA Align&Cropped Images 数据集
2. 修改 load_celebA.py 中 get_dataloader 的 root 路径指向你的数据目录
3. 运行 python main.py(在 main() 中选择要执行的功能)
"""
# ==================== 超参数 ====================
n_epochs = 10 # 训练轮数(CelebA 约 20 万张图,10 个 epoch 可得到基本可用的结果)
kl_weight = 0.00025 # KL 散度的权重系数
# 标准 VAE 中该系数为 1,这里设为远小于 1 是为了防止
# posterior collapse(后验坍缩)------即 KL 项过强时,编码器
# 直接输出先验 N(0,I),解码器学会忽略隐变量 z,导致 VAE
# 退化为普通自编码器。同时,对于高维图像数据,MSE 重构
# 损失的数值量级远大于 KL 散度,减小 KL 权重有助于平衡两项
lr = 0.005 # Adam 优化器的学习率
def train(device, dataloader, model):
"""
训练 VAE 模型
训练流程:
1. 从 dataloader 获取一个 batch 的图像
2. 前向传播:图像 → 编码 → 重参数化 → 解码 → 重建图像
3. 计算损失:MSE(重建, 原图) + kl_weight * KL( N(μ,σ²) || N(0,I) )
4. 反向传播、更新参数
5. 每 epoch 结束后打印损失并保存模型权重
"""
optimizer = torch.optim.Adam(model.parameters(), lr)
dataset_len = len(dataloader.dataset) # 数据集总样本数
begin_time = time()
for i in range(n_epochs):
loss_sum = 0 # 累计损失,用于计算平均
for x in dataloader:
x = x.to(device) # 将图像移到 GPU
y_hat, mean, logvar = model(x) # 前向传播:得到重建图、μ、log(σ²)
loss = loss_fn(x, y_hat, mean, logvar) # 计算损失
optimizer.zero_grad() # 清空梯度缓存
loss.backward() # 反向传播计算梯度
optimizer.step() # 更新参数
loss_sum += loss
loss_sum /= dataset_len # 计算每个样本的平均损失
training_time = time() - begin_time
minute = int(training_time // 60)
second = int(training_time % 60)
print(f'epoch {i}: loss {loss_sum} {minute}:{second}')
torch.save(model.state_dict(), 'model.pth') # 每轮保存一次模型
训练配置精简但有效:10 个 epoch、Adam 优化器、batch size 16。得益于 VAE 的稳定训练特性(相比 GAN 需要小心的 minmax 博弈),无需复杂的调参即可收敛。

9. 实验结果:重建与生成
训练完成后,代码提供了两种评估方式:
重建 (Reconstruct)
python
def reconstruct(device, dataloader, model):
"""
图像重建演示:取一张真实图像,经过 VAE 编码再解码,观察重建质量。
模型需要先训练好,加载 model.pth 权重后使用。
结果保存为 reconstruct.jpg(左半:重建,右半:原图)。
"""
model.eval() # 切换到评估模式(关闭 BN 的运行时统计)
batch = next(iter(dataloader)) # 取一个 batch
x = batch[0:1, ...].to(device) # 只取第一张图 [1, 3, 64, 64]
output = model(x)[0] # 重建图像(丢弃返回的 mean, logvar)
output = output[0].detach().cpu() # 去除 batch 维度,移回 CPU
input = batch[0].detach().cpu() # 原始图像
combined = torch.cat((output, input), 1) # 水平拼接:左重建 | 右原图
img = ToPILImage()(combined) # Tensor → PIL Image
img.save('reconstruct.jpg')

生成 (Generate)
python
def generate(device, model):
"""
随机生成新图像:从标准正态分布 N(0,I) 采样隐向量,通过解码器生成人脸。
这是 VAE 区别于普通自编码器的关键能力------无需输入图像即可生成新样本。
结果保存为 generate.jpg。
"""
model.eval() # 切换到评估模式
output = model.sample(device) # 从 N(0,I) 采样 → 解码 → 生成图像
output = output[0].detach().cpu() # 去除 batch 维度,移回 CPU
img = ToPILImage()(output) # Tensor → PIL Image
img.save('generate.jpg')

10. 数据加载:CelebA 数据集
CelebA (CelebFaces Attributes Dataset) 包含超过 20 万张名人面部图像。本项目使用 Align&Cropped 版本(已人脸对齐并裁剪),原始尺寸 178×218。
python
"""
CelebA 数据集加载模块
CelebA (CelebFaces Attributes Dataset) 是一个大规模人脸属性数据集,
包含超过 20 万张名人面部图像,每张标注了 40 种属性。
本模块使用 Align&Cropped 版本,即已经过人脸对齐和裁剪的图像,
原始尺寸为 178×218,通过预处理转换为 64×64 的 RGB 图像供 VAE 训练使用。
"""
class CelebADataset(Dataset):
"""
CelebA 自定义数据集类
继承 torch.utils.data.Dataset,实现 __len__ 和 __getitem__ 方法,
使其可以被 PyTorch 的 DataLoader 加载。
预处理流程:
1. 读取原始图像 (178×218)
2. CenterCrop(168): 中心裁剪为 168×168 的正方形,去除不均匀的背景
3. Resize(64×64): 缩放到目标尺寸
4. ToTensor(): 将 PIL Image 转换为 [0, 1] 范围的 torch.Tensor
"""
def __init__(self, root, img_shape=(64, 64)) -> None:
"""
参数:
root: 数据集根目录,包含所有 .jpg 图像文件
img_shape: 目标图像尺寸 (高, 宽),默认 (64, 64)
"""
super().__init__()
self.root = root
self.img_shape = img_shape
self.filenames = sorted(os.listdir(root)) # 按文件名排序,保证每次加载顺序一致
def __len__(self) -> int:
"""返回数据集样本总数"""
return len(self.filenames)
def __getitem__(self, index: int):
"""
根据索引返回预处理后的图像张量
参数:
index: 样本索引 (0 ~ len(dataset)-1)
返回:
torch.Tensor: 形状为 [3, H, W] 的 RGB 图像,值域 [0, 1]
"""
path = os.path.join(self.root, self.filenames[index])
img = Image.open(path).convert('RGB') # 打开图像并确保为 RGB 三通道
# 预处理 pipeline
# CenterCrop(168): CelebA 原始尺寸 178×218,中心裁出 168×168 正方形
# 人脸位于图像中央,裁剪后去除大部分背景区域
# Resize(64,64): 将 168×168 缩放至 64×64,大幅降低计算量
# ToTensor(): PIL Image [0, 255] → torch.Tensor [0.0, 1.0]
# 同时自动将通道维度从 HWC 转为 CHW
pipeline = transforms.Compose([
transforms.CenterCrop(168),
transforms.Resize(self.img_shape),
transforms.ToTensor()
])
return pipeline(img)
def get_dataloader(root='/data/ym/datasets-face/img_align_celeba', **kwargs):
"""
创建 CelebA 数据集的 DataLoader
参数:
root: 数据集路径,请根据实际存放位置修改
**kwargs: 传递给 CelebADataset 的额外参数(如 img_shape)
返回:
DataLoader: batch_size=16, shuffle=True
每次迭代返回形状为 [16, 3, 64, 64] 的图像 batch
"""
dataset = CelebADataset(root, **kwargs)
return DataLoader(dataset, 16, shuffle=True) # batch_size=16,每个 epoch 随机打乱
预处理流程:CenterCrop(168) → Resize(64,64) → ToTensor()。CelebA 原始对齐图像为 178×218,中心裁剪 168×168 后缩放到 64×64,去除了大部分背景,保留了人脸核心区域。

11. 完整的前向传播流程
将整个 pipeline 串联起来,一个 batch 的完整前向传播如下:
text
输入 x: [B, 3, 64, 64]
│
▼
编码器 (5层 stride-2 卷积) → 特征图 [B, 256, 2, 2]
│
▼
Flatten → [B, 1024]
│
├── mean_linear → μ [B, 128]
└── var_linear → log σ² [B, 128]
│
▼
重参数化: z = μ + exp(log σ²/2) · ε (ε ~ N(0,I))
│
▼
Decoder Projection: Linear(128 → 1024) → Reshape → [B, 256, 2, 2]
│
▼
解码器 (5层转置卷积) → 重建图像 [B, 3, 64, 64]
│
▼
损失: MSE(x, x̂) + kl_weight · KL( N(μ,σ²) || N(0,I) )
12. 关键超参数与调优建议
| 超参数 | 取值 | 影响 |
|---|---|---|
latent_dim |
128 | 越大表达力越强,但训练越慢,过大可能过拟合 |
kl_weight |
0.00025 | 最关键的超参数。太大会导致 posterior collapse(解码器忽略 z),太小则隐空间不够正则化 |
hiddens |
16,32,64,128,256 | 控制模型容量,可根据显存调整 |
lr |
0.005 | Adam 下略高的学习率,加速收敛 |
batch_size |
16 | 小 batch 带来的噪声有助于隐空间泛化 |
13. 扩展方向
基于这个基础实现,可以进一步探索:
- 隐空间插值:在两个人脸的隐向量之间线性插值,观察生成的渐变效果
- 条件 VAE (CVAE):加入属性标签(性别、发色等),控制生成特定特征的人脸
- VQ-VAE:将连续隐空间替换为离散的 codebook,通常能获得更清晰的生成结果
- 与 GAN 对比:在同一数据集上训练 DCGAN,对比两种生成范式的差异
14. 完整代码清单
所有代码可在 dldemos/VAE/ 目录下找到:
model.py--- VAE 模型定义(含详尽注释)main.py--- 训练与推理(含详尽注释)load_celebA.py--- 数据加载(含详尽注释)
运行方式:
bash
# 1. 下载 CelebA Align&Cropped Images
# 2. 修改 main.py 中 get_dataloader 的数据路径
# 3. 运行
python main.py
15. 完整代码
load_celebA.py --- 数据加载
py
"""
CelebA 数据集加载模块
CelebA (CelebFaces Attributes Dataset) 是一个大规模人脸属性数据集,
包含超过 20 万张名人面部图像,每张标注了 40 种属性。
本模块使用 Align&Cropped 版本,即已经过人脸对齐和裁剪的图像,
原始尺寸为 178×218,通过预处理转换为 64×64 的 RGB 图像供 VAE 训练使用。
"""
import os
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
class CelebADataset(Dataset):
"""
CelebA 自定义数据集类
继承 torch.utils.data.Dataset,实现 __len__ 和 __getitem__ 方法,
使其可以被 PyTorch 的 DataLoader 加载。
预处理流程:
1. 读取原始图像 (178×218)
2. CenterCrop(168): 中心裁剪为 168×168 的正方形,去除不均匀的背景
3. Resize(64×64): 缩放到目标尺寸
4. ToTensor(): 将 PIL Image 转换为 [0, 1] 范围的 torch.Tensor
"""
def __init__(self, root, img_shape=(64, 64)) -> None:
"""
参数:
root: 数据集根目录,包含所有 .jpg 图像文件
img_shape: 目标图像尺寸 (高, 宽),默认 (64, 64)
"""
super().__init__()
self.root = root
self.img_shape = img_shape
self.filenames = sorted(os.listdir(root)) # 按文件名排序,保证每次加载顺序一致
def __len__(self) -> int:
"""返回数据集样本总数"""
return len(self.filenames)
def __getitem__(self, index: int):
"""
根据索引返回预处理后的图像张量
参数:
index: 样本索引 (0 ~ len(dataset)-1)
返回:
torch.Tensor: 形状为 [3, H, W] 的 RGB 图像,值域 [0, 1]
"""
path = os.path.join(self.root, self.filenames[index])
img = Image.open(path).convert('RGB') # 打开图像并确保为 RGB 三通道
# 预处理 pipeline
# CenterCrop(168): CelebA 原始尺寸 178×218,中心裁出 168×168 正方形
# 人脸位于图像中央,裁剪后去除大部分背景区域
# Resize(64,64): 将 168×168 缩放至 64×64,大幅降低计算量
# ToTensor(): PIL Image [0, 255] → torch.Tensor [0.0, 1.0]
# 同时自动将通道维度从 HWC 转为 CHW
pipeline = transforms.Compose([
transforms.CenterCrop(168),
transforms.Resize(self.img_shape),
transforms.ToTensor()
])
return pipeline(img)
def get_dataloader(root='/data/bryant/datasets-face/img_align_celeba', **kwargs):
"""
创建 CelebA 数据集的 DataLoader
参数:
root: 数据集路径,请根据实际存放位置修改
**kwargs: 传递给 CelebADataset 的额外参数(如 img_shape)
返回:
DataLoader: batch_size=16, shuffle=True
每次迭代返回形状为 [16, 3, 64, 64] 的图像 batch
"""
dataset = CelebADataset(root, **kwargs)
return DataLoader(dataset, 16, shuffle=True) # batch_size=16,每个 epoch 随机打乱
if __name__ == '__main__':
# 测试数据加载:取一个 batch 的图像,拼接成 4×4 的网格图保存
# 用于检查数据预处理效果和图像质量
dataloader = get_dataloader()
img = next(iter(dataloader)) # 获取一个 batch: [16, 3, 64, 64]
print(img.shape) # 打印形状以确认: torch.Size([16, 3, 64, 64])
# 将 16 张图拼接成 4×4 的网格
N, C, H, W = img.shape
assert N == 16 # 确保 batch_size=16
img = torch.permute(img, (1, 0, 2, 3)) # [3, 16, 64, 64]
img = torch.reshape(img, (C, 4, 4 * H, W)) # [3, 4, 256, 64] --- 4行,每行4张图拼接
img = torch.permute(img, (0, 2, 1, 3)) # [3, 256, 4, 64]
img = torch.reshape(img, (C, 4 * H, 4 * W)) # [3, 256, 256] --- 最终 4×4 网格
img = transforms.ToPILImage()(img) # Tensor → PIL Image
img.save('tmp.jpg') # 保存预览图
print('Preview grid saved as tmp.jpg')
model.py--- VAE 模型定义
py
import torch
import torch.nn as nn
class VAE(nn.Module):
"""VAE for 64x64 face generation.
VAE (Variational Autoencoder) 变分自编码器,用于 64x64 人脸图像的生成。
与普通自编码器不同,VAE 的编码器输出隐变量的概率分布参数 (μ, σ²),
而非一个确定的隐向量,从而使得模型具备生成新样本的能力。
The hidden dimensions can be tuned.
中间的隐藏层通道数可以根据显存和效果需求进行调整。
"""
def __init__(self, hiddens=[16, 32, 64, 128, 256], latent_dim=128) -> None:
"""
参数:
hiddens: 编码器每层卷积的输出通道数,同时也是解码器转置卷积的输入通道数(逆序使用)
默认 [16, 32, 64, 128, 256],共 5 层,每次下采样 2 倍
latent_dim: 隐空间维度,即潜变量 z 的维度,默认 128
"""
super().__init__()
# ==================== 编码器 (Encoder) ====================
# 编码器将输入图像 [B, 3, 64, 64] 逐步压缩为特征图 [B, 256, 2, 2]
# 然后通过两个全连接层分别输出隐变量的均值和对数方差
prev_channels = 3 # 输入为 3 通道 RGB 图像
modules = [] # 存储各层卷积模块
img_length = 64 # 输入图像的空间尺寸 (高/宽)
for cur_channels in hiddens:
# 每层:Conv2d(stride=2) 将空间尺寸减半,通道数翻倍
# kernel_size=3, stride=2, padding=1 实现 exactly 减半的效果
# 以 64 → 32 为例: output = (64 - 3 + 2*1) / 2 + 1 = 32 ✓
modules.append(
nn.Sequential(
nn.Conv2d(prev_channels, # 输入通道数
cur_channels, # 输出通道数
kernel_size=3, # 3×3 卷积核
stride=2, # 步长为 2,替代池化实现下采样
padding=1), # 填充 1,保持边界信息
nn.BatchNorm2d(cur_channels), # 批归一化,加速收敛、稳定训练
nn.ReLU())) # ReLU 激活,引入非线性
prev_channels = cur_channels
img_length //= 2 # 每层空间尺寸减半: 64→32→16→8→4→2
self.encoder = nn.Sequential(*modules)
# 编码器最终输出特征图尺寸: [B, 256, 2, 2],展平后为 [B, 1024]
# 两个独立的全连接头:一个预测均值 μ,一个预测对数方差 log(σ²)
# 使用 log(σ²) 而非 σ² 的原因:
# 1. 数值稳定性更好,避免方差趋近于 0
# 2. 无需约束输出为正(exp 后自然为正)
self.mean_linear = nn.Linear(prev_channels * img_length * img_length,
latent_dim) # 1024 → 128,输出均值 μ
self.var_linear = nn.Linear(prev_channels * img_length * img_length,
latent_dim) # 1024 → 128,输出对数方差 log(σ²)
self.latent_dim = latent_dim
# ==================== 解码器 (Decoder) ====================
# 解码器将隐向量 z [B, 128] 映射回原始图像空间 [B, 3, 64, 64]
modules = []
# 首先通过全连接层将隐向量投影到与编码器输出相同的维度
# [B, 128] → [B, 256*2*2=1024] → reshape → [B, 256, 2, 2]
self.decoder_projection = nn.Linear(
latent_dim, prev_channels * img_length * img_length)
self.decoder_input_chw = (prev_channels, img_length, img_length) # (256, 2, 2)
# 转置卷积 (ConvTranspose2d) 实现上采样:每次将空间尺寸翻倍
# 通道数从 256 逐步减半至 16(与编码器对称但逆序)
# 注意:range(len(hiddens)-1, 0, -1) 即 [4, 3, 2, 1]
# 对应 hiddens[4]→hiddens[3]: 256→128, 128→64, 64→32, 32→16
for i in range(len(hiddens) - 1, 0, -1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hiddens[i], # 输入通道 (深层)
hiddens[i - 1], # 输出通道 (浅层)
kernel_size=3, # 3×3 卷积核
stride=2, # 步长 2,实现 2 倍上采样
padding=1, # 填充 1
output_padding=1),# 输出填充,确保尺寸精确翻倍
nn.BatchNorm2d(hiddens[i - 1]),
nn.ReLU()))
# 最后一层解码:将 16 通道特征图转换为 3 通道 RGB 图像
# 分两步:先用转置卷积上采样 (16→16, 32→64),再用普通卷积映射通道 (16→3)
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hiddens[0], # 输入: 16 通道, 32×32
hiddens[0], # 输出: 16 通道, 64×64
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hiddens[0]),
nn.ReLU(),
# 普通卷积将通道数从 16 映射到 3 (RGB),保持 64×64 尺寸不变
nn.Conv2d(hiddens[0], 3, kernel_size=3, stride=1, padding=1),
nn.ReLU())) # ReLU 保证输出像素值非负
self.decoder = nn.Sequential(*modules)
def forward(self, x):
"""
前向传播:输入图像 → 编码 → 重参数化采样 → 解码 → 重建图像
参数:
x: 输入图像 [B, 3, 64, 64]
返回:
decoded: 重建图像 [B, 3, 64, 64]
mean: 隐变量分布的均值 μ [B, latent_dim]
logvar: 隐变量分布的对数方差 log(σ²) [B, latent_dim]
"""
# ---- 编码阶段 ----
encoded = self.encoder(x) # [B, 3, 64, 64] → [B, 256, 2, 2]
encoded = torch.flatten(encoded, 1) # 展平: [B, 256, 2, 2] → [B, 1024]
mean = self.mean_linear(encoded) # 均值 μ: [B, 128]
logvar = self.var_linear(encoded) # 对数方差 log(σ²): [B, 128]
# ---- 重参数化技巧 (Reparameterization Trick) ----
# 目标:从 N(μ, σ²) 中采样 z,同时保持梯度可回传
# 做法:先从 N(0,1) 采样 ε,再计算 z = μ + σ · ε
# 这样随机性被隔离在 ε 中,μ 和 σ 仍可接收梯度
eps = torch.randn_like(logvar) # ε ~ N(0, I),与 logvar 同形状
std = torch.exp(logvar / 2) # σ = exp(log σ² / 2) = exp(log σ)
z = eps * std + mean # z = μ + σ · ε,重参数化后的隐变量
# ---- 解码阶段 ----
x = self.decoder_projection(z) # [B, 128] → [B, 1024]
x = torch.reshape(x, (-1, *self.decoder_input_chw)) # [B, 1024] → [B, 256, 2, 2]
decoded = self.decoder(x) # [B, 256, 2, 2] → [B, 3, 64, 64]
return decoded, mean, logvar
def sample(self, device='cuda'):
"""
从标准正态分布 N(0, I) 采样一个隐向量,通过解码器生成一张新的人脸图像。
这是 VAE 作为生成模型的核心功能:无需任何输入图像,
直接从先验分布采样即可生成新样本。
参数:
device: 计算设备,默认 'cuda'
返回:
decoded: 生成的图像 [1, 3, 64, 64]
"""
z = torch.randn(1, self.latent_dim).to(device) # z ~ N(0, I),batch_size=1
x = self.decoder_projection(z)
x = torch.reshape(x, (-1, *self.decoder_input_chw))
decoded = self.decoder(x)
return decoded
main.py--- 训练与推理
py
"""
VAE 训练与推理脚本
功能:
1. train(): 在 CelebA 数据集上训练 VAE 模型
2. reconstruct(): 输入真实图像,经过编码-解码后重建,保存对比图
3. generate(): 从 N(0,I) 随机采样,生成全新的人脸图像并保存
使用方式:
1. 下载 CelebA Align&Cropped Images 数据集
2. 修改 load_celebA.py 中 get_dataloader 的 root 路径指向你的数据目录
3. 运行 python main.py(在 main() 中选择要执行的功能)
"""
from time import time
import torch
import torch.nn.functional as F
from torchvision.transforms import ToPILImage
# 从本地模块导入数据加载器和模型
# 注:如果作为包运行,可改为 from dldemos.VAE.load_celebA import get_dataloader
from load_celebA import get_dataloader
from model import VAE
# ==================== 超参数 ====================
n_epochs = 10 # 训练轮数(CelebA 约 20 万张图,10 个 epoch 可得到基本可用的结果)
kl_weight = 0.00025 # KL 散度的权重(β-VAE 中的 β)
# 设为远小于 1 的值,让模型更注重重建质量
# 若设为 1 则等价于标准 VAE,但重建效果通常较差
lr = 0.005 # Adam 优化器的学习率
def loss_fn(y, y_hat, mean, logvar):
"""
VAE 的损失函数 = 重构损失 + KL 散度正则项
VAE 的优化目标(ELBO):
L = E[log p(x|z)] - D_KL(q(z|x) || p(z))
即最大化证据下界等价于:
最小化 重构误差 + KL 散度
参数:
y: 原始输入图像 [B, 3, 64, 64]
y_hat: 重建图像 [B, 3, 64, 64]
mean: 编码器输出的均值 μ [B, latent_dim]
logvar: 编码器输出的对数方差 log(σ²) [B, latent_dim]
返回:
loss: 总损失(标量)
"""
# ---- 重构损失 (Reconstruction Loss) ----
# 使用 MSE 衡量原始图像与重建图像的像素级差异
# 等价于假设 p(x|z) 为高斯分布时的负对数似然
recons_loss = F.mse_loss(y_hat, y)
# ---- KL 散度 (Kullback-Leibler Divergence) ----
# D_KL( N(μ, σ²) || N(0, I) ) 的解析形式:
# = -0.5 * Σ( 1 + log(σ²) - μ² - σ² )
#
# 直观理解:KL 散度约束编码器输出的分布接近标准正态分布
# - 当 μ=0, σ²=1 时 KL=0(完全匹配先验)
# - 当 μ 偏离 0 或 σ² 偏离 1 时 KL>0(惩罚偏离先验)
kl_loss = torch.mean(
-0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar), 1), 0)
# ---- 总损失(β-VAE 形式)----
# kl_weight < 1 意味着放松对 KL 散度的约束
# 允许编码器学习更灵活的分布,换取更好的重建质量
loss = recons_loss + kl_loss * kl_weight
return loss
def train(device, dataloader, model):
"""
训练 VAE 模型
训练流程:
1. 从 dataloader 获取一个 batch 的图像
2. 前向传播:图像 → 编码 → 重参数化 → 解码 → 重建图像
3. 计算损失:MSE(重建, 原图) + β * KL( N(μ,σ²) || N(0,I) )
4. 反向传播、更新参数
5. 每 epoch 结束后打印损失并保存模型权重
"""
optimizer = torch.optim.Adam(model.parameters(), lr)
dataset_len = len(dataloader.dataset) # 数据集总样本数
begin_time = time()
# train
for i in range(n_epochs):
loss_sum = 0 # 累计损失,用于计算平均
for x in dataloader:
x = x.to(device) # 将图像移到 GPU
y_hat, mean, logvar = model(x) # 前向传播:得到重建图、μ、log(σ²)
loss = loss_fn(x, y_hat, mean, logvar) # 计算损失
optimizer.zero_grad() # 清空梯度缓存
loss.backward() # 反向传播计算梯度
optimizer.step() # 更新参数
loss_sum += loss
loss_sum /= dataset_len # 计算每个样本的平均损失
training_time = time() - begin_time
minute = int(training_time // 60)
second = int(training_time % 60)
print(f'epoch {i}: loss {loss_sum} {minute}:{second}')
torch.save(model.state_dict(), 'model.pth') # 每轮保存一次模型
def reconstruct(device, dataloader, model):
"""
图像重建演示:取一张真实图像,经过 VAE 编码再解码,观察重建质量。
模型需要先训练好,加载 model.pth 权重后使用。
结果保存为 reconstruct.jpg(左半:重建,右半:原图)。
"""
model.eval() # 切换到评估模式(关闭 BN 的运行时统计)
batch = next(iter(dataloader)) # 取一个 batch
x = batch[0:1, ...].to(device) # 只取第一张图 [1, 3, 64, 64]
output = model(x)[0] # 重建图像(丢弃返回的 mean, logvar)
output = output[0].detach().cpu() # 去除 batch 维度,移回 CPU
input = batch[0].detach().cpu() # 原始图像
combined = torch.cat((output, input), 1) # 水平拼接:左重建 | 右原图
img = ToPILImage()(combined) # Tensor → PIL Image
img.save('reconstruct.jpg')
def generate(device, model):
"""
随机生成新图像:从标准正态分布 N(0,I) 采样隐向量,通过解码器生成人脸。
这是 VAE 区别于普通自编码器的关键能力------无需输入图像即可生成新样本。
结果保存为 generate.jpg。
"""
model.eval() # 切换到评估模式
output = model.sample(device) # 从 N(0,I) 采样 → 解码 → 生成图像
output = output[0].detach().cpu() # 去除 batch 维度,移回 CPU
img = ToPILImage()(output) # Tensor → PIL Image
img.save('generate.jpg')
def main():
"""
主函数:根据需求选择执行 train / reconstruct / generate。
使用方法:
1. 首次运行:取消 train() 的注释,训练模型
2. 训练完成后:取消 reconstruct() 和 generate() 的注释,查看效果
"""
device = 'cuda:2' # 指定使用的 GPU(可根据实际环境修改)
dataloader = get_dataloader() # 加载 CelebA 数据集
model = VAE().to(device) # 实例化模型并移至 GPU
# If you obtain the ckpt, load it
model.load_state_dict(torch.load('model.pth', 'cuda:0')) # 加载预训练权重
# Choose the function
# train(device, dataloader, model)
reconstruct(device, dataloader, model) # 重建演示
generate(device, model) # 生成演示
if __name__ == '__main__':
main()
参考
- Kingma, D. P., & Welling, M. (2014). Auto-Encoding Variational Bayes. ICLR.
- PyTorch-VAE --- 本项目代码的主要参考来源
- CelebA Dataset
- 【VAE】《Variational Auto-Encoder》
- 【VAE】《Variational Auto-Encoder vs Auto-Encoder》
本文是对 DL-Demos 项目中 VAE 实现的详细解读,旨在帮助初学者理解变分自编码器的原理与 PyTorch 实现细节。

