深度学习周报(12.22~12.28)

目录

摘要

Abstract

[1 生成对抗网络(GAN)](#1 生成对抗网络(GAN))

[1.1 核心思想](#1.1 核心思想)

[1.2 变体网络及应用](#1.2 变体网络及应用)

[1.3 代码梳理](#1.3 代码梳理)

[2 Wasserstein距离](#2 Wasserstein距离)

[2.1 数学表示](#2.1 数学表示)

[2.2 常见用法](#2.2 常见用法)

[2.2.1 1-Wasserstein 距离](#2.2.1 1-Wasserstein 距离)

[2.2.2 2-Wasserstein 距离](#2.2.2 2-Wasserstein 距离)

[2.2.3 对比](#2.2.3 对比)

[3 总结](#3 总结)


摘要

本周首先学习了生成对抗网络(GAN)的核心思想、关键组件以及训练流程,了解了它的相关变体网络(如 DCGAN)及应用领域,并用代码梳理了其关键组件的结构与训练过程;其次,学习了Wasserstein距离的数学表示与常见用法,主要了解了一阶和二阶 Wasserstein 距离及它们之间的区别。

Abstract

This week, I first studied the core concepts, key components, and training process of Generative Adversarial Networks , gained an understanding of their related variant networks, such as DCGAN, and application areas, and used code to outline the structure of its key components and training process. Secondly, I learned about the mathematical representation and common applications of the Wasserstein distance, focusing primarily on the first-order and second-order Wasserstein distances and the differences between them.

1 生成对抗网络(GAN)

GAN(Generative Adversarial Networks,生成对抗网络)由Ian Goodfellow等人于2014年提出,它的灵感源于博弈论中的零和博弈**,**通过两个神经网络相互博弈的方式来学习数据分布。

p.s. 零和博弈是博弈论中的一个核心模型,即在博弈中,所有参与者的收益总和始终为零,故不存在双赢或共赢的局面,因为利益的总量固定,一方多得,其他方必然少得。

1.1 核心思想

GAN 的核心思想就是对抗训练,它主要包括两个关键组件,生成器与判别器。前者负责根据随机噪声向量生成假数据,最终达到欺骗判别器,让其认为生成的数据是真实数据的目的;后者则负责判断数据真伪。在训练过程中,生成器生成的数据会越来越逼真,而判别器对于数据真伪的判别能力也会越来越强。这个训练过程可以形式化为一个最小最大博弈问题:

p.s. 最小最大博弈问题是指在两人零和博弈中,玩家 1 采用最大化自己最小收益的策略,而玩家 2 采用采用最小化对方最大收益的策略,然后在最优策略下,双方能达到一个均衡点。

公式中, 代表真实数据分布,而 表示噪声分布; 代表判别器认为样本 x 是真实的概率; 代表生成器从噪声 z 生成的数据。

在每次迭代中,其训练的具体步骤如下:

第一,固定生成器,训练判别器。先分别从真实分布与噪声分布中采样,根据从噪声分布的采样样本生成假样本,然后更新判别器梯度(公式如下),以达到最大化目标:

第二,固定判别器,训练生成器。从噪声分布中采样样本,更新生成器梯度(公式如下),以达到最小化目标:

GAN 面临的挑战主要是训练不稳定,即损失值剧烈振荡、难以收敛,这是其核心思想带来的,其生成器和判别器的优化目标相互冲突,当一方过于强大时,另一方有效度就会降低。另一个挑战则是模式崩溃,即指生成器只产生数据分布中有限几种模式的样本,缺乏多样性。这是因为生成器找到了一个能欺骗判别器的"捷径",即不断重复相似的样本。此外,高质量 GAN 训练需要大量计算资源,在生成高分辨率图像时尤为明显。

1.2 变体网络及应用

GAN 包括 DCGAN、WGAN、CycleGAN 与 Conditional GAN 等变体网络。

DCGAN(深度卷积生成对抗网络)将 CNN 引入 GAN 框架,使用步幅卷积替代池化层,在判别器中使用步幅卷积进行下采样,在生成器中使用转置卷积进行上采样。它能够学习到有意义的特征表示,同时还具有一定的的算术性质;

WGAN 从根本上改变了 GAN 的训练目标,用 Wasserstein 距离(后文学习)替代原始的 JS 散度,解决了原始 GAN 训练中常见的梯度消失问题。它的关键创新是移除了判别器最后的 Sigmoid 激活函数,输出一个未限制的分数而不是概率;

CycleGAN(循环一致生成对抗网络)则解决了传统方法中严格配对的训练数据实际难以获得的问题,它包含两个生成器和两个判别器,形成两个对称的转换循环,从而不需要配对数据,只需要两个域的图像集合,极大地扩展了应用范围;

Conditional GAN(条件生成对抗网络)则通过将额外信息(如类别标签、文本描述等)同时输入生成器和判别器,实现对生成过程的控制。在cGAN中,生成器接收噪声和条件信息,生成符合条件的样本;判别器则同时判断样本的真实性和条件匹配性。这个框架打破了GAN只能随机生成的限制,为用户提供了可控的生成能力,被扩展到了许多具体应用上。

GAN 主要应用于图像生成、编辑与转换,数据增强,隐私保护,音乐生成等许多领域。

1.3 代码梳理

生成器:

python 复制代码
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        # latent_dim 潜在空间维度(噪声向量大小),img_shape 输出图像形状
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),     # 全连接层
            nn.LeakyReLU(0.2),              # 带泄露的ReLU,负斜率0.2
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),            # 一维批量归一化,加速收敛
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh()                       # 输出激活函数 
        )
        self.img_shape = img_shape          # 保存图像形状
    
    def forward(self, z):
        img = self.model(z)                 # 生成展平的"假图像"
        return img.view(img.size(0), *self.img_shape)     # 恢复形状

LeakyReLU 是为了解决ReLU的神经元死亡问题所被提出的一个改进。它与标准 ReLU 将所有负值直接置零不同,LeakyReLU 为负输入值赋予一个小的非零斜率**,**从而保持梯度流动,对于对抗训练的稳定性至关重要;

噪声通过多层非线性变换转变为展平的新图像。

判别器:

python 复制代码
# 判别器网络
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()
        self.model = nn.Sequential(
            # torch.prod 计算图像所有维度的乘积
            nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),   
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()        # 将输出映射到[0,1]概率范围
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)      # 展平为(batch_size,像素),-1表示自动计算剩余维度大小
        validity = self.model(img_flat)
        return validity

判别器网络通常不使用批量归一化,是因为判别器的判断基于混合统计,而非纯粹的内容,如果使用可能导致真实数据通过梯度影响生成器。

训练过程:

python 复制代码
# 训练循环
#(生成器模型实例,判别器模型实例,数据加载器,训练轮数)
def train_gan(generator, discriminator, dataloader, epochs):
    adversarial_loss = nn.BCELoss()         # 对抗损失函数:二元交叉熵
    # 分别为两个网络创建优化器
    optimizer_G = torch.optim.Adam(generator.parameters())         
    optimizer_D = torch.optim.Adam(discriminator.parameters())
    
    for epoch in range(epochs):
        for i, (imgs, _) in enumerate(dataloader):
            
            real = torch.ones(imgs.size(0), 1)         # 真实图像标签 全1
            fake = torch.zeros(imgs.size(0), 1)        # 生成图像标签 全0
            
            # 训练判别器
            optimizer_D.zero_grad()           # 清空判别器梯度,防止累积
            real_loss = adversarial_loss(discriminator(imgs), real)     # 计算真实图像的损失,希望判别器对真实图像输出接近1
            z = torch.randn(imgs.size(0), latent_dim)       # 采样噪声
            fake_imgs = generator(z)                        # 生成假图像
            fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake)    # 计算生成图像的损失,希望判别器对其输出接近0
            d_loss = (real_loss + fake_loss) / 2          # 总损失
            d_loss.backward()               # 反向传播
            optimizer_D.step()              # 更新判别器参数
            
            # 训练生成器
            optimizer_G.zero_grad()
            z = torch.randn(imgs.size(0), latent_dim)
            gen_imgs = generator(z)
            g_loss = adversarial_loss(discriminator(gen_imgs), real)      # 使用real标签!
            g_loss.backward()
            optimizer_G.step()

训练判别器时,真实数据样本已经通过 dataloader 从数据集中采样得到,无需多余步骤再采样;训练生成器时,需要为生成的图像打上真实标签(全1),因为生成器希望判别器输出1或接近1的值。

2 Wasserstein距离

Wasserstein距离,也称为推土机距离,是一种度量两个概率分布之间距离的方法。它起源于最优传输理论,由俄罗斯数学家Leonid Vaseršteĭn引入,也因此得名。基于前面对 Monge 问题与 Kantorovich 松弛的学习,Wasserstein距离指得就是最低的总运输成本。

2.1 数学表示

给定两个分别定义在空间 X 和 Y 上的概率分布 P 和 Q,以及一个连续成本函数 c(x,y)表示将单位质量从位置 x 移动到 y 的成本。

对于这两个概率分布,p 阶 Wasserstein 距离(也可表示为 p-Wasserstein 距离)的数学公式如下:

其中:

表示所有以 P 和 Q 为边缘分布的联合分布 的集合;

表示对于一个给定的运输计划 ,计算其总成本;

表示对所有可能的运输计划取下确界(即最小值),也对应最优的传输计划的成本;

**,**最后开 p 次方根,使其量纲与原始距离一致,从而满足度量的齐次性。

2.2 常见用法

通常,当 X 和 Y 是欧几里得空间时,代价函数是两点间距离的 p 次方,即 。由此,p 能够决定如何惩罚长距离的运输,当 p 值为 1 时,成本与距离相等,总成本即为运输质量与距离的乘积的简单求和;而当 p 值大于 1 时,长距离的代价会被指数级放大,会更加惩罚长距离的运输。在实际应用中,p 值为 1 和 2 最为常见,且两者的性质和适用场景有着微妙而重要的差异。

2.2.1 1-Wasserstein 距离

1-Wasserstein 距离(p = 1时)是最直观且广泛使用的形式,对应最优运输理论中的原始问题。它也被称为 Earth Mover's Distance(EMD),其总成本就像前面说的是运输质量与距离两者乘积的和,它的公式如下:

这个距离衡量的是将分布 P 转换为分布 Q 所需的最小平均移动距离。

在数学性质上,1-Wasserstein 距离具有对偶表示:

其中:

整个式子实际上等价于前面学习的 Kantorovich 对偶问题,前面学习的为一般形式,本文中则更适用于1-Wasserstein 距离;

表示函数 f 在分布 P 下的期望值与在分布 Q 下的期望值之差;

表示函数 f 的 Lipschitz 常数不超过 1,意味着函数 f 是 1-Lipschitz 连续的,这是核心约束条件。

p.s. Lipschitz 常数衡量了函数变化的最大速率,如果一个函数 f 是 1-Lipschitz 连续的,则对于任意两点 x 与 y,有

这个对偶形式在理论分析和实际计算中都极为重要,特别是在 Wasserstein 生成对抗网络中,判别器被限制为Lipschitz连续函数,其目标函数直接来源于此对偶形式。

从几何直观来看,1-Wasserstein 距离对离群点相对敏感但不过分惩罚,在考虑整体运输效率的同时能够兼顾个别较远的点。

2.2.2 2-Wasserstein 距离

2-Wasserstein 距离在数学上更为光滑且具有更丰富的几何结构,其公式为:

它最小化的是平方距离的期望,放大了远距离点的贡献,使得其对长距离运输更加敏感。

从几何视角看,2-Wasserstein 距离在概率分布空间上诱导了一个黎曼几何结构,这个空间被称为Wasserstein 空间或最优运输空间。在这个空间中,概率分布之间的测地线对应于质量的最优运输路径,为研究分布之间的插值和演化提供了强有力的框架。

应用场景上,2-Wasserstein 距离在图像处理中常用于颜色直方图匹配,在计算流体力学中用于比较密度场,在统计学中用于定义分布的质心(barycenter)。同时,由于其对异常值更加敏感的特性,它在需要强调分布尾部差异的场景中特别有用。

2.2.3 对比

一阶和二阶 Wasserstein 距离的核心差异源于它们的惩罚函数性质。一阶距离使用线性惩罚,运输成本与距离成正比,这使其对分布的细微变化相对稳健,更适合存在噪声或异常值的场景;而二阶距离使用平方惩罚,更加强调长距离运输的成本,因此对分布的尾部行为更加敏感,能够更好地区分具有相似主体但尾部不同的分布。也因此,它们在不同领域的适用性不同,一阶距离广泛用于稳健统计和生成模型训练,二阶距离则在需要精确匹配的物理建模和几何分析中更有优势。

在计算复杂度上,两者都面临维度灾难的挑战,但具体算法有所不同。一阶距离的线性规划方法在低维离散情况下效率较高,而对偶形式在高维连续情况下更易处理;二阶距离由于平方项的存在,其计算常涉及更复杂的优化问题,但对于高斯分布等特定家族有解析解。

在实际应用中,选择一阶还是二阶 Wasserstein 距离取决于具体问题的需求。如果目标是稳健地比较分布而不希望被少数远离主体的点过度影响,一阶距离是合适的选择;如果需要精确度量分布间的几何差异,特别是关注分布的扩散程度和形状变化,二阶距离通常能提供更多信息。此外在某些情况下,可以使用更一般的 p 阶 Wasserstein 距离,其中 p 的选择成为调节敏感性的连续参数。

3 总结

本周主要学习了生成对抗网络和 Wasserstein 距离的相关知识。GAN 主要是两个神经网络互相博弈的想法感觉可以扩展学习和训练思路; Wasserstein 距离学习的时候主要还是会涉及到部分没怎么接触过的概念,比如Lipschitz 常数,学起来比较慢,不好理解。

相关推荐
NAGNIP8 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab9 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab9 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP13 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年13 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼13 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS13 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区14 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈14 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang15 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx