系列文章目录
一 Conditional Generative Adversarial Nets
二 cGANs with Projection Discriminator
三 Conditional Image Synthesis with Auxiliary Classifier GANs
四 InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets
- 系列文章目录
- [InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets](#InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets)
-
- [1. 生成对抗网络](#1. 生成对抗网络)
- 2.用于引导潜在代码的互信息
- 鉴别器与Q同时训练
InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets
InfoGAN是生成对抗网络的信息论扩展,能够以完全无监督的方式学习解开的表示。InfoGAN 是一个生成对抗网络,它最大化了一小部分潜在变量和观察之间的互信息。
InfoGAN 在 GAN 的基础上只增加了可以忽略不计的计算成本,并且易于训练。使用互信息诱导表示的核心思想可以应用于 VAE等其他方法。实验结果表示:互信息成本增强的生成模型可能是学习解缠结表示的有效方法
1. 生成对抗网络
生成对抗网络(GAN)是一种使用最大最小化博弈训练深度生成模型的框架。目标是学习与真实数据分布 Pdata(x) 匹配的生成器分布 PG(x)。 GAN 没有尝试显式地为数据分布中的每个 x 分配概率,而是学习生成器网络 G,该生成器网络 G 通过将噪声变量 z ∼ Pnoise(z) 转换为样本 G(z),从生成器分布 PG 生成样本。该生成器通过与对抗性判别器网络 D 进行训练,该网络旨在区分真实数据分布 Pdata 和生成器分布 PG 中的样本。因此,对于给定的生成器,最佳判别器是 D(x) = Pdata(x)/(Pdata(x) + PG(x))。更正式地说,极小极大博弈由以下表达式给出:
2.用于引导潜在代码的互信息
GAN 公式使用简单分解的连续输入噪声向量 z,同时对生成器使用该噪声的方式没有限制。因此,噪声可能会被生成器以高度纠缠的方式使用,导致 z 的各个维度与数据的语义特征不对应 。
然而,许多领域自然地分解为一组语义上有意义的变化因素。例如,当从 MNIST 数据集生成图像时,**如果模型自动选择分配一个离散随机变量来表示数字 (0-9) 的数字身份,并选择两个数字角度与粗细作为连续变量来表示。**在这种情况下,这些属性既是独立的又是显着的,如果能够在没有任何监督的情况下恢复这些概念,通过简单地指定 MNIST 数字是由独立的 1-of-10 变量和两个独立的连续变量生成的,那将会非常有用变量。
论文建议将输入噪声向量分解为两部分,而不是使用单个非结构化噪声向量:(i)z:它被视为不可压缩噪声源; (ii) c,潜在代码: 针对数据分布的显着结构化语义特征。
- 实现逻辑:
为生成器网络提供不可压缩噪声 z 和潜在代码 c,因此生成器的形式变为 G(z, c)。然而,在标准 GAN 中,生成器可以通过找到满足 PG(x|c) = PG(x) 的解决方案来自由忽略额外的潜在代码 c。为了解决平凡码的问题,我们提出了一种信息论正则化:潜在码 c 和生成器分布 G(z, c) 之间应该存在高互信息。因此 I(c; G(z, c)) 应该很高。
在信息论中,X 和 Y 之间的互信息,I(X; Y ),衡量从随机变量 Y 的知识中学到的关于另一个随机变量 X 的"信息量"。互信息可以表示为两个随机变量的差值熵项:
这个定义有一个直观的解释:I(X; Y ) 是观察到 Y 时 X 的不确定性的减少。如果 X 和 Y 是独立的,则 I(X; Y ) = 0,因为了解一个变量并不能揭示另一个变量;相反,如果 X 和 Y 通过确定性可逆函数相关,则可以获得最大互信息。这种解释使得很容易制定成本:给定任何 x ∼ PG(x),我们希望 PG(c|x) 具有较小的熵。换句话说,潜在代码c中的信息不应在生成过程中丢失。类似的互信息启发目标之前已经在聚类背景下被考虑过。因此,我们建议解决以下信息正则化极小极大博弈:
实际上,互信息项 I(c; G(z, c)) 很难直接最大化,因为它需要访问后验 P (c|x)。幸运的是,我们可以通过定义辅助分布 Q(c|x) 来近似 P(c|x) 来获得它的下界:
在实践中,我们将辅助分布 Q 参数化为神经网络。在大多数实验中,Q 和 D 共享所有卷积层,并且有一个最终的全连接层来输出条件分布 Q(c|x) 的参数,这意味着 InfoGAN 只为 GAN 添加了可以忽略不计的计算成本。我们还观察到,LI (G, Q) 总是比正常的 GAN 目标收敛得更快,因此 InfoGAN 本质上是免费随 GAN 提供。
代码实现
python
class QInfoGAN2(nn.Module):
def __init__(self, x_dim, c_dim, dim=96, norm='batch_norm', weight_norm='none'):
super(QInfoGAN2, self).__init__()
norm_fn = _get_norm_fn_2d(norm)
weight_norm_fn = _get_weight_norm_fn(weight_norm)
def conv_norm_lrelu(in_dim, out_dim, kernel_size=3, stride=1, padding=1):
return nn.Sequential(
weight_norm_fn(nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding)),
norm_fn(out_dim),
nn.LeakyReLU(0.2)
)
self.ls = nn.Sequential( # (N, x_dim, 32, 32)
conv_norm_lrelu(x_dim, dim),
conv_norm_lrelu(dim, dim),
conv_norm_lrelu(dim, dim, stride=2), # (N, dim , 16, 16)
conv_norm_lrelu(dim, dim * 2),
conv_norm_lrelu(dim * 2, dim * 2),
conv_norm_lrelu(dim * 2, dim * 2, stride=2), # (N, dim*2, 8, 8)
conv_norm_lrelu(dim * 2, dim * 2, kernel_size=3, stride=1, padding=0),
conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0),
conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0), # (N, dim*2, 6, 6)
nn.AvgPool2d(kernel_size=6), # (N, dim*2, 1, 1)
torchlib.Reshape(-1, dim * 2), # (N, dim*2)
nn.Linear(dim * 2, c_dim) # (N, c_dim)
)
def forward(self, x):
# x: (N, x_dim, 32, 32)
logit = self.ls(x)
return logit
class DiscriminatorInfoGAN2(nn.Module):
def __init__(self, x_dim, dim=96, norm='none', weight_norm='spectral_norm'):
super(DiscriminatorInfoGAN2, self).__init__()
norm_fn = _get_norm_fn_2d(norm)
weight_norm_fn = _get_weight_norm_fn(weight_norm)
def conv_norm_lrelu(in_dim, out_dim, kernel_size=3, stride=1, padding=1):
return nn.Sequential(
weight_norm_fn(nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding)),
norm_fn(out_dim),
nn.LeakyReLU(0.2)
)
self.ls = nn.Sequential( # (N, x_dim, 32, 32)
conv_norm_lrelu(x_dim, dim),
conv_norm_lrelu(dim, dim),
conv_norm_lrelu(dim, dim, stride=2), # (N, dim , 16, 16)
conv_norm_lrelu(dim, dim * 2),
conv_norm_lrelu(dim * 2, dim * 2),
conv_norm_lrelu(dim * 2, dim * 2, stride=2), # (N, dim*2, 8, 8)
conv_norm_lrelu(dim * 2, dim * 2, kernel_size=3, stride=1, padding=0),
conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0),
conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0), # (N, dim*2, 6, 6)
nn.AvgPool2d(kernel_size=6), # (N, dim*2, 1, 1)
torchlib.Reshape(-1, dim * 2), # (N, dim*2)
weight_norm_fn(nn.Linear(dim * 2, 1)) # (N, 1)
)
def forward(self, x):
# x: (N, x_dim, 32, 32)
logit = self.ls(x)
return logit
class GeneratorCGAN(nn.Module):
def __init__(self, z_dim, c_dim, dim=128):
super(GeneratorCGAN, self).__init__()
def dconv_bn_relu(in_dim, out_dim, kernel_size=4, stride=2, padding=1, output_padding=0):
return nn.Sequential(
nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride, padding, output_padding),
nn.BatchNorm2d(out_dim),
nn.ReLU()
)
self.ls = nn.Sequential(
dconv_bn_relu(z_dim + c_dim, dim * 4, 4, 1, 0, 0), # (N, dim * 4, 4, 4)
dconv_bn_relu(dim * 4, dim * 2), # (N, dim * 2, 8, 8)
dconv_bn_relu(dim * 2, dim), # (N, dim, 16, 16)
nn.ConvTranspose2d(dim, 3, 4, 2, padding=1), nn.Tanh() # (N, 3, 32, 32)
)
def forward(self, z, c):
# z: (N, z_dim), c: (N, c_dim) ->[64, 110]
x = torch.cat([z, c], 1)
# [64, 110] -> [64, 3, 32, 32]
x = self.ls(x.view(x.size(0), x.size(1), 1, 1))
# print(x.shape)
return x
# model
D = DiscriminatorInfoGAN2(x_dim=3).to(device)
Q = model.QInfoGAN2(x_dim=3, c_dim=c_dim).to(device)
G = model.GeneratorInfoGAN2(z_dim=z_dim, c_dim=c_dim).to(device)
模式1
鉴别器与Q同时训练
python
# train D and Q
x = x.to(device)
c_dense = torch.tensor(np.random.randint(c_dim, size=[batch_size])).to(device)
z = torch.randn(batch_size, z_dim).to(device)
c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)
x_f = G(z, c).detach()
x_gan_logit = D(x)
x_f_gan_logit = D(x_f)
x_f_c_logit = Q(x_f)
d_x_gan_loss, d_x_f_gan_loss = d_loss_fn(x_gan_logit, x_f_gan_logit)
d_x_f_c_logit = torch.nn.functional.cross_entropy(x_f_c_logit, c_dense)
gp = model.gradient_penalty(D, x, x_f, mode=gp_mode)
d_loss = d_x_gan_loss + d_x_f_gan_loss + gp * gp_coef
d_q_loss = d_x_f_c_logit
python
## 训练生成器
c_dense = torch.tensor(np.random.randint(c_dim, size=[batch_size])).to(device)
c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)
z = torch.randn(batch_size, z_dim).to(device)
x_f = G(z, c)
x_f_gan_logit = D(x_f)
x_f_c_logit = Q(x_f)
g_gan_loss = g_loss_fn(x_f_gan_logit)
d_x_f_c_logit = torch.nn.functional.cross_entropy(x_f_c_logit, c_dense)
模式二
生成器与Q同时进行训练
python
# train D
x = x.to(device)
c_dense = torch.tensor(np.random.randint(c_dim, size=[batch_size])).to(device)
z = torch.randn(batch_size, z_dim).to(device)
c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)
# 输入随机图像与随机条件
x_f = G(z, c).detach()
# 输入图像,鉴别器判断
x_gan_logit = D(x)
# 输入伪图像鉴别器进行判断
x_f_gan_logit = D(x_f)
d_x_gan_loss, d_x_f_gan_loss = d_loss_fn(x_gan_logit, x_f_gan_logit)
python
# train G and Q
c_dense = torch.tensor(np.random.randint(c_dim, size=[batch_size])).to(device)
c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)
z = torch.randn(batch_size, z_dim).to(device)
x_f = G(z, c)
x_f_gan_logit = D(x_f)
# Q与类别计算损失
x_f_c_logit = Q(x_f)
g_gan_loss = g_loss_fn(x_f_gan_logit)
d_x_f_c_logit = torch.nn.functional.cross_entropy(x_f_c_logit, c_dense)
g_loss = g_gan_loss + d_x_f_c_logit