pytorch-AutoEncoders实战之VAE

目录

  • [1. VAE回顾](#1. VAE回顾)
  • [2. KL的计算公式](#2. KL的计算公式)
  • [3. 构建网络](#3. 构建网络)
  • [4. 模型训练](#4. 模型训练)

1. VAE回顾

VAE = Variational Auto Encoder,变分自编码器。是一种常见的生成模型,属于无监督学习的范畴。它能够学习一个函数/模型,使得输出数据的分布尽可能的逼近原始数据分布,其基本思路是:把一堆真实样本通过编码器网络变换成一个理想的数据分布,然后这个数据分布再传递给一个解码器网络,得到一堆生成样本,生成样本与真实样本足够接近的话,就训练出了一个VAE模型.

下图中的公式,前半部分计算的是重建误差,可以理解为MSE或者是Cross Entropy,而后半部分KL是散度的公式,主要是计算q分布与p分布的相似度。

那么公式的目标就是重建误差越小越好,q和p的分布越接近越好。

2. KL的计算公式

reparametrize trick

按照上图推导公式实现即可。

3. 构建网络

根据公式可以知道,前半部分计算的是重建误差,后半部分是KL,再根据reparametrize trick分别计算z和epison~N(0, 1)

先将encode的[b,20],切分为两个[b,10]分别作为μ和σ,通过μ和σ计算z值,代码如下:

python 复制代码
 mu, sigma = h_.chunk(2, dim=1)
 # reparametrize trick, epison~N(0, 1)
 z = mu + sigma * torch.randn_like(sigma)

计算KL,根据2中的推导公式写代码即可,代码中batchsz2828意思是计算像素级的kld,1e-8是防止log函数变量为0时,趋于无穷大,这里起到限幅的作用

python 复制代码
kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

完整代码:

python 复制代码
import  torch
from    torch import nn

class VAE(nn.Module):

    def __init__(self):
        super(VAE, self).__init__()

        # [b, 784] => [b, 20]
        # u: [b, 10]
        # sigma: [b, 10]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )
        # [b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

        self.criteon = nn.MSELoss()

    def forward(self, x):
        """

        :param x: [b, 1, 28, 28]
        :return:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, 784)
        # encoder
        # [b, 20], including mean and sigma
        h_ = self.encoder(x)
        # [b, 20] => [b, 10] and [b, 10]
        mu, sigma = h_.chunk(2, dim=1)
        # reparametrize trick, epison~N(0, 1)
        h = mu + sigma * torch.randn_like(sigma)

        # decoder
        x_hat = self.decoder(h)
        # reshape
        x_hat = x_hat.view(batchsz, 1, 28, 28)

        kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

        return x_hat, kld

4. 模型训练

与上一篇的AutoEncoders步骤相近,这里不再详述

python 复制代码
import  torch
from    torch.utils.data import DataLoader
from    torch import nn, optim
from    torchvision import transforms, datasets

from    ae import AE
from    vae import VAE

import  visdom

def main():
    mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    x, _ = iter(mnist_train).next()
    print('x:', x.shape)

    device = torch.device('cuda')
    # model = AE().to(device)
    model = VAE().to(device)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    viz = visdom.Visdom()

    for epoch in range(1000):
        for batchidx, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            x = x.to(device)

            x_hat, kld = model(x)
            loss = criteon(x_hat, x)

            if kld is not None:
                elbo = - loss - 1.0 * kld
                loss = - elbo

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


        print(epoch, 'loss:', loss.item(), 'kld:', kld.item())

        x, _ = iter(mnist_test).next()
        x = x.to(device)
        with torch.no_grad():
            x_hat, kld = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))

if __name__ == '__main__':
    main()
相关推荐
FreakStudio42 分钟前
全网最适合入门的面向对象编程教程:50 Python函数方法与接口-接口和抽象基类
python·嵌入式·面向对象·电子diy
redcocal2 小时前
地平线秋招
python·嵌入式硬件·算法·fpga开发·求职招聘
artificiali2 小时前
Anaconda配置pytorch的基本操作
人工智能·pytorch·python
RaidenQ3 小时前
2024.9.13 Python与图像处理新国大EE5731课程大作业,索贝尔算子计算边缘,高斯核模糊边缘,Haar小波计算边缘
图像处理·python·算法·课程设计
花生了什么树~.3 小时前
python基础知识(六)--字典遍历、公共运算符、公共方法、函数、变量分类、参数分类、拆包、引用
开发语言·python
酱香编程,风雨兼程3 小时前
深度学习——基础知识
人工智能·深度学习
Lossya3 小时前
【机器学习】参数学习的基本概念以及贝叶斯网络的参数学习和马尔可夫随机场的参数学习
人工智能·学习·机器学习·贝叶斯网络·马尔科夫随机场·参数学习
Trouvaille ~3 小时前
【Python篇】深度探索NumPy(下篇):从科学计算到机器学习的高效实战技巧
图像处理·python·机器学习·numpy·信号处理·时间序列分析·科学计算
爆更小小刘3 小时前
Python基础语法(3)下
开发语言·python
哪 吒3 小时前
华为OD机试 - 第 K 个字母在原来字符串的索引(Python/JS/C/C++ 2024 E卷 100分)
javascript·python·华为od