相对自适应判别器(Relative Adaptive Discriminator, RAD)

文献[1]提出了一种改进的生成对抗网络(GAN)框架,旨在解决传统GAN和进化GAN(EGAN)中常见的训练不稳定、模式崩溃、梯度消失和高计算成本等问题。其中有一个创新点指出传统判别器基于绝对分数判断真假样本,容易导致训练不平衡。RAD基于相对logit差异来评估真实样本和生成样本。

传统GAN判别器的损失:

这篇文章提出的:

其中:

生成器的损失:

在这里回顾一下GAN的设计思路,这涉及了一个博弈论的思想,简单说就是让生成器尽力去欺骗判别器(xfake→xreal),而判别器要尽力去区分这个数据是否来自生成器(D(xreal)→1、D(xfake)→0)。

再来看这篇论文的改进思路,其核心思想是:比较同一小批中的真实样本和生成样本,并从它们的差异(相对logits)中学习,而不是绝对概率。

他给了伪代码,其中他设计了两个生成样本,但并未见其如何使用。

在这里我用Pytorch实现一下:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
python 复制代码
def rad_generator_loss(real_logits, fake_logits, eps=1e-8):
    """
    相对自适应生成器损失
    参数:
        real_logits: 判别器对真实样本的输出, shape (batch_size,)
        fake_logits: 判别器对生成样本的输出, shape (batch_size,)
    返回:
        生成器损失标量
    """
    # 生成器希望 fake_logits - real_logits 尽量大
    diff_fake_real = fake_logits - real_logits
    loss_gen = -torch.log(torch.sigmoid(diff_fake_real) + eps).mean()
    return loss_gen
python 复制代码
def rad_discriminator_loss(real_logits, fake_logits, eps=1e-8):
    """
    相对自适应判别器损失 (RAD)
    参数:
        real_logits: 判别器对真实样本的输出, shape (batch_size,)
        fake_logits: 判别器对生成样本的输出, shape (batch_size,)
        eps: 防止 log 数值不稳定
    返回:
        判别器损失标量
    """
    # 保持广播机制,对每个样本对计算差值
    # 第一项: log σ(real - fake)
    diff_real_fake = real_logits - fake_logits
    loss_real = -torch.log(torch.sigmoid(diff_real_fake) + eps).mean()
    # 第二项: log σ(fake - real)
    diff_fake_real = fake_logits - real_logits
    loss_fake = -torch.log(torch.sigmoid(diff_fake_real) + eps).mean()
    
    return loss_real + loss_fake

参考文献:

1\]Atifa Rafique, Xue Yu, Kashif Iqbal, Mujahid Tabassum, Amir Hussain, Khursheed Aurangzeb, MD-EGAN: Evolutionary GAN with dynamic latent sampling and relative adaptive discriminator for improved performance,Neurocomputing,Volume664,2026,131951, https://doi.org/10.1016/j.ne ucom.2025.131951. \*\*\*\*\*\*\*\*\*\*\*\*\*\*\*END\*\*\*\*\*\*\*\*\*\*\*\*\*\*

相关推荐
技术小黑18 小时前
CNN算法实战系列02 | ResNet50V2算法实战与解析
pytorch·深度学习·算法·cnn
Febu421 小时前
Nano-vLLM-MS
pytorch·深度学习·transformer
盼小辉丶1 天前
PyTorch强化学习实战——使用交叉熵方法解决 FrozenLake 环境
人工智能·pytorch·python·强化学习
郝学胜-神的一滴2 天前
反向传播:神经网络的「灵魂」修炼法则
人工智能·pytorch·深度学习·神经网络·机器学习·数据挖掘
Jmayday2 天前
Pytorch:问题整理
人工智能·pytorch·python
盼小辉丶3 天前
PyTorch强化学习实战(6)——交叉熵方法详解与实现
人工智能·pytorch·python·强化学习
ZhengEnCi3 天前
06-多头注意力机制 🎯
人工智能·pytorch·python
赵优秀一一3 天前
AI入门学习
人工智能·pytorch·深度学习
盼小辉丶3 天前
PyTorch强化学习实战(5)——PyTorch Ignite 事件驱动机制与实践
人工智能·pytorch·python·强化学习