GAN生成对抗网络数学原理解释并实现MNIST数据集生产(附代码演示)

什么是GAN?

  • GAN,全称 Generative Adversarial Network,即生成对抗网络。该网络模型由Ian J. Goodfellow 在2014年首次提出,以下是该论文原文下载地址:Generative Adversarial Nets

  • 生成对抗网络(GAN)是一种通过框架内两个核心模块------生成模型(Generative Model)判别模型(Discriminative Model)------相互博弈学习,从而产生高质量输出的深度学习模型。作为当前最具前景和活跃度的生成式模型之一,GAN 在样本数据生成、图像合成、图像修复、图像转换以及文本生成等多个领域展现出强大能力,标志着生成式人工智能(AIGC)的关键突破。

    GAN 的核心思想是通过生成器判别器的对抗训练,使生成器能够不断优化以生成逼真的数据,而判别器则不断提升鉴别真伪的能力。这种动态博弈机制使得 GAN 能够生成高度接近真实分布的图像或数据,成为现代生成式 AI 的重要基石之一。

    生成对抗网络(GAN)的核心思想是通过**生成器(Generator)判别器(Discriminator)**的对抗训练,使生成数据的分布逐步逼近真实数据的分布。在训练过程中,生成器从随机噪声中合成样本,并不断优化其生成能力,力求使生成的样本与真实数据尽可能相似,从而"欺骗"判别器。与此同时,判别器则通过对比生成样本和真实样本,持续提升自身的鉴别能力,以更精准地区分两者的差异。这种动态博弈机制推动双方不断优化,最终使生成器能够输出高度逼真的数据。

GAN的工作原理

  • 核心构成

    GAN由两个重要的部分构成:生成器(Generator,简写作G)和判别器(Discriminator,简写作D)。

    • 生成器:通过机器生成数据,目的是尽可能"骗过"判别器,生成的数据记做G(z);
    • 判别器:判断数据是真实数据还是「生成器」生成的数据,目的是尽可能找出「生成器」造的"假数据"。它的输入参数是x,x代表数据,输出D(x)代表x为真实数据的概率,如果为1,就代表100%是真实的数据,而输出为0,就代表不可能是真实的数据。

    经过这样的设计,G和D就构成了一个动态对抗的过程,随着多次训练之后,G生成的数据越来越接近真实数据,D判断数据真伪的水平也越来越高。最后在训练的后期,G所生成的数据足够欺骗D,对于D来讲,它则难以判断数据究竟是G生成还是真实数据,因此最后的D(G(z))=0.5。这样我们就得到了一个生成模型可以生成足够以假乱真的数据。

  • 训练步骤

    • 第一阶段:固定判别器D,训练生成器G。首先使用一个性能不错的判别器D,G通过噪声不断生成假数据,将其丢给D去判断。实验开始时,G生成数据能力还比较弱,很容易就被判别出来。但随着训练的继续,G的生成能力逐渐提升,最终骗过判别器D,这时候D判断是否为假数据的概率为0.5。
    • 第二阶段:固定生成器G,训练判别器D。当D判断是否为假数据的概率为0.5,再训练G就没有意义了,此时我们需要训练D。训练D之前,我们先固定G,然后不断训练D。通过不断训练,D提高了自己的鉴别能力,又能够判断出假数据了。
    • 不断重复第一阶段与第二阶段:通过不断的训练循环,生成器G和判别器D的能力都很强了,我们就能得到一个生成数据效果很好的生成器G。

GAN的数学原理

注意:该章主要是对GAN文献原文中所涉及到的部分数学原理做介绍,内容相对有难度,请读者按需阅读!


  • GAN中各种数据变量解释

    GAN原文的应用是分别训练两个多层感知机来扮演生成器G和判别器D,首先为了训在真实数据 上的真实数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x) ,我们定义了一个噪声数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z上的噪声数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( z ) p_z(z) </math>pz(z),通常该分布可以使用均匀分布、高斯分布等,是实验者人为定义的分布。

    接下来,我们定义一个多层感知机 <math xmlns="http://www.w3.org/1998/Math/MathML"> G ( z ; θ g ) G(z;θ_g) </math>G(z;θg),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z是噪声数据, <math xmlns="http://www.w3.org/1998/Math/MathML"> θ g θ_g </math>θg为生成器多层感知机的训练参数。再将上文提到的噪声数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( z ) p_z(z) </math>pz(z)作为生成器的输入,并其映射为一个新的数据分布,即生成样本分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x),该分布不同于噪声数据分布,该分布可能十分复杂。接下来的训练过程就是将生成样本分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)不断逼近 真实数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x)

    那么以上的各种表达式就满足以下的数学关系:
    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x = G ( z ) , z ∼ p z ( z ) ⟹ x ∼ p g ( x ) x=G(z), z∼p_z(z) ⟹ x∼p_g(x) </math>x=G(z),z∼pz(z)⟹x∼pg(x)

    接一下我们定义第二个多层感知机 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x ; θ d ) D(x;θ_d) </math>D(x;θd),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x是判别器的输入,它可能来源于生成器生成的假数据,也可能来自于真实数据, <math xmlns="http://www.w3.org/1998/Math/MathML"> θ d θ_d </math>θd为判别器多层感知机的训练参数。那么判别器D的输出为一个标量即判别该 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x是真数据还是假数据。

    我们可以使用下面这个表格再次理解一下其中的各个变量。

    变量 含义
    <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z 噪声向量
    <math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( z ) p_z(z) </math>pz(z) 噪声向量的先验分布(如高斯分布、均匀分布等)
    <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x) 真实数据的概率分布
    <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x) 生成器生成的隐式分布
    <math xmlns="http://www.w3.org/1998/Math/MathML"> θ g θ_g </math>θg 生成器网络的训练参数
    <math xmlns="http://www.w3.org/1998/Math/MathML"> G ( z ; θ g ) G(z;θ_g) </math>G(z;θg) 生成器网络,将 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( z ) p_z(z) </math>pz(z)映射为 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)(GAN原文使用的是多层感知机)
    <math xmlns="http://www.w3.org/1998/Math/MathML"> θ d θ_d </math>θd 判别器网络的训练参数
    <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x ; θ d ) D(x;θ_d) </math>D(x;θd) 判别器网络,判别 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x是否来自真实数据(GAN原文使用的是多层感知机)
  • GAN的损失函数解析

    训练网络得少不了解析损失函数,我们直接给出GAN原文中提到的损失函数,我们再对其进行解析。

    损失函数如下:
    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \min_G\max_DV(D,G)=E_{x∼p_{data}(x)}[logD(x)]+E_{z∼p_z(z)}[log(1-D(G(z)))] </math>GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]

    这个公式看似很复杂,其实是可以理解为两个公式。

    • 针对生成器G,损失函数可以理解为:

      <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> min ⁡ G V ( G ) = E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( Z ) ) ) ] \min_GV(G)=E_{z∼p_z(z)}[log(1-D(G(Z)))] </math>GminV(G)=Ez∼pz(z)[log(1−D(G(Z)))]

      • 其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z为噪声,G(z)为生成器由噪声生成的假数据,D(G(z))为判别器判别由生成器送来数据的结果。

        如果此时D(G(z)) = 0,则代表判别器成功判断出该数据是假数据,那么此时log(1-D(G(z)))就会等于0。如果此时D(G(z)) = 1,则代表判别器没能判断出该数据是假数据,那么此时log(1-D(G(z)))就会趋向于负无穷。所以我们训练生成器的目标就是尽量让判别器出错,这样该损失函数的值就能取得最小值。

        :在GAN原文中指出,早期训练log(1-D(G(z)))时,由于此时的生成器太弱,容易出现判别器赢得对抗,导致生成器无法进行训练优化的情况,在数学上的表现就是训练过程中梯度消失,所以我们在训练早期改用最大化log(D(G(z)))来训练生成器。

    • 针对判别器D,损失函数可以理解为:

      <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> max ⁡ D V ( D ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \max_DV(D)=E_{x∼p_{data}(x)}[logD(x)]+E_{z∼p_z(z)}[log(1-D(G(z)))] </math>DmaxV(D)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]

      • 其中,x是真实数据,D(x)是判别器判断真实数据的结果。

        上文中,我们已经解释加号后部分工作原理,即该部分越大,判别器越能判断出数据是否是假数据,所以该部分对于判别器来说应当取得最大值。接一下我们主要解释加号前部分的工作原理。

        此时,若D(x)=1,则判别器成功判别出该真实数据为真实数据,那么log(D(x))就会等于0。若此时D(x)=0,则代表判别器将真实数据判断为假数据,那么log(D(x))就会趋向于负无穷。所以,我们为了训练判别器D,我们就需要让判别器尽量正确判别出数据是否为真数据,即要让该公式取得最大值。

  • GAN训练过程的图解

    注:该图来源于GAN原文

    • 图中元素解析

      • 黑色虚线:真实数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x)
      • 绿色实线:生成器所拟合的数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)
      • 蓝色虚线:判别器的输出概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x ) D(x) </math>D(x)
        • 判别器最佳时,x为真数据时,D(x)=1,x为假数据,D(x)=0。生成器最佳时,D(x)=0.5即判别器只能乱猜数据是否为真。
      • 上方水平线:数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x的分布空间
      • 下方水平线:噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z的采样空间
      • 箭头:生成器G将噪声z映射到数据空间x的过程,即将 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( z ) p_z(z) </math>pz(z)映射为 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)的过程。
    • 图中各个阶段解读

      (a)初始阶段

      • 绿色实线与黑色虚线差别很大,即生成分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_g </math>pg与真实分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata差异过大,生成数据质量比较低。
      • 蓝色虚线在绿色实线低的位置高,在绿色实线高的位置低,代表判别器D能够初步区分出真实数据与生成数据。
      • 总结该阶段:此时生成器还没有能力生成足够欺骗判别器的数据,判别器已经有了初步的判别能力。

      (b)判别器优化

      • 从(a)到(b)的主要差异是蓝色虚线的变化,蓝色虚线从(a)阶段的有高低起伏趋向于稳定。
      • 判别器D趋向于最优解,即判别器在生成数据少的部分能够有效判断出为真实数据,在生成数据多的部分也能有效判断出假数据。
      • 总结该阶段:此时生成器被固定依然没有能力生成足够欺骗判别器的数据,而判别器的判别能力趋向最优解。

      (c )生成器优化

      • 从(b)到(c)的主要差异是绿色实线的变化和下方箭头的变化,绿色实线开始向黑色虚线趋近,箭头也从指向右侧变为指向中部。
      • 这两个变化的含义相同,由于生成器的能力不断优化,箭头的变化代表生成器正将噪声z映射到数据空间 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)的变化, <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)越来越接近真实数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x),这就造成了绿色实线不断靠近黑色虚线,即生成数据的数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)正逐渐趋近于真实数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x)。
      • 总结该阶段:生成器不断优化生成数据的能力,生成数据不断接近真实数据。

      (d)收敛阶段

      • 从(c)到(d)的主要差异是绿色实线与黑色虚线重合,蓝色虚线变为一条无变化的直线,箭头更加趋近于中部。

      • 这些变化的含义都表示此时生成器已经达到最优,箭头的变化代表生成器已经有能力将噪声z映射到数据空间 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_g </math>pg,并且该分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)与真实分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x)完全相同,这就造成绿色实线与黑色虚线完全重合,同时,蓝色虚线代表的判别器的输出概率公式为

        <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> D ( x ) = p d a t a ( x ) p d a t a ( x ) + p g ( x ) D(x)=\frac{p_{data}(x)}{p_{data}(x)+p_g(x)} </math>D(x)=pdata(x)+pg(x)pdata(x)

        由于 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_g(x) </math>pg(x)与真实分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x)完全相同,使得D(x)恒等于0.5,所以蓝色虚线就变为了一条直线。

      • 总结该阶段:此时生成器已经有了能力生成足够欺骗判别器的数据,判别器没有能力再判断数据的真伪,陷入只能瞎猜的境地。

    • 针对图解常见问题解答

      1. 该过程就是GAN训练的全过程吗?

      • 实际上,该图是用训练过程中几个理想片段来表达GAN的训练过程,其中(a)是训练一开始阶段,(d)是训练达到收敛的阶段,而(b)和(c)在实际训练中需要经过多许多次迭代,才能达到(d),即真实训练中,(a)需要经过很多(b)和(c)阶段才能达到(d)。

      2. 该图解表述为先训练D,而上文步骤中表述为先训练G,究竟是先训练哪一个?

      • 由上一问我们得知,在一次迭代中,生成器和训练器都要进行一次参数更新优化,其中一个网络的性能提升都会带动另外一个网络的性能提升,所以在完整的一个训练过程中一次细微迭代中究竟是先训练G还是D并不会对结果造成太大的影响。
  • GAN的训练算法步骤

    注:本章是对GAN原文所提及的算法做解释,可能与实际生成中算法有一定出入

    以下是GAN原本中提及的算法伪代码:

    python 复制代码
    for 训练迭代次数 do
        # 步骤1:优化判别器 D(k 次更新)
        for k steps do
            1. 从噪声先验中采样批噪声:{z^(1), ..., z^(m)} ∼ p_z(z)
            2. 从真实数据中采样批样本:{x^(1), ..., x^(m)} ∼ p_data(x)
            3. 更新判别器参数 θ_d,通过梯度上升:
        end for
    
        # 步骤2:优化生成器 G(1 次更新)
        1. 从噪声先验中采样批噪声:{z^(1), ..., z^(m)} ∼ p_z(z)
        2. 更新生成器参数 θ_g,通过梯度下降:
    end for

    可能第一次看不明白以上代码究竟是什么含义,接下来我们会做完整介绍。

    • 第一层循环是重复迭代次数个循环,这个循环等同于上文中重复多次(b)和(c)的过程。

    • 第二层循环是重复k次,k是一个超参数是由实验者人为指定的参数,该层循环等同上文中图解中的(b)过程,只是训练(b)时,我们需要重复k次。

    • 判别器的优化过程

      • 首先,我们从噪声先验 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( z ) p_z(z) </math>pz(z)中采样批噪声: <math xmlns="http://www.w3.org/1998/Math/MathML"> z ( 1 ) , . . . , z ( m ) z^{(1)}, ..., z^{(m)} </math>z(1),...,z(m)。

      • 然后,我们再从真实数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x)中采样批样本: <math xmlns="http://www.w3.org/1998/Math/MathML"> x ( 1 ) , . . . , x ( m ) x^{(1)}, ..., x^{(m)} </math>x(1),...,x(m)。

      • 将这一批的噪声与数据同时送入到以下损失函数中并计算梯度:

        判别器的损失函数构成:
        <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> max ⁡ D V ( D ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \max_DV(D)=E_{x∼p_{data}(x)}[logD(x)]+E_{z∼p_z(z)}[log(1-D(G(z)))] </math>DmaxV(D)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]

        所以,我们将真实数据送入到加号前一项,噪声数据送入到加号后一项,然后我们计算该批次梯度,梯度计算公式如下:
        <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∇ θ d 1 m ∑ i = 1 m [ l o g D ( x i ) + l o g ( 1 − D ( G ( z i ) ) ) ] ∇{θ_d}\frac{1}{m}\sum{i=1}^m[logD(x^i)+log(1-D(G(z^{i})))] </math>∇θdm1i=1∑m[logD(xi)+log(1−D(G(zi)))]

      • 接下来,我们做参数更新,由于我们要求的最大值,所以此时应该是梯度上升:

        <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ d + 1 = θ d + η ∇ θ d θ_{d+1} = θ_{d} + η∇_{θ_d} </math>θd+1=θd+η∇θd

        • 其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> θ d + 1 θ_{d+1} </math>θd+1和 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ d θ_{d} </math>θd是更新前后的参数, <math xmlns="http://www.w3.org/1998/Math/MathML"> η η </math>η是学习率, <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ θ d ∇_{θ_d} </math>∇θd为本次计算得到梯度。
      • 重复上述过程k次。

        注:k的选择需要保证判别器有一定的优化空间,又不至于优化太好,使得生成器的优化受限。

    • 生成器的优化过程:

      • 首先,我们也是从噪声先验 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( z ) p_z(z) </math>pz(z)中采样批噪声: <math xmlns="http://www.w3.org/1998/Math/MathML"> z ( 1 ) , . . . , z ( m ) z^{(1)}, ..., z^{(m)} </math>z(1),...,z(m)。

      • 再将这一批噪声送入到以下损失函数并计算梯度:

        生成器的损失函数构成:
        <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> min ⁡ G V ( G ) = E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( Z ) ) ) ] \min_GV(G)=E_{z∼p_z(z)}[log(1-D(G(Z)))] </math>GminV(G)=Ez∼pz(z)[log(1−D(G(Z)))]

        然后,我们计算该批次噪声的梯度:
        <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∇ θ g 1 m ∑ i = 1 m l o g ( 1 − D ( G ( z i ) ) ) ∇{θ_g}\frac{1}{m}\sum{i=1}^mlog(1-D(G(z^{i}))) </math>∇θgm1i=1∑mlog(1−D(G(zi)))

      • 接下来,做参数更新,由于我们此时要求最小值,所以应当使用梯度下降:

        <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ g + 1 = θ g − η ∇ θ g θ_{g+1} = θ_{g} - η∇_{θ_g} </math>θg+1=θg−η∇θg

        • 其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> θ g + 1 θ_{g+1} </math>θg+1和 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ g θ_{g} </math>θg是更新前后的参数, <math xmlns="http://www.w3.org/1998/Math/MathML"> η η </math>η是学习率, <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ θ g ∇_{θ_g} </math>∇θg为本次计算得到梯度。
    • 最后将以上判别器和生成器的优化过程重复迭代次数即可。

