什么是GAN?
-
GAN,全称 Generative Adversarial Network,即生成对抗网络。该网络模型由Ian J. Goodfellow 在2014年首次提出,以下是该论文原文下载地址:Generative Adversarial Nets。
-
生成对抗网络(GAN)是一种通过框架内两个核心模块------生成模型(Generative Model)和判别模型(Discriminative Model)------相互博弈学习,从而产生高质量输出的深度学习模型。作为当前最具前景和活跃度的生成式模型之一,GAN 在样本数据生成、图像合成、图像修复、图像转换以及文本生成等多个领域展现出强大能力,标志着生成式人工智能(AIGC)的关键突破。
GAN 的核心思想是通过生成器 和判别器的对抗训练,使生成器能够不断优化以生成逼真的数据,而判别器则不断提升鉴别真伪的能力。这种动态博弈机制使得 GAN 能够生成高度接近真实分布的图像或数据,成为现代生成式 AI 的重要基石之一。
生成对抗网络(GAN)的核心思想是通过**生成器(Generator)和判别器(Discriminator)**的对抗训练,使生成数据的分布逐步逼近真实数据的分布。在训练过程中,生成器从随机噪声中合成样本,并不断优化其生成能力,力求使生成的样本与真实数据尽可能相似,从而"欺骗"判别器。与此同时,判别器则通过对比生成样本和真实样本,持续提升自身的鉴别能力,以更精准地区分两者的差异。这种动态博弈机制推动双方不断优化,最终使生成器能够输出高度逼真的数据。
GAN的工作原理
-
核心构成
GAN由两个重要的部分构成:生成器(Generator,简写作G)和判别器(Discriminator,简写作D)。
- 生成器:通过机器生成数据,目的是尽可能"骗过"判别器,生成的数据记做G(z);
- 判别器:判断数据是真实数据还是「生成器」生成的数据,目的是尽可能找出「生成器」造的"假数据"。它的输入参数是x,x代表数据,输出D(x)代表x为真实数据的概率,如果为1,就代表100%是真实的数据,而输出为0,就代表不可能是真实的数据。
经过这样的设计,G和D就构成了一个动态对抗的过程,随着多次训练之后,G生成的数据越来越接近真实数据,D判断数据真伪的水平也越来越高。最后在训练的后期,G所生成的数据足够欺骗D,对于D来讲,它则难以判断数据究竟是G生成还是真实数据,因此最后的D(G(z))=0.5。这样我们就得到了一个生成模型可以生成足够以假乱真的数据。
-
训练步骤
- 第一阶段:固定判别器D,训练生成器G。首先使用一个性能不错的判别器D,G通过噪声不断生成假数据,将其丢给D去判断。实验开始时,G生成数据能力还比较弱,很容易就被判别出来。但随着训练的继续,G的生成能力逐渐提升,最终骗过判别器D,这时候D判断是否为假数据的概率为0.5。
- 第二阶段:固定生成器G,训练判别器D。当D判断是否为假数据的概率为0.5,再训练G就没有意义了,此时我们需要训练D。训练D之前,我们先固定G,然后不断训练D。通过不断训练,D提高了自己的鉴别能力,又能够判断出假数据了。
- 不断重复第一阶段与第二阶段:通过不断的训练循环,生成器G和判别器D的能力都很强了,我们就能得到一个生成数据效果很好的生成器G。
GAN的数学原理
注意:该章主要是对GAN文献原文中所涉及到的部分数学原理做介绍,内容相对有难度,请读者按需阅读!
-
GAN中各种数据变量解释
GAN原文的应用是分别训练两个多层感知机来扮演生成器G和判别器D,首先为了训在真实数据 上的真实数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x) ,我们定义了一个噪声数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z上的噪声数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( z ) p_z(z) </math>pz(z),通常该分布可以使用均匀分布、高斯分布等,是实验者人为定义的分布。
接下来,我们定义一个多层感知机 <math xmlns="http://www.w3.org/1998/Math/MathML"> G ( z ; θ g ) G(z;θ_g) </math>G(z;θg),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z是噪声数据, <math xmlns="http://www.w3.org/1998/Math/MathML"> θ g θ_g </math>θg为生成器多层感知机的训练参数。再将上文提到的噪声数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( z ) p_z(z) </math>pz(z)作为生成器的输入,并其映射为一个新的数据分布,即生成样本分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x),该分布不同于噪声数据分布,该分布可能十分复杂。接下来的训练过程就是将生成样本分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)不断逼近 真实数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x)。
那么以上的各种表达式就满足以下的数学关系:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x = G ( z ) , z ∼ p z ( z ) ⟹ x ∼ p g ( x ) x=G(z), z∼p_z(z) ⟹ x∼p_g(x) </math>x=G(z),z∼pz(z)⟹x∼pg(x)接一下我们定义第二个多层感知机 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x ; θ d ) D(x;θ_d) </math>D(x;θd),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x是判别器的输入,它可能来源于生成器生成的假数据,也可能来自于真实数据, <math xmlns="http://www.w3.org/1998/Math/MathML"> θ d θ_d </math>θd为判别器多层感知机的训练参数。那么判别器D的输出为一个标量即判别该 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x是真数据还是假数据。
我们可以使用下面这个表格再次理解一下其中的各个变量。
变量 含义 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z 噪声向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( z ) p_z(z) </math>pz(z) 噪声向量的先验分布(如高斯分布、均匀分布等) <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x) 真实数据的概率分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x) 生成器生成的隐式分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ g θ_g </math>θg 生成器网络的训练参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> G ( z ; θ g ) G(z;θ_g) </math>G(z;θg) 生成器网络,将 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( z ) p_z(z) </math>pz(z)映射为 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)(GAN原文使用的是多层感知机) <math xmlns="http://www.w3.org/1998/Math/MathML"> θ d θ_d </math>θd 判别器网络的训练参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x ; θ d ) D(x;θ_d) </math>D(x;θd) 判别器网络,判别 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x是否来自真实数据(GAN原文使用的是多层感知机) -
GAN的损失函数解析
训练网络得少不了解析损失函数,我们直接给出GAN原文中提到的损失函数,我们再对其进行解析。
损失函数如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> min G max D V ( D , G ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \min_G\max_DV(D,G)=E_{x∼p_{data}(x)}[logD(x)]+E_{z∼p_z(z)}[log(1-D(G(z)))] </math>GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]这个公式看似很复杂,其实是可以理解为两个公式。
-
针对生成器G,损失函数可以理解为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> min G V ( G ) = E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( Z ) ) ) ] \min_GV(G)=E_{z∼p_z(z)}[log(1-D(G(Z)))] </math>GminV(G)=Ez∼pz(z)[log(1−D(G(Z)))]
-
其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z为噪声,G(z)为生成器由噪声生成的假数据,D(G(z))为判别器判别由生成器送来数据的结果。
如果此时D(G(z)) = 0,则代表判别器成功判断出该数据是假数据,那么此时log(1-D(G(z)))就会等于0。如果此时D(G(z)) = 1,则代表判别器没能判断出该数据是假数据,那么此时log(1-D(G(z)))就会趋向于负无穷。所以我们训练生成器的目标就是尽量让判别器出错,这样该损失函数的值就能取得最小值。
注:在GAN原文中指出,早期训练log(1-D(G(z)))时,由于此时的生成器太弱,容易出现判别器赢得对抗,导致生成器无法进行训练优化的情况,在数学上的表现就是训练过程中梯度消失,所以我们在训练早期改用最大化log(D(G(z)))来训练生成器。
-
-
针对判别器D,损失函数可以理解为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> max D V ( D ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \max_DV(D)=E_{x∼p_{data}(x)}[logD(x)]+E_{z∼p_z(z)}[log(1-D(G(z)))] </math>DmaxV(D)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
-
其中,x是真实数据,D(x)是判别器判断真实数据的结果。
上文中,我们已经解释加号后部分工作原理,即该部分越大,判别器越能判断出数据是否是假数据,所以该部分对于判别器来说应当取得最大值。接一下我们主要解释加号前部分的工作原理。
此时,若D(x)=1,则判别器成功判别出该真实数据为真实数据,那么log(D(x))就会等于0。若此时D(x)=0,则代表判别器将真实数据判断为假数据,那么log(D(x))就会趋向于负无穷。所以,我们为了训练判别器D,我们就需要让判别器尽量正确判别出数据是否为真数据,即要让该公式取得最大值。
-
-
-
GAN训练过程的图解
注:该图来源于GAN原文
-
图中元素解析
- 黑色虚线:真实数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x)
- 绿色实线:生成器所拟合的数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)
- 蓝色虚线:判别器的输出概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x ) D(x) </math>D(x)
- 判别器最佳时,x为真数据时,D(x)=1,x为假数据,D(x)=0。生成器最佳时,D(x)=0.5即判别器只能乱猜数据是否为真。
- 上方水平线:数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x的分布空间
- 下方水平线:噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z的采样空间
- 箭头:生成器G将噪声z映射到数据空间x的过程,即将 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( z ) p_z(z) </math>pz(z)映射为 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)的过程。
-
图中各个阶段解读
(a)初始阶段
- 绿色实线与黑色虚线差别很大,即生成分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_g </math>pg与真实分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata差异过大,生成数据质量比较低。
- 蓝色虚线在绿色实线低的位置高,在绿色实线高的位置低,代表判别器D能够初步区分出真实数据与生成数据。
- 总结该阶段:此时生成器还没有能力生成足够欺骗判别器的数据,判别器已经有了初步的判别能力。
(b)判别器优化
- 从(a)到(b)的主要差异是蓝色虚线的变化,蓝色虚线从(a)阶段的有高低起伏趋向于稳定。
- 判别器D趋向于最优解,即判别器在生成数据少的部分能够有效判断出为真实数据,在生成数据多的部分也能有效判断出假数据。
- 总结该阶段:此时生成器被固定依然没有能力生成足够欺骗判别器的数据,而判别器的判别能力趋向最优解。
(c )生成器优化
- 从(b)到(c)的主要差异是绿色实线的变化和下方箭头的变化,绿色实线开始向黑色虚线趋近,箭头也从指向右侧变为指向中部。
- 这两个变化的含义相同,由于生成器的能力不断优化,箭头的变化代表生成器正将噪声z映射到数据空间 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)的变化, <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)越来越接近真实数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x),这就造成了绿色实线不断靠近黑色虚线,即生成数据的数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)正逐渐趋近于真实数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x)。
- 总结该阶段:生成器不断优化生成数据的能力,生成数据不断接近真实数据。
(d)收敛阶段
-
从(c)到(d)的主要差异是绿色实线与黑色虚线重合,蓝色虚线变为一条无变化的直线,箭头更加趋近于中部。
-
这些变化的含义都表示此时生成器已经达到最优,箭头的变化代表生成器已经有能力将噪声z映射到数据空间 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_g </math>pg,并且该分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)与真实分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x)完全相同,这就造成绿色实线与黑色虚线完全重合,同时,蓝色虚线代表的判别器的输出概率公式为
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> D ( x ) = p d a t a ( x ) p d a t a ( x ) + p g ( x ) D(x)=\frac{p_{data}(x)}{p_{data}(x)+p_g(x)} </math>D(x)=pdata(x)+pg(x)pdata(x)
由于 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)与真实分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x)完全相同,使得D(x)恒等于0.5,所以蓝色虚线就变为了一条直线。
-
总结该阶段:此时生成器已经有了能力生成足够欺骗判别器的数据,判别器没有能力再判断数据的真伪,陷入只能瞎猜的境地。
-
针对图解常见问题解答
1. 该过程就是GAN训练的全过程吗?
- 实际上,该图是用训练过程中几个理想片段来表达GAN的训练过程,其中(a)是训练一开始阶段,(d)是训练达到收敛的阶段,而(b)和(c)在实际训练中需要经过多许多次迭代,才能达到(d),即真实训练中,(a)需要经过很多(b)和(c)阶段才能达到(d)。
2. 该图解表述为先训练D,而上文步骤中表述为先训练G,究竟是先训练哪一个?
- 由上一问我们得知,在一次迭代中,生成器和训练器都要进行一次参数更新优化,其中一个网络的性能提升都会带动另外一个网络的性能提升,所以在完整的一个训练过程中一次细微迭代中究竟是先训练G还是D并不会对结果造成太大的影响。
-
-
GAN的训练算法步骤
注:本章是对GAN原文所提及的算法做解释,可能与实际生成中算法有一定出入
以下是GAN原本中提及的算法伪代码:
pythonfor 训练迭代次数 do # 步骤1:优化判别器 D(k 次更新) for k steps do 1. 从噪声先验中采样批噪声:{z^(1), ..., z^(m)} ∼ p_z(z) 2. 从真实数据中采样批样本:{x^(1), ..., x^(m)} ∼ p_data(x) 3. 更新判别器参数 θ_d,通过梯度上升: end for # 步骤2:优化生成器 G(1 次更新) 1. 从噪声先验中采样批噪声:{z^(1), ..., z^(m)} ∼ p_z(z) 2. 更新生成器参数 θ_g,通过梯度下降: end for
可能第一次看不明白以上代码究竟是什么含义,接下来我们会做完整介绍。
-
第一层循环是重复迭代次数个循环,这个循环等同于上文中重复多次(b)和(c)的过程。
-
第二层循环是重复k次,k是一个超参数是由实验者人为指定的参数,该层循环等同上文中图解中的(b)过程,只是训练(b)时,我们需要重复k次。
-
判别器的优化过程:
-
首先,我们从噪声先验 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( z ) p_z(z) </math>pz(z)中采样批噪声: <math xmlns="http://www.w3.org/1998/Math/MathML"> z ( 1 ) , . . . , z ( m ) z^{(1)}, ..., z^{(m)} </math>z(1),...,z(m)。
-
然后,我们再从真实数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x)中采样批样本: <math xmlns="http://www.w3.org/1998/Math/MathML"> x ( 1 ) , . . . , x ( m ) x^{(1)}, ..., x^{(m)} </math>x(1),...,x(m)。
-
将这一批的噪声与数据同时送入到以下损失函数中并计算梯度:
判别器的损失函数构成:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> max D V ( D ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \max_DV(D)=E_{x∼p_{data}(x)}[logD(x)]+E_{z∼p_z(z)}[log(1-D(G(z)))] </math>DmaxV(D)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]所以,我们将真实数据送入到加号前一项,噪声数据送入到加号后一项,然后我们计算该批次梯度,梯度计算公式如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∇ θ d 1 m ∑ i = 1 m [ l o g D ( x i ) + l o g ( 1 − D ( G ( z i ) ) ) ] ∇{θ_d}\frac{1}{m}\sum{i=1}^m[logD(x^i)+log(1-D(G(z^{i})))] </math>∇θdm1i=1∑m[logD(xi)+log(1−D(G(zi)))] -
接下来,我们做参数更新,由于我们要求的最大值,所以此时应该是梯度上升:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ d + 1 = θ d + η ∇ θ d θ_{d+1} = θ_{d} + η∇_{θ_d} </math>θd+1=θd+η∇θd
- 其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> θ d + 1 θ_{d+1} </math>θd+1和 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ d θ_{d} </math>θd是更新前后的参数, <math xmlns="http://www.w3.org/1998/Math/MathML"> η η </math>η是学习率, <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ θ d ∇_{θ_d} </math>∇θd为本次计算得到梯度。
-
重复上述过程k次。
注:k的选择需要保证判别器有一定的优化空间,又不至于优化太好,使得生成器的优化受限。
-
-
生成器的优化过程:
-
首先,我们也是从噪声先验 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( z ) p_z(z) </math>pz(z)中采样批噪声: <math xmlns="http://www.w3.org/1998/Math/MathML"> z ( 1 ) , . . . , z ( m ) z^{(1)}, ..., z^{(m)} </math>z(1),...,z(m)。
-
再将这一批噪声送入到以下损失函数并计算梯度:
生成器的损失函数构成:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> min G V ( G ) = E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( Z ) ) ) ] \min_GV(G)=E_{z∼p_z(z)}[log(1-D(G(Z)))] </math>GminV(G)=Ez∼pz(z)[log(1−D(G(Z)))]然后,我们计算该批次噪声的梯度:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∇ θ g 1 m ∑ i = 1 m l o g ( 1 − D ( G ( z i ) ) ) ∇{θ_g}\frac{1}{m}\sum{i=1}^mlog(1-D(G(z^{i}))) </math>∇θgm1i=1∑mlog(1−D(G(zi))) -
接下来,做参数更新,由于我们此时要求最小值,所以应当使用梯度下降:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ g + 1 = θ g − η ∇ θ g θ_{g+1} = θ_{g} - η∇_{θ_g} </math>θg+1=θg−η∇θg
- 其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> θ g + 1 θ_{g+1} </math>θg+1和 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ g θ_{g} </math>θg是更新前后的参数, <math xmlns="http://www.w3.org/1998/Math/MathML"> η η </math>η是学习率, <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ θ g ∇_{θ_g} </math>∇θg为本次计算得到梯度。
-
-
最后将以上判别器和生成器的优化过程重复迭代次数即可。
-
GAN代码实例演示------实现手写数字
数据集选择与加载:MNIST数据集
-
MNIST数据集是机器学习领域最经典的入门数据集之一,主要用于手写数字识别任务,该数据集的内容主要包括0到9的手写数字的灰度图片,每张图片大小为28x28像素。该数据集的数据量训练集有60,000张图片,测试集10,000张图片。本文只使用MNIST数据集的训练集部分。
数据集加载与显示代码部分(本文最后设计有全部代码)
注:其中有部分设计到超参数的设置,在一章会有说明
python# 加载MNIST数据集 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) # 将像素值归一化到[-1, 1] ]) dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # 可视化部分 # 定义类别标签 class_names = ['0','1','2','3','4','5','6','7','8','9'] # 从训练集中随机取一个batch的图像 images, labels = next(iter(loader)) # 获取一个batch(64张图) # iter()转换为迭代器,next()获取下一个批次的数据 # images为一个形状为[64, 1, 28, 28]的张量 labels为[64]的张量 # 显示图像函数 def imshow(img): img = img.numpy() img = np.squeeze(img) # 移除单通道维度 (1,28,28) -> (28,28) img = img * 0.5 + 0.5 # 反归一化到[0,1] plt.imshow(img, cmap='gray') plt.axis('off') # 画出一个4x8的网格(共32张图) plt.figure(figsize=(12, 6)) for i in range(32): # 显示前32张 plt.subplot(4, 8, i+1) imshow(images[i]) plt.title(class_names[labels[i].item()], fontsize=8) plt.tight_layout() plt.show()
-
数据集图片演示
超参数设置与网络设计
-
超参数设置一般放在代码的最前面,这一部分并非必需,也可以在后面的代码部分手动设置,这里只是习惯问题。
python# 设置超参数 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 训练设备 print(device) lr = 0.0002 # 学习率 z_dim = 64 # 噪声维度 image_dim = 28 * 28 * 1 # MNIST图像维度 batch_size = 64 # 批量大小 epochs = 50 # 训练轮数
-
网络设计:
-
生成器G网络设计
生成器采用三层感知机,其中激活函数选用LeakyReLU函数,斜率设置为0.1,最后的激活函数选择Tan函数。
python# 生成器网络设计 class Generator(nn.Module): def __init__(self, z_dim, img_dim): super(Generator, self).__init__() self.model = nn.Sequential( nn.Linear(z_dim, 128), nn.LeakyReLU(0.1), nn.Linear(128, 256), nn.LeakyReLU(0.1), nn.Linear(256, img_dim), nn.Tanh() ) def forward(self, x): return self.model(x)
-
判别器D网络设计
生成器采用三层感知机,其中激活函数选用LeakyReLU函数,斜率设置为0.1,最后的激活函数选择Sigmoid函数。
python# 判别器网络设计 class Discriminator(nn.Module): def __init__(self, img_dim): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(img_dim, 256), nn.LeakyReLU(0.1), nn.Linear(256, 128), nn.LeakyReLU(0.1), nn.Linear(128, 1), nn.Sigmoid() ) def forward(self, x): return self.model(x)
-
网络实例化与循环训练
python
# 循环轮次
for epoch in range(epochs):
# 提取数据
for i, (real_img, _) in enumerate(loader):
# 1. 训练判别器 log(D(real)) + log(1 - D(G(z)))
# 采样真实数据
batch_size = real_img.shape[0]
real_img = real_img.view(-1, image_dim).to(device)
# 进行判别得到损失函数值
disc_real = discriminator(real_img).flatten()
real_labels = torch.ones_like(disc_real).to(device)
loss_real = criterion(disc_real, real_labels)
# 采样噪声数据
noise = torch.randn(batch_size, z_dim).to(device)
fake_img = generator(noise)
# 进行判别得到损失函数值
disc_fake = discriminator(fake_img.detach()).flatten()
fake_labels = torch.zeros_like(disc_fake).to(device)
loss_fake = criterion(disc_fake, fake_labels)
# 将两者损失值求和除以二,以免其中一个损失值过大影响训练
loss_disc = (loss_real + loss_fake) / 2
# 更新参数
opt_disc.zero_grad()
loss_disc.backward()
opt_disc.step()
# 训练生成器 最小化 log(1 - D(G(z))) → 最大化 log(D(G(z)))
# 将噪声数据采样进行判别
output = discriminator(fake_img).flatten()
# 计算损失函数值
loss_gen = criterion(output, real_labels)
# 更新参数
opt_gen.zero_grad()
loss_gen.backward()
opt_gen.step()
打印结果与结果保存
python
# 生成结果保存文件夹
os.makedirs("generated_images", exist_ok=True)
# 打印数据并保存图像数据
if i == 0:
print(
f"Epoch [{epoch+1}/{epochs}] "
f"Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}"
)
with torch.no_grad():
noise = torch.randn(batch_size, z_dim).to(device)
fake = generator(noise).reshape(-1, 1, 28, 28)
img_grid = torchvision.utils.make_grid(fake, nrow=4, normalize=True)
plt.figure(figsize=(8, 8))
plt.imshow(img_grid.permute(1, 2, 0).cpu().numpy())
plt.axis('off')
plt.title(f"Epoch {epoch+1}")
plt.savefig(f"generated_images/epoch{epoch+1}.png")
#plt.show()
plt.close()
全部代码
python
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # 允许重复加载OpenMP库
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# 设置超参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 训练设备
print(device)
lr = 0.0002 # 学习率
z_dim = 64 # 噪声维度
image_dim = 28 * 28 * 1 # MNIST图像维度
batch_size = 64 # 批量大小
epochs = 50 # 训练轮数
os.makedirs("generated_images", exist_ok=True)
# 加载MNIST数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # 将像素值归一化到[-1, 1]
])
dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 可视化部分
# 定义类别标签
class_names = ['0','1','2','3','4','5','6','7','8','9']
# 从训练集中随机取一个batch的图像
images, labels = next(iter(loader)) # 获取一个batch(64张图)
# iter()转换为迭代器,next()获取下一个批次的数据
# images为一个形状为[64, 1, 28, 28]的张量 labels为[64]的张量
# 显示图像函数
def imshow(img):
img = img.numpy()
img = np.squeeze(img) # 移除单通道维度 (1,28,28) -> (28,28)
img = img * 0.5 + 0.5 # 反归一化到[0,1]
plt.imshow(img, cmap='gray')
plt.axis('off')
# 画出一个4x8的网格(共32张图)
plt.figure(figsize=(12, 6))
for i in range(32): # 显示前32张
plt.subplot(4, 8, i+1)
imshow(images[i])
plt.title(class_names[labels[i].item()], fontsize=8)
plt.tight_layout()
plt.show()
# 生成器网络设计
class Generator(nn.Module):
def __init__(self, z_dim, img_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(z_dim, 128),
nn.LeakyReLU(0.1),
nn.Linear(128, 256),
nn.LeakyReLU(0.1),
nn.Linear(256, img_dim),
nn.Tanh()
)
def forward(self, x):
return self.model(x)
# 判别器网络设计
class Discriminator(nn.Module):
def __init__(self, img_dim):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_dim, 256),
nn.LeakyReLU(0.1),
nn.Linear(256, 128),
nn.LeakyReLU(0.1),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# 网络实例化
generator = Generator(z_dim, image_dim).to(device)
discriminator = Discriminator(image_dim).to(device)
# 损失函数选择交叉熵损失函数
criterion = nn.BCELoss()
# 优化器选择Adam优化器
opt_gen = torch.optim.Adam(generator.parameters(), lr=lr)
opt_disc = torch.optim.Adam(discriminator.parameters(), lr=lr)
# 循环轮次
for epoch in range(epochs):
# 提取数据
for i, (real_img, _) in enumerate(loader):
# 1. 训练判别器 log(D(real)) + log(1 - D(G(z)))
# 采样真实数据
batch_size = real_img.shape[0]
real_img = real_img.view(-1, image_dim).to(device)
# 进行判别得到损失函数值
disc_real = discriminator(real_img).flatten()
real_labels = torch.ones_like(disc_real).to(device)
loss_real = criterion(disc_real, real_labels)
# 采样噪声数据
noise = torch.randn(batch_size, z_dim).to(device)
fake_img = generator(noise)
# 进行判别得到损失函数值
disc_fake = discriminator(fake_img.detach()).flatten()
fake_labels = torch.zeros_like(disc_fake).to(device)
loss_fake = criterion(disc_fake, fake_labels)
# 将两者损失值求和除以二,以免其中一个损失值过大影响训练
loss_disc = (loss_real + loss_fake) / 2
# 更新参数
opt_disc.zero_grad()
loss_disc.backward()
opt_disc.step()
# 训练生成器 最小化 log(1 - D(G(z))) → 最大化 log(D(G(z)))
# 将噪声数据采样进行判别
output = discriminator(fake_img).flatten()
# 计算损失函数值
loss_gen = criterion(output, real_labels)
# 更新参数
opt_gen.zero_grad()
loss_gen.backward()
opt_gen.step()
# 打印数据并保存图像数据
if i == 0:
print(
f"Epoch [{epoch+1}/{epochs}] "
f"Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}"
)
with torch.no_grad():
noise = torch.randn(batch_size, z_dim).to(device)
fake = generator(noise).reshape(-1, 1, 28, 28)
img_grid = torchvision.utils.make_grid(fake, nrow=4, normalize=True)
plt.figure(figsize=(8, 8))
plt.imshow(img_grid.permute(1, 2, 0).cpu().numpy())
plt.axis('off')
plt.title(f"Epoch {epoch+1}")
plt.savefig(f"generated_images/epoch{epoch+1}.png")
#plt.show()
plt.close()
结果展示
- 显然该网络并没有训练到完全拟合,还可以继续增加训练的轮数使得网络训练更加趋近于拟合。