DL:生成对抗网络的基本原理与 PyTorch 实现

生成对抗网络(Generative Adversarial Network,GAN)是深度学习中非常重要的一类生成模型。与分类模型、回归模型不同,GAN 的目标不是根据输入判断类别,也不是预测一个连续数值,而是学习真实数据的分布,并生成看起来像真实数据的新样本。

例如:

• 生成一张手写数字图片

• 生成一张看起来真实的人脸图像

• 修复图像缺失区域

• 提升图像分辨率

• 把一种图像风格转换为另一种风格

• 根据条件信息生成指定类型的样本

GAN 的核心思想可以概括为:让两个神经网络相互竞争,一个负责生成假样本,另一个负责判断样本真假。通过这种对抗过程,生成器逐渐学会生成越来越接近真实数据的样本。

一、为什么需要生成对抗网络

图 1:从判别模型到生成模型

在很多深度学习任务中,我们训练的是判别模型(Discriminative Model)。判别模型的目标是根据输入判断结果。

例如:

• 输入图像 → 判断是猫还是狗

• 输入评论 → 判断是正面还是负面

• 输入房屋信息 → 预测房价

这类模型关注的是:给定输入 x,预测目标 y。

可以写成:

其中:

• x 表示输入数据

• y 表示目标标签

• p(y|x) 表示在给定 x 的条件下,y 出现的概率

但是,生成模型(Generative Model)关注的是另一个问题:数据本身是如何产生的?

它希望学习真实数据的分布,并从这个分布中生成新样本。

可以简单写为:

其中:

• x 表示数据样本

• p(x) 表示数据样本出现的概率分布

例如,如果模型学习的是手写数字图像分布,那么它应该能够生成新的手写数字图片;如果模型学习的是人脸图像分布,那么它应该能够生成新的人脸图像。

GAN 的特别之处在于:它不直接写出一个明确的数据分布公式,而是通过两个网络的对抗训练,让生成器逐渐逼近真实数据分布。

可以简单理解为:

• 判别模型:学习如何判断

• 生成模型:学习如何创造

GAN 通过"生成---辨别"的对抗过程学习生成。

二、GAN 的基本结构

GAN 通常由两个神经网络组成:

• 生成器

• 判别器

生成器(Generator)负责"生成假样本",判别器(Discriminator)负责"判断真假"。二者在训练过程中相互竞争、共同变化。

图 2:GAN 的基本结构

1、生成器:从随机噪声生成样本

生成器的输入通常是一个随机噪声向量 z。这个 z 可以来自正态分布或均匀分布。

生成器把 z 映射为一个假样本:

其中:

• z 表示随机噪声向量

• G 表示生成器

• G(z) 表示生成器输出的假样本

• x̃ 表示生成样本

如果任务是生成手写数字图像,那么 G(z) 就是一张模型生成的手写数字图片。

生成器的目标是:让生成样本尽可能像真实样本,使判别器难以分辨真假。

2、判别器:判断样本是真是假

判别器接收一个样本 x,并输出它是真实样本的概率:

其中:

• D 表示判别器

• x 表示输入样本

• D(x) 表示判别器认为 x 来自真实数据的概率

如果 D(x) 接近 1,表示判别器认为样本很可能是真实样本。

如果 D(x) 接近 0,表示判别器认为样本很可能是生成器伪造的样本。

判别器的目标是:尽可能把真实样本判断为真,把生成样本判断为假。

3、生成器与判别器的对抗关系

GAN 的训练过程类似一个"生成者"和"鉴别者"的博弈:

• 生成器 G:尽量生成更逼真的假样本

• 判别器 D:尽量分辨真实样本和生成样本

随着训练进行:

• 判别器会越来越擅长识别真假

• 生成器会根据判别器反馈不断改进

• 当生成器足够强时,判别器很难区分真假样本

理想情况下,生成器学到的数据分布会逐渐接近真实数据分布。

三、GAN 的对抗训练目标

