在 CelebA 数据集上训练的 PyTorch 中的基本变分自动编码器

摩西·西珀博士

一、说明

我最近发现自己需要一种方法将图像编码到潜在嵌入中, 调整 嵌入,然后生成 新图像。有一些强大的方法可以创建嵌入从嵌入生成。如果你想同时做到这两点,一种自然且相当简单的方法是使用变分自动编码器。

这样的深度网络不仅可以进行编码和解码,而且相当简单,我可以在以后的研究中使用它,而不必过多担心编码解码阶段的各种隐藏复杂性。我也更喜欢对软件内部有尽可能多的控制。

因此,考虑到所有这些规范,我从 GitHub 收集了一些零碎的东西,施展了一些我自己的魔法,最终得到了一个漂亮、简单的变分自动编码器。我将在下面描述主要部分,完整的包可在以下位置找到:
vae-torch-celeba,PyTorch 中 CelebA 数据集的变分自动编码器,下载vae-torch-celeba的源码_GitHub_帮酷
PyTorch 中用于 CelebA 数据集的变分自动编码器 - GitHub - moshesipper/vae-torch-celeba:变分...
它相当小并且完全独立------这就是我的意图!

二、自动编码器

为了使本文简短易懂,我将避免提供变分自动编码器的冗长概述。此外,您还可以在 Medium 上找到有关基础知识的优秀文章。我只提供三张快速图片。

这是基本自动编码器的样子:

来源: https: //commons.wikimedia.org/wiki/File :Autoencoder_schema.png

简而言之,网络将输入数据压缩为潜在向量(也称为嵌入),然后将其解压缩回来。这两个阶段称为编码解码

变分自动编码器(VAE)看起来非常相似,除了中间的嵌入部分。对于每个输入,VAE 的编码器输出潜在空间中预定义分布的参数,而不是潜在空间中的向量:

来源:https ://commons.wikimedia.org/wiki/File:Reparameterized_Variational_Autoencoder.png

最后一张图片:如果我们处理的是图像输入,我们需要一个卷积VAE,如下所示:

来源:https://github.com/arthurmeyer/Saliency_Detection_Convolutional_Autoencoder

注意#1:观察编码器部分如何在每一层中添加越来越多的滤波器,图像变得越来越小;解码器则相反。

注意#2:注意符号。如果只有一个通道,则术语"过滤器"和"内核"基本相同。对于多个通道,每个过滤器都是一组内核。查看这篇很棒的 Medium 文章:"直观地理解深度学习的卷积"。

三、CelebA数据集

我将使用的数据集是 CelebA,其中包含 202,599 张名人面孔图像。
CelebA 数据集
CelebFaces Attributes Dataset (CelebA) 是一个大规模人脸属性数据集,包含超过 20 万张名人图像......
可以通过以下方式访问它torchvision:

ba 复制代码
from torchvision.datasets import CelebA

train_dataset = CelebA(path, split='train')
test_dataset = CelebA(path, split='valid') # or 'test'

四、VAE类

我的 VAE 基于此PyTorch 示例和存储库的普通 VAE模型(将我使用的普通 VAE 替换为中的任何其他模型PyTorch-VAE应该不会太难)。PyTorch-VAE

该文件vae.py包含VAE类以及图像大小的定义、两个潜在向量的维度(均值和方差)以及数据集的路径:

ba 复制代码
CELEB_PATH = './data/'
IMAGE_SIZE = 150
LATENT_DIM = 128
image_dim = 3 * IMAGE_SIZE * IMAGE_SIZE

在课堂上VAE,我使用了以下隐藏过滤器维度:

ba 复制代码
hidden_dims = [32, 64, 128, 256, 512]

编码器看起来像这样:

