深度学习Day-30:CGAN入门丨生成手势图像丨可控制生成

🍨 本文为:[🔗365天深度学习训练营] 中的学习记录博客

🍖 原作者:[K同学啊 | 接辅导、项目定制]

要求:

  1. 结合代码进一步了解CGAN
  2. 学习如何运用生成好的生成器生成指定图像

一、 基础配置

  • 语言环境:Python3.8
  • 编译器选择:Pycharm
  • 深度学习环境:
    • torch==1.12.1+cu113
    • torchvision==0.13.1+cu113

二、 前期准备

1. 导入第三方库

python 复制代码
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import os

os.makedirs('./images', exist_ok=True)
os.makedirs('./training_weights', exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

得到如下输出:

python 复制代码
cuda

2. 导入数据

python 复制代码
batch_size = 128
train_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])

train_dataset = datasets.ImageFolder(root="GAN-3-data", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=0)

3. 数据可视化

运行下述代码:

python 复制代码
def show_images(dl):
    for images, _ in dl:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images.detach(), nrow=16).permute(1, 2, 0))
        break

show_images(train_loader)

输出图像为:

4. 定义超参数

运行下述代码:

python 复制代码
latent_dim = 100
n_classes = 3
embedding_dim = 100

5. 构建模型

5.1.初始化权重

python 复制代码
def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)

    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

5.2.定义生成器

python 复制代码
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_conditioned_generator = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),
            nn.Linear(embedding_dim, 16)
        )
        self.latent = nn.Sequential(
            nn.Linear(latent_dim, 4*4*512),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.model = nn.Sequential(
            nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, inputs):
        noise_vector, label = inputs
        label_output = self.label_conditioned_generator(label)
        label_output = label_output.view(-1, 1, 4, 4)
        latent_output = self.latent(noise_vector)
        latent_output = latent_output.view(-1, 512, 4, 4)
        concat = torch.cat((latent_output, label_output), dim=1)
        image = self.model(concat)
        return image

generator = Generator().to(device)
generator.apply(weights_init)
print(generator)

from torchinfo import summary
summary(generator)

输出为:

python 复制代码
Generator(
  (label_conditioned_generator): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=16, bias=True)
  )
  (latent): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (model): Sequential(
    (0): ConvTranspose2d(513, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Generator                                --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       1,616
├─Sequential: 1-2                        --
│    └─Linear: 2-3                       827,392
│    └─LeakyReLU: 2-4                    --
├─Sequential: 1-3                        --
│    └─ConvTranspose2d: 2-5              4,202,496
│    └─BatchNorm2d: 2-6                  1,024
│    └─ReLU: 2-7                         --
│    └─ConvTranspose2d: 2-8              2,097,152
│    └─BatchNorm2d: 2-9                  512
│    └─ReLU: 2-10                        --
│    └─ConvTranspose2d: 2-11             524,288
│    └─BatchNorm2d: 2-12                 256
│    └─ReLU: 2-13                        --
│    └─ConvTranspose2d: 2-14             131,072
│    └─BatchNorm2d: 2-15                 128
│    └─ReLU: 2-16                        --
│    └─ConvTranspose2d: 2-17             3,072
│    └─Tanh: 2-18                        --
=================================================================
Total params: 7,789,308
Trainable params: 7,789,308
Non-trainable params: 0
=================================================================

5.3.定义鉴别器

python 复制代码
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_condition_disc = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),
            nn.Linear(embedding_dim, 3 * 128 * 128)
        )

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 2, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 4, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 4, 64 * 8, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 8, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Dropout(0.4),
            nn.Linear(4608, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        img, label = inputs

        label_output = self.label_condition_disc(label)
        label_output = label_output.view(-1, 3, 128, 128)

        concat = torch.cat((img, label_output), dim=1)

        output = self.model(concat)
        return output

discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

summary(discriminator)

输出为:

python 复制代码
Discriminator(
  (label_condition_disc): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=49152, bias=True)
  )
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (6): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (9): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Flatten(start_dim=1, end_dim=-1)
    (12): Dropout(p=0.4, inplace=False)
    (13): Linear(in_features=4608, out_features=1, bias=True)
    (14): Sigmoid()
  )
)
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Discriminator                            --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       4,964,352
├─Sequential: 1-2                        --
│    └─Conv2d: 2-3                       6,144
│    └─LeakyReLU: 2-4                    --
│    └─Conv2d: 2-5                       131,072
│    └─BatchNorm2d: 2-6                  256
│    └─LeakyReLU: 2-7                    --
│    └─Conv2d: 2-8                       524,288
│    └─BatchNorm2d: 2-9                  512
│    └─LeakyReLU: 2-10                   --
│    └─Conv2d: 2-11                      2,097,152
│    └─BatchNorm2d: 2-12                 1,024
│    └─LeakyReLU: 2-13                   --
│    └─Flatten: 2-14                     --
│    └─Dropout: 2-15                     --
│    └─Linear: 2-16                      4,609
│    └─Sigmoid: 2-17                     --
=================================================================
Total params: 7,729,709
Trainable params: 7,729,709
Non-trainable params: 0
=================================================================