GAN 的核心是对抗训练。它不是训练一个网络,而是同时训练生成器 G 和判别器 D。

判别器希望真实样本被判断为真,生成样本被判断为假;生成器则希望生成样本被判别器判断为真。

图 3:GAN 的对抗训练目标

1、判别器的目标

对于真实样本 x,判别器希望:

对于生成样本 G(z),判别器希望:

因此,判别器希望最大化:

其中:

• D(x) 表示判别器认为真实样本为真的概率

• D(G(z)) 表示判别器认为生成样本为真的概率

• log D(x) 鼓励真实样本被判断为真

• log(1 − D(G(z))) 鼓励生成样本被判断为假

从直观角度看,判别器在学习:

• 真实样本 → 1

• 生成样本 → 0

2、生成器的目标

生成器希望自己的输出 G(z) 被判别器判断为真,也就是希望:

在原始 GAN 目标中,生成器试图最小化:

但在实际训练中,常用非饱和形式,让生成器最大化:

等价地,可以最小化:

其中:

• G(z) 表示生成器生成的假样本

• D(G(z)) 表示判别器认为该假样本为真的概率

• −log D(G(z)) 越小,说明生成器越容易骗过判别器

这种写法在训练早期通常能提供更强的梯度信号。

3、GAN 的极小极大目标

原始 GAN 的总体目标可以写为:

其中:

• G 表示生成器

• D 表示判别器

• p_data(x) 表示真实数据分布

• p_z(z) 表示噪声分布

• x ∼ p_data(x) 表示真实样本来自真实数据分布

• z ∼ p_z(z) 表示噪声来自预设噪声分布

• E 表示期望

这个目标的含义是:

• 判别器 D 尽量最大化真假区分能力

• 生成器 G 尽量最小化判别器对生成样本的识别能力

这也是 GAN 名称中"对抗"的来源。

四、GAN 的训练过程

GAN 的训练通常不是一次性同时更新两个网络,而是交替更新判别器和生成器。

一个典型训练流程如下:

  1. 从真实数据集中取一批真实样本

  2. 从噪声分布中采样一批随机向量

  3. 生成器根据噪声生成一批假样本

  4. 用真实样本和假样本训练判别器

  5. 再采样一批噪声,生成假样本

  6. 固定判别器,用判别器反馈训练生成器

  7. 重复多轮训练

读取中... 读取中...

图 4:GAN 的训练闭环

1、训练判别器

训练判别器时,需要同时使用真实样本和生成样本。

真实样本的标签设为 1:

go 复制代码
真实样本 → 标签 1

生成样本的标签设为 0:

go 复制代码
生成样本 → 标签 0

判别器损失可以写为:

其中:

• L_D 表示判别器损失

• m 表示批量大小

• xᵢ 表示第 i 个真实样本

• zᵢ 表示第 i 个噪声向量

• G(zᵢ) 表示第 i 个生成样本

• D(xᵢ) 表示判别器认为真实样本为真的概率

• D(G(zᵢ)) 表示判别器认为生成样本为真的概率

训练判别器时,生成器通常不更新。

在 PyTorch 中,常用 .detach() 阻断生成样本到生成器的梯度传播:

ini 复制代码
fake_images = generator(z).detach()

这样判别器训练时只更新判别器参数,不会更新生成器参数。

2、训练生成器

训练生成器时,生成器希望判别器把生成样本判断为真。

生成器损失常写为:

其中:

• L_G 表示生成器损失

• zᵢ 表示第 i 个噪声向量

• G(zᵢ) 表示生成器生成的假样本

• D(G(zᵢ)) 表示判别器认为假样本为真的概率

训练生成器时,判别器参与前向计算,但判别器参数不更新;它主要为生成器提供梯度信号,告诉生成器如何调整输出,使生成样本更容易被判别为真。

从直观角度看:

• 判别器训练:提高辨别真假能力

• 生成器训练:提高欺骗判别器能力

这两个过程交替进行,就形成了 GAN 的对抗训练。

五、GAN 为什么能生成数据

