pytorch-AutoEncoders实战之VAE

目录

  • [1. VAE回顾](#1. VAE回顾)
  • [2. KL的计算公式](#2. KL的计算公式)
  • [3. 构建网络](#3. 构建网络)
  • [4. 模型训练](#4. 模型训练)

1. VAE回顾

VAE = Variational Auto Encoder,变分自编码器。是一种常见的生成模型,属于无监督学习的范畴。它能够学习一个函数/模型,使得输出数据的分布尽可能的逼近原始数据分布,其基本思路是:把一堆真实样本通过编码器网络变换成一个理想的数据分布,然后这个数据分布再传递给一个解码器网络,得到一堆生成样本,生成样本与真实样本足够接近的话,就训练出了一个VAE模型.

下图中的公式,前半部分计算的是重建误差,可以理解为MSE或者是Cross Entropy,而后半部分KL是散度的公式,主要是计算q分布与p分布的相似度。

那么公式的目标就是重建误差越小越好,q和p的分布越接近越好。

2. KL的计算公式

reparametrize trick

按照上图推导公式实现即可。

3. 构建网络

根据公式可以知道,前半部分计算的是重建误差,后半部分是KL,再根据reparametrize trick分别计算z和epison~N(0, 1)

先将encode的[b,20],切分为两个[b,10]分别作为μ和σ,通过μ和σ计算z值,代码如下:

python 复制代码
 mu, sigma = h_.chunk(2, dim=1)
 # reparametrize trick, epison~N(0, 1)
 z = mu + sigma * torch.randn_like(sigma)

计算KL,根据2中的推导公式写代码即可,代码中batchsz2828意思是计算像素级的kld,1e-8是防止log函数变量为0时,趋于无穷大,这里起到限幅的作用

python 复制代码
kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

完整代码:

python 复制代码
import  torch
from    torch import nn

class VAE(nn.Module):

    def __init__(self):
        super(VAE, self).__init__()

        # [b, 784] => [b, 20]
        # u: [b, 10]
        # sigma: [b, 10]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )
        # [b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

        self.criteon = nn.MSELoss()

    def forward(self, x):
        """

        :param x: [b, 1, 28, 28]
        :return:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, 784)
        # encoder
        # [b, 20], including mean and sigma
        h_ = self.encoder(x)
        # [b, 20] => [b, 10] and [b, 10]
        mu, sigma = h_.chunk(2, dim=1)
        # reparametrize trick, epison~N(0, 1)
        h = mu + sigma * torch.randn_like(sigma)

        # decoder
        x_hat = self.decoder(h)
        # reshape
        x_hat = x_hat.view(batchsz, 1, 28, 28)

        kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

        return x_hat, kld

4. 模型训练

与上一篇的AutoEncoders步骤相近,这里不再详述

python 复制代码
import  torch
from    torch.utils.data import DataLoader
from    torch import nn, optim
from    torchvision import transforms, datasets

from    ae import AE
from    vae import VAE

import  visdom

def main():
    mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    x, _ = iter(mnist_train).next()
    print('x:', x.shape)

    device = torch.device('cuda')
    # model = AE().to(device)
    model = VAE().to(device)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    viz = visdom.Visdom()

    for epoch in range(1000):
        for batchidx, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            x = x.to(device)

            x_hat, kld = model(x)
            loss = criteon(x_hat, x)

            if kld is not None:
                elbo = - loss - 1.0 * kld
                loss = - elbo

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


        print(epoch, 'loss:', loss.item(), 'kld:', kld.item())

        x, _ = iter(mnist_test).next()
        x = x.to(device)
        with torch.no_grad():
            x_hat, kld = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))

if __name__ == '__main__':
    main()
相关推荐
一个天蝎座 白勺 程序猿2 分钟前
Python爬虫(29)Python爬虫高阶:动态页面处理与云原生部署全链路实践(Selenium、Scrapy、K8s)
redis·爬虫·python·selenium·scrapy·云原生·k8s
90后小陈老师3 分钟前
WebXR教学 09 项目7 使用python从0搭建一个简易个人博客
开发语言·python·web
weixin-WNXZ021815 分钟前
闲上淘 自动上货工具运行原理解析
爬虫·python·自动化·软件工程·软件需求
正在走向自律30 分钟前
Conda 完全指南:从环境管理到工具集成
开发语言·python·conda·numpy·fastapi·pip·开发工具
东临碣石8239 分钟前
【AI论文】EnerVerse-AC:用行动条件来构想具身环境
人工智能
lqjun08271 小时前
PyTorch实现CrossEntropyLoss示例
人工智能·pytorch·python
心灵彼岸-诗和远方1 小时前
芯片生态链深度解析(三):芯片设计篇——数字文明的造物主战争
人工智能·制造
DpHard1 小时前
Vscode 配置python调试环境
ide·vscode·python
小蜗笔记1 小时前
显卡、Cuda和pytorch兼容问题
人工智能·pytorch·python
高建伟-joe1 小时前
内容安全:使用开源框架Caffe实现上传图片进行敏感内容识别
人工智能·python·深度学习·flask·开源·html5·caffe