三、 训练模型

1. 定义训练参数

python 复制代码
adversarial_loss = nn.BCELoss()

def generator_loss(fake_output, label):
    gen_loss = adversarial_loss(fake_output, label)
    return gen_loss

def discriminator_loss(output, label):
    disc_loss = adversarial_loss(output, label)
    return disc_loss

2. 定义优化器

python 复制代码
learning_rate = 0.0002

G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

3. 训练模型

python 复制代码
num_epochs = 100

D_loss_plot, G_loss_plot = [], []

for epoch in range(1, num_epochs + 1):

    D_loss_list, G_loss_list = [], []

    for index, (real_images, labels) in enumerate(train_loader):
        D_optimizer.zero_grad()

        real_images = real_images.to(device)
        labels = labels.to(device)

        labels = labels.unsqueeze(1).long()

        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))

        D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)

        noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
        noise_vector = noise_vector.to(device)
        generated_image = generator((noise_vector, labels))

        output = discriminator((generated_image.detach(), labels))
        D_fake_loss = discriminator_loss(output, fake_target)

        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss)

        D_total_loss.backward()
        D_optimizer.step()

        G_optimizer.zero_grad()
        G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
        G_loss_list.append(G_loss)

        G_loss.backward()
        G_optimizer.step()

    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
        (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),
        torch.mean(torch.FloatTensor(G_loss_list))))

    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))

    if epoch % 10 == 0:
        save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow=5, normalize=True)
        torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch))
        torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch))

输出为:

