5 分钟内构建一个简单的基于 Python 的 GAN

文章目录

一、说明

生成对抗网络(GAN)因其能力而在学术界引起轩然大波。机器能够创作出新颖、富有灵感的作品,这让每个人都感到敬畏和恐惧。因此,人们开始好奇,如何构建一个这样的网络?

生成对抗网络 (GAN) 是一种深度学习模型,可生成与某些输入数据相似的新合成数据。GAN 由两个神经网络组成:生成器和鉴别器。生成器经过训练可生成与输入数据相同的合成数据,而鉴别器经过训练可区分合成数据和真实数据。

生成模型学习输入数据 f (x)的内在分布函数,使其能够生成合成输入x'和输出y',通常给定一些隐藏参数。GAN 的优势在于它们能够生成最清晰的图像,并且易于训练。

二、代码

此代码会训练 GAN 一定数量的周期,其中周期定义为对整个数据集的一次遍历。在每个周期中,代码会迭代数据加载器(应该是包装数据集的 PyTorch DataLoader 对象)中的数据,并在每个批次上训练鉴别器和生成器。

生成器的训练方式是试图欺骗鉴别器,而鉴别器则被训练来区分真实图像和假图像。这里使用的损失函数是二元交叉熵损失,这是 GAN 的常见选择。使用的优化器是 Adam,它是一种随机梯度下降优化器。

首先,导入必要的库并定义生成器和鉴别器模型。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim

生成器应该是一个神经网络,它接受随机噪声向量并生成合成数据。同时,鉴别器应该是一个神经网络,它接受真实数据或合成数据并输出输入数据为真实的概率。

类 生成器(nn.Module):

python 复制代码
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        return x
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x
  1. 在下面的代码块中,我们设置了 GAN 的环境。这包括:

设置鉴别器和生成器网络的输入层、隐藏层和输出层的大小。

创建 Generator 和 Discriminator 类的实例

设置损失函数和优化器

python 复制代码
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set the input and output sizes
input_size = 784
hidden_size = 256
output_size = 1

# Create the discriminator and generator
discriminator = Discriminator(input_size, hidden_size, output_size).to(device)
generator = Generator(input_size, hidden_size, output_size).to(device)

# Set the loss function and optimizers
loss_fn = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)

# Set the number of epochs and the noise size
num_epochs = 200
noise_size = 100

# Training loop
for epoch in range(num_epochs):
  for i, (real_images, _) in enumerate(dataloader):
    # Get the batch size
    batch_size = real_images.size(0)

三、训练

  1. 在下面的代码中,生成器通过尝试欺骗鉴别器来训练,而鉴别器经过训练可以区分真假图像。为此,

我们给生成器一批噪声样本作为输入,并生成一批假图像。然后这些假图像通过鉴别器,鉴别器对批次中的每幅图像产生预测。

然后计算生成器的损失,代码通过生成器反向传播损失,并使用 Adam 优化器优化生成器的参数。此过程会以减少损失和提高生成器欺骗鉴别器的能力的方向更新生成器的参数。

python 复制代码
 # Generate fake images
  noise = torch.randn(batch_size, noise_size).to(device)
  fake_images = generator(noise)
  
  # Train the discriminator on real and fake images
  d_real = discriminator(real_images)
  d_fake = discriminator(fake_images)
  
  # Calculate the loss
  real_loss = loss_fn(d_real, torch.ones_like(d_real))
  fake_loss = loss_fn(d_fake, torch.zeros_like(d_fake))
  d_loss = real_loss + fake_loss
  
  # Backpropagate and optimize
  d_optimizer.zero_grad()
  d_loss.backward()
  d_optimizer.step()
  
  # Train the generator
  d_fake = discriminator(fake_images)
  g_loss = loss_fn(d_fake, torch.ones_like(d_fake))
  
  # Backpropagate and optimize
  g_optimizer.zero_grad()
  g_loss.backward()
  g_optimizer.step()
  
  # Print the loss every 50 batches
  if (i+1) % 50 == 0:
    print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}' 
          .format(epoch+1, num_epochs, i+1, len(dataloader), d_loss.item(), g_loss.item()))

就这样......一个可以快速使用的 GAN 模型就完成了。

四、后记

关于成对抗网络(GAN)由两部分组成:

  • 生成器学习生成可信的数据。生成的实例将成为鉴别器的反面训练示例。
  • 鉴别器学会区分生成器的虚假数据和真实数据。鉴别器会惩罚产生不合理结果的生成器。
    当训练开始时,生成器会生成明显是假的数据,而鉴别器很快就能分辨出这是假的。
    更多的阐述将在本系列文章中展现。
相关推荐
q567315233 分钟前
Python/Django 服务器升级脚本
服务器·开发语言·python·游戏·django
hakesashou13 分钟前
python怎么样将一段程序无效掉
python
Michael Lee.20 分钟前
Python学习篇:Python基础知识(三)
开发语言·python·学习·pycharm
zhangbin_23720 分钟前
【Python机器学习】处理文本数据——将文本数据表示为词袋
人工智能·python·算法·机器学习·分类
martian66522 分钟前
学懂C#编程:属性(Property)的概念定义及使用详解
java·开发语言·c#·属性·property
孑渡24 分钟前
【LeetCode】每日一题:跳跃游戏
python·算法·leetcode·游戏·职场和发展
逸群不凡31 分钟前
C++11|lambda语法与使用
开发语言·c++
dc爱傲雪和技术36 分钟前
卡尔曼滤波Q和R怎么调
python·算法·r语言
DieSnowK39 分钟前
[C++][CMake][CMake基础]详细讲解
开发语言·c++·makefile·make·cmake·新手向·详细讲解
时间瑾42 分钟前
线程池实践篇
java·开发语言