变分自编码器VAE的Pytorch实现

一、导入第三方库

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False   # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset,DataLoader

二、手写数字数据集准备

python 复制代码
#手写数字数据集
class MINISTDataset(Dataset):
    def __init__(self,files,root_dir,transform=None):
        self.files=files
        self.root_dir=root_dir
        self.transform=transform
        self.labels=[]
        for f in files:
            parts=f.split("_")
            p=parts[2].split(".")[0]
            self.labels.append(int(p))

    def __len__(self):
        return len(self.files)

    def __getitem__(self,idx):
        img_path=os.path.join(self.root_dir,self.files[idx])
        img=Image.open(img_path).convert("L")

        if self.transform:
            img=self.transform(img)

        label=self.labels[idx]
        return img,label

三、VAE模型的pytorch代码

python 复制代码
#编码器
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(1,10,kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.conv2=nn.Sequential(
            nn.Conv2d(10,20,kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.fc1=nn.Linear(320,160)
        self.fc21=nn.Linear(160,80)  #均值
        self.fc22=nn.Linear(160,80)  #方差
        self.relu=nn.ReLU()

    def forward(self,x):
        batch_size=x.size(0)
        x=self.conv1(x)
        x=self.conv2(x)
        x=x.view(batch_size,-1)
        h=self.relu(self.fc1(x))
        mu=self.fc21(h)
        log_var=self.fc22(h)
        return mu,log_var

#解码器
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.main=nn.Sequential(
            nn.Linear(80,160),
            nn.ReLU(),
            nn.Linear(160,320),
            nn.ReLU(),
            nn.Linear(320,28*28),
            nn.Sigmoid()
        )

    def forward(self,z):
        return self.main(z)

#变分自编码器
class VAE(nn.Module):
    def __init__(self,encoder,decoder):
        super().__init__()
        self.encoder=encoder
        self.decoder=decoder

    #重参数化
    def reparameterize(self,mu,log_var):
        std=torch.exp(0.5*log_var)  #计算标准差
        eps=torch.randn_like(std)   #从标准正态分布中采样噪声
        z=mu+eps*std  #重参数化
        return z

    def forward(self,x):
        mu,log_var=self.encoder(x)
        z=self.reparameterize(mu,log_var)
        return self.decoder(z),mu,log_var

四、主程序

python 复制代码
if __name__=="__main__":

    #对数据做归一化处理
    transforms=transforms.Compose([
        transforms.Resize((28,28)),
        transforms.ToTensor()
    ])

    #路径
    base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'
    train_dir=os.path.join(base_dir,"minist_train")

    #获取文件夹里图像的名称
    train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]

    #创建数据集和数据加载器
    train_dataset=MINISTDataset(train_files,train_dir,transform=transforms)
    train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)

    #参数
    num_epochs=50
    lr=0.001

    #模型初始化
    encoder=Encoder()
    decoder=Decoder()
    vae=VAE(encoder,decoder)
    criterion=nn.BCELoss()
    optimizer=optim.Adam(vae.parameters(),lr=lr,betas=(0.5,0.999))

    #记录损失函数值
    epoch_loss=[]

    for epoch in range(num_epochs):
        total_loss=0.0

        for data in train_loader:
            images,_=data
            #images=images.view(images.size(0),-1)

            optimizer.zero_grad()

            outputs,mu,logvar=vae(images)

            #计算重构损失和KL散度
            reconstruction_loss=criterion(outputs,images.view(images.size(0),-1))
            kl_divergence=-0.5*torch.mean(1+logvar-mu.pow(2)-logvar.exp())

            loss=reconstruction_loss+0.1*kl_divergence

            loss.backward()
            optimizer.step()

            total_loss+=loss.item()

        avg_loss=total_loss/len(train_loader)
        epoch_loss.append(avg_loss)

        print("Epoch",epoch,"  Loss:",avg_loss)

        #生成新图像
        with torch.no_grad():
            if (epoch+1)%5==0:
                z=torch.randn(9,80)
                plt.figure(figsize=(9,9))
                for i in range(9):
                    plt.subplot(3,3,i+1)
                    plt.imshow(decoder(z[i]).view(28,28),cmap="gray")
                    plt.axis("off")
                name=f"vae_gen_img_{epoch}.jpg"
                gen_name=os.path.join("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_img",name)
                plt.savefig(gen_name,dpi=300)
                plt.close()

    #绘制损失函数曲线图
    plt.figure(figsize=(12,6))
    plt.plot(epoch_loss,color="tomato")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.title("损失函数曲线图")
    plt.legend()
    plt.grid()
    plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_loss.jpg")
    plt.close()

五、运行结果

5.1 损失函数曲线图

5.2 生成的图像

这里只展示一部分

vae_gen_img_4.jpg

vae_gen_img_29.jpg

vae_gen_img_49.jpg

六、VAE的完整代码

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False   # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset,DataLoader

#手写数字数据集
class MINISTDataset(Dataset):
    def __init__(self,files,root_dir,transform=None):
        self.files=files
        self.root_dir=root_dir
        self.transform=transform
        self.labels=[]
        for f in files:
            parts=f.split("_")
            p=parts[2].split(".")[0]
            self.labels.append(int(p))

    def __len__(self):
        return len(self.files)

    def __getitem__(self,idx):
        img_path=os.path.join(self.root_dir,self.files[idx])
        img=Image.open(img_path).convert("L")

        if self.transform:
            img=self.transform(img)

        label=self.labels[idx]
        return img,label

#编码器
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(1,10,kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.conv2=nn.Sequential(
            nn.Conv2d(10,20,kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.fc1=nn.Linear(320,160)
        self.fc21=nn.Linear(160,80)  #均值
        self.fc22=nn.Linear(160,80)  #方差
        self.relu=nn.ReLU()

    def forward(self,x):
        batch_size=x.size(0)
        x=self.conv1(x)
        x=self.conv2(x)
        x=x.view(batch_size,-1)
        h=self.relu(self.fc1(x))
        mu=self.fc21(h)
        log_var=self.fc22(h)
        return mu,log_var

#解码器
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.main=nn.Sequential(
            nn.Linear(80,160),
            nn.ReLU(),
            nn.Linear(160,320),
            nn.ReLU(),
            nn.Linear(320,28*28),
            nn.Sigmoid()
        )

    def forward(self,z):
        return self.main(z)

#变分自编码器
class VAE(nn.Module):
    def __init__(self,encoder,decoder):
        super().__init__()
        self.encoder=encoder
        self.decoder=decoder

    #重参数化
    def reparameterize(self,mu,log_var):
        std=torch.exp(0.5*log_var)  #计算标准差
        eps=torch.randn_like(std)   #从标准正态分布中采样噪声
        z=mu+eps*std  #重参数化
        return z

    def forward(self,x):
        mu,log_var=self.encoder(x)
        z=self.reparameterize(mu,log_var)
        return self.decoder(z),mu,log_var

if __name__=="__main__":

    #对数据做归一化处理
    transforms=transforms.Compose([
        transforms.Resize((28,28)),
        transforms.ToTensor()
    ])

    #路径
    base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'
    train_dir=os.path.join(base_dir,"minist_train")

    #获取文件夹里图像的名称
    train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]

    #创建数据集和数据加载器
    train_dataset=MINISTDataset(train_files,train_dir,transform=transforms)
    train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)

    #参数
    num_epochs=50
    lr=0.001

    #模型初始化
    encoder=Encoder()
    decoder=Decoder()
    vae=VAE(encoder,decoder)
    criterion=nn.BCELoss()
    optimizer=optim.Adam(vae.parameters(),lr=lr,betas=(0.5,0.999))

    #记录损失函数值
    epoch_loss=[]

    for epoch in range(num_epochs):
        total_loss=0.0

        for data in train_loader:
            images,_=data
            #images=images.view(images.size(0),-1)

            optimizer.zero_grad()

            outputs,mu,logvar=vae(images)

            #计算重构损失和KL散度
            reconstruction_loss=criterion(outputs,images.view(images.size(0),-1))
            kl_divergence=-0.5*torch.mean(1+logvar-mu.pow(2)-logvar.exp())

            loss=reconstruction_loss+0.1*kl_divergence

            loss.backward()
            optimizer.step()

            total_loss+=loss.item()

        avg_loss=total_loss/len(train_loader)
        epoch_loss.append(avg_loss)

        print("Epoch",epoch,"  Loss:",avg_loss)

        #生成新图像
        with torch.no_grad():
            if (epoch+1)%5==0:
                z=torch.randn(9,80)
                plt.figure(figsize=(9,9))
                for i in range(9):
                    plt.subplot(3,3,i+1)
                    plt.imshow(decoder(z[i]).view(28,28),cmap="gray")
                    plt.axis("off")
                name=f"vae_gen_img_{epoch}.jpg"
                gen_name=os.path.join("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_img",name)
                plt.savefig(gen_name,dpi=300)
                plt.close()

    #绘制损失函数曲线图
    plt.figure(figsize=(12,6))
    plt.plot(epoch_loss,color="tomato")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.title("损失函数曲线图")
    plt.legend()
    plt.grid()
    plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_loss.jpg")
    plt.close()
相关推荐
MediaTea14 分钟前
Python 第三方库:Requests(HTTP 客户端)
开发语言·网络·python·网络协议·http
AI大法师20 分钟前
Python:PyQt5 全栈开发教程,构建跨平台桌面应用
python·pyqt
华科云商xiao徐25 分钟前
分布式爬虫双核引擎:Java大脑+Python触手的完美协同
java·爬虫·python
LLM精进之路1 小时前
RCL 2025 | LLM采样机制的新视角:来自处方性偏移的解释
人工智能·深度学习·机器学习·语言模型·transformer
计算机毕业设计木哥1 小时前
计算机毕设大数据选题推荐 基于spark+Hadoop+python的贵州茅台股票数据分析系统【源码+文档+调试】
大数据·hadoop·python·计算机网络·spark·课程设计
Re_draw_debubu1 小时前
torchvision中数据集的使用与DataLoader 小土堆pytorch记录
pytorch·python·小土堆
猫先生OVO1 小时前
shellgpt
python
数据智能老司机1 小时前
GPU 编程实战——使用 PyCUDA 与 CuPy 功能
人工智能·python·gpu
Cl_rown去掉l变成C2 小时前
第R5周:天气预测
人工智能·python·深度学习·算法·tensorflow2