生成对抗网络 (GAN) 详解

目录

  1. 概述
  2. 生成模型的发展
  3. GAN的核心思想
  4. GAN的数学原理
  5. GAN的训练过程详解
  6. GAN训练中的问题与解决方案
  7. GAN的变体
  8. 条件GAN
  9. 完整代码实现
  10. 参考资料

1. 概述

1.1 什么是生成对抗网络

生成对抗网络(Generative Adversarial Network, GAN)是Ian Goodfellow等人于2014年提出的一种生成模型框架。它的核心思想是通过两个神经网络的对抗训练来学习数据分布,从而生成逼真的数据样本。

GAN的提出被认为是深度学习领域最重要的突破之一。Yann LeCun(深度学习三巨头之一)曾经评价GAN是"过去十年机器学习领域最有趣的想法"。

1.2 为什么需要生成模型

在机器学习中,我们通常关注判别模型和生成模型两类:

判别模型学习条件概率P(Y|X),即给定输入X预测输出Y。例如,给定一张图片,判断它是猫还是狗。判别模型关注的是"决策边界",即如何区分不同类别的数据。

生成模型学习数据的联合概率P(X,Y)或边缘概率P(X)。生成模型关注的是"数据分布",即数据是如何生成的。它能够生成新的、与训练数据相似的样本。

生成模型的重要性体现在多个方面:

数据增强:当训练数据不足时,生成模型可以生成合成数据来扩充训练集。这在医学影像、罕见事件等领域尤为重要。

无监督学习:生成模型可以在没有标签的情况下学习数据的结构和模式,这对于理解数据的本质非常重要。

密度估计:生成模型可以估计数据的概率密度,这对于异常检测、数据压缩等任务很有用。

创意应用:生成模型可以用于艺术创作、图像编辑、风格迁移等创意应用。

1.3 GAN的核心组成

GAN由两个核心组件组成:

生成器(Generator):接收一个随机噪声向量作为输入,输出一个与训练数据相似的样本。生成器的目标是"欺骗"判别器,让它无法区分生成的样本和真实样本。

生成器可以看作一个从潜在空间到数据空间的映射函数G: Z → X,其中Z是潜在空间(通常是高斯分布),X是数据空间。

判别器(Discriminator):接收一个样本(可能是真实的也可能是生成的)作为输入,输出该样本是真实样本的概率。判别器的目标是正确区分真实样本和生成样本。

判别器可以看作一个二分类器D: X → [0,1],输出值越接近1表示样本越可能是真实的。

1.4 GAN的影响

GAN的提出开启了生成模型的新时代,催生了大量的变体和应用:

图像生成:从低分辨率到高分辨率(1024×1024甚至更高),从随机生成到条件生成。

图像编辑:风格迁移、图像修复、超分辨率、图像到图像翻译。

数据增强:为训练数据不足的任务生成合成数据。

人脸生成:生成逼真的人脸图像,甚至可以控制属性(年龄、表情、性别等)。

视频生成:从静态图像生成视频,预测未来帧。


2. 生成模型的发展

2.1 显式密度模型

显式密度模型明确地定义数据的概率密度函数P(X),然后通过最大化似然来学习模型参数。

自回归模型:将数据的联合概率分解为条件概率的乘积:

P ( X ) = ∏ i = 1 n P ( x i ∣ x 1 , . . . , x i − 1 ) P(X) = \prod_{i=1}^{n} P(x_i | x_1, ..., x_{i-1}) P(X)=i=1∏nP(xi∣x1,...,xi−1)

代表模型包括PixelCNN、WaveNet、GPT等。优点是可以精确计算似然,缺点是生成速度慢,因为需要逐维度生成。

变分自编码器(VAE):假设数据是由潜在变量z生成的,通过编码器将数据映射到潜在空间,再通过解码器从潜在空间重建数据。VAE的优化目标是证据下界(ELBO):

L = E q ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] − D K L ( q ( z ∣ x ) ∥ p ( z ) ) \mathcal{L} = \mathbb{E}{q(z|x)}[\log p(x|z)] - D{KL}(q(z|x) \| p(z)) L=Eq(z∣x)[logp(x∣z)]−DKL(q(z∣x)∥p(z))

优点是有清晰的概率框架,缺点是生成的图像通常比较模糊。