python 复制代码
Epoch: [1/100]: D_loss: 0.285, G_loss: 2.018
Epoch: [2/100]: D_loss: 0.331, G_loss: 2.298
Epoch: [3/100]: D_loss: 0.403, G_loss: 1.715
Epoch: [4/100]: D_loss: 0.467, G_loss: 1.416
Epoch: [5/100]: D_loss: 0.490, G_loss: 1.618
Epoch: [6/100]: D_loss: 0.490, G_loss: 1.585
Epoch: [7/100]: D_loss: 0.379, G_loss: 1.674
Epoch: [8/100]: D_loss: 0.443, G_loss: 1.889
Epoch: [9/100]: D_loss: 0.541, G_loss: 2.067
Epoch: [10/100]: D_loss: 0.565, G_loss: 1.751
Epoch: [11/100]: D_loss: 0.528, G_loss: 1.495
Epoch: [12/100]: D_loss: 0.555, G_loss: 1.461
Epoch: [13/100]: D_loss: 0.569, G_loss: 1.490
Epoch: [14/100]: D_loss: 0.531, G_loss: 1.498
Epoch: [15/100]: D_loss: 0.504, G_loss: 1.532
Epoch: [16/100]: D_loss: 0.487, G_loss: 1.612
Epoch: [17/100]: D_loss: 0.457, G_loss: 1.776
Epoch: [18/100]: D_loss: 0.462, G_loss: 1.767
Epoch: [19/100]: D_loss: 0.437, G_loss: 1.946
Epoch: [20/100]: D_loss: 0.446, G_loss: 1.848
Epoch: [21/100]: D_loss: 0.463, G_loss: 1.718
Epoch: [22/100]: D_loss: 0.473, G_loss: 1.748
Epoch: [23/100]: D_loss: 0.503, G_loss: 1.579
Epoch: [24/100]: D_loss: 0.482, G_loss: 1.410
Epoch: [25/100]: D_loss: 0.489, G_loss: 1.440
Epoch: [26/100]: D_loss: 0.494, G_loss: 1.425
Epoch: [27/100]: D_loss: 0.510, G_loss: 1.398
Epoch: [28/100]: D_loss: 0.475, G_loss: 1.410
Epoch: [29/100]: D_loss: 0.473, G_loss: 1.459
Epoch: [30/100]: D_loss: 0.473, G_loss: 1.489
Epoch: [31/100]: D_loss: 0.462, G_loss: 1.484
Epoch: [32/100]: D_loss: 0.448, G_loss: 1.520
Epoch: [33/100]: D_loss: 0.457, G_loss: 1.548
Epoch: [34/100]: D_loss: 0.418, G_loss: 1.558
Epoch: [35/100]: D_loss: 0.433, G_loss: 1.667
Epoch: [36/100]: D_loss: 0.402, G_loss: 1.665
Epoch: [37/100]: D_loss: 0.401, G_loss: 1.709
Epoch: [38/100]: D_loss: 0.425, G_loss: 1.841
Epoch: [39/100]: D_loss: 0.399, G_loss: 1.711
Epoch: [40/100]: D_loss: 0.429, G_loss: 1.873
Epoch: [41/100]: D_loss: 0.374, G_loss: 1.857
Epoch: [42/100]: D_loss: 0.382, G_loss: 1.869
Epoch: [43/100]: D_loss: 0.431, G_loss: 1.935
Epoch: [44/100]: D_loss: 0.355, G_loss: 1.871
Epoch: [45/100]: D_loss: 0.363, G_loss: 1.875
Epoch: [46/100]: D_loss: 0.485, G_loss: 2.011
Epoch: [47/100]: D_loss: 0.391, G_loss: 1.994
Epoch: [48/100]: D_loss: 0.331, G_loss: 1.924
Epoch: [49/100]: D_loss: 0.317, G_loss: 1.930
Epoch: [50/100]: D_loss: 0.353, G_loss: 2.035
Epoch: [51/100]: D_loss: 0.334, G_loss: 2.072
Epoch: [52/100]: D_loss: 0.387, G_loss: 2.092
Epoch: [53/100]: D_loss: 0.380, G_loss: 2.139
Epoch: [54/100]: D_loss: 0.302, G_loss: 2.077
Epoch: [55/100]: D_loss: 0.311, G_loss: 2.055
Epoch: [56/100]: D_loss: 0.326, G_loss: 2.169
Epoch: [57/100]: D_loss: 0.309, G_loss: 2.239
Epoch: [58/100]: D_loss: 0.323, G_loss: 2.207
Epoch: [59/100]: D_loss: 0.285, G_loss: 2.239
Epoch: [60/100]: D_loss: 0.306, G_loss: 2.304
Epoch: [61/100]: D_loss: 0.287, G_loss: 2.254
Epoch: [62/100]: D_loss: 0.295, G_loss: 2.406
Epoch: [63/100]: D_loss: 0.305, G_loss: 2.499
Epoch: [64/100]: D_loss: 0.298, G_loss: 2.462
Epoch: [65/100]: D_loss: 0.255, G_loss: 2.418
Epoch: [66/100]: D_loss: 0.480, G_loss: 2.714
Epoch: [67/100]: D_loss: 0.265, G_loss: 2.379
Epoch: [68/100]: D_loss: 0.256, G_loss: 2.453
Epoch: [69/100]: D_loss: 0.252, G_loss: 2.465
Epoch: [70/100]: D_loss: 0.240, G_loss: 2.600
Epoch: [71/100]: D_loss: 0.250, G_loss: 2.516
Epoch: [72/100]: D_loss: 0.228, G_loss: 2.534
Epoch: [73/100]: D_loss: 0.249, G_loss: 2.566
Epoch: [74/100]: D_loss: 0.385, G_loss: 2.915
Epoch: [75/100]: D_loss: 0.232, G_loss: 2.566
Epoch: [76/100]: D_loss: 0.335, G_loss: 2.776
Epoch: [77/100]: D_loss: 0.243, G_loss: 2.703
Epoch: [78/100]: D_loss: 0.232, G_loss: 2.650
Epoch: [79/100]: D_loss: 0.216, G_loss: 2.736
Epoch: [80/100]: D_loss: 0.219, G_loss: 2.725
Epoch: [81/100]: D_loss: 0.272, G_loss: 2.869
Epoch: [82/100]: D_loss: 0.218, G_loss: 2.839
Epoch: [83/100]: D_loss: 0.219, G_loss: 2.836
Epoch: [84/100]: D_loss: 0.233, G_loss: 2.948
Epoch: [85/100]: D_loss: 0.209, G_loss: 2.952
Epoch: [86/100]: D_loss: 0.251, G_loss: 3.052
Epoch: [87/100]: D_loss: 0.198, G_loss: 2.905
Epoch: [88/100]: D_loss: 0.193, G_loss: 3.054
Epoch: [89/100]: D_loss: 0.215, G_loss: 2.995
Epoch: [90/100]: D_loss: 0.193, G_loss: 3.081
Epoch: [91/100]: D_loss: 0.446, G_loss: 3.269
Epoch: [92/100]: D_loss: 0.227, G_loss: 2.871
Epoch: [93/100]: D_loss: 0.191, G_loss: 3.008
Epoch: [94/100]: D_loss: 0.200, G_loss: 3.066
Epoch: [95/100]: D_loss: 0.200, G_loss: 3.142
Epoch: [96/100]: D_loss: 0.186, G_loss: 3.113
Epoch: [97/100]: D_loss: 0.207, G_loss: 3.159
Epoch: [98/100]: D_loss: 0.219, G_loss: 3.213
Epoch: [99/100]: D_loss: 0.177, G_loss: 3.205
Epoch: [100/100]: D_loss: 0.184, G_loss: 3.258

