深度学习Day-30:CGAN入门丨生成手势图像丨可控制生成

🍨 本文为:[🔗365天深度学习训练营] 中的学习记录博客

🍖 原作者:[K同学啊 | 接辅导、项目定制]

要求:

  1. 结合代码进一步了解CGAN
  2. 学习如何运用生成好的生成器生成指定图像

一、 基础配置

  • 语言环境:Python3.8
  • 编译器选择:Pycharm
  • 深度学习环境:
    • torch==1.12.1+cu113
    • torchvision==0.13.1+cu113

二、 前期准备

1. 导入第三方库

python 复制代码
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import os

os.makedirs('./images', exist_ok=True)
os.makedirs('./training_weights', exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

得到如下输出:

python 复制代码
cuda

2. 导入数据

python 复制代码
batch_size = 128
train_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])

train_dataset = datasets.ImageFolder(root="GAN-3-data", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=0)

3. 数据可视化

运行下述代码:

python 复制代码
def show_images(dl):
    for images, _ in dl:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images.detach(), nrow=16).permute(1, 2, 0))
        break

show_images(train_loader)

输出图像为:

4. 定义超参数

运行下述代码:

python 复制代码
latent_dim = 100
n_classes = 3
embedding_dim = 100

5. 构建模型

5.1.初始化权重

python 复制代码
def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)

    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

5.2.定义生成器

python 复制代码
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_conditioned_generator = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),
            nn.Linear(embedding_dim, 16)
        )
        self.latent = nn.Sequential(
            nn.Linear(latent_dim, 4*4*512),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.model = nn.Sequential(
            nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, inputs):
        noise_vector, label = inputs
        label_output = self.label_conditioned_generator(label)
        label_output = label_output.view(-1, 1, 4, 4)
        latent_output = self.latent(noise_vector)
        latent_output = latent_output.view(-1, 512, 4, 4)
        concat = torch.cat((latent_output, label_output), dim=1)
        image = self.model(concat)
        return image

generator = Generator().to(device)
generator.apply(weights_init)
print(generator)

from torchinfo import summary
summary(generator)

输出为:

python 复制代码
Generator(
  (label_conditioned_generator): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=16, bias=True)
  )
  (latent): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (model): Sequential(
    (0): ConvTranspose2d(513, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Generator                                --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       1,616
├─Sequential: 1-2                        --
│    └─Linear: 2-3                       827,392
│    └─LeakyReLU: 2-4                    --
├─Sequential: 1-3                        --
│    └─ConvTranspose2d: 2-5              4,202,496
│    └─BatchNorm2d: 2-6                  1,024
│    └─ReLU: 2-7                         --
│    └─ConvTranspose2d: 2-8              2,097,152
│    └─BatchNorm2d: 2-9                  512
│    └─ReLU: 2-10                        --
│    └─ConvTranspose2d: 2-11             524,288
│    └─BatchNorm2d: 2-12                 256
│    └─ReLU: 2-13                        --
│    └─ConvTranspose2d: 2-14             131,072
│    └─BatchNorm2d: 2-15                 128
│    └─ReLU: 2-16                        --
│    └─ConvTranspose2d: 2-17             3,072
│    └─Tanh: 2-18                        --
=================================================================
Total params: 7,789,308
Trainable params: 7,789,308
Non-trainable params: 0
=================================================================

5.3.定义鉴别器

python 复制代码
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_condition_disc = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),
            nn.Linear(embedding_dim, 3 * 128 * 128)
        )

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 2, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 4, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 4, 64 * 8, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 8, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Dropout(0.4),
            nn.Linear(4608, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        img, label = inputs

        label_output = self.label_condition_disc(label)
        label_output = label_output.view(-1, 3, 128, 128)

        concat = torch.cat((img, label_output), dim=1)

        output = self.model(concat)
        return output

discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

summary(discriminator)

输出为:

python 复制代码
Discriminator(
  (label_condition_disc): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=49152, bias=True)
  )
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (6): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (9): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Flatten(start_dim=1, end_dim=-1)
    (12): Dropout(p=0.4, inplace=False)
    (13): Linear(in_features=4608, out_features=1, bias=True)
    (14): Sigmoid()
  )
)
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Discriminator                            --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       4,964,352
├─Sequential: 1-2                        --
│    └─Conv2d: 2-3                       6,144
│    └─LeakyReLU: 2-4                    --
│    └─Conv2d: 2-5                       131,072
│    └─BatchNorm2d: 2-6                  256
│    └─LeakyReLU: 2-7                    --
│    └─Conv2d: 2-8                       524,288
│    └─BatchNorm2d: 2-9                  512
│    └─LeakyReLU: 2-10                   --
│    └─Conv2d: 2-11                      2,097,152
│    └─BatchNorm2d: 2-12                 1,024
│    └─LeakyReLU: 2-13                   --
│    └─Flatten: 2-14                     --
│    └─Dropout: 2-15                     --
│    └─Linear: 2-16                      4,609
│    └─Sigmoid: 2-17                     --
=================================================================
Total params: 7,729,709
Trainable params: 7,729,709
Non-trainable params: 0
=================================================================