GAN 能生成数据的关键,在于生成器不是直接复制训练样本,而是学习把随机噪声映射到数据空间。

图 5:从噪声空间到数据空间的映射

可以把生成器理解为一个函数:

其中:

• z 表示低维随机噪声

• x̃ 表示生成样本

• G 表示从噪声空间到数据空间的映射

训练开始时,G(z) 通常像随机噪声,没有明显结构。随着训练进行,判别器不断指出生成样本与真实样本之间的差异,生成器则通过梯度更新逐渐修正自己的输出。

在理想情况下:

其中:

• p_g(x) 表示生成器学到的生成分布

• p_data(x) 表示真实数据分布

• ≈ 表示两者逐渐接近

此时,从噪声 z 中采样,再输入生成器,就可以得到看起来像真实数据的新样本。

六、GAN 的主要问题

GAN 的思想非常优雅,但训练并不容易。相比普通分类网络,GAN 更容易出现不稳定现象。

图 6:GAN 的主要问题:训练不稳定与模式崩塌

1、训练不稳定

GAN 中有两个网络同时博弈。如果判别器太强,生成器可能得不到有效梯度;如果生成器变化太快,判别器又可能跟不上。

这会导致训练过程震荡,很难像普通监督学习那样稳定下降。

2、模式崩塌

模式崩塌(Mode Collapse)是 GAN 中非常经典的问题。它指的是生成器只学会生成少数几种样本,而没有覆盖真实数据分布中的多样性。

例如,在手写数字生成任务中,生成器可能只生成类似数字 1 或 7 的图像,而很少生成其他数字。

从直观角度看:真实数据有很多种模式,生成器只学会了其中少数模式。

这会导致生成结果看似逼真,但多样性不足。

3、评价困难

分类模型可以用准确率、精确率、召回率等指标评价;回归模型可以用 MSE、MAE、R² 等指标评价。

但生成模型的评价更复杂,因为我们不仅关心生成样本是否清晰,还关心:

• 是否真实

• 是否多样

• 是否覆盖真实数据分布

• 是否与条件输入一致

• 是否具有语义合理性

因此,GAN 的评价通常比普通监督学习任务更困难。

4、对超参数敏感

GAN 对学习率、网络结构、优化器、批量大小、归一化方法等都比较敏感。不同设置可能导致训练效果差异很大。

常见改进方法包括:

• 使用更稳定的损失函数

• 使用归一化技巧

• 调整生成器和判别器的更新频率

• 使用梯度惩罚

• 使用更合理的网络结构

七、PyTorch 实现:使用 GAN 生成手写数字

下面使用 PyTorch 构建一个简单 GAN,用于生成 MNIST 风格的手写数字图像。

图 7:GAN 生成手写数字的训练与输出流程

为了突出 GAN 的基本训练流程,这里使用全连接网络实现生成器和判别器。真实图像生成任务中,通常会使用卷积结构,例如 DCGAN。

1、导入库

python 复制代码
# 导入 PyTorch 核心模块import torchimport torch.nn as nn             # 神经网络层和损失函数import torch.optim as optim       # 优化器
import matplotlib.pyplot as plt   # 可视化生成图像
from torch.utils.data import DataLoader          # 批量数据加载from torchvision import datasets, transforms    # 标准数据集和图像预处理

这里使用:

• DataLoader 按批量加载数据

• torchvision.datasets 加载 MNIST 数据集

• torchvision.transforms 进行图像预处理

2、设置超参数

ini 复制代码
# GAN 超参数设置latent_dim = 100          # 噪声向量维度(生成器输入)image_size = 28 * 28      # MNIST 图像展平后的像素数(28x28=784)batch_size = 128          # 每批处理的样本数num_epochs = 20           # 训练轮数learning_rate = 0.0002    # Adam优化器学习率(常见于GAN训练)

MNIST 图像大小为 28 × 28,因此展平后大小为:

3、准备 MNIST 数据集

