目录
- [1. VAE回顾](#1. VAE回顾)
- [2. KL的计算公式](#2. KL的计算公式)
- [3. 构建网络](#3. 构建网络)
- [4. 模型训练](#4. 模型训练)
1. VAE回顾
VAE = Variational Auto Encoder,变分自编码器。是一种常见的生成模型,属于无监督学习的范畴。它能够学习一个函数/模型,使得输出数据的分布尽可能的逼近原始数据分布,其基本思路是:把一堆真实样本通过编码器网络变换成一个理想的数据分布,然后这个数据分布再传递给一个解码器网络,得到一堆生成样本,生成样本与真实样本足够接近的话,就训练出了一个VAE模型.
下图中的公式,前半部分计算的是重建误差,可以理解为MSE或者是Cross Entropy,而后半部分KL是散度的公式,主要是计算q分布与p分布的相似度。
那么公式的目标就是重建误差越小越好,q和p的分布越接近越好。
2. KL的计算公式
reparametrize trick
按照上图推导公式实现即可。
3. 构建网络
根据公式可以知道,前半部分计算的是重建误差,后半部分是KL,再根据reparametrize trick分别计算z和epison~N(0, 1)
先将encode的[b,20],切分为两个[b,10]分别作为μ和σ,通过μ和σ计算z值,代码如下:
python
mu, sigma = h_.chunk(2, dim=1)
# reparametrize trick, epison~N(0, 1)
z = mu + sigma * torch.randn_like(sigma)
计算KL,根据2中的推导公式写代码即可,代码中batchsz2828意思是计算像素级的kld,1e-8是防止log函数变量为0时,趋于无穷大,这里起到限幅的作用
python
kld = 0.5 * torch.sum(
torch.pow(mu, 2) +
torch.pow(sigma, 2) -
torch.log(1e-8 + torch.pow(sigma, 2)) - 1
) / (batchsz*28*28)
完整代码:
python
import torch
from torch import nn
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
# [b, 784] => [b, 20]
# u: [b, 10]
# sigma: [b, 10]
self.encoder = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, 20),
nn.ReLU()
)
# [b, 20] => [b, 784]
self.decoder = nn.Sequential(
nn.Linear(10, 64),
nn.ReLU(),
nn.Linear(64, 256),
nn.ReLU(),
nn.Linear(256, 784),
nn.Sigmoid()
)
self.criteon = nn.MSELoss()
def forward(self, x):
"""
:param x: [b, 1, 28, 28]
:return:
"""
batchsz = x.size(0)
# flatten
x = x.view(batchsz, 784)
# encoder
# [b, 20], including mean and sigma
h_ = self.encoder(x)
# [b, 20] => [b, 10] and [b, 10]
mu, sigma = h_.chunk(2, dim=1)
# reparametrize trick, epison~N(0, 1)
h = mu + sigma * torch.randn_like(sigma)
# decoder
x_hat = self.decoder(h)
# reshape
x_hat = x_hat.view(batchsz, 1, 28, 28)
kld = 0.5 * torch.sum(
torch.pow(mu, 2) +
torch.pow(sigma, 2) -
torch.log(1e-8 + torch.pow(sigma, 2)) - 1
) / (batchsz*28*28)
return x_hat, kld
4. 模型训练
与上一篇的AutoEncoders步骤相近,这里不再详述
python
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision import transforms, datasets
from ae import AE
from vae import VAE
import visdom
def main():
mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([
transforms.ToTensor()
]), download=True)
mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
transforms.ToTensor()
]), download=True)
mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)
x, _ = iter(mnist_train).next()
print('x:', x.shape)
device = torch.device('cuda')
# model = AE().to(device)
model = VAE().to(device)
criteon = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)
viz = visdom.Visdom()
for epoch in range(1000):
for batchidx, (x, _) in enumerate(mnist_train):
# [b, 1, 28, 28]
x = x.to(device)
x_hat, kld = model(x)
loss = criteon(x_hat, x)
if kld is not None:
elbo = - loss - 1.0 * kld
loss = - elbo
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch, 'loss:', loss.item(), 'kld:', kld.item())
x, _ = iter(mnist_test).next()
x = x.to(device)
with torch.no_grad():
x_hat, kld = model(x)
viz.images(x, nrow=8, win='x', opts=dict(title='x'))
viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))
if __name__ == '__main__':
main()