三、 训练模型

1. 定义训练参数

python 复制代码
adversarial_loss = nn.BCELoss()

def generator_loss(fake_output, label):
    gen_loss = adversarial_loss(fake_output, label)
    return gen_loss

def discriminator_loss(output, label):
    disc_loss = adversarial_loss(output, label)
    return disc_loss

2. 定义优化器

python 复制代码
learning_rate = 0.0002

G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

3. 训练模型

python 复制代码
num_epochs = 100

D_loss_plot, G_loss_plot = [], []

for epoch in range(1, num_epochs + 1):

    D_loss_list, G_loss_list = [], []

    for index, (real_images, labels) in enumerate(train_loader):
        D_optimizer.zero_grad()

        real_images = real_images.to(device)
        labels = labels.to(device)

        labels = labels.unsqueeze(1).long()

        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))

        D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)

        noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
        noise_vector = noise_vector.to(device)
        generated_image = generator((noise_vector, labels))

        output = discriminator((generated_image.detach(), labels))
        D_fake_loss = discriminator_loss(output, fake_target)

        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss)

        D_total_loss.backward()
        D_optimizer.step()

        G_optimizer.zero_grad()
        G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
        G_loss_list.append(G_loss)

        G_loss.backward()
        G_optimizer.step()

    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
        (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),
        torch.mean(torch.FloatTensor(G_loss_list))))

    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))

    if epoch % 10 == 0:
        save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow=5, normalize=True)
        torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch))
        torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch))

输出为:

python 复制代码
Epoch: [1/100]: D_loss: 0.285, G_loss: 2.018
Epoch: [2/100]: D_loss: 0.331, G_loss: 2.298
Epoch: [3/100]: D_loss: 0.403, G_loss: 1.715
Epoch: [4/100]: D_loss: 0.467, G_loss: 1.416
Epoch: [5/100]: D_loss: 0.490, G_loss: 1.618
Epoch: [6/100]: D_loss: 0.490, G_loss: 1.585
Epoch: [7/100]: D_loss: 0.379, G_loss: 1.674
Epoch: [8/100]: D_loss: 0.443, G_loss: 1.889
Epoch: [9/100]: D_loss: 0.541, G_loss: 2.067
Epoch: [10/100]: D_loss: 0.565, G_loss: 1.751
Epoch: [11/100]: D_loss: 0.528, G_loss: 1.495
Epoch: [12/100]: D_loss: 0.555, G_loss: 1.461
Epoch: [13/100]: D_loss: 0.569, G_loss: 1.490
Epoch: [14/100]: D_loss: 0.531, G_loss: 1.498
Epoch: [15/100]: D_loss: 0.504, G_loss: 1.532
Epoch: [16/100]: D_loss: 0.487, G_loss: 1.612
Epoch: [17/100]: D_loss: 0.457, G_loss: 1.776
Epoch: [18/100]: D_loss: 0.462, G_loss: 1.767
Epoch: [19/100]: D_loss: 0.437, G_loss: 1.946
Epoch: [20/100]: D_loss: 0.446, G_loss: 1.848
Epoch: [21/100]: D_loss: 0.463, G_loss: 1.718
Epoch: [22/100]: D_loss: 0.473, G_loss: 1.748
Epoch: [23/100]: D_loss: 0.503, G_loss: 1.579
Epoch: [24/100]: D_loss: 0.482, G_loss: 1.410
Epoch: [25/100]: D_loss: 0.489, G_loss: 1.440
Epoch: [26/100]: D_loss: 0.494, G_loss: 1.425
Epoch: [27/100]: D_loss: 0.510, G_loss: 1.398
Epoch: [28/100]: D_loss: 0.475, G_loss: 1.410
Epoch: [29/100]: D_loss: 0.473, G_loss: 1.459
Epoch: [30/100]: D_loss: 0.473, G_loss: 1.489
Epoch: [31/100]: D_loss: 0.462, G_loss: 1.484
Epoch: [32/100]: D_loss: 0.448, G_loss: 1.520
Epoch: [33/100]: D_loss: 0.457, G_loss: 1.548
Epoch: [34/100]: D_loss: 0.418, G_loss: 1.558
Epoch: [35/100]: D_loss: 0.433, G_loss: 1.667
Epoch: [36/100]: D_loss: 0.402, G_loss: 1.665
Epoch: [37/100]: D_loss: 0.401, G_loss: 1.709
Epoch: [38/100]: D_loss: 0.425, G_loss: 1.841
Epoch: [39/100]: D_loss: 0.399, G_loss: 1.711
Epoch: [40/100]: D_loss: 0.429, G_loss: 1.873
Epoch: [41/100]: D_loss: 0.374, G_loss: 1.857
Epoch: [42/100]: D_loss: 0.382, G_loss: 1.869
Epoch: [43/100]: D_loss: 0.431, G_loss: 1.935
Epoch: [44/100]: D_loss: 0.355, G_loss: 1.871
Epoch: [45/100]: D_loss: 0.363, G_loss: 1.875
Epoch: [46/100]: D_loss: 0.485, G_loss: 2.011
Epoch: [47/100]: D_loss: 0.391, G_loss: 1.994
Epoch: [48/100]: D_loss: 0.331, G_loss: 1.924
Epoch: [49/100]: D_loss: 0.317, G_loss: 1.930
Epoch: [50/100]: D_loss: 0.353, G_loss: 2.035
Epoch: [51/100]: D_loss: 0.334, G_loss: 2.072
Epoch: [52/100]: D_loss: 0.387, G_loss: 2.092
Epoch: [53/100]: D_loss: 0.380, G_loss: 2.139
Epoch: [54/100]: D_loss: 0.302, G_loss: 2.077
Epoch: [55/100]: D_loss: 0.311, G_loss: 2.055
Epoch: [56/100]: D_loss: 0.326, G_loss: 2.169
Epoch: [57/100]: D_loss: 0.309, G_loss: 2.239
Epoch: [58/100]: D_loss: 0.323, G_loss: 2.207
Epoch: [59/100]: D_loss: 0.285, G_loss: 2.239
Epoch: [60/100]: D_loss: 0.306, G_loss: 2.304
Epoch: [61/100]: D_loss: 0.287, G_loss: 2.254
Epoch: [62/100]: D_loss: 0.295, G_loss: 2.406
Epoch: [63/100]: D_loss: 0.305, G_loss: 2.499
Epoch: [64/100]: D_loss: 0.298, G_loss: 2.462
Epoch: [65/100]: D_loss: 0.255, G_loss: 2.418
Epoch: [66/100]: D_loss: 0.480, G_loss: 2.714
Epoch: [67/100]: D_loss: 0.265, G_loss: 2.379
Epoch: [68/100]: D_loss: 0.256, G_loss: 2.453
Epoch: [69/100]: D_loss: 0.252, G_loss: 2.465
Epoch: [70/100]: D_loss: 0.240, G_loss: 2.600
Epoch: [71/100]: D_loss: 0.250, G_loss: 2.516
Epoch: [72/100]: D_loss: 0.228, G_loss: 2.534
Epoch: [73/100]: D_loss: 0.249, G_loss: 2.566
Epoch: [74/100]: D_loss: 0.385, G_loss: 2.915
Epoch: [75/100]: D_loss: 0.232, G_loss: 2.566
Epoch: [76/100]: D_loss: 0.335, G_loss: 2.776
Epoch: [77/100]: D_loss: 0.243, G_loss: 2.703
Epoch: [78/100]: D_loss: 0.232, G_loss: 2.650
Epoch: [79/100]: D_loss: 0.216, G_loss: 2.736
Epoch: [80/100]: D_loss: 0.219, G_loss: 2.725
Epoch: [81/100]: D_loss: 0.272, G_loss: 2.869
Epoch: [82/100]: D_loss: 0.218, G_loss: 2.839
Epoch: [83/100]: D_loss: 0.219, G_loss: 2.836
Epoch: [84/100]: D_loss: 0.233, G_loss: 2.948
Epoch: [85/100]: D_loss: 0.209, G_loss: 2.952
Epoch: [86/100]: D_loss: 0.251, G_loss: 3.052
Epoch: [87/100]: D_loss: 0.198, G_loss: 2.905
Epoch: [88/100]: D_loss: 0.193, G_loss: 3.054
Epoch: [89/100]: D_loss: 0.215, G_loss: 2.995
Epoch: [90/100]: D_loss: 0.193, G_loss: 3.081
Epoch: [91/100]: D_loss: 0.446, G_loss: 3.269
Epoch: [92/100]: D_loss: 0.227, G_loss: 2.871
Epoch: [93/100]: D_loss: 0.191, G_loss: 3.008
Epoch: [94/100]: D_loss: 0.200, G_loss: 3.066
Epoch: [95/100]: D_loss: 0.200, G_loss: 3.142
Epoch: [96/100]: D_loss: 0.186, G_loss: 3.113
Epoch: [97/100]: D_loss: 0.207, G_loss: 3.159
Epoch: [98/100]: D_loss: 0.219, G_loss: 3.213
Epoch: [99/100]: D_loss: 0.177, G_loss: 3.205
Epoch: [100/100]: D_loss: 0.184, G_loss: 3.258