makefile 复制代码
# 图像预处理:将图像转为张量并标准化到 [-1, 1] 范围(因为 tanh 输出在 -1 到 1)transform = transforms.Compose([    transforms.ToTensor(),                     # PIL/NumPy (H,W) → (1,28,28),值域 [0,1]    transforms.Normalize((0.5,), (0.5,))       # 标准化: (x - 0.5) / 0.5 → 值域 [-1,1]])
# 加载 MNIST 训练集(60000张手写数字)train_dataset = datasets.MNIST(    root="./data",    train=True,    download=True,    transform=transform)
# 数据加载器:批量加载、打乱顺序train_loader = DataLoader(    train_dataset,    batch_size=batch_size,    shuffle=True)

这里将图像标准化到大致 −1 到 1 的范围。后面生成器最后使用 Tanh(),输出范围也是 −1 到 1,这样输入输出尺度更匹配。

4、定义生成器

生成器接收随机噪声 z,输出一张展平后的图像。

ruby 复制代码
# 生成器:将随机噪声向量转换为伪造图像(784维像素值)class Generator(nn.Module):    def __init__(self, latent_dim, image_size):        super().__init__()        # 全连接网络:噪声向量 → 逐层升维 → 最终输出图像像素(值域 -1 到 1)        self.net = nn.Sequential(            nn.Linear(latent_dim, 256),   # 100 → 256            nn.ReLU(),            nn.Linear(256, 512),          # 256 → 512            nn.ReLU(),            nn.Linear(512, image_size),   # 512 → 784            nn.Tanh()                     # 输出范围 (-1, 1),匹配标准化后的真实图像        )
    def forward(self, z):        return self.net(z)

生成器结构可以概括为:

随机噪声 z → 全连接层 → ReLU → 全连接层 → ReLU → 全连接层 → Tanh → 生成图像

其中:

• 输入是长度为 latent_dim 的随机噪声

• 输出是长度为 784 的向量

• Tanh 使输出范围接近 −1 到 1

• 输出向量可以 reshape 为 1 × 28 × 28 的图像

5、定义判别器

判别器接收一张图像,并输出它是真实图像的概率。

ruby 复制代码
# 判别器:接收图像(784维),输出该图像为真实图像的概率class Discriminator(nn.Module):    def __init__(self, image_size):        super().__init__()        # 全连接网络:逐层降维,最终输出一个概率(0~1)        self.net = nn.Sequential(            nn.Linear(image_size, 512),      # 784 → 512            nn.LeakyReLU(0.2),               # LeakyReLU 负斜率0.2,避免梯度饱和            nn.Linear(512, 256),             # 512 → 256            nn.LeakyReLU(0.2),            nn.Linear(256, 1),               # 256 → 1            nn.Sigmoid()                     # 压缩到 (0,1) 表示真实概率        )
    def forward(self, x):        return self.net(x)

判别器结构可以概括为:

图像向量 → 全连接层 → LeakyReLU → 全连接层 → LeakyReLU → 全连接层 → Sigmoid → 真假概率

其中:

• 输入是长度为 784 的图像向量

• 输出是 0 到 1 之间的概率

• 越接近 1,表示越像真实图像

• 越接近 0,表示越像生成图像

这里为了便于初学者理解,判别器最后显式使用 Sigmoid(),损失函数使用 BCELoss()。在更稳定的工程写法中,也可以让判别器输出 logits,并使用 BCEWithLogitsLoss()。

6、创建模型、损失函数和优化器

makefile 复制代码
# 选择训练设备(GPU优先)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 实例化生成器和判别器,并移动到设备generator = Generator(latent_dim, image_size).to(device)discriminator = Discriminator(image_size).to(device)
# 损失函数:二分类交叉熵(适合判别器输出0/1概率)criterion = nn.BCELoss()
# 生成器优化器:Adam,学习率0.0002,beta1=0.5(GAN常用,避免震荡)optimizer_G = optim.Adam(    generator.parameters(),    lr=learning_rate,    betas=(0.5, 0.999))
# 判别器优化器:相同配置optimizer_D = optim.Adam(    discriminator.parameters(),    lr=learning_rate,    betas=(0.5, 0.999))

