一、导入第三方库
python
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset,DataLoader
二、手写数字数据集准备

python
#手写数字数据集
class MINISTDataset(Dataset):
def __init__(self,files,root_dir,transform=None):
self.files=files
self.root_dir=root_dir
self.transform=transform
self.labels=[]
for f in files:
parts=f.split("_")
p=parts[2].split(".")[0]
self.labels.append(int(p))
def __len__(self):
return len(self.files)
def __getitem__(self,idx):
img_path=os.path.join(self.root_dir,self.files[idx])
img=Image.open(img_path).convert("L")
if self.transform:
img=self.transform(img)
label=self.labels[idx]
return img,label
三、VAE模型的pytorch代码
python
#编码器
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Sequential(
nn.Conv2d(1,10,kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv2=nn.Sequential(
nn.Conv2d(10,20,kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.fc1=nn.Linear(320,160)
self.fc21=nn.Linear(160,80) #均值
self.fc22=nn.Linear(160,80) #方差
self.relu=nn.ReLU()
def forward(self,x):
batch_size=x.size(0)
x=self.conv1(x)
x=self.conv2(x)
x=x.view(batch_size,-1)
h=self.relu(self.fc1(x))
mu=self.fc21(h)
log_var=self.fc22(h)
return mu,log_var
#解码器
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.main=nn.Sequential(
nn.Linear(80,160),
nn.ReLU(),
nn.Linear(160,320),
nn.ReLU(),
nn.Linear(320,28*28),
nn.Sigmoid()
)
def forward(self,z):
return self.main(z)
#变分自编码器
class VAE(nn.Module):
def __init__(self,encoder,decoder):
super().__init__()
self.encoder=encoder
self.decoder=decoder
#重参数化
def reparameterize(self,mu,log_var):
std=torch.exp(0.5*log_var) #计算标准差
eps=torch.randn_like(std) #从标准正态分布中采样噪声
z=mu+eps*std #重参数化
return z
def forward(self,x):
mu,log_var=self.encoder(x)
z=self.reparameterize(mu,log_var)
return self.decoder(z),mu,log_var
四、主程序
python
if __name__=="__main__":
#对数据做归一化处理
transforms=transforms.Compose([
transforms.Resize((28,28)),
transforms.ToTensor()
])
#路径
base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'
train_dir=os.path.join(base_dir,"minist_train")
#获取文件夹里图像的名称
train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]
#创建数据集和数据加载器
train_dataset=MINISTDataset(train_files,train_dir,transform=transforms)
train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)
#参数
num_epochs=50
lr=0.001
#模型初始化
encoder=Encoder()
decoder=Decoder()
vae=VAE(encoder,decoder)
criterion=nn.BCELoss()
optimizer=optim.Adam(vae.parameters(),lr=lr,betas=(0.5,0.999))
#记录损失函数值
epoch_loss=[]
for epoch in range(num_epochs):
total_loss=0.0
for data in train_loader:
images,_=data
#images=images.view(images.size(0),-1)
optimizer.zero_grad()
outputs,mu,logvar=vae(images)
#计算重构损失和KL散度
reconstruction_loss=criterion(outputs,images.view(images.size(0),-1))
kl_divergence=-0.5*torch.mean(1+logvar-mu.pow(2)-logvar.exp())
loss=reconstruction_loss+0.1*kl_divergence
loss.backward()
optimizer.step()
total_loss+=loss.item()
avg_loss=total_loss/len(train_loader)
epoch_loss.append(avg_loss)
print("Epoch",epoch," Loss:",avg_loss)
#生成新图像
with torch.no_grad():
if (epoch+1)%5==0:
z=torch.randn(9,80)
plt.figure(figsize=(9,9))
for i in range(9):
plt.subplot(3,3,i+1)
plt.imshow(decoder(z[i]).view(28,28),cmap="gray")
plt.axis("off")
name=f"vae_gen_img_{epoch}.jpg"
gen_name=os.path.join("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_img",name)
plt.savefig(gen_name,dpi=300)
plt.close()
#绘制损失函数曲线图
plt.figure(figsize=(12,6))
plt.plot(epoch_loss,color="tomato")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.title("损失函数曲线图")
plt.legend()
plt.grid()
plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_loss.jpg")
plt.close()
五、运行结果
5.1 损失函数曲线图

5.2 生成的图像
这里只展示一部分

vae_gen_img_4.jpg

vae_gen_img_29.jpg

vae_gen_img_49.jpg
六、VAE的完整代码
python
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset,DataLoader
#手写数字数据集
class MINISTDataset(Dataset):
def __init__(self,files,root_dir,transform=None):
self.files=files
self.root_dir=root_dir
self.transform=transform
self.labels=[]
for f in files:
parts=f.split("_")
p=parts[2].split(".")[0]
self.labels.append(int(p))
def __len__(self):
return len(self.files)
def __getitem__(self,idx):
img_path=os.path.join(self.root_dir,self.files[idx])
img=Image.open(img_path).convert("L")
if self.transform:
img=self.transform(img)
label=self.labels[idx]
return img,label
#编码器
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Sequential(
nn.Conv2d(1,10,kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv2=nn.Sequential(
nn.Conv2d(10,20,kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.fc1=nn.Linear(320,160)
self.fc21=nn.Linear(160,80) #均值
self.fc22=nn.Linear(160,80) #方差
self.relu=nn.ReLU()
def forward(self,x):
batch_size=x.size(0)
x=self.conv1(x)
x=self.conv2(x)
x=x.view(batch_size,-1)
h=self.relu(self.fc1(x))
mu=self.fc21(h)
log_var=self.fc22(h)
return mu,log_var
#解码器
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.main=nn.Sequential(
nn.Linear(80,160),
nn.ReLU(),
nn.Linear(160,320),
nn.ReLU(),
nn.Linear(320,28*28),
nn.Sigmoid()
)
def forward(self,z):
return self.main(z)
#变分自编码器
class VAE(nn.Module):
def __init__(self,encoder,decoder):
super().__init__()
self.encoder=encoder
self.decoder=decoder
#重参数化
def reparameterize(self,mu,log_var):
std=torch.exp(0.5*log_var) #计算标准差
eps=torch.randn_like(std) #从标准正态分布中采样噪声
z=mu+eps*std #重参数化
return z
def forward(self,x):
mu,log_var=self.encoder(x)
z=self.reparameterize(mu,log_var)
return self.decoder(z),mu,log_var
if __name__=="__main__":
#对数据做归一化处理
transforms=transforms.Compose([
transforms.Resize((28,28)),
transforms.ToTensor()
])
#路径
base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'
train_dir=os.path.join(base_dir,"minist_train")
#获取文件夹里图像的名称
train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]
#创建数据集和数据加载器
train_dataset=MINISTDataset(train_files,train_dir,transform=transforms)
train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)
#参数
num_epochs=50
lr=0.001
#模型初始化
encoder=Encoder()
decoder=Decoder()
vae=VAE(encoder,decoder)
criterion=nn.BCELoss()
optimizer=optim.Adam(vae.parameters(),lr=lr,betas=(0.5,0.999))
#记录损失函数值
epoch_loss=[]
for epoch in range(num_epochs):
total_loss=0.0
for data in train_loader:
images,_=data
#images=images.view(images.size(0),-1)
optimizer.zero_grad()
outputs,mu,logvar=vae(images)
#计算重构损失和KL散度
reconstruction_loss=criterion(outputs,images.view(images.size(0),-1))
kl_divergence=-0.5*torch.mean(1+logvar-mu.pow(2)-logvar.exp())
loss=reconstruction_loss+0.1*kl_divergence
loss.backward()
optimizer.step()
total_loss+=loss.item()
avg_loss=total_loss/len(train_loader)
epoch_loss.append(avg_loss)
print("Epoch",epoch," Loss:",avg_loss)
#生成新图像
with torch.no_grad():
if (epoch+1)%5==0:
z=torch.randn(9,80)
plt.figure(figsize=(9,9))
for i in range(9):
plt.subplot(3,3,i+1)
plt.imshow(decoder(z[i]).view(28,28),cmap="gray")
plt.axis("off")
name=f"vae_gen_img_{epoch}.jpg"
gen_name=os.path.join("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_img",name)
plt.savefig(gen_name,dpi=300)
plt.close()
#绘制损失函数曲线图
plt.figure(figsize=(12,6))
plt.plot(epoch_loss,color="tomato")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.title("损失函数曲线图")
plt.legend()
plt.grid()
plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_loss.jpg")
plt.close()