ba 复制代码
in_channels = 3
modules = []
for h_dim in hidden_dims:
    modules.append(
        nn.Sequential(
            nn.Conv2d(in_channels, out_channels=h_dim,
                      kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(h_dim),
            nn.LeakyReLU())
    )
    in_channels = h_dim
self.encoder = nn.Sequential(*modules)

然后是潜在向量:

ba 复制代码
self.fc_mu = nn.Linear(hidden_dims[-1] * self.size * self.size, LATENT_DIM)
self.fc_var = nn.Linear(hidden_dims[-1] * self.size * self.size, LATENT_DIM)

最后我们用解码器"倒退":

ba 复制代码
hidden_dims.reverse()

for i in range(len(hidden_dims) - 1):
   modules.append(
      nn.Sequential(
         nn.ConvTranspose2d(hidden_dims[i],
                            hidden_dims[i + 1],
                            kernel_size=3,
                            stride=2,
                            padding=1,
                            output_padding=1),
         nn.BatchNorm2d(hidden_dims[i + 1]),
         nn.LeakyReLU())
   )

self.decoder = nn.Sequential(*modules)

这就是它的要点------还有一些零碎的内容vae.py可以完成这VAE门课。

五、训练

该文件trainvae.py包含训练我们刚刚编码的 VAE 的代码。老实说,没什么花哨的......有 3 个主要函数:(train随着训练的进行,它也输出损失值),test(它还构建一个重建图像的小样本)和loss_function。训练和测试相当普通,损失函数是标准 VAE,带有重建组件 (MSE) 和 KL 散度组件。

epoch 上的主循环执行 4 个操作:1) train、2) test、3) 生成随机潜在向量并调用decode以输出相应的输出图像,以及 4) 将 epoch 的模型保存到文件中pth
以下是示例运行的输出。通过 20 个训练周期,您最终会得到 20 个重建图像文件、20 个潜在采样文件和 20 个 python 模型文件:

这里reconstruction_20.png,顶行显示 8 张原始图片,底行显示经过训练的 VAE 的相应重建。

在 epoch 20 时从模型重建(输出)图像。

这里的sample_20.png,显示了从随机潜在向量生成的 64 张图像:

只是为了好玩,我添加了一小段代码 --- genpics.py--- 从数据集中挑选一个随机图像并生成 7 个重建。以下是一些示例(最左边的图像是原始图像):




最后,我再次放置 GitHub 链接。享受!

相关推荐
九章云极AladdinEdu1 小时前
临床数据挖掘与分析:利用GPU加速Pandas和Scikit-learn处理大规模数据集
人工智能·pytorch·数据挖掘·pandas·scikit-learn·paddlepaddle·gpu算力
上海锝秉工控1 小时前
超声波风向传感器:以科技之翼,捕捉风的每一次呼吸
大数据·人工智能·科技
说私域1 小时前
基于开源AI智能名片、链动2+1模式与S2B2C商城小程序的流量运营与个人IP构建研究
人工智能·小程序·流量运营
xiaoxiaoxiaolll3 小时前
期刊速递 | 《Light Sci. Appl.》超宽带光热电机理研究,推动碳纳米管传感器在制药质控中的实际应用
人工智能·学习
练习两年半的工程师3 小时前
AWS TechFest 2025: 风险模型的转变、流程设计的转型、生成式 AI 从实验走向实施的三大关键要素、评估生成式 AI 用例的适配度
人工智能·科技·金融·aws
Elastic 中国社区官方博客5 小时前
Elasticsearch:智能搜索的 MCP
大数据·人工智能·elasticsearch·搜索引擎·全文检索
stbomei5 小时前
从“能说话”到“会做事”:AI Agent如何重构日常工作流?
人工智能
yzx9910136 小时前
生活在数字世界:一份人人都能看懂的网络安全生存指南
运维·开发语言·网络·人工智能·自动化
许泽宇的技术分享7 小时前
LangGraph深度解析:构建下一代智能Agent的架构革命——从Pregel到现代AI工作流的技术飞跃
人工智能·架构
乔巴先生247 小时前
LLMCompiler:基于LangGraph的并行化Agent架构高效实现
人工智能·python·langchain·人机交互