其中:

• generator 表示生成器

• discriminator 表示判别器

• BCELoss 表示二元交叉熵损失

• optimizer_G 用于更新生成器

• optimizer_D 用于更新判别器

• betas=(0.5, 0.999) 是 GAN 中常见的 Adam 参数设置

7、训练 GAN

GAN 的训练通常分为两步:

• 先训练判别器

• 再训练生成器

训练代码如下:

css 复制代码
# 训练循环for epoch in range(num_epochs):    for real_images, _ in train_loader:        batch_size_current = real_images.size(0)
        # 将真实图像展平为一维向量(batch, 784)并移至设备        real_images = real_images.view(batch_size_current, -1).to(device)
        # 定义标签:真实图像标签为1,生成图像标签为0        real_labels = torch.ones(batch_size_current, 1).to(device)        fake_labels = torch.zeros(batch_size_current, 1).to(device)
        # =========================        # 1. 训练判别器(最大化 log D(real) + log(1-D(fake)))        # =========================        # 生成随机噪声向量        z = torch.randn(batch_size_current, latent_dim).to(device)        fake_images = generator(z)                     # 生成假图像
        # 判别器对真实图像和假图像的预测        real_outputs = discriminator(real_images)        fake_outputs = discriminator(fake_images.detach())  # detach阻断梯度回传至生成器
        loss_real = criterion(real_outputs, real_labels)  # 真实图像损失        loss_fake = criterion(fake_outputs, fake_labels)  # 假图像损失        loss_D = loss_real + loss_fake                    # 判别器总损失
        optimizer_D.zero_grad()        loss_D.backward()        optimizer_D.step()
        # =========================        # 2. 训练生成器(最大化 log D(fake))        # =========================        z = torch.randn(batch_size_current, latent_dim).to(device)        fake_images = generator(z)        outputs = discriminator(fake_images)               # 判别器对假图像输出
        loss_G = criterion(outputs, real_labels)           # 生成器试图让假图像被判别为真
        optimizer_G.zero_grad()        loss_G.backward()        optimizer_G.step()
    # 每个epoch结束打印损失    print(        f"Epoch [{epoch + 1}/{num_epochs}], "        f"Loss_D: {loss_D.item():.4f}, "        f"Loss_G: {loss_G.item():.4f}"    )

这段代码体现了 GAN 的核心训练闭环。

训练判别器时:

• 真实图像希望被判为 1

• 生成图像希望被判为 0

• 使用 fake_images.detach() 避免更新生成器

训练生成器时:

• 生成器希望生成图像被判别器判为 1

• 判别器参与前向计算,但目标是更新生成器参数

• 生成器通过判别器反馈改进生成图像

8、生成并查看图像

训练完成后,可以从随机噪声生成图像:

apache 复制代码
import matplotlib.pyplot as plt
# 切换生成器到评估模式(关闭Dropout/BatchNorm等训练行为)generator.eval()
# 禁用梯度计算,节省内存with torch.no_grad():    # 生成16个随机噪声向量    z = torch.randn(16, latent_dim).to(device)    # 生成假图像(形状: 16, 784)    fake_images = generator(z)    # 重塑为图像格式:16张,1通道,28x28像素    fake_images = fake_images.view(-1, 1, 28, 28)    # 将像素范围从 [-1,1] 还原到 [0,1](便于matplotlib显示)    fake_images = (fake_images + 1) / 2
# 创建4x4子图网格fig, axes = plt.subplots(4, 4, figsize=(6, 6))
# 遍历子图,显示生成的图像for i, ax in enumerate(axes.flat):    # 移除通道维度(单通道灰度图),转换为numpy,显示灰度图像    ax.imshow(fake_images[i].cpu().squeeze(), cmap="gray")    ax.axis("off")      # 隐藏坐标轴
plt.show()              # 展示生成的图像

其中:

• z 是随机噪声