4. 可视化

4.1.LOSS图

python 复制代码
G_loss_list = [i.item() for i in G_loss_plot]
D_loss_list = [i.item() for i in D_loss_plot]

import warnings

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100

plt.figure(figsize=(8,4))
plt.title("Generator and Descriminator Loss During Training")
plt.plot(G_loss_list,label = "G")
plt.plot(D_loss_list,label = "D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

输出图像为:

4.2.生成指定图像

python 复制代码
from numpy.random import randn

generator.load_state_dict(torch.load("./training_weights/generator_epoch_100.pth"), strict = False)
generator.eval()

interpolated = randn(100)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)

label = 0
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()

predictions = generator((interpolated, labels))
predictions = predictions.permute(0, 2, 3, 1).detach().cpu()

import warnings

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100


plt.figure(figsize=(8, 3))
pred = (predictions[0, :, :, :] + 1 ) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()

输出图像为:

四、理论基础

CGAN(条件生成对抗网络)的原理是在原始GAN的基础上,为生成器和判别器提供 额外的条件信息。

CGAN通过将条件信息(如类别标签或其他辅助信息)加入生成器和判别器的输入中,使得生成器能够根据这些条件信息生成特定类型的数据,而判别器则负责区分真实数据和生成数据是否符合这些条件。这种方式让生成器在生成数据时有了明确的方向,从而提高了生成数据的质量与相关性。

