人脸图像生成(DCGAN)

深度卷积对抗网络(Deep Convolutional Generative Adversarial Networks)是一种深度学习模型,由生成器(Generator)和判别器(Discriminator)两个神经网络组成。DCGAN 结合了卷积神经网络和生成对抗网络的思想,用于生成逼真的图像。

一. 理论基础

1.DCGAN原理

深度卷积对抗网络是生成对抗网络的一种模型改进,其将卷积运算的思想引入到生成式模型当中来做无监督的训练,利用卷积网络强大的特征提取能力来提高生成网络的学习效果。DCGAN模型有以下特点:

  • 判别器模型使用了卷积步长取代了空间池化,生成器模型中使用了反卷积操作扩大数据维度。
  • 除了生成器模型的输出层和判别器模型的输入层,在整个对抗网络的其他层上都使用了Batch Normalization, 原因是Batch Normalization 可以稳定学习,有助于优化初始化参数值不良而导致的训练问题。
  • 整个网络去除了全连接层,直接使用卷积层连接生成器和判别器的输入层以及输出层。
  • 在生成器的输出层使用Tanh激活函数以控制输出范围,而在其他层中均使用了ReLU激活函数;在判别器上使用了Leaky ReLU激活函数。

图中所示了一种常见的DCGAN结构。主要包含了一个生成网络G 和一个判别网络 D,生成网络G 负责生成图像,它接受一个随机的噪声z,通过该噪声生成图像,将生成的图像记为G(z),判别网络D 负责判断一张图是否为真实,它的输入是x,代表一张图像,输出D(x)表示x为真实图像的概率。

实际上判别网络D是对数据的来源进行一个判别:究竟这个数据是来自真是的数据分布Pd(x)判别为"1",还是来自于一个生成网络G所产生的一个数据分布Pg(z)(判别为"0")。所以在整个训练过程中,生成网络G的目标是生成可以以假乱真的图像G(z),当判别网络D无法区分,即D(G(z))=0.5时,便得到了一个生成网络G用来生产图像扩充数据集。

二.前期准备

1.导入第三方库

python 复制代码
import torch,random,os
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from torch.autograd import Variable

manualSeed = 999
print("random seed:",manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True)

2.设置超参数

python 复制代码
dataroot = "/content/drive/MyDrive/GAN_Dataset"
batch_size = 128 #训练过程中的批次大小
image_size = 64 #图像的尺寸(宽度和高度)
nz = 100 # z潜在的向量大小(生成器输入的尺寸)
ngf = 64 # 生成器中的特征图大小
ndf = 64 #判别器中的特征图大小
num_epochs = 50 #训练的总论数
lr = 0.0002 #学习率
beta1=0.5 #adam 优化器的beta1超参数

3.导入数据

python 复制代码
dataset = dset.ImageFolder(root=dataroot,transform = transforms.Compose([transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]))
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=5)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:24],padding=2,normalize=True).cpu(),(1,2,0)))

三.定义模型

1.初始化权重

python 复制代码
def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    nn.init.normal_(m.weight.data,0.0,0.02)
  elif classname.find('BatchNorm')!=-1:
    nn.init.normal_(m.weight.data,1.0,0.02)
    nn.init.constant_(m.bias.data,0)

2.定义生成器

python 复制代码
class Generator(nn.Module):
  def __init__(self):
    super(Generator,self).__init__()
    self.main = nn.Sequential(
        nn.ConvTranspose2d(nz,ngf*8,4,1,0,bias=False),
        nn.BatchNorm2d(ngf*8),
        nn.ReLU(True),
        #输出尺寸:(ngf*8)x4x4
        nn.ConvTranspose2d(ngf*8,ngf*4,4,2,1,bias=False),
        nn.BatchNorm2d(ngf*4),
        nn.ReLU(True),
        #输出尺寸:(ngf*4)x8x8
        nn.ConvTranspose2d(ngf*4,ngf*2,4,2,1,bias=False),
        nn.BatchNorm2d(ngf*2),
        nn.ReLU(True),
        #输出尺寸:(ngf*2)x16x16
        nn.ConvTranspose2d(ngf*2,ngf,4,2,1,bias=False),
        nn.BatchNorm2d(ngf),
        nn.ReLU(True),
        #输出尺寸:(ngf)x32x32
        nn.ConvTranspose2d(ngf,3,4,2,1,bias=False),
        nn.Tanh()
        #输出尺寸:3x64x64
    )
  
  def forward(self,input):
    return self.main(input)
python 复制代码
#创建生成器
netG = Generator().to(device)
netG.apply(weights_init)
print(netG)

3.定义鉴别器

python 复制代码
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()
    self.main = nn.Sequential(
        nn.Conv2d(3,ndf,4,2,1,bias=False),
        nn.LeakyReLU(0.2,inplace=True),
        #输出尺寸:(ndf)x32x32
        nn.Conv2d(ndf,ndf*2,4,2,1,bias=False),
        nn.BatchNorm2d(ndf*2),
        nn.LeakyReLU(0.2,inplace=True),
        #输出尺寸:(ndf*2)x16x16
        nn.Conv2d(ndf*2,ndf*4,4,2,1,bias=False),
        nn.BatchNorm2d(ndf*4),
        nn.LeakyReLU(0.2,inplace=True),
        #输出尺寸:(ndf*4)x8x8
        nn.Conv2d(ndf*4,ndf*8,4,2,1,bias=False),
        nn.BatchNorm2d(ndf*8),
        nn.LeakyReLU(0.2,inplace=True),
        #输出尺寸:(ndf*8)x4x4
        nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
        nn.Sigmoid()
    )

  def forward(self,input):
    return self.main(input)