归一化流(Flow):通过一系列可逆变换将简单分布映射到复杂分布。如果变换是可逆的,我们可以精确计算变换后的概率密度。优点是可以精确计算似然和采样,缺点是模型架构受限。

2.2 隐式密度模型

GAN属于隐式密度模型,它不显式地定义数据的概率密度函数,而是通过采样来隐式地表示数据分布。

隐式密度模型的优势在于模型更加灵活,可以生成非常复杂的分布,不需要对数据分布做任何假设。

2.3 GAN的优势

与传统生成模型相比,GAN具有以下优势:

生成质量高:GAN生成的样本通常比VAE更清晰、更逼真。这是因为GAN通过对抗训练直接优化生成质量,而VAE优化的是重建损失。

生成速度快:一旦训练完成,生成只需要一次前向传播,不需要像自回归模型那样逐维度生成。

灵活性强:可以生成各种类型的数据(图像、音频、文本等),不需要对数据分布做任何假设。

无需显式建模:不需要定义数据的概率密度函数,避免了复杂的数学推导。


3. GAN的核心思想

3.1 对抗训练的思想

GAN的核心思想来自于博弈论中的零和博弈。在GAN的框架中,生成器和判别器是两个玩家,它们进行一场博弈:

生成器的目标是最小化判别器的准确率,即让判别器无法区分真假样本。

判别器的目标是最大化自己的准确率,即正确区分真假样本。

这场博弈的目标是达到纳什均衡,此时生成器生成的样本与真实样本无法区分,判别器对任何样本的输出概率都是0.5。

3.2 类比:造假者与鉴定师

一个经典的类比是造假者和鉴定师的关系:

造假者(生成器)不断改进伪造技术,试图制作出以假乱真的作品。

鉴定师(判别器)不断提高鉴别能力,试图发现作品的真伪。

随着双方不断进步,造假者最终能够制作出鉴定师也无法识别的赝品。在GAN中,这意味着生成器已经学会了真实数据的分布。

3.3 博弈论视角

从博弈论的角度,GAN可以看作一个极小极大博弈(minimax game):

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a [ log ⁡ D ( x ) ] + E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}{x \sim p{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Ex∼pdata[logD(x)]+Ez∼pz[log(1−D(G(z)))]

其中:

  • G是生成器
  • D是判别器
  • p d a t a p_{data} pdata 是真实数据分布
  • p z p_z pz 是噪声分布(通常是高斯分布)
  • V是价值函数

价值函数的直观理解

第一项 E x ∼ p d a t a [ log ⁡ D ( x ) ] \mathbb{E}{x \sim p{data}}[\log D(x)] Ex∼pdata[logD(x)] 表示判别器对真实样本的判断能力。D(x)越接近1,这项越大。

第二项 E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] Ez∼pz[log(1−D(G(z)))] 表示判别器对生成样本的判断能力。D(G(z))越接近0,这项越大。

判别器试图最大化V,即正确区分真假样本。

生成器试图最小化V,即让判别器无法区分真假样本。

3.4 纳什均衡

在理想的纳什均衡点:

  • 生成器生成的数据分布与真实数据分布完全一致: p g = p d a t a p_g = p_{data} pg=pdata
  • 判别器对任何输入都输出0.5: D ( x ) = 1 2 D(x) = \frac{1}{2} D(x)=21

此时,生成器无法进一步改进,判别器也无法进一步改进,达到平衡状态。


4. GAN的数学原理

4.1 目标函数详解

GAN的目标函数是一个极小极大博弈:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a [ log ⁡ D ( x ) ] + E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}{x \sim p{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Ex∼pdata[logD(x)]+Ez∼pz[log(1−D(G(z)))]

这个目标函数可以分解为两部分:

判别器的损失
L D = − E x ∼ p d a t a [ log ⁡ D ( x ) ] − E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \mathcal{L}D = -\mathbb{E}{x \sim p_{data}}[\log D(x)] - \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] LD=−Ex∼pdata[logD(x)]−Ez∼pz[log(1−D(G(z)))]

判别器试图最小化这个损失,即最大化对真实样本的预测概率,最小化对生成样本的预测概率。

生成器的损失
L G = E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \mathcal{L}G = \mathbb{E}{z \sim p_z}[\log(1 - D(G(z)))] LG=Ez∼pz[log(1−D(G(z)))]

生成器试图最小化这个损失,即让判别器对生成样本的预测概率最大化。

4.2 最优判别器的推导

对于固定的生成器G,最优判别器D*为:

D ∗ ( x ) = p d a t a ( x ) p d a t a ( x ) + p g ( x ) D^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)} D∗(x)=pdata(x)+pg(x)pdata(x)

