变分自编码器VAE的Pytorch实现

一、导入第三方库

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False   # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset,DataLoader

二、手写数字数据集准备

python 复制代码
#手写数字数据集
class MINISTDataset(Dataset):
    def __init__(self,files,root_dir,transform=None):
        self.files=files
        self.root_dir=root_dir
        self.transform=transform
        self.labels=[]
        for f in files:
            parts=f.split("_")
            p=parts[2].split(".")[0]
            self.labels.append(int(p))

    def __len__(self):
        return len(self.files)

    def __getitem__(self,idx):
        img_path=os.path.join(self.root_dir,self.files[idx])
        img=Image.open(img_path).convert("L")

        if self.transform:
            img=self.transform(img)

        label=self.labels[idx]
        return img,label

三、VAE模型的pytorch代码

python 复制代码
#编码器
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(1,10,kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.conv2=nn.Sequential(
            nn.Conv2d(10,20,kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.fc1=nn.Linear(320,160)
        self.fc21=nn.Linear(160,80)  #均值
        self.fc22=nn.Linear(160,80)  #方差
        self.relu=nn.ReLU()

    def forward(self,x):
        batch_size=x.size(0)
        x=self.conv1(x)
        x=self.conv2(x)
        x=x.view(batch_size,-1)
        h=self.relu(self.fc1(x))
        mu=self.fc21(h)
        log_var=self.fc22(h)
        return mu,log_var

#解码器
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.main=nn.Sequential(
            nn.Linear(80,160),
            nn.ReLU(),
            nn.Linear(160,320),
            nn.ReLU(),
            nn.Linear(320,28*28),
            nn.Sigmoid()
        )

    def forward(self,z):
        return self.main(z)

#变分自编码器
class VAE(nn.Module):
    def __init__(self,encoder,decoder):
        super().__init__()
        self.encoder=encoder
        self.decoder=decoder

    #重参数化
    def reparameterize(self,mu,log_var):
        std=torch.exp(0.5*log_var)  #计算标准差
        eps=torch.randn_like(std)   #从标准正态分布中采样噪声
        z=mu+eps*std  #重参数化
        return z

    def forward(self,x):
        mu,log_var=self.encoder(x)
        z=self.reparameterize(mu,log_var)
        return self.decoder(z),mu,log_var

四、主程序

python 复制代码
if __name__=="__main__":

    #对数据做归一化处理
    transforms=transforms.Compose([
        transforms.Resize((28,28)),
        transforms.ToTensor()
    ])

    #路径
    base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'
    train_dir=os.path.join(base_dir,"minist_train")

    #获取文件夹里图像的名称
    train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]

    #创建数据集和数据加载器
    train_dataset=MINISTDataset(train_files,train_dir,transform=transforms)
    train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)

    #参数
    num_epochs=50
    lr=0.001

    #模型初始化
    encoder=Encoder()
    decoder=Decoder()
    vae=VAE(encoder,decoder)
    criterion=nn.BCELoss()
    optimizer=optim.Adam(vae.parameters(),lr=lr,betas=(0.5,0.999))

    #记录损失函数值
    epoch_loss=[]

    for epoch in range(num_epochs):
        total_loss=0.0

        for data in train_loader:
            images,_=data
            #images=images.view(images.size(0),-1)

            optimizer.zero_grad()

            outputs,mu,logvar=vae(images)

            #计算重构损失和KL散度
            reconstruction_loss=criterion(outputs,images.view(images.size(0),-1))
            kl_divergence=-0.5*torch.mean(1+logvar-mu.pow(2)-logvar.exp())

            loss=reconstruction_loss+0.1*kl_divergence

            loss.backward()
            optimizer.step()

            total_loss+=loss.item()

        avg_loss=total_loss/len(train_loader)
        epoch_loss.append(avg_loss)

        print("Epoch",epoch,"  Loss:",avg_loss)

        #生成新图像
        with torch.no_grad():
            if (epoch+1)%5==0:
                z=torch.randn(9,80)
                plt.figure(figsize=(9,9))
                for i in range(9):
                    plt.subplot(3,3,i+1)
                    plt.imshow(decoder(z[i]).view(28,28),cmap="gray")
                    plt.axis("off")
                name=f"vae_gen_img_{epoch}.jpg"
                gen_name=os.path.join("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_img",name)
                plt.savefig(gen_name,dpi=300)
                plt.close()

    #绘制损失函数曲线图
    plt.figure(figsize=(12,6))
    plt.plot(epoch_loss,color="tomato")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.title("损失函数曲线图")
    plt.legend()
    plt.grid()
    plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_loss.jpg")
    plt.close()

五、运行结果

5.1 损失函数曲线图

5.2 生成的图像

这里只展示一部分

vae_gen_img_4.jpg

vae_gen_img_29.jpg

vae_gen_img_49.jpg

六、VAE的完整代码

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False   # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset,DataLoader

#手写数字数据集
class MINISTDataset(Dataset):
    def __init__(self,files,root_dir,transform=None):
        self.files=files
        self.root_dir=root_dir
        self.transform=transform
        self.labels=[]
        for f in files:
            parts=f.split("_")
            p=parts[2].split(".")[0]
            self.labels.append(int(p))

    def __len__(self):
        return len(self.files)

    def __getitem__(self,idx):
        img_path=os.path.join(self.root_dir,self.files[idx])
        img=Image.open(img_path).convert("L")

        if self.transform:
            img=self.transform(img)

        label=self.labels[idx]
        return img,label

#编码器
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(1,10,kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.conv2=nn.Sequential(
            nn.Conv2d(10,20,kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.fc1=nn.Linear(320,160)
        self.fc21=nn.Linear(160,80)  #均值
        self.fc22=nn.Linear(160,80)  #方差
        self.relu=nn.ReLU()

    def forward(self,x):
        batch_size=x.size(0)
        x=self.conv1(x)
        x=self.conv2(x)
        x=x.view(batch_size,-1)
        h=self.relu(self.fc1(x))
        mu=self.fc21(h)
        log_var=self.fc22(h)
        return mu,log_var

#解码器
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.main=nn.Sequential(
            nn.Linear(80,160),
            nn.ReLU(),
            nn.Linear(160,320),
            nn.ReLU(),
            nn.Linear(320,28*28),
            nn.Sigmoid()
        )

    def forward(self,z):
        return self.main(z)

#变分自编码器
class VAE(nn.Module):
    def __init__(self,encoder,decoder):
        super().__init__()
        self.encoder=encoder
        self.decoder=decoder

    #重参数化
    def reparameterize(self,mu,log_var):
        std=torch.exp(0.5*log_var)  #计算标准差
        eps=torch.randn_like(std)   #从标准正态分布中采样噪声
        z=mu+eps*std  #重参数化
        return z

    def forward(self,x):
        mu,log_var=self.encoder(x)
        z=self.reparameterize(mu,log_var)
        return self.decoder(z),mu,log_var

if __name__=="__main__":

    #对数据做归一化处理
    transforms=transforms.Compose([
        transforms.Resize((28,28)),
        transforms.ToTensor()
    ])

    #路径
    base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'
    train_dir=os.path.join(base_dir,"minist_train")

    #获取文件夹里图像的名称
    train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]

    #创建数据集和数据加载器
    train_dataset=MINISTDataset(train_files,train_dir,transform=transforms)
    train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)

    #参数
    num_epochs=50
    lr=0.001

    #模型初始化
    encoder=Encoder()
    decoder=Decoder()
    vae=VAE(encoder,decoder)
    criterion=nn.BCELoss()
    optimizer=optim.Adam(vae.parameters(),lr=lr,betas=(0.5,0.999))

    #记录损失函数值
    epoch_loss=[]

    for epoch in range(num_epochs):
        total_loss=0.0

        for data in train_loader:
            images,_=data
            #images=images.view(images.size(0),-1)

            optimizer.zero_grad()

            outputs,mu,logvar=vae(images)

            #计算重构损失和KL散度
            reconstruction_loss=criterion(outputs,images.view(images.size(0),-1))
            kl_divergence=-0.5*torch.mean(1+logvar-mu.pow(2)-logvar.exp())

            loss=reconstruction_loss+0.1*kl_divergence

            loss.backward()
            optimizer.step()

            total_loss+=loss.item()

        avg_loss=total_loss/len(train_loader)
        epoch_loss.append(avg_loss)

        print("Epoch",epoch,"  Loss:",avg_loss)

        #生成新图像
        with torch.no_grad():
            if (epoch+1)%5==0:
                z=torch.randn(9,80)
                plt.figure(figsize=(9,9))
                for i in range(9):
                    plt.subplot(3,3,i+1)
                    plt.imshow(decoder(z[i]).view(28,28),cmap="gray")
                    plt.axis("off")
                name=f"vae_gen_img_{epoch}.jpg"
                gen_name=os.path.join("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_img",name)
                plt.savefig(gen_name,dpi=300)
                plt.close()

    #绘制损失函数曲线图
    plt.figure(figsize=(12,6))
    plt.plot(epoch_loss,color="tomato")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.title("损失函数曲线图")
    plt.legend()
    plt.grid()
    plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_loss.jpg")
    plt.close()
相关推荐
玉木子25 分钟前
机器学习(六)朴素贝叶斯分类
开发语言·人工智能·python·算法·机器学习·分类
Hi202402172 小时前
使用 darkSCNN 和 Caffe 进行车道线检测
人工智能·深度学习·opencv·自动驾驶·caffe·车道线检测
Dxy12393102162 小时前
Python如何处理非标准JSON
开发语言·python·json
q567315232 小时前
从开发到部署深度解析Go与Python爬虫利弊
爬虫·python·golang
996终结者3 小时前
Python数据分析与处理(二):将数据写回.mat文件的不同方法【超详细】
python·matlab·数据分析
MediaTea4 小时前
Python:正则表达式
开发语言·c++·python·正则表达式
siliconstorm.ai4 小时前
开源与闭源的再对决:从Grok到中国力量,AI生态走向何方?
大数据·图像处理·人工智能·语言模型·ai作画·云计算·机器翻译
zhong liu bin6 小时前
maven【maven】技术详解
java·ide·python·spring·maven·intellij-idea
IAM四十二6 小时前
基于 Embedding 实现一个本地相册搜索功能
人工智能·python·llm