4. 可视化

4.1.LOSS图

python 复制代码
G_loss_list = [i.item() for i in G_loss_plot]
D_loss_list = [i.item() for i in D_loss_plot]

import warnings

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100

plt.figure(figsize=(8,4))
plt.title("Generator and Descriminator Loss During Training")
plt.plot(G_loss_list,label = "G")
plt.plot(D_loss_list,label = "D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

输出图像为:

4.2.生成指定图像

python 复制代码
from numpy.random import randn

generator.load_state_dict(torch.load("./training_weights/generator_epoch_100.pth"), strict = False)
generator.eval()

interpolated = randn(100)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)

label = 0
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()

predictions = generator((interpolated, labels))
predictions = predictions.permute(0, 2, 3, 1).detach().cpu()

import warnings

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100


plt.figure(figsize=(8, 3))
pred = (predictions[0, :, :, :] + 1 ) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()

输出图像为:

四、理论基础

CGAN(条件生成对抗网络)的原理是在原始GAN的基础上,为生成器和判别器提供 额外的条件信息。

CGAN通过将条件信息(如类别标签或其他辅助信息)加入生成器和判别器的输入中,使得生成器能够根据这些条件信息生成特定类型的数据,而判别器则负责区分真实数据和生成数据是否符合这些条件。这种方式让生成器在生成数据时有了明确的方向,从而提高了生成数据的质量与相关性。