• generator(z) 生成图像向量

• view(-1, 1, 28, 28) 把向量还原为图像形状

• (fake_images + 1) / 2 把图像从 −1 到 1 转回 0 到 1

八、GAN 的适用场景、局限与扩展方向

GAN 是生成式深度学习的重要代表模型之一。它在图像生成、图像编辑、风格迁移等任务中具有重要影响。

图 8:GAN 的适用场景、局限与扩展方向

1、适用场景

GAN 的常见应用包括:

• 图像生成

• 图像修复

• 图像超分辨率

• 图像风格迁移

• 数据增强

• 图像到图像转换

• 人脸生成与编辑

例如,超分辨率任务可以利用 GAN 生成更清晰、更自然的细节;图像到图像转换任务可以把草图转换为真实图像,或把白天场景转换为夜晚场景。

2、主要优势

GAN 的主要优势包括:

• 生成样本通常较清晰

• 能学习复杂数据分布

• 不需要显式写出数据分布公式

• 适合图像生成和图像编辑任务

• 对抗训练思想具有很强启发性

GAN 的重要价值不仅在于某一个具体模型,也在于它提出了一种新的训练范式:通过两个网络的竞争推动生成能力提升。

3、主要局限

GAN 的主要局限包括:

• 训练不稳定

• 容易出现模式崩塌

• 评价指标不如监督学习直观

• 对超参数和网络结构敏感

• 训练过程需要平衡生成器和判别器

• 在复杂任务中调试成本较高

这些问题使 GAN 的训练通常比普通分类模型更困难。

4、扩展方向

从基础 GAN 出发,可以继续学习以下模型:

• DCGAN:使用卷积结构改进图像生成

• CGAN:加入条件信息控制生成结果

• WGAN:改进训练稳定性

• WGAN-GP:加入梯度惩罚,进一步稳定训练

• CycleGAN:用于无配对图像到图像转换

• StyleGAN:高质量人脸与图像生成的重要代表

• Pix2Pix:用于有配对图像到图像转换

近年来,扩散模型(Diffusion Model)在许多生成任务中表现非常突出,但 GAN 仍然是理解生成建模和对抗训练思想的重要基础。

📘 小结

生成对抗网络通过生成器和判别器的对抗训练学习数据分布。生成器从随机噪声生成样本,判别器判断样本真假,二者交替优化,使生成结果逐渐接近真实数据。GAN 在图像生成和图像编辑中影响深远,但也存在训练不稳定、模式崩塌和评价困难等问题。

"点赞有美意,赞赏是鼓励"

相关推荐
霍格沃兹测试学院-小舟畅学2 小时前
高质量测试 Skill 编写手册 -- 渐进式披露
人工智能
韦胖漫谈IT2 小时前
数据与模型投毒 - 大语言模型 OWASP TOP 10系列
人工智能·语言模型·自然语言处理
wuxinyan1232 小时前
工业级大模型学习之路025:问题解决-检索质量全为0
人工智能·python·学习·langchain
weixin_408099672 小时前
2026 图片高清化 API 实战:AI超分辨率重建技术详解 + Python/Java/PHP/C#代码示例
图像处理·人工智能·python·超分辨率重建·石榴智能·图片变清晰·图片高清化api
@蔓蔓喜欢你2 小时前
WebRTC 实时通信:构建音视频通话应用
人工智能·ai
上海全爱科技2 小时前
全爱科技诚邀莅临 | 2026 高等教育博览会 携摩尔线程 GPU + 昇腾 NPU全栈 AI 解决方案,共启科教数智新征程
人工智能·科技
song5012 小时前
多模态模型在昇腾上的部署架构
人工智能·分布式·深度学习·架构·transformer·交互
韦胖漫谈IT2 小时前
敏感信息泄露 - 大语言模型 OWASP TOP 10系列
人工智能·安全·语言模型·自然语言处理
财经资讯数据_灵砚智能2 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(日间)2026年5月24日
大数据·人工智能·python·信息可视化·自然语言处理