【笔记】对抗训练-GAN

对抗训练-GAN


深度学习中 GAN 的对抗目标函数详解与最优解推导

生成对抗网络(GAN)是深度生成模型中的经典方法,其核心思想是两个网络之间的博弈:生成器 G G G 试图"伪造"样本,而判别器 D D D 尽力分辨真伪。本篇博客将从 GAN 的基本目标函数出发,逐步推导出判别器的最优形式,并分析其背后的数学含义。


一、GAN 的基本对抗目标函数

GAN 的原始目标是一个 min-max 游戏

min ⁡ G max ⁡ D ( E x ∼ P r [ log ⁡ D ( x ) ] + E z ∼ P z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] ) \min_G \max_D \left( \mathbb{E}{x \sim P_r}[\log D(x)] + \mathbb{E}{z \sim P_z}[\log(1 - D(G(z)))] \right) GminDmax(Ex∼Pr[logD(x)]+Ez∼Pz[log(1−D(G(z)))])

其中:

  • P r ( x ) P_r(x) Pr(x) 表示真实数据的分布;
  • P z ( z ) P_z(z) Pz(z) 是先验噪声分布(如高斯);
  • G ( z ) G(z) G(z) 是生成器生成的假样本;
  • D ( x ) D(x) D(x) 是判别器输出 x x x 为真实样本的概率。

二、判别器与生成器的博弈目标

  • 判别器 D 的目标 :让 D ( x ) D(x) D(x) 趋近于 1, D ( G ( z ) ) D(G(z)) D(G(z)) 趋近于 0,即正确分辨真实与生成样本。

    对应目标函数为最大化:

    E x ∼ P r [ log ⁡ D ( x ) ] + E z ∼ P z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \mathbb{E}{x \sim P_r}[\log D(x)] + \mathbb{E}{z \sim P_z}[\log(1 - D(G(z)))] Ex∼Pr[logD(x)]+Ez∼Pz[log(1−D(G(z)))]

  • 生成器 G 的目标 :生成样本让 D ( G ( z ) ) D(G(z)) D(G(z)) 尽量大,即"骗过"判别器。

    对应目标函数为最小化:

    E z ∼ P z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \mathbb{E}_{z \sim P_z}[\log(1 - D(G(z)))] Ez∼Pz[log(1−D(G(z)))]

这是一个典型的零和对抗过程。


三、判别器的最优解推导

我们接下来推导:在固定生成器 G G G 的前提下,判别器 D D D 的最优形式是怎样的?

令目标函数为:

V ( D ) = ∫ x P r ( x ) log ⁡ D ( x ) + P g ( x ) log ⁡ ( 1 − D ( x ) )   d x V(D) = \int_x P_r(x) \log D(x) + P_g(x) \log(1 - D(x)) \, dx V(D)=∫xPr(x)logD(x)+Pg(x)log(1−D(x))dx

对每个 x x x,令:

f ( D ( x ) ) = P r ( x ) log ⁡ D ( x ) + P g ( x ) log ⁡ ( 1 − D ( x ) ) f(D(x)) = P_r(x) \log D(x) + P_g(x) \log(1 - D(x)) f(D(x))=Pr(x)logD(x)+Pg(x)log(1−D(x))

对 D ( x ) D(x) D(x) 求导并令导数为 0:

d f d D ( x ) = P r ( x ) D ( x ) − P g ( x ) 1 − D ( x ) = 0 \frac{d f}{d D(x)} = \frac{P_r(x)}{D(x)} - \frac{P_g(x)}{1 - D(x)} = 0 dD(x)df=D(x)Pr(x)−1−D(x)Pg(x)=0

解得最优判别器为:

D ∗ ( x ) = P r ( x ) P r ( x ) + P g ( x ) D^*(x) = \frac{P_r(x)}{P_r(x) + P_g(x)} D∗(x)=Pr(x)+Pg(x)Pr(x)


四、最优判别器的含义

  1. D ∗ ( x ) D^*(x) D∗(x) 的输出值反映了 样本 x x x 来自真实分布的概率

    • 如果 P r ( x ) = P g ( x ) P_r(x) = P_g(x) Pr(x)=Pg(x),则 D ∗ ( x ) = 1 2 D^*(x) = \frac{1}{2} D∗(x)=21;
    • 如果 P r ( x ) ≫ P g ( x ) P_r(x) \gg P_g(x) Pr(x)≫Pg(x),则 D ∗ ( x ) ≈ 1 D^*(x) \approx 1 D∗(x)≈1;
    • 如果 P g ( x ) ≫ P r ( x ) P_g(x) \gg P_r(x) Pg(x)≫Pr(x),则 D ∗ ( x ) ≈ 0 D^*(x) \approx 0 D∗(x)≈0。
  2. 将 D ∗ D^* D∗ 代入 GAN 原始目标函数:

    V ( D ∗ ) = E x ∼ P r [ log ⁡ D ∗ ( x ) ] + E x ∼ P g [ log ⁡ ( 1 − D ∗ ( x ) ) ] V(D^*) = \mathbb{E}{x \sim P_r}[\log D^*(x)] + \mathbb{E}{x \sim P_g}[\log(1 - D^*(x))] V(D∗)=Ex∼Pr[logD∗(x)]+Ex∼Pg[log(1−D∗(x))]

    可推导出最终目标:

    min ⁡ G V ( D ∗ ) = − log ⁡ 4 + 2 ⋅ JS ( P r ∥ P g ) \min_G V(D^*) = -\log 4 + 2 \cdot \text{JS}(P_r \parallel P_g) GminV(D∗)=−log4+2⋅JS(Pr∥Pg)

    即:GAN 实质上是在最小化真实分布 P r P_r Pr 与生成分布 P g P_g Pg 之间的 Jensen-Shannon 散度


五、总结

内容 含义
D ∗ ( x ) = P r ( x ) P r ( x ) + P g ( x ) D^*(x) = \frac{P_r(x)}{P_r(x) + P_g(x)} D∗(x)=Pr(x)+Pg(x)Pr(x) 判别器在每个样本点处的最优输出
GAN 的优化目标 最小化 JS 散度
最优时的结果 当 P r = P g P_r = P_g Pr=Pg 时,GAN 达到最优, D ( x ) = 0.5 D(x)=0.5 D(x)=0.5,分不出真假

六、WGAN 的动机(为后续铺垫)

由于 Jensen-Shannon 散度在 P r P_r Pr 与 P g P_g Pg 没有交集时不连续(导致梯度消失),Wasserstein GAN(WGAN)改用 Wasserstein 距离替代 JS 散度,并要求判别器满足 1-Lipschitz 条件,这会在后续单独展开讲解。

相关推荐
美味的大香蕉3 小时前
Spark-SQL
笔记
踢足球的程序员·4 小时前
OpenGL学习笔记(立方体贴图、高级数据、高级GLSL)
笔记·学习·图形渲染
WarPigs4 小时前
VRoid-Blender-Unity个人工作流笔记
笔记·blender
米小葱5 小时前
【图解】系统设计学习笔记
笔记·学习
pumpkin845145 小时前
学习笔记十二——Rust 高阶函数彻底入门(超详细过程解析 + 每步数值追踪)
笔记·学习·rust
OKay_J6 小时前
蓝桥杯备赛笔记(嵌入式)
笔记·stm32·学习·蓝桥杯
牧木江7 小时前
【从C到C++的算法竞赛迁移指南】第二篇:动态数组与字符串完全攻略 —— 写给C程序员的全新世界
c语言·c++·经验分享·笔记·算法
初九之潜龙勿用7 小时前
技术与情感交织的一生 (六)
笔记
stanleyrain8 小时前
VIM学习笔记
笔记·学习·vim
美味的大香蕉15 小时前
Spark SQL
笔记