GAN代码实例演示------实现手写数字

数据集选择与加载:MNIST数据集

  • MNIST数据集是机器学习领域最经典的入门数据集之一,主要用于手写数字识别任务,该数据集的内容主要包括0到9的手写数字的灰度图片,每张图片大小为28x28像素。该数据集的数据量训练集有60,000张图片,测试集10,000张图片。本文只使用MNIST数据集的训练集部分。

    数据集加载与显示代码部分(本文最后设计有全部代码)

    注:其中有部分设计到超参数的设置,在一章会有说明

    python 复制代码
    # 加载MNIST数据集
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # 将像素值归一化到[-1, 1]
    ])
    
    dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # 可视化部分
    # 定义类别标签
    class_names = ['0','1','2','3','4','5','6','7','8','9']
    
    # 从训练集中随机取一个batch的图像
    images, labels = next(iter(loader))  # 获取一个batch(64张图)
    # iter()转换为迭代器,next()获取下一个批次的数据
    # images为一个形状为[64, 1, 28, 28]的张量 labels为[64]的张量
    
    # 显示图像函数
    def imshow(img):
        img = img.numpy()
        img = np.squeeze(img)  # 移除单通道维度 (1,28,28) -> (28,28)
        img = img * 0.5 + 0.5  # 反归一化到[0,1]
        plt.imshow(img, cmap='gray')
        plt.axis('off')
    
    # 画出一个4x8的网格(共32张图)
    plt.figure(figsize=(12, 6))
    for i in range(32):  # 显示前32张
        plt.subplot(4, 8, i+1)
        imshow(images[i])
        plt.title(class_names[labels[i].item()], fontsize=8)
    plt.tight_layout()
    plt.show()
  • 数据集图片演示

