InfoGAN:通过信息最大化生成对抗网络进行可解释的表示学习

系列文章目录

Conditional Generative Adversarial Nets

cGANs with Projection Discriminator

Conditional Image Synthesis with Auxiliary Classifier GANs

InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets


InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets

InfoGAN是生成对抗网络的信息论扩展,能够以完全无监督的方式学习解开的表示。InfoGAN 是一个生成对抗网络,它最大化了一小部分潜在变量和观察之间的互信息。

InfoGAN 在 GAN 的基础上只增加了可以忽略不计的计算成本,并且易于训练。使用互信息诱导表示的核心思想可以应用于 VAE等其他方法。实验结果表示:互信息成本增强的生成模型可能是学习解缠结表示的有效方法

1. 生成对抗网络

生成对抗网络(GAN)是一种使用最大最小化博弈训练深度生成模型的框架。目标是学习与真实数据分布 Pdata(x) 匹配的生成器分布 PG(x)。 GAN 没有尝试显式地为数据分布中的每个 x 分配概率,而是学习生成器网络 G,该生成器网络 G 通过将噪声变量 z ∼ Pnoise(z) 转换为样本 G(z),从生成器分布 PG 生成样本。该生成器通过与对抗性判别器网络 D 进行训练,该网络旨在区分真实数据分布 Pdata 和生成器分布 PG 中的样本。因此,对于给定的生成器,最佳判别器是 D(x) = Pdata(x)/(Pdata(x) + PG(x))。更正式地说,极小极大博弈由以下表达式给出:

2.用于引导潜在代码的互信息

GAN 公式使用简单分解的连续输入噪声向量 z,同时对生成器使用该噪声的方式没有限制。因此,噪声可能会被生成器以高度纠缠的方式使用,导致 z 的各个维度与数据的语义特征不对应

然而,许多领域自然地分解为一组语义上有意义的变化因素。例如,当从 MNIST 数据集生成图像时,**如果模型自动选择分配一个离散随机变量来表示数字 (0-9) 的数字身份,并选择两个数字角度与粗细作为连续变量来表示。**在这种情况下,这些属性既是独立的又是显着的,如果能够在没有任何监督的情况下恢复这些概念,通过简单地指定 MNIST 数字是由独立的 1-of-10 变量和两个独立的连续变量生成的,那将会非常有用变量。

论文建议将输入噪声向量分解为两部分,而不是使用单个非结构化噪声向量:(i)z:它被视为不可压缩噪声源; (ii) c,潜在代码: 针对数据分布的显着结构化语义特征

  • 实现逻辑:
    为生成器网络提供不可压缩噪声 z 和潜在代码 c,因此生成器的形式变为 G(z, c)。然而,在标准 GAN 中,生成器可以通过找到满足 PG(x|c) = PG(x) 的解决方案来自由忽略额外的潜在代码 c。为了解决平凡码的问题,我们提出了一种信息论正则化:潜在码 c 和生成器分布 G(z, c) 之间应该存在高互信息。因此 I(c; G(z, c)) 应该很高。

在信息论中,X 和 Y 之间的互信息,I(X; Y ),衡量从随机变量 Y 的知识中学到的关于另一个随机变量 X 的"信息量"。互信息可以表示为两个随机变量的差值熵项:

这个定义有一个直观的解释:I(X; Y ) 是观察到 Y 时 X 的不确定性的减少。如果 X 和 Y 是独立的,则 I(X; Y ) = 0,因为了解一个变量并不能揭示另一个变量;相反,如果 X 和 Y 通过确定性可逆函数相关,则可以获得最大互信息。这种解释使得很容易制定成本:给定任何 x ∼ PG(x),我们希望 PG(c|x) 具有较小的熵。换句话说,潜在代码c中的信息不应在生成过程中丢失。类似的互信息启发目标之前已经在聚类背景下被考虑过。因此,我们建议解决以下信息正则化极小极大博弈:

实际上,互信息项 I(c; G(z, c)) 很难直接最大化,因为它需要访问后验 P (c|x)。幸运的是,我们可以通过定义辅助分布 Q(c|x) 来近似 P(c|x) 来获得它的下界:

在实践中,我们将辅助分布 Q 参数化为神经网络。在大多数实验中,Q 和 D 共享所有卷积层,并且有一个最终的全连接层来输出条件分布 Q(c|x) 的参数,这意味着 InfoGAN 只为 GAN 添加了可以忽略不计的计算成本。我们还观察到,LI (G, Q) 总是比正常的 GAN 目标收敛得更快,因此 InfoGAN 本质上是免费随 GAN 提供。

