一、代码
python
class GANLoss(nn.Module):
"""Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
"""
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
""" Initialize the GANLoss class.
Parameters:
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool) - - label for a real image
target_fake_label (bool) - - label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
"""
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
elif gan_mode == 'RidgeRegressionaLoss':
self.loss = RidgeLoss1(alpha=0.1)
elif gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode in ['wgangp']:
self.loss = None
else:
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
def get_target_tensor(self, prediction, target_is_real):
"""Create label tensors with the same size as the input.
Parameters:
prediction (tensor) - - tpyically the prediction from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
A label tensor filled with ground truth label, and with the size of the input
"""
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)
def __call__(self, prediction, target_is_real):
"""Calculate loss given Discriminator's output and grount truth labels.
Parameters:
prediction (tensor) - - tpyically the prediction output from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
the calculated loss.
"""
if self.gan_mode in ['lsgan', 'vanilla','RidgeRegressionaLoss']:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == 'wgangp':
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss
二、讲解
target_tensor.expand_as(prediction)
的意思是将target_tensor
张量的尺寸扩展为与prediction
张量相同的尺寸。
在生成对抗网络(GAN)中,判别器的输出通常是一个张量,表示样本为真实样本的概率或得分。为了计算损失,需要创建与判别器输出相同尺寸的目标标签张量。target_tensor
在get_target_tensor
方法中获得,表示目标标签,可以是真实样本标签或虚假样本标签。为了与判别器的输出张量进行元素级别的比较,需要将目标标签张量的尺寸扩展为与判别器输出相同的形状。
expand_as(prediction)
方法是一个张量的方法,它返回一个尺寸与prediction
张量相同的新张量,其中新张量的元素以target_tensor
的元素进行填充或重复,以便与prediction
进行逐元素比较。
通过将目标标签张量的尺寸扩展为与判别器输出相同的尺寸,可以确保在计算损失时每个生成样本或真实样本的标签都与对应的判别器输出进行比较。