超参数设置与网络设计

  • 超参数设置一般放在代码的最前面,这一部分并非必需,也可以在后面的代码部分手动设置,这里只是习惯问题。

    python 复制代码
    # 设置超参数
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 训练设备
    print(device)
    lr = 0.0002 # 学习率
    z_dim = 64  # 噪声维度
    image_dim = 28 * 28 * 1  # MNIST图像维度
    batch_size = 64 # 批量大小
    epochs = 50 # 训练轮数
  • 网络设计:

    • 生成器G网络设计

      生成器采用三层感知机,其中激活函数选用LeakyReLU函数,斜率设置为0.1,最后的激活函数选择Tan函数。

      python 复制代码
      # 生成器网络设计
      class Generator(nn.Module):
          def __init__(self, z_dim, img_dim):
              super(Generator, self).__init__()
              self.model = nn.Sequential(
                  nn.Linear(z_dim, 128),
                  nn.LeakyReLU(0.1),
                  nn.Linear(128, 256),
                  nn.LeakyReLU(0.1),
                  nn.Linear(256, img_dim),
                  nn.Tanh()
              )
      
          def forward(self, x):
              return self.model(x)
    • 判别器D网络设计

      生成器采用三层感知机,其中激活函数选用LeakyReLU函数,斜率设置为0.1,最后的激活函数选择Sigmoid函数。

      python 复制代码
      # 判别器网络设计
      class Discriminator(nn.Module):
          def __init__(self, img_dim):
              super(Discriminator, self).__init__()
              self.model = nn.Sequential(
                  nn.Linear(img_dim, 256),
                  nn.LeakyReLU(0.1),
                  nn.Linear(256, 128),
                  nn.LeakyReLU(0.1),
                  nn.Linear(128, 1),
                  nn.Sigmoid()
              )
      
          def forward(self, x):
              return self.model(x)