详细推导过程

判别器的目标是最大化:
V ( D , G ) = ∫ x p d a t a ( x ) log ⁡ D ( x ) d x + ∫ z p z ( z ) log ⁡ ( 1 − D ( G ( z ) ) ) d z V(D, G) = \int_x p_{data}(x) \log D(x) dx + \int_z p_z(z) \log(1 - D(G(z))) dz V(D,G)=∫xpdata(x)logD(x)dx+∫zpz(z)log(1−D(G(z)))dz

由于 G ( z ) G(z) G(z) 的分布是 p g p_g pg,我们可以将第二项改写为:
V ( D , G ) = ∫ x p d a t a ( x ) log ⁡ D ( x ) d x + ∫ x p g ( x ) log ⁡ ( 1 − D ( x ) ) d x V(D, G) = \int_x p_{data}(x) \log D(x) dx + \int_x p_g(x) \log(1 - D(x)) dx V(D,G)=∫xpdata(x)logD(x)dx+∫xpg(x)log(1−D(x))dx

= ∫ x [ p d a t a ( x ) log ⁡ D ( x ) + p g ( x ) log ⁡ ( 1 − D ( x ) ) ] d x = \int_x \left[ p_{data}(x) \log D(x) + p_g(x) \log(1 - D(x)) \right] dx =∫x[pdata(x)logD(x)+pg(x)log(1−D(x))]dx

对于固定的x,要最大化 f ( D ) = a log ⁡ D + b log ⁡ ( 1 − D ) f(D) = a \log D + b \log(1 - D) f(D)=alogD+blog(1−D),其中 a = p d a t a ( x ) a = p_{data}(x) a=pdata(x), b = p g ( x ) b = p_g(x) b=pg(x)。

令导数为0:
d f d D = a D − b 1 − D = 0 \frac{df}{dD} = \frac{a}{D} - \frac{b}{1 - D} = 0 dDdf=Da−1−Db=0

解得:
D ∗ = a a + b = p d a t a ( x ) p d a t a ( x ) + p g ( x ) D^* = \frac{a}{a + b} = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)} D∗=a+ba=pdata(x)+pg(x)pdata(x)

4.3 最优生成器的推导

将最优判别器代入目标函数,得到:

C ( G ) = max ⁡ D V ( D , G ) = E x ∼ p d a t a [ log ⁡ p d a t a ( x ) p d a t a ( x ) + p g ( x ) ] + E x ∼ p g [ log ⁡ p g ( x ) p d a t a ( x ) + p g ( x ) ] C(G) = \max_D V(D, G) = \mathbb{E}{x \sim p{data}}\left[\log \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}\right] + \mathbb{E}{x \sim p_g}\left[\log \frac{p_g(x)}{p{data}(x) + p_g(x)}\right] C(G)=DmaxV(D,G)=Ex∼pdata[logpdata(x)+pg(x)pdata(x)]+Ex∼pg[logpdata(x)+pg(x)pg(x)]

= ∫ x p d a t a ( x ) log ⁡ p d a t a ( x ) p d a t a ( x ) + p g ( x ) d x + ∫ x p g ( x ) log ⁡ p g ( x ) p d a t a ( x ) + p g ( x ) d x = \int_x p_{data}(x) \log \frac{p_{data}(x)}{p_{data}(x) + p_g(x)} dx + \int_x p_g(x) \log \frac{p_g(x)}{p_{data}(x) + p_g(x)} dx =∫xpdata(x)logpdata(x)+pg(x)pdata(x)dx+∫xpg(x)logpdata(x)+pg(x)pg(x)dx

= − log ⁡ 4 + ∫ x p d a t a ( x ) log ⁡ p d a t a ( x ) p d a t a ( x ) + p g ( x ) 2 d x + ∫ x p g ( x ) log ⁡ p g ( x ) p d a t a ( x ) + p g ( x ) 2 d x = -\log 4 + \int_x p_{data}(x) \log \frac{p_{data}(x)}{\frac{p_{data}(x) + p_g(x)}{2}} dx + \int_x p_g(x) \log \frac{p_g(x)}{\frac{p_{data}(x) + p_g(x)}{2}} dx =−log4+∫xpdata(x)log2pdata(x)+pg(x)pdata(x)dx+∫xpg(x)log2pdata(x)+pg(x)pg(x)dx