CGAN的特点包括有监督学习、联合隐层表征、可控性、使用卷积结构等,其具体内容为:

  1. **有监督学习:**CGAN通过额外信息的使用,将原本无监督的GAN转变为一种有监督的学习模式,这使得网络的训练更加目标明确,生成结果更加符合预期。
  2. **联合隐层表征:**在生成模型中,噪声输入和条件信息共同构成了联合隐层表征,这有助于生成更多样化且具有特定属性的数据。
  3. **可控性:**CGAN的一个关键特点是提高了生成过程的可控性,即可以通过调整条件信息来指导模型生成特定类型的数据。
  4. **使用卷积结构:**CGAN可以采用卷积神经网络作为其内部结构,这在图像相关的任务中尤其有效,因为它能够捕捉到局部特征,并提高模型对细节的处理能力。

相比于传统的GAN,CGAN的主要异同点包括条件信息的输入、训练稳定性、损失函数、网络结构等,其具体内容为:

  1. **条件信息的输入:**CGAN引入了条件变量,使得生成器和判别器都能接收到更多的信息来指导训练过程,这是传统GAN所不具备的。
  2. **训练稳定性:**传统GAN在训练过程中容易产生模式崩溃(mode collapse)的问题,而CGAN由于有了额外的条件信息,可以提高训练的稳定性和生成数据的多样性。
  3. **损失函数:**虽然CGAN的损失函数仍然保留了传统GAN的对抗损失函数的形式,但额外添加的条件信息使得损失计算更加复杂且有针对性。
  4. **网络结构:**在实现上,CGAN可以采用更深更复杂的网络结构,如卷积神经网络,这有助于处理更为复杂的数据类型,比如高分辨率图像。

CGAN网络结构如下图所示: 由上图的网络结构可知,条件信息y作为额外的输入被引入对抗网络中,与生成器中的噪声z合并作为隐含层表达;而在判别器D中,条件信息y则与原始数据x合并作为判别函数的输入。这种改进在以后的诸多方面研究中被证明是非常有效的,也为后续的相关工作提供了积极的指导作用

综上所述,CGAN的核心在于它通过引入条件信息来增强模型的生成能力和可控性,与传统GAN相比,它提供了更明确的训练目标和更好的生成效果。

相关推荐
埃菲尔铁塔_CV算法几秒前
人工智能图像算法:开启视觉新时代的钥匙
人工智能·算法
EasyCVR1 分钟前
EHOME视频平台EasyCVR视频融合平台使用OBS进行RTMP推流,WebRTC播放出现抖动、卡顿如何解决?
人工智能·算法·ffmpeg·音视频·webrtc·监控视频接入
打羽毛球吗️7 分钟前
机器学习中的两种主要思路:数据驱动与模型驱动
人工智能·机器学习
好喜欢吃红柚子24 分钟前
万字长文解读空间、通道注意力机制机制和超详细代码逐行分析(SE,CBAM,SGE,CA,ECA,TA)
人工智能·pytorch·python·计算机视觉·cnn
小馒头学python28 分钟前
机器学习是什么?AIGC又是什么?机器学习与AIGC未来科技的双引擎
人工智能·python·机器学习
神奇夜光杯38 分钟前
Python酷库之旅-第三方库Pandas(202)
开发语言·人工智能·python·excel·pandas·标准库及第三方库·学习与成长
正义的彬彬侠40 分钟前
《XGBoost算法的原理推导》12-14决策树复杂度的正则化项 公式解析
人工智能·决策树·机器学习·集成学习·boosting·xgboost
Debroon1 小时前
RuleAlign 规则对齐框架:将医生的诊断规则形式化并注入模型,无需额外人工标注的自动对齐方法
人工智能
羊小猪~~1 小时前
神经网络基础--什么是正向传播??什么是方向传播??
人工智能·pytorch·python·深度学习·神经网络·算法·机器学习
AI小杨1 小时前
【车道线检测】一、传统车道线检测:基于霍夫变换的车道线检测史诗级详细教程
人工智能·opencv·计算机视觉·霍夫变换·车道线检测