python 复制代码
#创建判别器模型
netD = Discriminator().to(device)
netD.apply(weights_init)
print(netD)

四:训练模型

1.定义训练参数

python 复制代码
criterion = nn.BCELoss()
fixed_noise = torch.randn(64,nz,1,1,device = device)

real_label =1.
fake_label =0.

optimizerD = optim.Adam(netD.parameters(),lr=lr,betas=(beta1,0.999))
optimizerG = optim.Adam(netG.parameters(),lr=lr,betas=(beta1,0.999))

2.训练模型

下面的训练代码是一个典型的GAN训练循环。在训练过程中,首先更新判别器网络,然后更新生成器网络。在每个epoch的每个batch中,会进行以下操作:

  • 更新判别器网络:通过训练真实图像样本和生成图像样本,最大化判别器的损失。具体步骤如下:

    • 对于真实图像样本,计算判别器对真实图像样本的输出和真实标签之间的损失,然后进行反向传播计算梯度。
    • 对于生成的图像样本,计算判别器对生成图像样本的输出和假标签之间的损失,然后进行反向传播计算梯度。
    • 将真实图像样本的损失和生成图像样本的损失相加得到判别器的总损失,并更新判别器的参数。
  • 更新生成器网络:通过最大化生成器的损失,迫使生成器产生更逼真的图像样本。具体步骤如下:

    • 使用生成器生成一批假图像样本。
    • 将生成图像样本输入判别器,计算判别器对生成图像样本的输出和真实标签之间的损失,并进行反向传播计算生成器的梯度。
    • 更新生成器的参数。
  • 输出训练统计信息:每隔一定的步数,输出当前训练的epoch、batch以及判别器和生成器的损失值等信息。

  • 保存损失值:将生成器和判别器的损失值存储到相应的列表中,以便后续绘图和分析。

  • 检查生成器的性能:每隔一定的步数或者在训练结束时,通过将固定的噪声输入生成器,生成一批图像样本,并保存到img_list列表中。这样可以观察生成器在训练过程中生成的图像质量的变化。

  • 更新迭代次数:每完成一个batch的训练,将迭代次数iters加1。

总体来说,这段代码实现了GAN的训练过程,通过交替更新判别器和生成器的参数,目标是使生成器生成逼真的图像样本,同时判别器能够准确区分真实图像样本和生成图像样本。

python 复制代码
img_list =[]
G_losses=[]
D_losses=[]
iters=0
print("start training")

for epoch in range(num_epochs):
  for i,data in enumerate(dataloader,0):
    netD.zero_grad()
    real_cpu = data[0].to(device)
    b_size = real_cpu.size(0)
    label = torch.full((b_size,),real_label,dtype=torch.float,device=device)

    output = netD(real_cpu).view(-1)
    errD_real = criterion(output,label)
    errD_real.backward()
    D_x = output.mean().item()

    #使用生成图像样本训练
    noise = torch.randn(b_size,nz,1,1,device=device)
    fake = netG(noise)
    label.fill_(fake_label)
    output = netD(fake.detach()).view(-1)
    errD_fake = criterion(output,label)
    errD_fake.backward()
    D_G_z1 = output.mean().item()
    errD = errD_real + errD_fake
    optimizerD.step()

    #更新生成器网络
    netG.zero_grad()
    label.fill_(real_label)
    output = netD(fake).view(-1)
    errG = criterion(output,label)
    errG.backward()
    D_G_z2 = output.mean().item()
    optimizerG.step()

    if i % 400 == 0:
      print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
              % (epoch, num_epochs, i, len(dataloader),errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
    
    G_losses.append(errG.item())
    D_losses.append(errD.item())

    if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
      with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()
      img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
    iters += 1

3.可视化

python 复制代码
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
python 复制代码
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())
python 复制代码
real_batch = next(iter(dataloader))
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
相关推荐
归去_来兮20 小时前
生成式对抗网络(GAN)模型原理概述
人工智能·深度学习·生成对抗网络
19894 天前
【零基础学AI】第30讲:生成对抗网络(GAN)实战 - 手写数字生成
人工智能·python·深度学习·神经网络·机器学习·生成对抗网络·近邻算法
盼小辉丶8 天前
PyTorch实战(14)——条件生成对抗网络(conditional GAN,cGAN)
人工智能·pytorch·生成对抗网络
白熊18823 天前
【深度学习】生成对抗网络(GANs)深度解析:从理论到实践的革命性生成模型
人工智能·深度学习·生成对抗网络
这张生成的图像能检测吗1 个月前
生成对抗网络(GANs)入门介绍指南:让AI学会“创造“的魔法(二)【深入版】
人工智能·pytorch·深度学习·神经网络·算法·生成对抗网络·计算机视觉
ONEYAC唯样1 个月前
英飞凌亮相SEMICON China 2025:以SiC、GaN技术引领低碳化与数字化未来
人工智能·神经网络·生成对抗网络
啊哈哈哈哈哈啊哈哈1 个月前
G1周打卡——GAN入门
pytorch·深度学习·生成对抗网络
ICscholar1 个月前
生成对抗网络(GAN)损失函数解读
人工智能·机器学习·生成对抗网络
QQ676580081 个月前
基于 TensorFlow 2 的 WGAN来生成表格数据、数值数据和序列数据。 WGAN生成对抗网络。代码仅供参考
生成对抗网络·tensorflow·neo4j·表格数据·wgan·对抗网络·序列数据