= − log ⁡ 4 + K L ( p d a t a ∥ p d a t a + p g 2 ) + K L ( p g ∥ p d a t a + p g 2 ) = -\log 4 + KL\left(p_{data} \| \frac{p_{data} + p_g}{2}\right) + KL\left(p_g \| \frac{p_{data} + p_g}{2}\right) =−log4+KL(pdata∥2pdata+pg)+KL(pg∥2pdata+pg)

= − log ⁡ 4 + 2 ⋅ J S D ( p d a t a ∥ p g ) = -\log 4 + 2 \cdot JSD(p_{data} \| p_g) =−log4+2⋅JSD(pdata∥pg)

其中JSD是Jensen-Shannon散度:

J S D ( P ∥ Q ) = 1 2 K L ( P ∥ M ) + 1 2 K L ( Q ∥ M ) JSD(P \| Q) = \frac{1}{2} KL(P \| M) + \frac{1}{2} KL(Q \| M) JSD(P∥Q)=21KL(P∥M)+21KL(Q∥M)

其中 M = P + Q 2 M = \frac{P + Q}{2} M=2P+Q,KL是Kullback-Leibler散度。

关键洞察

  • JSD总是非负的: J S D ( P ∥ Q ) ≥ 0 JSD(P \| Q) \geq 0 JSD(P∥Q)≥0
  • 当 p d a t a = p g p_{data} = p_g pdata=pg 时,JSD = 0,C(G) = -log4,达到全局最小值
  • 最小化C(G)等价于最小化JSD,即让生成分布接近真实分布

4.4 训练动态分析

在训练过程中,判别器和生成器交替优化:

判别器更新
θ D ← θ D + η ∇ θ D [ E x ∼ p d a t a [ log ⁡ D ( x ) ] + E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] ] \theta_D \leftarrow \theta_D + \eta \nabla_{\theta_D} \left[ \mathbb{E}{x \sim p{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] \right] θD←θD+η∇θD[Ex∼pdata[logD(x)]+Ez∼pz[log(1−D(G(z)))]]

生成器更新
θ G ← θ G − η ∇ θ G E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \theta_G \leftarrow \theta_G - \eta \nabla_{\theta_G} \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] θG←θG−η∇θGEz∼pz[log(1−D(G(z)))]

训练动态的直觉理解

在训练初期,生成器生成的样本质量很差,判别器很容易区分真假。此时判别器的损失很小,但生成器的梯度接近于0(因为log(1-D(G(z)))在D(G(z))接近0时梯度很小)。

随着训练进行,生成器逐渐改进,生成的样本越来越逼真。判别器需要更加仔细才能区分真假。

在理想情况下,最终生成器生成的样本与真实样本无法区分,判别器对任何样本都输出0.5。


5. GAN的训练过程详解

5.1 训练算法

GAN的训练采用交替优化的方式:

复制代码
for each training iteration:
    # 训练判别器 (k步)
    for k steps:
        采样真实数据 batch {x_1, ..., x_m}
        采样噪声 batch {z_1, ..., z_m}
        生成假数据 batch {G(z_1), ..., G(z_m)}
        
        更新判别器:
        θ_D ← θ_D + η∇_θD [Σ log D(x_i) + Σ log(1 - D(G(z_i)))]
    
    # 训练生成器 (1步)
    采样噪声 batch {z_1, ..., z_m}
    
    更新生成器:
    θ_G ← θ_G - η∇_θG Σ log(1 - D(G(z_i)))

5.2 为什么需要多次更新判别器

原始论文建议在训练生成器之前,先训练判别器k步(k=1是常用设置)。这是因为:

如果判别器太弱,它无法为生成器提供有用的梯度信号。生成器需要一个足够强的判别器来指导它的学习。

然而,如果判别器太强,生成器的梯度会接近于0,导致生成器无法学习。因此需要在判别器和生成器之间保持平衡。

5.3 实际训练代码

