[PyTorch][chapter 54][Variational Auto-Encoder 实战]

前言:

这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器

其训练效果如下

训练的过程中要注意调节forward 中的kle ,调参。

整个工程两个文件:

vae.py

main.py

目录:

  1. vae
  2. main

一 vae

文件名: vae.py

作用: Variational Autoencoders (VAE)

训练的过程中加入一些限制,使它的latent space规则一点呢。于是就引入了variational autoencoder(VAE) ,它被定义为一个有规律地训练以避免过度拟合的Autoencoder,可以确保潜在空间具有良好的属性从而实现内容的生成。

variational autoencoder的架构和Autoencoder差不多,区别在于不再是把输入当作一个点,而是把输入当成一个分布。

复制代码
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:19:19 2023

@author: chengxf2
"""

import torch
from torch import nn

#ae: AutoEncoder

class VAE(nn.Module):
    
    def __init__(self,hidden_size=20):
        
        super(VAE, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(in_features=784, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=hidden_size),
            nn.ReLU()
            )
         # hidden [batch_size, 10]
         
        h_dim = int(hidden_size/2)
        self.hDim = h_dim

        self.decoder = nn.Sequential(
             nn.Linear(in_features=h_dim, out_features=64),
             nn.ReLU(),
             nn.Linear(in_features=64, out_features=128),
             nn.ReLU(),
             nn.Linear(in_features=128, out_features=256),
             nn.ReLU(),
             nn.Linear(in_features=256, out_features=784),
             nn.Sigmoid()
             )
        
        
    def forward(self, x):
            '''
            param x:[batch, 1,28,28]
            return 
        
            '''
      
            batchSz= x.size(0)
            #flatten
            x = x.view(batchSz, 784)
            
            #encoder
            h= self.encoder(x)
     
            #在给定维度上对所给张量进行分块,前一半的神经元看作u, 后一般的神经元看作sigma
            u, sigma = h.chunk(2,dim=1)
            
            #Reparameterize trick:
            #randn_like:产生一个正太分布 ~ N(0,1)
            #h.shape [batchSize,self.hDim]
            h = u+sigma* torch.randn_like(sigma)
           
            #kld :1e-8 防止sigma 平方为0
            kld = 0.5*torch.sum(
                torch.pow(u,2)+
                torch.pow(sigma,2)-
                torch.log(1e-8+torch.pow(sigma,2))-
                1
                )
            
            #MSE loss 是平均loss, 所以kld 也要算一个平均值
            kld = kld/(batchSz*32*32)
            xHat =   self.decoder(h)
            
            #reshape
            xHat = xHat.view(batchSz,1,28,28)
            
            return xHat,kld

二 main

文件名: main.py

作用: 训练,测试数据集

复制代码
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:24:10 2023

@author: chengxf2
"""

import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import time
from torch import optim,nn
from vae import VAE
import visdom





def main():
   
   batchNum = 32
   lr = 1e-3
   epochs = 20
   device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
   torch.manual_seed(1234)
   viz = visdom.Visdom()
   viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))

    
   

   tf= transforms.Compose([ transforms.ToTensor()])
   mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)
   train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)
   
   mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)
   test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)
   global_step =0

   
   

   
  
   model =VAE().to(device)
   criteon = nn.MSELoss().to(device) #损失函数
   optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则
   
   print("\n ----main-----")
   for epoch in range(epochs):
       
       start = time.perf_counter()
       for step ,(x,y) in enumerate(train_data):
           #[b,1,28,28]
           x = x.to(device)
           x_hat,kld = model(x)
           
           loss = criteon(x_hat, x)
           
           if kld is not None:
              
               
               elbo = -loss -1.0*kld
               loss = -elbo
           #backprop
           optimizer.zero_grad()
           loss.backward()
           optimizer.step()
           viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
           global_step +=1



    
       end = time.perf_counter()    
       interval = int(end - start)
  
       print("epoch: %d"%epoch, "\t 训练时间 %d"%interval, '\t 总loss: %4.7f'%loss.item(),"\t KL divergence: %4.7f"%kld.item())
       
       x,target = iter(test_data).next()
       x = x.to(device)
       with torch.no_grad():
           x_hat,kld = model(x)
       
       tip = 'hat'+str(epoch)
       viz.images(x,nrow=8, win='x',opts=dict(title='x'))
       viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))
           
           
           
           
   

if __name__ == '__main__':
    
    main()

参考:

课时118 变分Auto-Encoder实战-2_哔哩哔哩_bilibili

相关推荐
文心快码BaiduComate25 分钟前
百度云与光本位签署战略合作:用AI Agent 重构芯片研发流程
前端·人工智能·架构
风象南1 小时前
Claude Code这个隐藏技能,让我告别PPT焦虑
人工智能·后端
Mintopia2 小时前
OpenClaw 对软件行业产生的影响
人工智能
陈广亮2 小时前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬3 小时前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia3 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区3 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两6 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
前端付豪6 小时前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain