Wasserstein GAN(WGAN)

Wasserstein GAN(WGAN)

    • [0. 前言](#0. 前言)
    • [1. GAN 面临的挑战](#1. GAN 面临的挑战)
    • [2. 距离函数](#2. 距离函数)
    • [3. GAN 中的距离函数](#3. GAN 中的距离函数)
    • [4. Wasserstein 损失函数](#4. Wasserstein 损失函数)
    • [5. 使用 Keras 实现 WGAN](#5. 使用 Keras 实现 WGAN)

0. 前言

生成对抗网络 (Generative Adversarial Network, GAN)能有效生成逼真新数据,是一种实用的生成模型。后续众多深度学习研究论文针对原始 GAN 的缺陷与局限提出了大量改进方案。我们知道,GAN 存在训练难度大、易发生模式崩溃等问题。模式崩溃是指生成器在损失函数已优化的情况下仍持续产生相同输出的现象。以MNIST手写数字数据集为例,发生模式崩溃时,由于数字 49 外形相似,生成器可能仅会输出这两类数字。Wasserstein GAN (WGAN) 通过采用 Wasserstein 距离替代原始损失函数,成功解决了训练稳定性与模式崩溃问题。

1. GAN 面临的挑战

GAN 中,判别器与生成器的目标相互对立,极易导致训练失稳:判别器致力于准确区分真实数据与生成数据,而生成器则竭力欺骗判别器。若判别器的学习速度领先于生成器,生成器参数将无法有效优化;反之,若判别器学习速度滞后,梯度在传递至生成器前就可能消失。最严重的情况是,当判别器无法收敛时,生成器将无法获得任何有效反馈。
WGAN 研究指出,GAN 固有的不稳定性源于其基于 JS (Jensen-Shannon) 散度的损失函数。在 GAN 中,生成器的目标是学习从源分布(如噪声)到估计目标分布(如 MNIST 手写数字)的映射转换。原始 GAN 的损失函数实际上是在最小化目标分布与其估计值之间的距离,但问题在于某些分布之间并不存在平滑路径来最小化 JS 散度,从而导致训练无法收敛。

接下来我们将探究三种距离函数,分析哪些函数能作为 JS 散度的替代方案,从而更适用于 GAN 的优化过程。

2. 距离函数

通过分析损失函数可以理解 GAN 训练的稳定性。为深入探究 GAN 的损失函数,我们将回顾两种概率分布之间常用的距离度量与散度函数。

我们关注的是真实数据分布 p d a t a p_{data} pdata 与生成器数据分布 p g p_g pg 之间的距离。GAN 的目标是实现 p g → p d a t a p_g\rightarrow p_{data} pg→pdata。下表展示了常用的散度函数。

散度函数 公式
KL 散度 D K L ( p d a t a ∣ ∣ p g ) E x ∼ p d a t a l o g p d a t a ( x ) p g ( x ) ≠ D K L ( p g ∣ ∣ p d a t a ) E x ∼ p g l o g p g ( x ) p d a t a ( x ) D_{KL}(p_{data}||p_g)\mathbb E_{x\sim p_{data}}log\frac{p_{data}(x)}{p_g(x)}\neq D_{KL}(p_{g}||p_{data})\mathbb E_{x\sim p_{g}}log\frac{p_{g}(x)}{p_{data}(x)} DKL(pdata∣∣pg)Ex∼pdatalogpg(x)pdata(x)=DKL(pg∣∣pdata)Ex∼pglogpdata(x)pg(x)
JS 散度 D J S ( p d a t a ∣ ∣ p g ) = 1 2 E x ∼ p d a t a l o g p d a t a ( x ) p d a t a ( x ) + p g ( x ) 2 + 1 2 E x ∼ p g l o g p g ( x ) p d a t a ( x ) + p g ( x ) 2 = D J S ( p g ∣ ∣ p d a t a ) D_{JS}(p_{data}||p_g)=\frac 12\mathbb E_{x\sim p_{data}}log\frac {p_{data}(x)}{\frac{p_{data}(x)+p_g(x)}{2}}+\frac 12\mathbb E_{x\sim p_{g}}log\frac {p_{g}(x)}{\frac{p_{data}(x)+p_g(x)}{2}}=D_{JS}(p_g||p_{data}) DJS(pdata∣∣pg)=21Ex∼pdatalog2pdata(x)+pg(x)pdata(x)+21Ex∼pglog2pdata(x)+pg(x)pg(x)=DJS(pg∣∣pdata)
Wasserstein 距离 W ( p d a t a , p g ) = i n f γ ∈ ∏ ( p d a t a , p g ) E ( x , y ) ∼ γ [ ∣ ∣ x − y ∣ ∣ ] W(p_{data}, p_g)=\underset {\gamma\in\prod(p_{data},p_g)}{inf}\mathbb E_{(x,y)\sim \gamma}[||x-y||] W(pdata,pg)=γ∈∏(pdata,pg)infE(x,y)∼γ[∣∣x−y∣∣]

Wasserstein 距离的核心思想是:为了将概率分布 p d a t a p_{data} pdata 匹配至概率分布 p g p_g pg,需要度量在距离 d = ∣ ∣ x − y ∣ ∣ d=||x-y|| d=∣∣x−y∣∣ 下质量 γ ( x , y ) \gamma(x,y) γ(x,y) 的运输总量。其中 γ ( x , y ) \gamma(x,y) γ(x,y) 是所有可能联合分布空间 γ ∈ ∏ ( p d a t a , p g ) \gamma\in\prod(p_{data},p_g) γ∈∏(pdata,pg) 中的一个联合分布,也被称为传输方案------其反映了通过质量转移使两个概率分布相匹配的策略。给定两个概率分布存在多种可能的传输方案,而符号 i n f inf inf 表示的是具有最小成本的传输方案。

3. GAN 中的距离函数

首先回顾 GAN 一节中的公式:
L ( D ) = − E x ∼ p d a t a l o g D ( x ) − E z l o g ( 1 − D ( G ( z ) ) ) \mathcal L^{(D)}=-\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_zlog(1-D(G(z))) L(D)=−Ex∼pdatalogD(x)−Ezlog(1−D(G(z)))

若从生成器分布中采样,前述公式可表示为:
L ( D ) = − E x ∼ p d a t a l o g D ( x ) − E x ∼ p g l o g ( 1 − D ( x ) ) \mathcal L^{(D)}=-\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_{x\sim p_g}log(1-D(x)) L(D)=−Ex∼pdatalogD(x)−Ex∼pglog(1−D(x))

为求 L ( 𝐷 ) \mathcal L(𝐷) L(D) 的最小值:
L ( D ) = − ∫ x p d a t a ( x ) l o g D ( x ) d x − ∫ x p g ( x ) l o g ( 1 − D ( x ) ) d x L ( D ) = − ∫ x p d a t a ( x ) l o g D ( x ) + p g ( x ) l o g ( 1 − D ( x ) ) d x \mathcal L^{(D)}= -\int xp{data}(x)logD(x)dx-\int _xp_g(x)log(1-D(x))dx\\ \mathcal L^{(D)}= -\int xp{data}(x)logD(x)+p_g(x)log(1-D(x))dx L(D)=−∫xpdata(x)logD(x)dx−∫xpg(x)log(1−D(x))dxL(D)=−∫xpdata(x)logD(x)+pg(x)log(1−D(x))dx

被积函数形如 y → a l o g y + b l o g ( 1 − y ) y\rightarrow alogy + blog(1-y) y→alogy+blog(1−y),对于任意 a , b ∈ R 2 a,b \in\mathbb R^2 a,b∈R2(不包括 0 , 0 {0,0} 0,0),该函数在 y ∈ [ 0 , 1 ] y\in[0,1] y∈[0,1] 区间内取得最大值的点为 y = a a + b y=\frac a {a+b} y=a+ba。由于积分运算不改变该表达式最大值(即 L ( D ) \mathcal L^{(D)} L(D) 最小值)的位置,可得最优判别器为:
D ∗ ( x ) = p d a t a ( p d a t a + p g ) D^*(x)=\frac {p_{data}}{(p_{data}+p_g)} D∗(x)=(pdata+pg)pdata

因此,给定最优判别器时的损失函数为:
L ( D ∗ ) = − E x ∼ p d a t a l o g p d a t a ( p d a t a + p g ) − E x ∼ p g l o g ( 1 − p d a t a ( p d a t a + p g ) ) L ( D ∗ ) = − E x ∼ p d a t a l o g p d a t a ( p d a t a + p g ) − E x ∼ p g l o g ( p g ( p d a t a + p g ) ) L ( D ∗ ) = 2 l o g 2 − D K L ( p d a t a ∣ ∣ p d a t a + p g 2 ) − D K L ( p g ∣ ∣ p d a t a + p g 2 ) L ( D ∗ ) = 2 l o g 2 − 2 D J S ( p d a t a ∣ ∣ p g ) \mathcal L^{(D^*)}=-\mathbb E_{x\sim p_{data}}log{\frac {p_{data}}{(p_{data}+p_g)}}-\mathbb E_{x\sim p_g}log(1-{\frac {p_{data}}{(p_{data}+p_g)}}) \\ \mathcal L^{(D^*)}=-\mathbb E_{x\sim p_{data}}log{\frac {p_{data}}{(p_{data}+p_g)}}-\mathbb E_{x\sim p_g}log({\frac {p_{g}}{(p_{data}+p_g)}})\\ \mathcal L^{(D^*)}=2log2-D_{KL}(p_{data}||\frac {p_{data}+p_g} {2})-D_{KL}(p_{g}||\frac {p_{data}+p_g} {2})\\ \mathcal L^{(D^*)}=2log2-2D_{JS}(p_{data}||p_g )\\ L(D∗)=−Ex∼pdatalog(pdata+pg)pdata−Ex∼pglog(1−(pdata+pg)pdata)L(D∗)=−Ex∼pdatalog(pdata+pg)pdata−Ex∼pglog((pdata+pg)pg)L(D∗)=2log2−DKL(pdata∣∣2pdata+pg)−DKL(pg∣∣2pdata+pg)L(D∗)=2log2−2DJS(pdata∣∣pg)

从上式可以观察到,最优判别器的损失函数实际上是一个常数减去真实分布 p d a t a p_{data} pdata 与生成器分布 p g p_g pg 之间 JS 散度的两倍。最小化 L ( D ∗ ) \mathcal{L}(D^*) L(D∗) 意味着需要最大化 D J S ( p d a t a ∥ p g ) D_{JS}(p_{data} \| p_g) DJS(pdata∥pg),即要求判别器必须准确区分真实数据与生成数据。

同时我们可以确证,当生成器分布等于真实数据分布时,才能得到最优生成器:
G ∗ ( x ) → p g = p d a t a G^*(x) \rightarrow p_g = p_{data} G∗(x)→pg=pdata

这一结论符合直观理解:生成器的目标正是通过学习真实数据分布来欺骗判别器。实际上,我们可以通过最小化 JS 散度或使 p g → p d a t a p_g \rightarrow p_{data} pg→pdata 来获得最优生成器。当生成器最优时,最优判别器为 D ∗ ( x ) = 1 2 \mathcal{D}^*(x) = \frac{1}{2} D∗(x)=21,此时 L ( D ∗ ) = 2 log ⁡ 2 ≈ 0.60 \mathcal{L}(D^*) = 2\log 2 \approx 0.60 L(D∗)=2log2≈0.60。

问题在于,当两个分布没有重叠时,不存在平滑函数能够弥合其间的差距。此时基于梯度下降的GAN训练将无法收敛。例如假设:
p d a t a = ( x , y ) 其中 x = 0 , y ∼ U ( 0 , 1 ) p g = ( x , y ) 其中 x = θ , y ∼ U ( 0 , 1 ) p_{data} = (x,y)\ \ \ \ \ \text{其中 } x=0,\ y \sim U(0,1)\\ p_g = (x,y)\ \ \ \ \ \text{其中 } x=\theta,\ y \sim U(0,1) pdata=(x,y) 其中 x=0, y∼U(0,1)pg=(x,y) 其中 x=θ, y∼U(0,1)

其中 U ( 0 , 1 ) U(0,1) U(0,1) 表示均匀分布。各距离函数对应的散度值如下:

  • D K L ( p d a t a ∥ p g ) = E x = 0 , y ∼ U ( 0 , 1 ) log ⁡ p d a t a ( x , y ) p g ( x , y ) = ∑ 1 log ⁡ 1 0 = + ∞ D_{KL}(p_{data} \| p_g) = \mathbb{E}{x=0,y \sim U(0,1)} \log \frac{p{data}(x,y)}{p_g(x,y)} = \sum 1 \log \frac{1}{0} = +\infty DKL(pdata∥pg)=Ex=0,y∼U(0,1)logpg(x,y)pdata(x,y)=∑1log01=+∞
  • D K L ( p g ∥ p d a t a ) = E x = θ , y ∼ U ( 0 , 1 ) log ⁡ p g ( x , y ) p d a t a ( x , y ) = ∑ 1 log ⁡ 1 0 = + ∞ D_{KL}(p_g \| p_{data}) = \mathbb{E}{x=\theta,y \sim U(0,1)} \log \frac{p_g(x,y)}{p{data}(x,y)} = \sum 1 \log \frac{1}{0} = +\infty DKL(pg∥pdata)=Ex=θ,y∼U(0,1)logpdata(x,y)pg(x,y)=∑1log01=+∞
  • D J S ( p d a t a ∥ p g ) = 1 2 E x = 0 , y ∼ U ( 0 , 1 ) log ⁡ p d a t a ( x , y ) p d a t a ( x , y ) + p g ( x , y ) 2 + 1 2 E x = θ , y ∼ U ( 0 , 1 ) log ⁡ p g ( x , y ) p d a t a ( x , y ) + p g ( x , y ) 2 = 1 2 ∑ 1 log ⁡ 1 1 2 + 1 2 ∑ 1 log ⁡ 1 1 2 = log ⁡ 2 D_{JS}(p_{data} \| p_g) = \frac{1}{2} \mathbb{E}{x=0,y \sim U(0,1)} \log \frac{p{data}(x,y)}{\frac{p_{data}(x,y)+p_g(x,y)}{2}} + \frac{1}{2} \mathbb{E}{x=\theta,y \sim U(0,1)} \log \frac{p_g(x,y)}{\frac{p{data}(x,y)+p_g(x,y)}{2}} = \frac{1}{2} \sum 1 \log \frac{1}{\frac{1}{2}} + \frac{1}{2} \sum 1 \log \frac{1}{\frac{1}{2}} = \log 2 DJS(pdata∥pg)=21Ex=0,y∼U(0,1)log2pdata(x,y)+pg(x,y)pdata(x,y)+21Ex=θ,y∼U(0,1)log2pdata(x,y)+pg(x,y)pg(x,y)=21∑1log211+21∑1log211=log2
  • W ( p d a t a , p g ) = ∣ θ ∣ W(p_{data}, p_g) = |\theta| W(pdata,pg)=∣θ∣

由于 JS 散度为常数,生成对抗网络将缺乏足够的梯度推动 p g → p d a t a p_g \rightarrow p_{data} pg→pdata。同时可发现 KL 散度及其反向形式均无法提供有效优化梯度。然而,通过 Wasserstein 距离 W ( p d a t a , p g ) W(p_{data}, p_g) W(pdata,pg),我们可以获得平滑函数,从而利用梯度下降实现 p g → p d a t a p_g \rightarrow p_{data} pg→pdata。当两分布重叠度极低时,JS 散度会失效,而 Wasserstein 距离则能成为更合理的 GAN 损失函数。

4. Wasserstein 损失函数

在使用 Wasserstein 距离之前,还有一个问题需要解决------穷举所有可能的联合分布空间 ∏ ( p d a t a , p g ) \prod(p_{data}, p_g) ∏(pdata,pg) 以找到下确界 inf ⁡ γ ∈ ∏ ( p d a t a , p g ) \underset {\gamma \in \prod(p_{data}, p_g)} {\inf} γ∈∏(pdata,pg)inf 是难以实现的。解决方案是采用其 Kantorovich-Rubinstein 对偶形式:
W ( p d a t a , p g ) = 1 K sup ⁡ ∥ f ∥ L ≤ K [ E x ∼ p d a t a [ f ( x ) ] − E x ∼ p g [ f ( x ) ] ] W(p_{data}, p_g) = \frac{1}{K} \sup_{\|f\|L \leq K} \left[ \mathbb{E}{x \sim p_{data}}[f(x)] - \mathbb{E}_{x \sim p_g}[f(x)] \right] W(pdata,pg)=K1∥f∥L≤Ksup[Ex∼pdata[f(x)]−Ex∼pg[f(x)]]

等价地,Wasserstein 距离在 sup ⁡ ∥ f ∥ L ≤ K \underset {\|f\|_L \leq K}{\sup} ∥f∥L≤Ksup 条件下的上确界(可近似理解为最大值)就是针对所有 K-Lipschitz 函数 f : X → R f:X \to \mathbb{R} f:X→R 取的上界。K-Lipschitz 函数满足以下约束条件,对所有 x 1 , x 2 ∈ R x_1, x_2 \in \mathbb{R} x1,x2∈R,有:
∣ f ( x 1 ) − f ( x 2 ) ∣ ≤ K ∣ x 1 − x 2 ∣ |f(x_1) - f(x_2)| \leq K|x_1 - x_2| ∣f(x1)−f(x2)∣≤K∣x1−x2∣

K-Lipschitz 函数具有有界导数的特性,且几乎处处连续可微(例如函数 f ( x ) = ∣ x ∣ f(x) = |x| f(x)=∣x∣ 虽在 x = 0 x=0 x=0 处不可导,但其导数有界且函数连续)。

通过寻找一族 K-Lipschitz 函数 { f w } w ∈ W \{f_w\}{w \in \mathcal{W}} {fw}w∈W,以上公式可转化为:
W ( p d a t a , p g ) = max ⁡ w ∈ W [ E x ∼ p d a t a [ f w ( x ) ] − E x ∼ p g [ f w ( x ) ] ] W(p
{data}, p_g) = \max_{w \in \mathcal{W}} \left[ \mathbb{E}{x \sim p{data}}[f_w(x)] - \mathbb{E}_{x \sim p_g}[f_w(x)] \right] W(pdata,pg)=w∈Wmax[Ex∼pdata[fw(x)]−Ex∼pg[fw(x)]]

GAN 框架下,上式可通过从噪声分布 z z z 中采样,并用判别器函数 D w D_w Dw 替代 f w f_w fw 重新表述为:
W ( p d a t a , p g ) = max ⁡ w ∈ W [ E x ∼ p d a t a [ D w ( x ) ] − E z [ D w ( G ( z ) ) ] ] W(p_{data}, p_g) = \max_{w \in \mathcal{W}} \left[ \mathbb{E}{\boldsymbol{x} \sim p{data}}[\mathcal{D}w(\boldsymbol{x})] - \mathbb{E}{\boldsymbol{z}}[\mathcal{D}_w(\mathcal{G}(\boldsymbol{z}))] \right] W(pdata,pg)=w∈Wmax[Ex∼pdata[Dw(x)]−Ez[Dw(G(z))]]

最后的关键在于如何确定函数族 w ∈ W w \in \mathcal{W} w∈W。解决方案是在每次梯度更新时,将判别器权重 w w w 裁剪至指定区间(如 -0.010.01):
w ← c l i p ( w , − 0.01 , 0.01 ) w \leftarrow clip(w, -0.01, 0.01) w←clip(w,−0.01,0.01)

通过限制权重取值范围,可将判别器约束在紧参数空间内,从而保证 Lipschitz 连续性。

我们可以将上式作为构建新 GAN 损失函数的基础。Wasserstein 距离既作为生成器力求最小化的损失函数,也作为判别器试图最大化(或最小化其负值)的目标函数:
L ( D ) = − E x ∼ p d a t a D w ( x ) + E z D w ( G ( z ) ) L ( G ) = − E z D w ( G ( z ) ) \mathcal{L}^{(D)} = -\mathbb{E}{x \sim p{data}}D_w(x) + \mathbb{E}{z}D_w(G(z)) \\ \mathcal{L}^{(G)} = -\mathbb{E}{z}D_w(G(z)) L(D)=−Ex∼pdataDw(x)+EzDw(G(z))L(G)=−EzDw(G(z))

在生成器损失函数中,由于不直接对真实数据进行优化,因此不包含第一项。

下表展示了 GANWGAN 损失函数的区别。为简洁起见,我们对 L ( D ) \mathcal{L}(D) L(D) 和 L ( G ) \mathcal{L}(G) L(G) 的表示进行了简化。

网络类型 损失函数
GAN L ( D ) = − E x ∼ p d a t a l o g D ( x ) − E z l o g ( 1 − D ( G ( z ) ) ) L ( G ) = − E z l o g D ( G ( z ) ) \mathcal L^{(D)}=-\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_zlog(1-D(G(z)))\\ \mathcal L^{(G)} = −\mathbb E_z logD(G(z)) L(D)=−Ex∼pdatalogD(x)−Ezlog(1−D(G(z)))L(G)=−EzlogD(G(z))
WGAN L ( D ) = − E x ∼ p d a t a D w ( x ) + E z D w ( G ( z ) ) L ( G ) = − E z D w ( G ( z ) ) w ← c l i p ( w , − 0.01 , 0.01 ) \mathcal{L}^{(D)} = -\mathbb{E}{x \sim p{data}}D_w(x) + \mathbb{E}{z}D_w(G(z)) \\ \mathcal{L}^{(G)} = -\mathbb{E}{z}D_w(G(z)) \\ w \leftarrow clip(w, −0.01, 0.01) L(D)=−Ex∼pdataDw(x)+EzDw(G(z))L(G)=−EzDw(G(z))w←clip(w,−0.01,0.01)

这些损失函数用于训练 WGAN,算法如下所示。

定义学习率 α \alpha α,裁剪参数 c c c,批大小 m m m,每次生成器迭代时判别器的迭代次数 n c r i t i c n_{critic} ncritic;初始判别器参数 w 0 w_0 w0,初始生成器参数 θ 0 \theta_0 θ0
while θ \theta θ 未收敛 do

for t=1,..., n c r i t i c n_{critic} ncritic do

从真实数据中采样批次 { x ( i ) } i = 1 m ∼ p d a t a \{x^{(i)}\}{i=1}^m\sim p{data} {x(i)}i=1m∼pdata

从均匀噪声分布中采样批次 { z ( i ) } i = 1 m ∼ p ( z ) \{z^{(i)}\}_{i=1}^m\sim p(z) {z(i)}i=1m∼p(z)

计算判别器梯度: g w ← ∇ w [ − 1 m ∑ i = 1 m D w ( x ( i ) ) + 1 m ∑ i = 1 m D w ( G θ ( z ( i ) ) ) ] g_w \leftarrow\nabla_w[-\frac 1m \sum_{i=1}^mD_w(x^{(i)}) +\frac 1m\sum_{i=1}^m D_w(G_\theta(z^{(i)}))] gw←∇w[−m1∑i=1mDw(x(i))+m1∑i=1mDw(Gθ(z(i)))]

更新判别器参数: w ← w − α × R M S P r o p ( w , g w ) w\leftarrow w − \alpha\times RMSProp(w, g_w) w←w−α×RMSProp(w,gw)

裁剪判别器权重: w ← c l i p ( w , − c , c ) w\leftarrow clip(w, −c, c) w←clip(w,−c,c)

end for

从均匀噪声分布中采样批次 { z ( i ) } i = 1 m ∼ p ( z ) \{z^{(i)}\}_{i=1}^m\sim p(z) {z(i)}i=1m∼p(z)

计算生成器梯度: g θ ← ∇ θ − 1 m ∑ i = 1 m D w ( G θ ( z ( i ) ) ) g_\theta \leftarrow\nabla_\theta-\frac 1m\sum_{i=1}^m D_w(G_\theta(z^{(i)})) gθ←∇θ−m1∑i=1mDw(Gθ(z(i)))

更新生成器参数: θ ← θ − α × R M S P r o p ( θ , g θ ) \theta\leftarrow \theta -\alpha\times RMSProp(\theta, g_\theta) θ←θ−α×RMSProp(θ,gθ)
end while

WGAN 模型与 DCGAN 实际结构基本相同,主要区别仅在于真假数据标签和损失函数。

GAN 类似,WGAN 通过对抗方式交替训练判别器和生成器。不过 WGAN 中,判别器(亦称评论家)需先进行 n c r i t i c n_{critic} ncritic 轮迭代训练,随后才执行一次生成器训练。这与 GAN 中判别器与生成器训练次数相等的模式不同,换言之,在 GAN 中 n c r i t i c = 1 n_{critic}=1 ncritic=1。

判别器训练意味着学习其参数(权重和偏置)。该过程需要从真实数据中采样一个批次,从生成数据中采样一个批次,将采样数据输入判别器网络后计算参数梯度,判别器参数通过 RMSProp 算法进行优化。

最后,通过裁剪判别器参数来满足 Wasserstein 距离优化中的 Lipschitz 约束。完成 n c r i t i c n_{critic} ncritic 轮判别器训练后,固定判别器参数。生成器训练首先采样一批生成数据,将这些数据标记为真实 (1.0) 以欺骗判别器网络,随后计算生成器梯度并使用 RMSProp 进行优化。

生成器训练完成后,解除判别器参数固定,开启新一轮 n c r i t i c n_{critic} ncritic 次判别器训练。需要说明的是,在判别器训练期间无需固定生成器参数,因为生成器仅参与数据生成过程。与 GAN 类似,判别器可作为独立网络进行训练,但生成器训练始终需要判别器通过对抗网络参与,因为损失值需基于生成器网络的输出计算。

GAN 不同,WGAN 中真实数据标签为 1.0,生成数据标签为 -1.0,这是为了适配梯度计算的特殊处理:
L = − y l a b e l × 1 m ∑ i = 1 m y p r e d \mathcal{L} = -y_{label} \times \frac{1}{m} \sum_{i=1}^{m} y_{pred} L=−ylabel×m1i=1∑mypred

其中真实数据对应 y l a b e l = 1.0 y_{label}=1.0 ylabel=1.0,生成数据对应 y l a b e l = − 1.0 y_{label}=-1.0 ylabel=−1.0。为简化表示省略上标 i i i。对于判别器而言,WGAN 在真实数据训练时通过增大预测值 D w ( x ) D_w(x) Dw(x) 来最小化损失函数。

使用生成数据训练时,WGAN 通过降低预测值 y p r e d = D w ( G ( z ) ) y_{pred} = D_w(G(z)) ypred=Dw(G(z)) 来最小化损失函数。对于生成器而言,当生成数据在训练过程中被标记为真实时,WGAN 通过提高预测值 y p r e d = D w ( G ( z ) ) y_{pred} = D_w(G(z)) ypred=Dw(G(z)) 来最小化损失函数。需要注意的是, y l a b e l y_{label} ylabel 除符号作用外,对损失函数没有直接影响。

5. 使用 Keras 实现 WGAN

本节最重要的内容是使用了基于 Wasserstein 距离的新损失函数,用于实现 GAN 的稳定训练。整体网络模型在 tf.keras 中的构建方式与 DCGAN 结构相似,接下来使用 Keras 实现 WGAN。判别器的训练有一个小的调整,先训练一批真实数据,然后再训练一批伪造数据,而不是同时训练一批真实数据和虚假数据。这种调整将防止梯度消失,因为真实和伪造数据标签中的符号相反,并且由于裁剪而导致的权重较小。

python 复制代码
import tensorflow as tf
from tensorflow import keras
import numpy as np
import gan

def wasserstein_loss(y_label,y_pred):
    return -keras.backend.mean(y_label*y_pred)

def build_and_train_models():
    (x_train,_),_ = keras.datasets.mnist.load_data()
    image_size = x_train.shape[1]
    x_train = np.expand_dims(x_train,axis=-1)
    x_train = x_train.astype('float32') / 255.

    model_name = 'wgan_mnist'
    latent_size = 100
    # 超参数 from WGAN paper
    n_critic = 5
    clip_value = 0.01
    batch_size = 64
    lr = 5e-5
    train_step = 40000
    input_shape = (image_size,image_size,1)

    inputs = keras.layers.Input(shape=input_shape,name='discriminator_input')
    # WGAN 使用线性激活函数
    discriminator = gan.discriminator(inputs,activation='linear')
    optimizer = keras.optimizers.RMSprop(lr=lr)
    discriminator.compile(loss=wasserstein_loss,optimizer=optimizer,metrics=['acc'])
    discriminator.summary()
    # generator
    input_shape = (latent_size,)
    inputs = keras.layers.Input(shape=input_shape,name='z_input')
    generator = gan.generator(inputs,image_size)
    generator.summary()

    discriminator.trainable = False
    adversarial = keras.Model(inputs,discriminator(generator(inputs)),
            name=model_name)
    adversarial.compile(loss=wasserstein_loss,optimizer=optimizer,metrics='acc')
    adversarial.summary()

    models = (generator,discriminator,adversarial)
    params = (batch_size,latent_size,n_critic,clip_value,train_step,model_name)
    train(models,x_train,params)

def train(models,x_train,params):
    generator,discriminator,adversarial = models
    batch_size,latent_size,n_critic,clip_value,train_step,model_name = params
    save_interval = 500
    noise_input = np.random.uniform(-1.,1.,size=[16,latent_size])
    train_size = x_train.shape[0]
    real_labels = np.ones((batch_size,1))
    for i in range(train_step):
        loss = 0
        acc = 0
        for _ in range(n_critic):
            rand_indexes = np.random.randint(0,train_size,size=batch_size)
            real_images = x_train[rand_indexes]
            noise = np.random.uniform(-1.0,1.0,size=[batch_size,latent_size])
            fake_images = generator.predict(noise)
            #fake data labels = -1
            real_loss,real_acc = discriminator.train_on_batch(real_images,real_labels)
            fake_loss,fake_acc = discriminator.train_on_batch(fake_images,-real_labels)
            loss += 0.5 * (real_loss + fake_loss)
            acc += 0.5 * (real_acc + fake_acc)
            for layer in discriminator.layers:
                weights = layer.get_weights()
                weights = [np.clip(weight,-clip_value,clip_value) for weight in weights]
                layer.set_weights(weights)
        loss /= n_critic
        acc /= n_critic
        log = "%d: [discriminator loss: %f, acc: %f]" % (i, loss, acc)
        print(log)
        noise = np.random.uniform(-1.,1.,size=[batch_size,latent_size])
        loss,acc = adversarial.train_on_batch(noise,real_labels)
        log = "%s [adversarial loss: %f, acc: %f]" % (log, loss, acc)
        print(log)
        if (i+1) % save_interval == 0:
            gan.plot_images(generator,noise_input=noise_input,
                    show=False,step=(i+1),model_name=model_name)
    generator.save(model_name+'.h5')

if __name__ == '__main__':
    build_and_train_models()

WGAN 即使在网络配置更改下也很稳定。 例如,当在判别器网络中的 ReLU 之前插入批归一化时,会导致 DCGAN 不稳定。在 WGAN 中,相同的配置是稳定的。WGAN 通过使用 Wasserstein 损失函数来解决模式坍塌问题,该函数即使在两个分布之间很少或没有重叠时也具有平滑的微分函数。但是,WGAN 没有改进生成的图像质量。

相关推荐
EasyCVR9 小时前
视频融合平台EasyCVR在智慧水利中的实战应用:构建全域感知与智能预警平台
人工智能·音视频
DisonTangor9 小时前
阿里开源Qwen3-Omni-30B-A3B三剑客——Instruct、Thinking 和 Captioner
人工智能·语言模型·开源·aigc
独孤--蝴蝶9 小时前
AI人工智能-机器学习-第一周(小白)
人工智能·机器学习
西柚小萌新9 小时前
【深入浅出PyTorch】--上采样+下采样
人工智能·pytorch·python
丁学文武10 小时前
大语言模型(LLM)是“预制菜”? 从应用到底层原理,在到中央厨房的深度解析
人工智能·语言模型·自然语言处理·大语言模型·大模型应用·预制菜
fie888910 小时前
基于MATLAB的声呐图像特征提取与显示
开发语言·人工智能
文火冰糖的硅基工坊11 小时前
[嵌入式系统-100]:常见的IoT(物联网)开发板
人工智能·物联网·架构
刘晓倩11 小时前
实战任务二:用扣子空间通过任务提示词制作精美PPT
人工智能
shut up11 小时前
LangChain - 如何使用阿里云百炼平台的Qwen-plus模型构建一个桌面文件查询AI助手 - 超详细
人工智能·python·langchain·智能体