代码实现

python 复制代码
class QInfoGAN2(nn.Module):

    def __init__(self, x_dim, c_dim, dim=96, norm='batch_norm', weight_norm='none'):
        super(QInfoGAN2, self).__init__()

        norm_fn = _get_norm_fn_2d(norm)
        weight_norm_fn = _get_weight_norm_fn(weight_norm)

        def conv_norm_lrelu(in_dim, out_dim, kernel_size=3, stride=1, padding=1):
            return nn.Sequential(
                weight_norm_fn(nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding)),
                norm_fn(out_dim),
                nn.LeakyReLU(0.2)
            )

        self.ls = nn.Sequential(  # (N, x_dim, 32, 32)
            conv_norm_lrelu(x_dim, dim),
            conv_norm_lrelu(dim, dim),
            conv_norm_lrelu(dim, dim, stride=2),  # (N, dim , 16, 16)

            conv_norm_lrelu(dim, dim * 2),
            conv_norm_lrelu(dim * 2, dim * 2),
            conv_norm_lrelu(dim * 2, dim * 2, stride=2),  # (N, dim*2, 8, 8)

            conv_norm_lrelu(dim * 2, dim * 2, kernel_size=3, stride=1, padding=0),
            conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0),
            conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0),  # (N, dim*2, 6, 6)

            nn.AvgPool2d(kernel_size=6),  # (N, dim*2, 1, 1)
            torchlib.Reshape(-1, dim * 2),  # (N, dim*2)
            nn.Linear(dim * 2, c_dim)  # (N, c_dim)
        )

    def forward(self, x):
        # x: (N, x_dim, 32, 32)
        logit = self.ls(x)
        return logit

class DiscriminatorInfoGAN2(nn.Module):

    def __init__(self, x_dim, dim=96, norm='none', weight_norm='spectral_norm'):
        super(DiscriminatorInfoGAN2, self).__init__()

        norm_fn = _get_norm_fn_2d(norm)
        weight_norm_fn = _get_weight_norm_fn(weight_norm)

        def conv_norm_lrelu(in_dim, out_dim, kernel_size=3, stride=1, padding=1):
            return nn.Sequential(
                weight_norm_fn(nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding)),
                norm_fn(out_dim),
                nn.LeakyReLU(0.2)
            )

        self.ls = nn.Sequential(  # (N, x_dim, 32, 32)
            conv_norm_lrelu(x_dim, dim),
            conv_norm_lrelu(dim, dim),
            conv_norm_lrelu(dim, dim, stride=2),  # (N, dim , 16, 16)

            conv_norm_lrelu(dim, dim * 2),
            conv_norm_lrelu(dim * 2, dim * 2),
            conv_norm_lrelu(dim * 2, dim * 2, stride=2),  # (N, dim*2, 8, 8)

            conv_norm_lrelu(dim * 2, dim * 2, kernel_size=3, stride=1, padding=0),
            conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0),
            conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0),  # (N, dim*2, 6, 6)

            nn.AvgPool2d(kernel_size=6),  # (N, dim*2, 1, 1)
            torchlib.Reshape(-1, dim * 2),  # (N, dim*2)
            weight_norm_fn(nn.Linear(dim * 2, 1))  # (N, 1)
        )

    def forward(self, x):
        # x: (N, x_dim, 32, 32)
        logit = self.ls(x)
        return logit

class GeneratorCGAN(nn.Module):

    def __init__(self, z_dim, c_dim, dim=128):
        super(GeneratorCGAN, self).__init__()

        def dconv_bn_relu(in_dim, out_dim, kernel_size=4, stride=2, padding=1, output_padding=0):
            return nn.Sequential(
                nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride, padding, output_padding),
                nn.BatchNorm2d(out_dim),
                nn.ReLU()
            )

        self.ls = nn.Sequential(
            dconv_bn_relu(z_dim + c_dim, dim * 4, 4, 1, 0, 0),  # (N, dim * 4, 4, 4)
            dconv_bn_relu(dim * 4, dim * 2),  # (N, dim * 2, 8, 8)
            dconv_bn_relu(dim * 2, dim),   # (N, dim, 16, 16)
            nn.ConvTranspose2d(dim, 3, 4, 2, padding=1), nn.Tanh()  # (N, 3, 32, 32)
        )

    def forward(self, z, c):
        # z: (N, z_dim), c: (N, c_dim) ->[64, 110]
        x = torch.cat([z, c], 1)
        # [64, 110] -> [64, 3, 32, 32]
        x = self.ls(x.view(x.size(0), x.size(1), 1, 1))
        # print(x.shape)
        return x
