pytorch-AutoEncoders实战之VAE

目录

  • [1. VAE回顾](#1. VAE回顾)
  • [2. KL的计算公式](#2. KL的计算公式)
  • [3. 构建网络](#3. 构建网络)
  • [4. 模型训练](#4. 模型训练)

1. VAE回顾

VAE = Variational Auto Encoder,变分自编码器。是一种常见的生成模型,属于无监督学习的范畴。它能够学习一个函数/模型,使得输出数据的分布尽可能的逼近原始数据分布,其基本思路是:把一堆真实样本通过编码器网络变换成一个理想的数据分布,然后这个数据分布再传递给一个解码器网络,得到一堆生成样本,生成样本与真实样本足够接近的话,就训练出了一个VAE模型.

下图中的公式,前半部分计算的是重建误差,可以理解为MSE或者是Cross Entropy,而后半部分KL是散度的公式,主要是计算q分布与p分布的相似度。

那么公式的目标就是重建误差越小越好,q和p的分布越接近越好。

2. KL的计算公式

reparametrize trick

按照上图推导公式实现即可。

3. 构建网络

根据公式可以知道,前半部分计算的是重建误差,后半部分是KL,再根据reparametrize trick分别计算z和epison~N(0, 1)

先将encode的[b,20],切分为两个[b,10]分别作为μ和σ,通过μ和σ计算z值,代码如下:

python 复制代码
 mu, sigma = h_.chunk(2, dim=1)
 # reparametrize trick, epison~N(0, 1)
 z = mu + sigma * torch.randn_like(sigma)

计算KL,根据2中的推导公式写代码即可,代码中batchsz2828意思是计算像素级的kld,1e-8是防止log函数变量为0时,趋于无穷大,这里起到限幅的作用

python 复制代码
kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

完整代码:

python 复制代码
import  torch
from    torch import nn

class VAE(nn.Module):

    def __init__(self):
        super(VAE, self).__init__()

        # [b, 784] => [b, 20]
        # u: [b, 10]
        # sigma: [b, 10]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )
        # [b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

        self.criteon = nn.MSELoss()

    def forward(self, x):
        """

        :param x: [b, 1, 28, 28]
        :return:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, 784)
        # encoder
        # [b, 20], including mean and sigma
        h_ = self.encoder(x)
        # [b, 20] => [b, 10] and [b, 10]
        mu, sigma = h_.chunk(2, dim=1)
        # reparametrize trick, epison~N(0, 1)
        h = mu + sigma * torch.randn_like(sigma)

        # decoder
        x_hat = self.decoder(h)
        # reshape
        x_hat = x_hat.view(batchsz, 1, 28, 28)

        kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

        return x_hat, kld

4. 模型训练

与上一篇的AutoEncoders步骤相近,这里不再详述

python 复制代码
import  torch
from    torch.utils.data import DataLoader
from    torch import nn, optim
from    torchvision import transforms, datasets

from    ae import AE
from    vae import VAE

import  visdom

def main():
    mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    x, _ = iter(mnist_train).next()
    print('x:', x.shape)

    device = torch.device('cuda')
    # model = AE().to(device)
    model = VAE().to(device)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    viz = visdom.Visdom()

    for epoch in range(1000):
        for batchidx, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            x = x.to(device)

            x_hat, kld = model(x)
            loss = criteon(x_hat, x)

            if kld is not None:
                elbo = - loss - 1.0 * kld
                loss = - elbo

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


        print(epoch, 'loss:', loss.item(), 'kld:', kld.item())

        x, _ = iter(mnist_test).next()
        x = x.to(device)
        with torch.no_grad():
            x_hat, kld = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))

if __name__ == '__main__':
    main()
相关推荐
Christo316 小时前
TKDE-2026《Efficient Co-Clustering via Bipartite Graph Factorization》
人工智能·算法·机器学习·数据挖掘
jackylzh16 小时前
PyTorch 2.x 中 `torch.load` 的 `FutureWarning` 与 `weights_only=False` 参数分析
人工智能·pytorch·python
LCG米16 小时前
基于PyTorch的Transformer-CNN时序预测实战:从特征工程到服务化部署
pytorch·cnn·transformer
子夜江寒16 小时前
基于PyTorch的语言模型实现详解
pytorch·语言模型
叶庭云16 小时前
AI Agent KernelCAT:深耕算子开发和模型迁移的 “计算加速专家”
人工智能·运筹优化·算子·ai agent·kernelcat·模型迁移适配·生态壁垒
码农三叔16 小时前
(8-2)传感器系统与信息获取:外部环境传感
人工智能·嵌入式硬件·数码相机·机器人·人形机器人
MACKEI16 小时前
服务器流式传输接口问题排查与解决方案
python·nginx·流式
小宇的天下16 小时前
innovus/virtuoso/ICC2 三大工具的工艺文件有什么区别?
人工智能
产品经理邹继强16 小时前
VTC营销与增长篇④:增长战略全景图——构建自驱进化的VTC增长飞轮
人工智能
2401_8322981016 小时前
阿里云倚天ECS实例,Arm架构重构算力性价比范式
人工智能