【生成对抗网络】

生成对抗网络 (Generative Adversarial Networks,简称GANs)是深度学习领域的一种创新结构,由Ian Goodfellow在2014年首次提出。GANs包括两个深度神经网络------一个生成器和一个判别器,它们通常以对抗的方式进行训练。

以下是GANs的基本概念:
生成器 (Generator): 生成器的任务是从随机噪声中生成尽可能真实的数据。换句话说,生成器尝试产生与真实数据集非常相似的新数据。
判别器 (Discriminator): 判别器的任务是区分其接收到的数据是来自真实数据集还是生成器。换句话说,它尝试判断输入数据是真实的还是假的。
训练过程 : 在GANs的训练过程中,生成器和判别器在一个对抗游戏中进行竞争。生成器努力产生越来越真实的数据,以欺骗判别器,而判别器努力变得越来越好地识别真实与生成的数据。这种对抗过程导致生成器产生高质量的数据。
损失函数 : GANs通常使用两个损失函数来评估生成器和判别器的性能。判别器的损失函数评估其正确区分真实数据和生成数据的能力,而生成器的损失函数评估其生成数据欺骗判别器的能力。

GANs在许多应用中都非常有用,如图像生成、图像超分辨率、风格迁移、数据增强和更多其他应用。然而,它们也有一些挑战,例如训练不稳定性、模式崩溃等问题,但随着研究的深入,已经发展了许多变种和技术来解决这些问题。

使用Pytorch在MNIST数据集上写了一个最简单的生成对抗网络,整个训练过程的目的是使生成器和判别器在一个动态的竞技环境中进化。判别器的目标是尽可能地区分真实图像和假图像,而生成器的目标是尽可能地生成能够骗过判别器的图像。
训练判别器 :它在真实图像上进行训练,试图将其标记为真实的,同时在假图像上进行训练,试图将其标记为假的。
训练生成器 :它尝试生成图像,这些图像应当被判别器标记为真实的。

这种竞争性的环境确保了两者都在进步。当GAN训练得当时,生成器应当能够生成非常逼真的图像,而判别器则变得非常擅长于区分真假。

最后,值得注意的是,在训练过程中,GAN可能会遇到一些问题,如模式崩溃,其中生成器可能会始终生成相同的图像。有许多技术和变体,

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd.variable import Variable

# 定义超参数
batch_size = 128
learning_rate = 0.0002
epochs = 10
latent_dim = 100

# 数据加载器
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size, shuffle=True)

# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), 784)
        return self.model(x)

# 生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

D = Discriminator()
G = Generator()

criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=learning_rate)
g_optimizer = optim.Adam(G.parameters(), lr=learning_rate)

for epoch in range(epochs):
    for batch_idx, (real_images, _) in enumerate(dataloader):
        batch_size = real_images.size(0)
        
        # 训练判别器
        D.zero_grad()
        
        labels = torch.ones(batch_size, 1)
        outputs = D(real_images)
        real_loss = criterion(outputs, labels)
        real_loss.backward()

        z = torch.randn(batch_size, latent_dim)
        fake_images = G(z)
        labels = torch.zeros(batch_size, 1)
        outputs = D(fake_images.detach())
        fake_loss = criterion(outputs, labels)
        fake_loss.backward()
        
        d_optimizer.step()

        # 训练生成器
        G.zero_grad()

        labels

针对MNIST数据集,生成对抗网络(GAN)中的生成器和判别器具有特定的目标和功能。
生成器:

  • 目标:生成器的目标是创建与MNIST数据集中的真实手写数字图像尽可能相似的图像。
  • 功能:它接受一个随机噪声向量(通常是从某种分布,如正态分布中抽取的)作为输入,并将其映射到一个图像空间,生成一个手写数字图像。
  • 训练过程中的角色 :在GAN的训练过程中,生成器试图不断改进其生成的图像,使其更难以被判别器区分为假的。
    判别器:
  • 目标 :判别器的目标是区分接收到的图像是真实的MNIST手写数字图像还是生成器生成的假图像。
    功能 :它接受一个图像作为输入,并输出一个概率值,表示该图像是真实的MNIST图像的概率。
    训练过程中的角色 :在GAN的训练过程中,判别器不断地被训练,以更好地区分真实图像和生成图像。
    总之,针对MNIST数据集,生成器尝试生成与真实MNIST手写数字尽可能相似的图像,而判别器的任务是判断图像是否是真实的MNIST手写数字图像。在这个"游戏"中,生成器和判别器相互对抗,随着训练的进行,两者都变得更加熟练,直到生成器生成的图像与真实图像几乎无法区分。
相关推荐
珠海新立电子科技有限公司15 分钟前
FPC柔性线路板与智能生活的融合
人工智能·生活·制造
IT古董29 分钟前
【机器学习】机器学习中用到的高等数学知识-8. 图论 (Graph Theory)
人工智能·机器学习·图论
曼城周杰伦39 分钟前
自然语言处理:第六十三章 阿里Qwen2 & 2.5系列
人工智能·阿里云·语言模型·自然语言处理·chatgpt·nlp·gpt-3
余炜yw1 小时前
【LSTM实战】跨越千年,赋诗成文:用LSTM重现唐诗的韵律与情感
人工智能·rnn·深度学习
莫叫石榴姐2 小时前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
如若1232 小时前
利用 `OpenCV` 和 `Matplotlib` 库进行图像读取、颜色空间转换、掩膜创建、颜色替换
人工智能·opencv·matplotlib
YRr YRr2 小时前
深度学习:神经网络中的损失函数的使用
人工智能·深度学习·神经网络
ChaseDreamRunner2 小时前
迁移学习理论与应用
人工智能·机器学习·迁移学习
Guofu_Liao2 小时前
大语言模型---梯度的简单介绍;梯度的定义;梯度计算的方法
人工智能·语言模型·矩阵·llama
我爱学Python!2 小时前
大语言模型与图结构的融合: 推荐系统中的新兴范式
人工智能·语言模型·自然语言处理·langchain·llm·大语言模型·推荐系统