深度学习周报(12.22~12.28)

目录

摘要

Abstract

[1 生成对抗网络(GAN)](#1 生成对抗网络(GAN))

[1.1 核心思想](#1.1 核心思想)

[1.2 变体网络及应用](#1.2 变体网络及应用)

[1.3 代码梳理](#1.3 代码梳理)

[2 Wasserstein距离](#2 Wasserstein距离)

[2.1 数学表示](#2.1 数学表示)

[2.2 常见用法](#2.2 常见用法)

[2.2.1 1-Wasserstein 距离](#2.2.1 1-Wasserstein 距离)

[2.2.2 2-Wasserstein 距离](#2.2.2 2-Wasserstein 距离)

[2.2.3 对比](#2.2.3 对比)

[3 总结](#3 总结)


摘要

本周首先学习了生成对抗网络(GAN)的核心思想、关键组件以及训练流程,了解了它的相关变体网络(如 DCGAN)及应用领域,并用代码梳理了其关键组件的结构与训练过程;其次,学习了Wasserstein距离的数学表示与常见用法,主要了解了一阶和二阶 Wasserstein 距离及它们之间的区别。

Abstract

This week, I first studied the core concepts, key components, and training process of Generative Adversarial Networks , gained an understanding of their related variant networks, such as DCGAN, and application areas, and used code to outline the structure of its key components and training process. Secondly, I learned about the mathematical representation and common applications of the Wasserstein distance, focusing primarily on the first-order and second-order Wasserstein distances and the differences between them.

1 生成对抗网络(GAN)

GAN(Generative Adversarial Networks,生成对抗网络)由Ian Goodfellow等人于2014年提出,它的灵感源于博弈论中的零和博弈**,**通过两个神经网络相互博弈的方式来学习数据分布。

p.s. 零和博弈是博弈论中的一个核心模型,即在博弈中,所有参与者的收益总和始终为零,故不存在双赢或共赢的局面,因为利益的总量固定,一方多得,其他方必然少得。

1.1 核心思想

GAN 的核心思想就是对抗训练,它主要包括两个关键组件,生成器与判别器。前者负责根据随机噪声向量生成假数据,最终达到欺骗判别器,让其认为生成的数据是真实数据的目的;后者则负责判断数据真伪。在训练过程中,生成器生成的数据会越来越逼真,而判别器对于数据真伪的判别能力也会越来越强。这个训练过程可以形式化为一个最小最大博弈问题:

p.s. 最小最大博弈问题是指在两人零和博弈中,玩家 1 采用最大化自己最小收益的策略,而玩家 2 采用采用最小化对方最大收益的策略,然后在最优策略下,双方能达到一个均衡点。

公式中, 代表真实数据分布,而 表示噪声分布; 代表判别器认为样本 x 是真实的概率; 代表生成器从噪声 z 生成的数据。

在每次迭代中,其训练的具体步骤如下:

第一,固定生成器,训练判别器。先分别从真实分布与噪声分布中采样,根据从噪声分布的采样样本生成假样本,然后更新判别器梯度(公式如下),以达到最大化目标:

第二,固定判别器,训练生成器。从噪声分布中采样样本,更新生成器梯度(公式如下),以达到最小化目标:

GAN 面临的挑战主要是训练不稳定,即损失值剧烈振荡、难以收敛,这是其核心思想带来的,其生成器和判别器的优化目标相互冲突,当一方过于强大时,另一方有效度就会降低。另一个挑战则是模式崩溃,即指生成器只产生数据分布中有限几种模式的样本,缺乏多样性。这是因为生成器找到了一个能欺骗判别器的"捷径",即不断重复相似的样本。此外,高质量 GAN 训练需要大量计算资源,在生成高分辨率图像时尤为明显。

1.2 变体网络及应用

GAN 包括 DCGAN、WGAN、CycleGAN 与 Conditional GAN 等变体网络。

DCGAN(深度卷积生成对抗网络)将 CNN 引入 GAN 框架,使用步幅卷积替代池化层,在判别器中使用步幅卷积进行下采样,在生成器中使用转置卷积进行上采样。它能够学习到有意义的特征表示,同时还具有一定的的算术性质;

WGAN 从根本上改变了 GAN 的训练目标,用 Wasserstein 距离(后文学习)替代原始的 JS 散度,解决了原始 GAN 训练中常见的梯度消失问题。它的关键创新是移除了判别器最后的 Sigmoid 激活函数,输出一个未限制的分数而不是概率;

CycleGAN(循环一致生成对抗网络)则解决了传统方法中严格配对的训练数据实际难以获得的问题,它包含两个生成器和两个判别器,形成两个对称的转换循环,从而不需要配对数据,只需要两个域的图像集合,极大地扩展了应用范围;

Conditional GAN(条件生成对抗网络)则通过将额外信息(如类别标签、文本描述等)同时输入生成器和判别器,实现对生成过程的控制。在cGAN中,生成器接收噪声和条件信息,生成符合条件的样本;判别器则同时判断样本的真实性和条件匹配性。这个框架打破了GAN只能随机生成的限制,为用户提供了可控的生成能力,被扩展到了许多具体应用上。

GAN 主要应用于图像生成、编辑与转换,数据增强,隐私保护,音乐生成等许多领域。

1.3 代码梳理

生成器:

python 复制代码
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        # latent_dim 潜在空间维度(噪声向量大小),img_shape 输出图像形状
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),     # 全连接层
            nn.LeakyReLU(0.2),              # 带泄露的ReLU,负斜率0.2
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),            # 一维批量归一化,加速收敛
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh()                       # 输出激活函数 
        )
        self.img_shape = img_shape          # 保存图像形状
    
    def forward(self, z):
        img = self.model(z)                 # 生成展平的"假图像"
        return img.view(img.size(0), *self.img_shape)     # 恢复形状

LeakyReLU 是为了解决ReLU的神经元死亡问题所被提出的一个改进。它与标准 ReLU 将所有负值直接置零不同,LeakyReLU 为负输入值赋予一个小的非零斜率**,**从而保持梯度流动,对于对抗训练的稳定性至关重要;

噪声通过多层非线性变换转变为展平的新图像。

判别器:

python 复制代码
# 判别器网络
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()
        self.model = nn.Sequential(
            # torch.prod 计算图像所有维度的乘积
            nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),   
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()        # 将输出映射到[0,1]概率范围
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)      # 展平为(batch_size,像素),-1表示自动计算剩余维度大小
        validity = self.model(img_flat)
        return validity

判别器网络通常不使用批量归一化,是因为判别器的判断基于混合统计,而非纯粹的内容,如果使用可能导致真实数据通过梯度影响生成器。

训练过程:

python 复制代码
# 训练循环
#(生成器模型实例,判别器模型实例,数据加载器,训练轮数)
def train_gan(generator, discriminator, dataloader, epochs):
    adversarial_loss = nn.BCELoss()         # 对抗损失函数:二元交叉熵
    # 分别为两个网络创建优化器
    optimizer_G = torch.optim.Adam(generator.parameters())         
    optimizer_D = torch.optim.Adam(discriminator.parameters())
    
    for epoch in range(epochs):
        for i, (imgs, _) in enumerate(dataloader):
            
            real = torch.ones(imgs.size(0), 1)         # 真实图像标签 全1
            fake = torch.zeros(imgs.size(0), 1)        # 生成图像标签 全0
            
            # 训练判别器
            optimizer_D.zero_grad()           # 清空判别器梯度,防止累积
            real_loss = adversarial_loss(discriminator(imgs), real)     # 计算真实图像的损失,希望判别器对真实图像输出接近1
            z = torch.randn(imgs.size(0), latent_dim)       # 采样噪声
            fake_imgs = generator(z)                        # 生成假图像
            fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake)    # 计算生成图像的损失,希望判别器对其输出接近0
            d_loss = (real_loss + fake_loss) / 2          # 总损失
            d_loss.backward()               # 反向传播
            optimizer_D.step()              # 更新判别器参数
            
            # 训练生成器
            optimizer_G.zero_grad()
            z = torch.randn(imgs.size(0), latent_dim)
            gen_imgs = generator(z)
            g_loss = adversarial_loss(discriminator(gen_imgs), real)      # 使用real标签!
            g_loss.backward()
            optimizer_G.step()

训练判别器时,真实数据样本已经通过 dataloader 从数据集中采样得到,无需多余步骤再采样;训练生成器时,需要为生成的图像打上真实标签(全1),因为生成器希望判别器输出1或接近1的值。

2 Wasserstein距离

Wasserstein距离,也称为推土机距离,是一种度量两个概率分布之间距离的方法。它起源于最优传输理论,由俄罗斯数学家Leonid Vaseršteĭn引入,也因此得名。基于前面对 Monge 问题与 Kantorovich 松弛的学习,Wasserstein距离指得就是最低的总运输成本。

2.1 数学表示

给定两个分别定义在空间 X 和 Y 上的概率分布 P 和 Q,以及一个连续成本函数 c(x,y)表示将单位质量从位置 x 移动到 y 的成本。

对于这两个概率分布,p 阶 Wasserstein 距离(也可表示为 p-Wasserstein 距离)的数学公式如下:

其中:

表示所有以 P 和 Q 为边缘分布的联合分布 的集合;

表示对于一个给定的运输计划 ,计算其总成本;

表示对所有可能的运输计划取下确界(即最小值),也对应最优的传输计划的成本;

**,**最后开 p 次方根,使其量纲与原始距离一致,从而满足度量的齐次性。

2.2 常见用法

通常,当 X 和 Y 是欧几里得空间时,代价函数是两点间距离的 p 次方,即 。由此,p 能够决定如何惩罚长距离的运输,当 p 值为 1 时,成本与距离相等,总成本即为运输质量与距离的乘积的简单求和;而当 p 值大于 1 时,长距离的代价会被指数级放大,会更加惩罚长距离的运输。在实际应用中,p 值为 1 和 2 最为常见,且两者的性质和适用场景有着微妙而重要的差异。

2.2.1 1-Wasserstein 距离

1-Wasserstein 距离(p = 1时)是最直观且广泛使用的形式,对应最优运输理论中的原始问题。它也被称为 Earth Mover's Distance(EMD),其总成本就像前面说的是运输质量与距离两者乘积的和,它的公式如下:

这个距离衡量的是将分布 P 转换为分布 Q 所需的最小平均移动距离。

在数学性质上,1-Wasserstein 距离具有对偶表示:

其中:

整个式子实际上等价于前面学习的 Kantorovich 对偶问题,前面学习的为一般形式,本文中则更适用于1-Wasserstein 距离;

表示函数 f 在分布 P 下的期望值与在分布 Q 下的期望值之差;

表示函数 f 的 Lipschitz 常数不超过 1,意味着函数 f 是 1-Lipschitz 连续的,这是核心约束条件。

p.s. Lipschitz 常数衡量了函数变化的最大速率,如果一个函数 f 是 1-Lipschitz 连续的,则对于任意两点 x 与 y,有

这个对偶形式在理论分析和实际计算中都极为重要,特别是在 Wasserstein 生成对抗网络中,判别器被限制为Lipschitz连续函数,其目标函数直接来源于此对偶形式。

从几何直观来看,1-Wasserstein 距离对离群点相对敏感但不过分惩罚,在考虑整体运输效率的同时能够兼顾个别较远的点。

2.2.2 2-Wasserstein 距离

2-Wasserstein 距离在数学上更为光滑且具有更丰富的几何结构,其公式为:

它最小化的是平方距离的期望,放大了远距离点的贡献,使得其对长距离运输更加敏感。

从几何视角看,2-Wasserstein 距离在概率分布空间上诱导了一个黎曼几何结构,这个空间被称为Wasserstein 空间或最优运输空间。在这个空间中,概率分布之间的测地线对应于质量的最优运输路径,为研究分布之间的插值和演化提供了强有力的框架。

应用场景上,2-Wasserstein 距离在图像处理中常用于颜色直方图匹配,在计算流体力学中用于比较密度场,在统计学中用于定义分布的质心(barycenter)。同时,由于其对异常值更加敏感的特性,它在需要强调分布尾部差异的场景中特别有用。

2.2.3 对比

一阶和二阶 Wasserstein 距离的核心差异源于它们的惩罚函数性质。一阶距离使用线性惩罚,运输成本与距离成正比,这使其对分布的细微变化相对稳健,更适合存在噪声或异常值的场景;而二阶距离使用平方惩罚,更加强调长距离运输的成本,因此对分布的尾部行为更加敏感,能够更好地区分具有相似主体但尾部不同的分布。也因此,它们在不同领域的适用性不同,一阶距离广泛用于稳健统计和生成模型训练,二阶距离则在需要精确匹配的物理建模和几何分析中更有优势。

在计算复杂度上,两者都面临维度灾难的挑战,但具体算法有所不同。一阶距离的线性规划方法在低维离散情况下效率较高,而对偶形式在高维连续情况下更易处理;二阶距离由于平方项的存在,其计算常涉及更复杂的优化问题,但对于高斯分布等特定家族有解析解。

在实际应用中,选择一阶还是二阶 Wasserstein 距离取决于具体问题的需求。如果目标是稳健地比较分布而不希望被少数远离主体的点过度影响,一阶距离是合适的选择;如果需要精确度量分布间的几何差异,特别是关注分布的扩散程度和形状变化,二阶距离通常能提供更多信息。此外在某些情况下,可以使用更一般的 p 阶 Wasserstein 距离,其中 p 的选择成为调节敏感性的连续参数。

3 总结

本周主要学习了生成对抗网络和 Wasserstein 距离的相关知识。GAN 主要是两个神经网络互相博弈的想法感觉可以扩展学习和训练思路; Wasserstein 距离学习的时候主要还是会涉及到部分没怎么接触过的概念,比如Lipschitz 常数,学起来比较慢,不好理解。

相关推荐
颜颜yan_2 小时前
从 0 到 1 搭建一个「塔罗感」AI 智能体 —— 微光运势实践记录
人工智能·ai·智能体·modelengine
灰灰勇闯IT2 小时前
鸿蒙 ArkUI 声明式 UI 核心:状态管理(@State/@Prop/@Link)实战解析
人工智能·计算机视觉·harmonyos
质变科技AI就绪数据云2 小时前
AI记忆架构三大路线
人工智能·ai·ai agent·智能体·记忆体
WBluuue2 小时前
Codeforces Global 31 Div1+2(ABCD)
c++·算法
智算菩萨2 小时前
【Python机器学习】回归模型评估指标深度解析:MAE、MSE、RMSE与R²的理论与实践
python·机器学习·回归
dajun1811234562 小时前
简单快速跨职能流程图在线设计工具 中文
人工智能·架构·流程图
瀚岳-诸葛弩2 小时前
ViT(Vision Transformer)的理解、实现与应用拓展的思考
人工智能·深度学习·transformer
IT_陈寒2 小时前
Vite 3.0 实战:5个优化技巧让你的开发效率提升50%
前端·人工智能·后端
Mintopia2 小时前
🤖✨ 生成式应用架构师的修炼手册
人工智能·llm·aigc