网络实例化与循环训练

python 复制代码
# 循环轮次
for epoch in range(epochs):
    # 提取数据
    for i, (real_img, _) in enumerate(loader):
        # 1. 训练判别器 log(D(real)) + log(1 - D(G(z)))
        # 采样真实数据
        batch_size = real_img.shape[0]
        real_img = real_img.view(-1, image_dim).to(device)
        # 进行判别得到损失函数值
        disc_real = discriminator(real_img).flatten()
        real_labels = torch.ones_like(disc_real).to(device)
        loss_real = criterion(disc_real, real_labels)
        # 采样噪声数据
        noise = torch.randn(batch_size, z_dim).to(device)
        fake_img = generator(noise)
        # 进行判别得到损失函数值
        disc_fake = discriminator(fake_img.detach()).flatten()
        fake_labels = torch.zeros_like(disc_fake).to(device)
        loss_fake = criterion(disc_fake, fake_labels)
        # 将两者损失值求和除以二,以免其中一个损失值过大影响训练
        loss_disc = (loss_real + loss_fake) / 2
        # 更新参数
        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # 训练生成器 最小化 log(1 - D(G(z))) → 最大化 log(D(G(z)))
        # 将噪声数据采样进行判别
        output = discriminator(fake_img).flatten()
        # 计算损失函数值
        loss_gen = criterion(output, real_labels)
        # 更新参数
        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

