pytorch-AutoEncoders实战之VAE

目录

  • [1. VAE回顾](#1. VAE回顾)
  • [2. KL的计算公式](#2. KL的计算公式)
  • [3. 构建网络](#3. 构建网络)
  • [4. 模型训练](#4. 模型训练)

1. VAE回顾

VAE = Variational Auto Encoder,变分自编码器。是一种常见的生成模型,属于无监督学习的范畴。它能够学习一个函数/模型,使得输出数据的分布尽可能的逼近原始数据分布,其基本思路是:把一堆真实样本通过编码器网络变换成一个理想的数据分布,然后这个数据分布再传递给一个解码器网络,得到一堆生成样本,生成样本与真实样本足够接近的话,就训练出了一个VAE模型.

下图中的公式,前半部分计算的是重建误差,可以理解为MSE或者是Cross Entropy,而后半部分KL是散度的公式,主要是计算q分布与p分布的相似度。

那么公式的目标就是重建误差越小越好,q和p的分布越接近越好。

2. KL的计算公式

reparametrize trick

按照上图推导公式实现即可。

3. 构建网络

根据公式可以知道,前半部分计算的是重建误差,后半部分是KL,再根据reparametrize trick分别计算z和epison~N(0, 1)

先将encode的[b,20],切分为两个[b,10]分别作为μ和σ,通过μ和σ计算z值,代码如下:

python 复制代码
 mu, sigma = h_.chunk(2, dim=1)
 # reparametrize trick, epison~N(0, 1)
 z = mu + sigma * torch.randn_like(sigma)

计算KL,根据2中的推导公式写代码即可,代码中batchsz2828意思是计算像素级的kld,1e-8是防止log函数变量为0时,趋于无穷大,这里起到限幅的作用

python 复制代码
kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

完整代码:

python 复制代码
import  torch
from    torch import nn

class VAE(nn.Module):

    def __init__(self):
        super(VAE, self).__init__()

        # [b, 784] => [b, 20]
        # u: [b, 10]
        # sigma: [b, 10]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )
        # [b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

        self.criteon = nn.MSELoss()

    def forward(self, x):
        """

        :param x: [b, 1, 28, 28]
        :return:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, 784)
        # encoder
        # [b, 20], including mean and sigma
        h_ = self.encoder(x)
        # [b, 20] => [b, 10] and [b, 10]
        mu, sigma = h_.chunk(2, dim=1)
        # reparametrize trick, epison~N(0, 1)
        h = mu + sigma * torch.randn_like(sigma)

        # decoder
        x_hat = self.decoder(h)
        # reshape
        x_hat = x_hat.view(batchsz, 1, 28, 28)

        kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

        return x_hat, kld

4. 模型训练

与上一篇的AutoEncoders步骤相近,这里不再详述

python 复制代码
import  torch
from    torch.utils.data import DataLoader
from    torch import nn, optim
from    torchvision import transforms, datasets

from    ae import AE
from    vae import VAE

import  visdom

def main():
    mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    x, _ = iter(mnist_train).next()
    print('x:', x.shape)

    device = torch.device('cuda')
    # model = AE().to(device)
    model = VAE().to(device)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    viz = visdom.Visdom()

    for epoch in range(1000):
        for batchidx, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            x = x.to(device)

            x_hat, kld = model(x)
            loss = criteon(x_hat, x)

            if kld is not None:
                elbo = - loss - 1.0 * kld
                loss = - elbo

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


        print(epoch, 'loss:', loss.item(), 'kld:', kld.item())

        x, _ = iter(mnist_test).next()
        x = x.to(device)
        with torch.no_grad():
            x_hat, kld = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))

if __name__ == '__main__':
    main()
相关推荐
DKNG4 小时前
【Windows Host】 hosts配置增加访问github流畅度
人工智能·git·github
white-persist4 小时前
【攻防世界】reverse | simple-check-100 详细题解 WP
c语言·开发语言·汇编·数据结构·c++·python·算法
昨日之日20064 小时前
Fun-ASR - 多语言多方言的高精度语音识别软件 支持50系显卡 一键整合包下载
人工智能·音视频·语音识别
王大傻09284 小时前
Series的属性简介
python·pandas
AIGC科技4 小时前
焕新而来,境由AI生|AIRender升级更名“渲境AI”,重新定义设计渲染效率
人工智能·深度学习·图形渲染
出来吧皮卡丘4 小时前
A2UI:让 AI Agent 自主构建用户界面的新范式
前端·人工智能·aigc
nju_spy5 小时前
深度强化学习 TRPO 置信域策略优化实验(sb3_contrib / 手搓 + CartPole-v1 / Breakout-v5)
人工智能·强化学习·共轭梯度法·策略网络·trpo·sb3_contrib·breakout游戏
A0_張張5 小时前
记录一个PDF盖章工具(PyQt5 + PyMuPDF)
开发语言·python·qt·pdf
Faker66363aaa5 小时前
Arive-Dantu叶片识别系统:基于cascade-mask-rcnn_regnetx-400MF_fpn_ms-3x_coco模型实现_1
python
程序员欣宸5 小时前
LangChain4j实战之四:集成到spring-boot
java·人工智能·spring boot