[PyTorch][chapter 54][Variational Auto-Encoder 实战]

前言:

这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器

其训练效果如下

训练的过程中要注意调节forward 中的kle ,调参。

整个工程两个文件:

vae.py

main.py

目录:

  1. vae
  2. main

一 vae

文件名: vae.py

作用: Variational Autoencoders (VAE)

训练的过程中加入一些限制,使它的latent space规则一点呢。于是就引入了variational autoencoder(VAE) ,它被定义为一个有规律地训练以避免过度拟合的Autoencoder,可以确保潜在空间具有良好的属性从而实现内容的生成。

variational autoencoder的架构和Autoencoder差不多,区别在于不再是把输入当作一个点,而是把输入当成一个分布。

复制代码
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:19:19 2023

@author: chengxf2
"""

import torch
from torch import nn

#ae: AutoEncoder

class VAE(nn.Module):
    
    def __init__(self,hidden_size=20):
        
        super(VAE, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(in_features=784, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=hidden_size),
            nn.ReLU()
            )
         # hidden [batch_size, 10]
         
        h_dim = int(hidden_size/2)
        self.hDim = h_dim

        self.decoder = nn.Sequential(
             nn.Linear(in_features=h_dim, out_features=64),
             nn.ReLU(),
             nn.Linear(in_features=64, out_features=128),
             nn.ReLU(),
             nn.Linear(in_features=128, out_features=256),
             nn.ReLU(),
             nn.Linear(in_features=256, out_features=784),
             nn.Sigmoid()
             )
        
        
    def forward(self, x):
            '''
            param x:[batch, 1,28,28]
            return 
        
            '''
      
            batchSz= x.size(0)
            #flatten
            x = x.view(batchSz, 784)
            
            #encoder
            h= self.encoder(x)
     
            #在给定维度上对所给张量进行分块,前一半的神经元看作u, 后一般的神经元看作sigma
            u, sigma = h.chunk(2,dim=1)
            
            #Reparameterize trick:
            #randn_like:产生一个正太分布 ~ N(0,1)
            #h.shape [batchSize,self.hDim]
            h = u+sigma* torch.randn_like(sigma)
           
            #kld :1e-8 防止sigma 平方为0
            kld = 0.5*torch.sum(
                torch.pow(u,2)+
                torch.pow(sigma,2)-
                torch.log(1e-8+torch.pow(sigma,2))-
                1
                )
            
            #MSE loss 是平均loss, 所以kld 也要算一个平均值
            kld = kld/(batchSz*32*32)
            xHat =   self.decoder(h)
            
            #reshape
            xHat = xHat.view(batchSz,1,28,28)
            
            return xHat,kld

二 main

文件名: main.py

作用: 训练,测试数据集

复制代码
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:24:10 2023

@author: chengxf2
"""

import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import time
from torch import optim,nn
from vae import VAE
import visdom





def main():
   
   batchNum = 32
   lr = 1e-3
   epochs = 20
   device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
   torch.manual_seed(1234)
   viz = visdom.Visdom()
   viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))

    
   

   tf= transforms.Compose([ transforms.ToTensor()])
   mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)
   train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)
   
   mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)
   test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)
   global_step =0

   
   

   
  
   model =VAE().to(device)
   criteon = nn.MSELoss().to(device) #损失函数
   optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则
   
   print("\n ----main-----")
   for epoch in range(epochs):
       
       start = time.perf_counter()
       for step ,(x,y) in enumerate(train_data):
           #[b,1,28,28]
           x = x.to(device)
           x_hat,kld = model(x)
           
           loss = criteon(x_hat, x)
           
           if kld is not None:
              
               
               elbo = -loss -1.0*kld
               loss = -elbo
           #backprop
           optimizer.zero_grad()
           loss.backward()
           optimizer.step()
           viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
           global_step +=1



    
       end = time.perf_counter()    
       interval = int(end - start)
  
       print("epoch: %d"%epoch, "\t 训练时间 %d"%interval, '\t 总loss: %4.7f'%loss.item(),"\t KL divergence: %4.7f"%kld.item())
       
       x,target = iter(test_data).next()
       x = x.to(device)
       with torch.no_grad():
           x_hat,kld = model(x)
       
       tip = 'hat'+str(epoch)
       viz.images(x,nrow=8, win='x',opts=dict(title='x'))
       viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))
           
           
           
           
   

if __name__ == '__main__':
    
    main()

参考:

课时118 变分Auto-Encoder实战-2_哔哩哔哩_bilibili

相关推荐
美狐美颜sdk1 小时前
跨平台直播美颜SDK集成实录:Android/iOS如何适配贴纸功能
android·人工智能·ios·架构·音视频·美颜sdk·第三方美颜sdk
DeepSeek-大模型系统教程1 小时前
推荐 7 个本周 yyds 的 GitHub 项目。
人工智能·ai·语言模型·大模型·github·ai大模型·大模型学习
有Li1 小时前
通过具有一致性嵌入的大语言模型实现端到端乳腺癌放射治疗计划制定|文献速递-最新论文分享
论文阅读·深度学习·分类·医学生
郭庆汝2 小时前
pytorch、torchvision与python版本对应关系
人工智能·pytorch·python
小雷FansUnion4 小时前
深入理解MCP架构:智能服务编排、上下文管理与动态路由实战
人工智能·架构·大模型·mcp
资讯分享周4 小时前
扣子空间PPT生产力升级:AI智能生成与多模态创作新时代
人工智能·powerpoint
叶子爱分享5 小时前
计算机视觉与图像处理的关系
图像处理·人工智能·计算机视觉
鱼摆摆拜拜5 小时前
第 3 章:神经网络如何学习
人工智能·神经网络·学习
一只鹿鹿鹿5 小时前
信息化项目验收,软件工程评审和检查表单
大数据·人工智能·后端·智慧城市·软件工程
张较瘦_5 小时前
[论文阅读] 人工智能 | 深度学习系统崩溃恢复新方案:DaiFu框架的原位修复技术
论文阅读·人工智能·深度学习