打印结果与结果保存

python 复制代码
# 生成结果保存文件夹
os.makedirs("generated_images", exist_ok=True)
# 打印数据并保存图像数据
if i == 0:
    print(
        f"Epoch [{epoch+1}/{epochs}] "
        f"Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}"
    )
    with torch.no_grad():
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = generator(noise).reshape(-1, 1, 28, 28)
        img_grid = torchvision.utils.make_grid(fake, nrow=4, normalize=True)

        plt.figure(figsize=(8, 8))
        plt.imshow(img_grid.permute(1, 2, 0).cpu().numpy())
        plt.axis('off')
        plt.title(f"Epoch {epoch+1}")
        plt.savefig(f"generated_images/epoch{epoch+1}.png")
        #plt.show()
        plt.close()

全部代码

python 复制代码
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # 允许重复加载OpenMP库
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# 设置超参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 训练设备
print(device)
lr = 0.0002 # 学习率
z_dim = 64  # 噪声维度
image_dim = 28 * 28 * 1  # MNIST图像维度
batch_size = 64 # 批量大小
epochs = 50 # 训练轮数

os.makedirs("generated_images", exist_ok=True)

# 加载MNIST数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 将像素值归一化到[-1, 1]
])

dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 可视化部分
# 定义类别标签
class_names = ['0','1','2','3','4','5','6','7','8','9']

# 从训练集中随机取一个batch的图像
images, labels = next(iter(loader))  # 获取一个batch(64张图)
# iter()转换为迭代器,next()获取下一个批次的数据
# images为一个形状为[64, 1, 28, 28]的张量 labels为[64]的张量

# 显示图像函数
def imshow(img):
    img = img.numpy()
    img = np.squeeze(img)  # 移除单通道维度 (1,28,28) -> (28,28)
    img = img * 0.5 + 0.5  # 反归一化到[0,1]
    plt.imshow(img, cmap='gray')
    plt.axis('off')

# 画出一个4x8的网格(共32张图)
plt.figure(figsize=(12, 6))
for i in range(32):  # 显示前32张
    plt.subplot(4, 8, i+1)
    imshow(images[i])
    plt.title(class_names[labels[i].item()], fontsize=8)