# model
D = DiscriminatorInfoGAN2(x_dim=3).to(device)
Q = model.QInfoGAN2(x_dim=3, c_dim=c_dim).to(device)
G = model.GeneratorInfoGAN2(z_dim=z_dim, c_dim=c_dim).to(device)

模式1

鉴别器与Q同时训练

python 复制代码
 # train D and Q
 x = x.to(device)
c_dense = torch.tensor(np.random.randint(c_dim, size=[batch_size])).to(device)
z = torch.randn(batch_size, z_dim).to(device)
c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)

x_f = G(z, c).detach()
x_gan_logit = D(x)
x_f_gan_logit = D(x_f)
x_f_c_logit = Q(x_f)

d_x_gan_loss, d_x_f_gan_loss = d_loss_fn(x_gan_logit, x_f_gan_logit)
d_x_f_c_logit = torch.nn.functional.cross_entropy(x_f_c_logit, c_dense)
gp = model.gradient_penalty(D, x, x_f, mode=gp_mode)
d_loss = d_x_gan_loss + d_x_f_gan_loss + gp * gp_coef
d_q_loss = d_x_f_c_logit
python 复制代码
## 训练生成器
c_dense = torch.tensor(np.random.randint(c_dim, size=[batch_size])).to(device)
            c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)
            z = torch.randn(batch_size, z_dim).to(device)

            x_f = G(z, c)
            x_f_gan_logit = D(x_f)
            x_f_c_logit = Q(x_f)

            g_gan_loss = g_loss_fn(x_f_gan_logit)
            d_x_f_c_logit = torch.nn.functional.cross_entropy(x_f_c_logit, c_dense)

模式二

生成器与Q同时进行训练

python 复制代码
# train D
x = x.to(device)
c_dense = torch.tensor(np.random.randint(c_dim, size=[batch_size])).to(device)
z = torch.randn(batch_size, z_dim).to(device)
c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)
# 输入随机图像与随机条件
x_f = G(z, c).detach()
# 输入图像,鉴别器判断
x_gan_logit = D(x)
# 输入伪图像鉴别器进行判断
x_f_gan_logit = D(x_f)
d_x_gan_loss, d_x_f_gan_loss = d_loss_fn(x_gan_logit, x_f_gan_logit)
python 复制代码
# train G and Q
c_dense = torch.tensor(np.random.randint(c_dim, size=[batch_size])).to(device)
c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)
z = torch.randn(batch_size, z_dim).to(device)
x_f = G(z, c)
x_f_gan_logit = D(x_f)
# Q与类别计算损失
 x_f_c_logit = Q(x_f)

g_gan_loss = g_loss_fn(x_f_gan_logit)
d_x_f_c_logit = torch.nn.functional.cross_entropy(x_f_c_logit, c_dense)
 g_loss = g_gan_loss + d_x_f_c_logit


源代码链接
参考链接

相关推荐
ZStack开发者社区1 小时前
AI应用、轻量云、虚拟化|云轴科技ZStack参编金融行标与报告
人工智能·科技·金融
真想骂*3 小时前
人工智能如何重塑音频、视觉及多模态领域的应用格局
人工智能·音视频
赛丽曼5 小时前
机器学习-K近邻算法
人工智能·机器学习·近邻算法
架构文摘JGWZ6 小时前
FastJson很快,有什么用?
后端·学习
啊波次得饿佛哥7 小时前
7. 计算机视觉
人工智能·计算机视觉·视觉检测
XianxinMao7 小时前
RLHF技术应用探析:从安全任务到高阶能力提升
人工智能·python·算法
Swift社区8 小时前
【分布式日志篇】从工具选型到实战部署:全面解析日志采集与管理路径
人工智能·spring boot·分布式
量子-Alex8 小时前
【多视图学习】显式视图-标签问题:多视图聚类的多方面互补性研究
学习
Quz8 小时前
OpenCV:高通滤波之索贝尔、沙尔和拉普拉斯
图像处理·人工智能·opencv·计算机视觉·矩阵
去往火星8 小时前
OpenCV文字绘制支持中文显示
人工智能·opencv·计算机视觉