python 复制代码
def train_gan(generator, discriminator, dataloader, num_epochs):
    """
    完整的GAN训练函数
    
    Args:
        generator: 生成器模型
        discriminator: 判别器模型
        dataloader: 数据加载器
        num_epochs: 训练轮数
    """
    # 优化器 - 使用Adam,学习率较小,beta1=0.5
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    # 损失函数 - 二元交叉熵
    criterion = nn.BCELoss()

    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(dataloader):
            batch_size = real_images.size(0)

            # 创建标签 - 使用标签平滑
            real_labels = torch.ones(batch_size, 1) * 0.9  # 0.9而不是1.0
            fake_labels = torch.zeros(batch_size, 1)

            # ============ 训练判别器 ============
            # 真实图像
            real_outputs = discriminator(real_images)
            d_loss_real = criterion(real_outputs, real_labels)

            # 生成假图像
            z = torch.randn(batch_size, 100, 1, 1)
            fake_images = generator(z)
            fake_outputs = discriminator(fake_images.detach())  # detach阻止梯度传到生成器
            d_loss_fake = criterion(fake_outputs, fake_labels)

            # 判别器总损失
            d_loss = d_loss_real + d_loss_fake

            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()

            # ============ 训练生成器 ============
            fake_outputs = discriminator(fake_images)
            g_loss = criterion(fake_outputs, real_labels)  # 生成器希望判别器输出1

            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

6. GAN训练中的问题与解决方案

6.1 模式坍塌

问题描述:模式坍塌(Mode Collapse)是指生成器只生成少数几种样本,无法覆盖数据分布的所有模式。例如,在生成手写数字时,生成器可能只生成某几种数字。

产生原因:生成器找到了一种能够欺骗判别器的"捷径",就不断生成类似的样本,而不会探索数据分布的其他模式。

解决方案

  • 使用WGAN替代原始GAN
  • 使用小批量判别(Mini-batch Discrimination)
  • 使用多判别器
  • 使用Unrolled GAN

6.2 训练不稳定

问题描述:GAN的训练过程不稳定,容易出现震荡、发散等问题。

产生原因:生成器和判别器的训练相互影响,可能导致训练过程不稳定。

解决方案

  • 使用WGAN-GP
  • 调整学习率
  • 使用谱归一化(Spectral Normalization)
  • 使用渐进式训练

6.3 梯度消失

问题描述:当判别器太强时,生成器的梯度接近于0,导致生成器无法学习。

产生原因:当D(G(z))接近0时,log(1-D(G(z)))的梯度很小。

解决方案

  • 使用最大化log D(G(z))替代最小化log(1-D(G(z)))
  • 使用WGAN
  • 使用标签平滑

6.4 评估困难

问题描述:GAN没有像VAE那样明确的损失函数可以监控训练进度。

解决方案

  • 使用FID(Fréchet Inception Distance)
  • 使用IS(Inception Score)
  • 使用人工评估

7. GAN的变体

7.1 DCGAN

DCGAN(Deep Convolutional GAN)是第一个使用卷积神经网络的GAN,为GAN的架构设计提供了重要指导原则:

架构指导原则

  • 使用卷积层替代池化层(判别器用步长卷积下采样,生成器用转置卷积上采样)
  • 在生成器和判别器中使用批量归一化
  • 移除全连接层,使用全卷积网络
  • 生成器使用ReLU激活(输出层用Tanh)
  • 判别器使用LeakyReLU激活

这些原则至今仍被广泛使用。

7.2 WGAN

WGAN(Wasserstein GAN)使用Wasserstein距离替代JS散度,解决了原始GAN的训练不稳定问题。

Wasserstein距离 (也称为推土机距离):
W ( p d a t a , p g ) = inf ⁡ γ ∈ Π ( p d a t a , p g ) E ( x , y ) ∼ γ [ ∥ x − y ∥ ] W(p_{data}, p_g) = \inf_{\gamma \in \Pi(p_{data}, p_g)} \mathbb{E}_{(x,y) \sim \gamma}[\|x - y\|] W(pdata,pg)=γ∈Π(pdata,pg)infE(x,y)∼γ[∥x−y∥]

直觉上,Wasserstein距离衡量的是将一个分布"搬运"到另一个分布所需的最小"工作量"。