plt.tight_layout()
plt.show()


# 生成器网络设计
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# 判别器网络设计
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# 网络实例化
generator = Generator(z_dim, image_dim).to(device)
discriminator = Discriminator(image_dim).to(device)

# 损失函数选择交叉熵损失函数
criterion = nn.BCELoss()

# 优化器选择Adam优化器
opt_gen = torch.optim.Adam(generator.parameters(), lr=lr)
opt_disc = torch.optim.Adam(discriminator.parameters(), lr=lr)

# 循环轮次
for epoch in range(epochs):
    # 提取数据
    for i, (real_img, _) in enumerate(loader):
        # 1. 训练判别器 log(D(real)) + log(1 - D(G(z)))
        # 采样真实数据
        batch_size = real_img.shape[0]
        real_img = real_img.view(-1, image_dim).to(device)
        # 进行判别得到损失函数值
        disc_real = discriminator(real_img).flatten()
        real_labels = torch.ones_like(disc_real).to(device)
        loss_real = criterion(disc_real, real_labels)
        # 采样噪声数据
        noise = torch.randn(batch_size, z_dim).to(device)
        fake_img = generator(noise)
        # 进行判别得到损失函数值
        disc_fake = discriminator(fake_img.detach()).flatten()
        fake_labels = torch.zeros_like(disc_fake).to(device)
        loss_fake = criterion(disc_fake, fake_labels)
        # 将两者损失值求和除以二,以免其中一个损失值过大影响训练
        loss_disc = (loss_real + loss_fake) / 2
        # 更新参数
        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # 训练生成器 最小化 log(1 - D(G(z))) → 最大化 log(D(G(z)))
        # 将噪声数据采样进行判别
        output = discriminator(fake_img).flatten()
        # 计算损失函数值
        loss_gen = criterion(output, real_labels)
        # 更新参数
        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        # 打印数据并保存图像数据
        if i == 0:
            print(
                f"Epoch [{epoch+1}/{epochs}] "
                f"Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}"
            )
            with torch.no_grad():
                noise = torch.randn(batch_size, z_dim).to(device)
                fake = generator(noise).reshape(-1, 1, 28, 28)
                img_grid = torchvision.utils.make_grid(fake, nrow=4, normalize=True)

                plt.figure(figsize=(8, 8))
                plt.imshow(img_grid.permute(1, 2, 0).cpu().numpy())
                plt.axis('off')
                plt.title(f"Epoch {epoch+1}")
                plt.savefig(f"generated_images/epoch{epoch+1}.png")
                #plt.show()
                plt.close()

结果展示

  • 显然该网络并没有训练到完全拟合,还可以继续增加训练的轮数使得网络训练更加趋近于拟合。
相关推荐
智驱力人工智能5 分钟前
无感通行与精准管控:AI单元楼安全方案的技术融合实践
人工智能·安全·智慧城市·智慧园区
Chrome深度玩家12 分钟前
谷歌翻译安卓版拍照翻译精准度与语音识别评测【轻松交流】
android·人工智能·语音识别
一点.点23 分钟前
李沐动手深度学习(pycharm中运行笔记)——04.数据预处理
pytorch·笔记·python·深度学习·pycharm·动手深度学习
机器之心25 分钟前
ICLR 2025 Oral|差分注意力机制引领变革,DIFF Transformer攻克长序列建模难题
人工智能
一点.点26 分钟前
李沐动手深度学习(pycharm中运行笔记)——07.自动求导
pytorch·笔记·python·深度学习·pycharm·动手深度学习
机器之心29 分钟前
字节Seed团队PHD-Transformer突破预训练长度扩展!破解KV缓存膨胀难题
人工智能
正宗咸豆花33 分钟前
开源提示词管理平台PromptMinder使用体验
人工智能·开源·prompt
Lilith的AI学习日记33 分钟前
AI提示词(Prompt)终极指南:从入门到精通(附实战案例)
大数据·人工智能·prompt·aigc·deepseek
夏之繁花34 分钟前
AI图像编辑器 Luminar Neo 便携版 Win1.24.0.14794
人工智能
L2ncE37 分钟前
【LanTech】DeepWiki 101 —— 以后不用自己写README了
人工智能·程序员·github