CGAN的特点包括有监督学习、联合隐层表征、可控性、使用卷积结构等,其具体内容为:

  1. **有监督学习:**CGAN通过额外信息的使用,将原本无监督的GAN转变为一种有监督的学习模式,这使得网络的训练更加目标明确,生成结果更加符合预期。
  2. **联合隐层表征:**在生成模型中,噪声输入和条件信息共同构成了联合隐层表征,这有助于生成更多样化且具有特定属性的数据。
  3. **可控性:**CGAN的一个关键特点是提高了生成过程的可控性,即可以通过调整条件信息来指导模型生成特定类型的数据。
  4. **使用卷积结构:**CGAN可以采用卷积神经网络作为其内部结构,这在图像相关的任务中尤其有效,因为它能够捕捉到局部特征,并提高模型对细节的处理能力。

相比于传统的GAN,CGAN的主要异同点包括条件信息的输入、训练稳定性、损失函数、网络结构等,其具体内容为:

  1. **条件信息的输入:**CGAN引入了条件变量,使得生成器和判别器都能接收到更多的信息来指导训练过程,这是传统GAN所不具备的。
  2. **训练稳定性:**传统GAN在训练过程中容易产生模式崩溃(mode collapse)的问题,而CGAN由于有了额外的条件信息,可以提高训练的稳定性和生成数据的多样性。
  3. **损失函数:**虽然CGAN的损失函数仍然保留了传统GAN的对抗损失函数的形式,但额外添加的条件信息使得损失计算更加复杂且有针对性。
  4. **网络结构:**在实现上,CGAN可以采用更深更复杂的网络结构,如卷积神经网络,这有助于处理更为复杂的数据类型,比如高分辨率图像。

CGAN网络结构如下图所示: 由上图的网络结构可知,条件信息y作为额外的输入被引入对抗网络中,与生成器中的噪声z合并作为隐含层表达;而在判别器D中,条件信息y则与原始数据x合并作为判别函数的输入。这种改进在以后的诸多方面研究中被证明是非常有效的,也为后续的相关工作提供了积极的指导作用

综上所述,CGAN的核心在于它通过引入条件信息来增强模型的生成能力和可控性,与传统GAN相比,它提供了更明确的训练目标和更好的生成效果。

相关推荐
草莓屁屁我不吃8 分钟前
200美元/月的ChatGPT Pro版上线?OpenAI草莓模型曝两周内发布,但模型表现要打个问号?
人工智能·chatgpt
AI-入门12 分钟前
AI 产品经理:2024 年职场新航标 ——AI 产品经理的未来与契机
人工智能·chatgpt·prompt·产品经理·agi
北极冰雨24 分钟前
Agent、RAG、LangChain的概念及作用
人工智能
极客小张25 分钟前
基于OpenCV与MQTT的停车场车牌识别系统:结合SQLite和Flask的设计流程
arm开发·人工智能·单片机·opencv·物联网·flask·毕业设计
z千鑫44 分钟前
【数据分析】利用Python+AI+工作流实现自动化数据分析-全流程讲解
人工智能·python·ai·数据分析·自动化·ai编程·ai工作流
雪碧有白泡泡1 小时前
Stable Diffusion AI算法,实现一键式后期处理与图像修复魔法
人工智能·算法·stable diffusion
高登先生1 小时前
科技之光,照亮未来之路“2024南京国际人工智能展会”
大数据·人工智能·科技·数学建模·能源
.5481 小时前
AI基础 L19 Quantifying Uncertainty and Reasoning with Probabilities I 量化不确定性和概率推理
人工智能
菌菌的快乐生活1 小时前
树莓派安装 OpenCV 教程
人工智能·opencv·计算机视觉
SEVEN-YEARS1 小时前
深度学习入门:探索神经网络、感知器与损失函数
人工智能·深度学习·神经网络