WGAN目标函数
min ⁡ G max ⁡ D ∈ D E x ∼ p d a t a [ D ( x ) ] − E z ∼ p z [ D ( G ( z ) ) ] \min_G \max_{D \in \mathcal{D}} \mathbb{E}{x \sim p{data}}[D(x)] - \mathbb{E}_{z \sim p_z}[D(G(z))] GminD∈DmaxEx∼pdata[D(x)]−Ez∼pz[D(G(z))]

其中D必须满足1-Lipschitz约束。

优势

  • Wasserstein距离即使在分布不重叠时也有意义
  • 训练更稳定,损失值可以反映生成质量
  • 缓解模式坍塌问题

7.3 WGAN-GP

WGAN-GP使用梯度惩罚替代权重裁剪来确保Lipschitz约束:

L = E x ∼ p d a t a [ D ( x ) ] − E z ∼ p z [ D ( G ( z ) ) ] + λ E x ^ ∼ p x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] \mathcal{L} = \mathbb{E}{x \sim p{data}}[D(x)] - \mathbb{E}{z \sim p_z}[D(G(z))] + \lambda \mathbb{E}{\hat{x} \sim p_{\hat{x}}}[(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2] L=Ex∼pdata[D(x)]−Ez∼pz[D(G(z))]+λEx^∼px^[(∥∇x^D(x^)∥2−1)2]

7.4 StyleGAN

StyleGAN是NVIDIA提出的高质量图像生成模型,其核心创新包括:

映射网络:将潜在向量映射到中间潜在空间W,解耦不同的属性。

自适应实例归一化(AdaIN):通过风格向量控制生成图像的风格。

噪声注入:在不同分辨率注入噪声,控制细节。

渐进式增长:从低分辨率逐步增长到高分辨率。


8. 条件GAN

8.1 条件GAN的原理

条件GAN(Conditional GAN, cGAN)在生成器和判别器中都引入了条件信息(如类别标签、文本描述等),使得生成过程可以被控制。

条件GAN的目标函数
min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a [ log ⁡ D ( x ∣ y ) ] + E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ∣ y ) ∣ y ) ) ] \min_G \max_D V(D, G) = \mathbb{E}{x \sim p{data}}[\log D(x|y)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z|y)|y))] GminDmaxV(D,G)=Ex∼pdata[logD(x∣y)]+Ez∼pz[log(1−D(G(z∣y)∣y))]

其中y是条件信息。

8.2 条件信息的注入方式

类别标签:通过嵌入层将类别标签转换为向量,然后与噪声向量拼接。

文本描述:使用预训练的文本编码器(如LSTM、BERT)将文本转换为向量。

图像:使用编码器提取图像特征作为条件。


9. 完整代码实现

9.1 可直接运行的GAN完整实现

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# ==================== 模型定义 ====================

class Generator(nn.Module):
    """
    生成器:将随机噪声映射为图像
    
    输入: [batch_size, latent_dim, 1, 1] 的随机噪声
    输出: [batch_size, channels, 32, 32] 的生成图像
    """
    def __init__(self, latent_dim=100, channels=1):
        super().__init__()
        
        self.main = nn.Sequential(
            # 第一层:latent_dim -> 256, 1x1 -> 4x4
            nn.ConvTranspose2d(latent_dim, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            # 第二层:256 -> 128, 4x4 -> 8x8
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            # 第三层:128 -> 64, 8x8 -> 16x16
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            # 第四层:64 -> channels, 16x16 -> 32x32
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)


class Discriminator(nn.Module):
    """
    判别器:判断图像是真实的还是生成的
    
    输入: [batch_size, channels, 32, 32] 的图像
    输出: [batch_size, 1] 的概率值(1表示真实,0表示生成)
    """
    def __init__(self, channels=1):
        super().__init__()
        
        self.main = nn.Sequential(
            # 第一层:channels -> 64, 32x32 -> 16x16
            nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 第二层:64 -> 128, 16x16 -> 8x8
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 第三层:128 -> 256, 8x8 -> 4x4
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 第四层:256 -> 1, 4x4 -> 1x1
            nn.Conv2d(256, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x).view(-1, 1)


# ==================== 训练函数 ====================

def train_gan():
    """
    完整的GAN训练函数
    可以直接运行,在MNIST数据集上训练GAN
    """
    # 超参数
    latent_dim = 100
    batch_size = 128
    num_epochs = 20
    lr = 0.0002
    beta1 = 0.5
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print(f"使用设备: {device}")
    
    # 数据预处理
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    # 加载MNIST数据集
    dataset = torchvision.datasets.MNIST(
        root='./data', 
        train=True, 
        download=True, 
        transform=transform
    )
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    # 创建模型
    generator = Generator(latent_dim, 1).to(device)
    discriminator = Discriminator(1).to(device)
    
    # 损失函数
    criterion = nn.BCELoss()
    
    # 优化器
    optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
    
    # 用于可视化的固定噪声
    fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)
    
    # 记录损失
    G_losses = []
    D_losses = []
    
    print("\n开始训练...")
    print("=" * 60)
    
    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(dataloader):
            batch_size = real_images.size(0)
            real_images = real_images.to(device)
            
            # 创建标签
            real_labels = torch.ones(batch_size, 1, device=device) * 0.9  # 标签平滑
            fake_labels = torch.zeros(batch_size, 1, device=device)
            
            # ============ 训练判别器 ============
            # 真实图像
            real_outputs = discriminator(real_images)
            d_loss_real = criterion(real_outputs, real_labels)
            
            # 生成假图像
            noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
            fake_images = generator(noise)
            fake_outputs = discriminator(fake_images.detach())
            d_loss_fake = criterion(fake_outputs, fake_labels)
            
            # 判别器总损失
            d_loss = d_loss_real + d_loss_fake
            
            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()
            
            # ============ 训练生成器 ============
            fake_outputs = discriminator(fake_images)
            g_loss = criterion(fake_outputs, real_labels)
            
            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()
            
            # 记录损失
            G_losses.append(g_loss.item())
            D_losses.append(d_loss.item())
            
            # 打印训练信息
            if i % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(dataloader)}], '
                      f'D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}')
        
        # 每个epoch结束后生成示例图像
        with torch.no_grad():
            fake_images = generator(fixed_noise).cpu()
            torchvision.utils.save_image(
                fake_images, 
                f'generated_epoch_{epoch+1}.png', 
                normalize=True, 
                nrow=8
            )
    
    print("=" * 60)
    print("训练完成!")
    
    # 绘制损失曲线
    plt.figure(figsize=(10, 5))
    plt.plot(G_losses, label='Generator Loss')
    plt.plot(D_losses, label='Discriminator Loss')
    plt.xlabel('Iterations')
    plt.ylabel('Loss')
    plt.title('GAN Training Loss')
    plt.legend()
    plt.savefig('training_loss.png')
    plt.show()
    
    return generator, discriminator


if __name__ == '__main__':
    generator, discriminator = train_gan()

10. 参考资料

核心论文

  1. GAN: Goodfellow et al., "Generative Adversarial Networks", 2014
  2. DCGAN: Radford et al., "Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks", 2015
  3. WGAN: Arjovsky et al., "Wasserstein GAN", 2017
  4. WGAN-GP: Gulrajani et al., "Improved Training of Wasserstein GANs", 2017
  5. StyleGAN: Karras et al., "A Style-Based Generator Architecture for Generative Adversarial Networks", 2019

开源库

相关推荐
哦哦~9211 小时前
机器学习在智能水泥基复合材料中的应用与实践
人工智能·机器学习·机器人
码农杂谈00071 小时前
医药行业GEA:企业级智能体系统如何开启医药学术运营新范式
大数据·人工智能
hh.h.1 小时前
昇腾 CANN cann-samples 仓:从 HelloWorld 到 ResNet50 推理
人工智能·cann·samples
三掌柜6661 小时前
OpenClaw 部署实战:智能体 Skills 破解长视频复用难题
人工智能
Maydaycxc1 小时前
RPA 稳定性深度剖析:元素定位失效、界面更新与 AI 增强实战
人工智能·机器人·rpa
QYR-分析1 小时前
深耕智慧物流赛道:交叉带分拣机器人行业全景解析
大数据·人工智能·机器人
君为先-bey1 小时前
LeMiCa——基于扩散的高效视频生成的词典序最小最大路径缓存
人工智能·深度学习·计算机视觉·扩散模型
Days20501 小时前
AI提示词管理器:解锁大模型高效应用的核心工具
大数据·人工智能
KaMeidebaby1 小时前
卡梅德生物技术快报|抗体的制备与纯化:分子实验实操:番茄 sHSP 重组表达与抗体的制备与纯化工艺
前端·数据库·人工智能·其他·算法·百度·新浪微博