深度学习Day-30:CGAN入门丨生成手势图像丨可控制生成

🍨 本文为:[🔗365天深度学习训练营] 中的学习记录博客

🍖 原作者:[K同学啊 | 接辅导、项目定制]

要求:

  1. 结合代码进一步了解CGAN
  2. 学习如何运用生成好的生成器生成指定图像

一、 基础配置

  • 语言环境:Python3.8
  • 编译器选择:Pycharm
  • 深度学习环境:
    • torch==1.12.1+cu113
    • torchvision==0.13.1+cu113

二、 前期准备

1. 导入第三方库

python 复制代码
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import os

os.makedirs('./images', exist_ok=True)
os.makedirs('./training_weights', exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

得到如下输出:

python 复制代码
cuda

2. 导入数据

python 复制代码
batch_size = 128
train_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])

train_dataset = datasets.ImageFolder(root="GAN-3-data", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=0)

3. 数据可视化

运行下述代码:

python 复制代码
def show_images(dl):
    for images, _ in dl:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images.detach(), nrow=16).permute(1, 2, 0))
        break

show_images(train_loader)

输出图像为:

4. 定义超参数

运行下述代码:

python 复制代码
latent_dim = 100
n_classes = 3
embedding_dim = 100

5. 构建模型

5.1.初始化权重

python 复制代码
def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)

    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

5.2.定义生成器

python 复制代码
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_conditioned_generator = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),
            nn.Linear(embedding_dim, 16)
        )
        self.latent = nn.Sequential(
            nn.Linear(latent_dim, 4*4*512),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.model = nn.Sequential(
            nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, inputs):
        noise_vector, label = inputs
        label_output = self.label_conditioned_generator(label)
        label_output = label_output.view(-1, 1, 4, 4)
        latent_output = self.latent(noise_vector)
        latent_output = latent_output.view(-1, 512, 4, 4)
        concat = torch.cat((latent_output, label_output), dim=1)
        image = self.model(concat)
        return image

generator = Generator().to(device)
generator.apply(weights_init)
print(generator)

from torchinfo import summary
summary(generator)

输出为:

python 复制代码
Generator(
  (label_conditioned_generator): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=16, bias=True)
  )
  (latent): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (model): Sequential(
    (0): ConvTranspose2d(513, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Generator                                --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       1,616
├─Sequential: 1-2                        --
│    └─Linear: 2-3                       827,392
│    └─LeakyReLU: 2-4                    --
├─Sequential: 1-3                        --
│    └─ConvTranspose2d: 2-5              4,202,496
│    └─BatchNorm2d: 2-6                  1,024
│    └─ReLU: 2-7                         --
│    └─ConvTranspose2d: 2-8              2,097,152
│    └─BatchNorm2d: 2-9                  512
│    └─ReLU: 2-10                        --
│    └─ConvTranspose2d: 2-11             524,288
│    └─BatchNorm2d: 2-12                 256
│    └─ReLU: 2-13                        --
│    └─ConvTranspose2d: 2-14             131,072
│    └─BatchNorm2d: 2-15                 128
│    └─ReLU: 2-16                        --
│    └─ConvTranspose2d: 2-17             3,072
│    └─Tanh: 2-18                        --
=================================================================
Total params: 7,789,308
Trainable params: 7,789,308
Non-trainable params: 0
=================================================================

5.3.定义鉴别器

python 复制代码
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_condition_disc = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),
            nn.Linear(embedding_dim, 3 * 128 * 128)
        )

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 2, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 4, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 4, 64 * 8, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 8, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Dropout(0.4),
            nn.Linear(4608, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        img, label = inputs

        label_output = self.label_condition_disc(label)
        label_output = label_output.view(-1, 3, 128, 128)

        concat = torch.cat((img, label_output), dim=1)

        output = self.model(concat)
        return output

discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

summary(discriminator)

输出为:

python 复制代码
Discriminator(
  (label_condition_disc): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=49152, bias=True)
  )
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (6): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (9): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Flatten(start_dim=1, end_dim=-1)
    (12): Dropout(p=0.4, inplace=False)
    (13): Linear(in_features=4608, out_features=1, bias=True)
    (14): Sigmoid()
  )
)
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Discriminator                            --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       4,964,352
├─Sequential: 1-2                        --
│    └─Conv2d: 2-3                       6,144
│    └─LeakyReLU: 2-4                    --
│    └─Conv2d: 2-5                       131,072
│    └─BatchNorm2d: 2-6                  256
│    └─LeakyReLU: 2-7                    --
│    └─Conv2d: 2-8                       524,288
│    └─BatchNorm2d: 2-9                  512
│    └─LeakyReLU: 2-10                   --
│    └─Conv2d: 2-11                      2,097,152
│    └─BatchNorm2d: 2-12                 1,024
│    └─LeakyReLU: 2-13                   --
│    └─Flatten: 2-14                     --
│    └─Dropout: 2-15                     --
│    └─Linear: 2-16                      4,609
│    └─Sigmoid: 2-17                     --
=================================================================
Total params: 7,729,709
Trainable params: 7,729,709
Non-trainable params: 0
=================================================================

三、 训练模型

1. 定义训练参数

python 复制代码
adversarial_loss = nn.BCELoss()

def generator_loss(fake_output, label):
    gen_loss = adversarial_loss(fake_output, label)
    return gen_loss

def discriminator_loss(output, label):
    disc_loss = adversarial_loss(output, label)
    return disc_loss

2. 定义优化器

python 复制代码
learning_rate = 0.0002

G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

3. 训练模型

python 复制代码
num_epochs = 100

D_loss_plot, G_loss_plot = [], []

for epoch in range(1, num_epochs + 1):

    D_loss_list, G_loss_list = [], []

    for index, (real_images, labels) in enumerate(train_loader):
        D_optimizer.zero_grad()

        real_images = real_images.to(device)
        labels = labels.to(device)

        labels = labels.unsqueeze(1).long()

        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))

        D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)

        noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
        noise_vector = noise_vector.to(device)
        generated_image = generator((noise_vector, labels))

        output = discriminator((generated_image.detach(), labels))
        D_fake_loss = discriminator_loss(output, fake_target)

        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss)

        D_total_loss.backward()
        D_optimizer.step()

        G_optimizer.zero_grad()
        G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
        G_loss_list.append(G_loss)

        G_loss.backward()
        G_optimizer.step()

    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
        (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),
        torch.mean(torch.FloatTensor(G_loss_list))))

    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))

    if epoch % 10 == 0:
        save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow=5, normalize=True)
        torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch))
        torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch))

输出为:

python 复制代码
Epoch: [1/100]: D_loss: 0.285, G_loss: 2.018
Epoch: [2/100]: D_loss: 0.331, G_loss: 2.298
Epoch: [3/100]: D_loss: 0.403, G_loss: 1.715
Epoch: [4/100]: D_loss: 0.467, G_loss: 1.416
Epoch: [5/100]: D_loss: 0.490, G_loss: 1.618
Epoch: [6/100]: D_loss: 0.490, G_loss: 1.585
Epoch: [7/100]: D_loss: 0.379, G_loss: 1.674
Epoch: [8/100]: D_loss: 0.443, G_loss: 1.889
Epoch: [9/100]: D_loss: 0.541, G_loss: 2.067
Epoch: [10/100]: D_loss: 0.565, G_loss: 1.751
Epoch: [11/100]: D_loss: 0.528, G_loss: 1.495
Epoch: [12/100]: D_loss: 0.555, G_loss: 1.461
Epoch: [13/100]: D_loss: 0.569, G_loss: 1.490
Epoch: [14/100]: D_loss: 0.531, G_loss: 1.498
Epoch: [15/100]: D_loss: 0.504, G_loss: 1.532
Epoch: [16/100]: D_loss: 0.487, G_loss: 1.612
Epoch: [17/100]: D_loss: 0.457, G_loss: 1.776
Epoch: [18/100]: D_loss: 0.462, G_loss: 1.767
Epoch: [19/100]: D_loss: 0.437, G_loss: 1.946
Epoch: [20/100]: D_loss: 0.446, G_loss: 1.848
Epoch: [21/100]: D_loss: 0.463, G_loss: 1.718
Epoch: [22/100]: D_loss: 0.473, G_loss: 1.748
Epoch: [23/100]: D_loss: 0.503, G_loss: 1.579
Epoch: [24/100]: D_loss: 0.482, G_loss: 1.410
Epoch: [25/100]: D_loss: 0.489, G_loss: 1.440
Epoch: [26/100]: D_loss: 0.494, G_loss: 1.425
Epoch: [27/100]: D_loss: 0.510, G_loss: 1.398
Epoch: [28/100]: D_loss: 0.475, G_loss: 1.410
Epoch: [29/100]: D_loss: 0.473, G_loss: 1.459
Epoch: [30/100]: D_loss: 0.473, G_loss: 1.489
Epoch: [31/100]: D_loss: 0.462, G_loss: 1.484
Epoch: [32/100]: D_loss: 0.448, G_loss: 1.520
Epoch: [33/100]: D_loss: 0.457, G_loss: 1.548
Epoch: [34/100]: D_loss: 0.418, G_loss: 1.558
Epoch: [35/100]: D_loss: 0.433, G_loss: 1.667
Epoch: [36/100]: D_loss: 0.402, G_loss: 1.665
Epoch: [37/100]: D_loss: 0.401, G_loss: 1.709
Epoch: [38/100]: D_loss: 0.425, G_loss: 1.841
Epoch: [39/100]: D_loss: 0.399, G_loss: 1.711
Epoch: [40/100]: D_loss: 0.429, G_loss: 1.873
Epoch: [41/100]: D_loss: 0.374, G_loss: 1.857
Epoch: [42/100]: D_loss: 0.382, G_loss: 1.869
Epoch: [43/100]: D_loss: 0.431, G_loss: 1.935
Epoch: [44/100]: D_loss: 0.355, G_loss: 1.871
Epoch: [45/100]: D_loss: 0.363, G_loss: 1.875
Epoch: [46/100]: D_loss: 0.485, G_loss: 2.011
Epoch: [47/100]: D_loss: 0.391, G_loss: 1.994
Epoch: [48/100]: D_loss: 0.331, G_loss: 1.924
Epoch: [49/100]: D_loss: 0.317, G_loss: 1.930
Epoch: [50/100]: D_loss: 0.353, G_loss: 2.035
Epoch: [51/100]: D_loss: 0.334, G_loss: 2.072
Epoch: [52/100]: D_loss: 0.387, G_loss: 2.092
Epoch: [53/100]: D_loss: 0.380, G_loss: 2.139
Epoch: [54/100]: D_loss: 0.302, G_loss: 2.077
Epoch: [55/100]: D_loss: 0.311, G_loss: 2.055
Epoch: [56/100]: D_loss: 0.326, G_loss: 2.169
Epoch: [57/100]: D_loss: 0.309, G_loss: 2.239
Epoch: [58/100]: D_loss: 0.323, G_loss: 2.207
Epoch: [59/100]: D_loss: 0.285, G_loss: 2.239
Epoch: [60/100]: D_loss: 0.306, G_loss: 2.304
Epoch: [61/100]: D_loss: 0.287, G_loss: 2.254
Epoch: [62/100]: D_loss: 0.295, G_loss: 2.406
Epoch: [63/100]: D_loss: 0.305, G_loss: 2.499
Epoch: [64/100]: D_loss: 0.298, G_loss: 2.462
Epoch: [65/100]: D_loss: 0.255, G_loss: 2.418
Epoch: [66/100]: D_loss: 0.480, G_loss: 2.714
Epoch: [67/100]: D_loss: 0.265, G_loss: 2.379
Epoch: [68/100]: D_loss: 0.256, G_loss: 2.453
Epoch: [69/100]: D_loss: 0.252, G_loss: 2.465
Epoch: [70/100]: D_loss: 0.240, G_loss: 2.600
Epoch: [71/100]: D_loss: 0.250, G_loss: 2.516
Epoch: [72/100]: D_loss: 0.228, G_loss: 2.534
Epoch: [73/100]: D_loss: 0.249, G_loss: 2.566
Epoch: [74/100]: D_loss: 0.385, G_loss: 2.915
Epoch: [75/100]: D_loss: 0.232, G_loss: 2.566
Epoch: [76/100]: D_loss: 0.335, G_loss: 2.776
Epoch: [77/100]: D_loss: 0.243, G_loss: 2.703
Epoch: [78/100]: D_loss: 0.232, G_loss: 2.650
Epoch: [79/100]: D_loss: 0.216, G_loss: 2.736
Epoch: [80/100]: D_loss: 0.219, G_loss: 2.725
Epoch: [81/100]: D_loss: 0.272, G_loss: 2.869
Epoch: [82/100]: D_loss: 0.218, G_loss: 2.839
Epoch: [83/100]: D_loss: 0.219, G_loss: 2.836
Epoch: [84/100]: D_loss: 0.233, G_loss: 2.948
Epoch: [85/100]: D_loss: 0.209, G_loss: 2.952
Epoch: [86/100]: D_loss: 0.251, G_loss: 3.052
Epoch: [87/100]: D_loss: 0.198, G_loss: 2.905
Epoch: [88/100]: D_loss: 0.193, G_loss: 3.054
Epoch: [89/100]: D_loss: 0.215, G_loss: 2.995
Epoch: [90/100]: D_loss: 0.193, G_loss: 3.081
Epoch: [91/100]: D_loss: 0.446, G_loss: 3.269
Epoch: [92/100]: D_loss: 0.227, G_loss: 2.871
Epoch: [93/100]: D_loss: 0.191, G_loss: 3.008
Epoch: [94/100]: D_loss: 0.200, G_loss: 3.066
Epoch: [95/100]: D_loss: 0.200, G_loss: 3.142
Epoch: [96/100]: D_loss: 0.186, G_loss: 3.113
Epoch: [97/100]: D_loss: 0.207, G_loss: 3.159
Epoch: [98/100]: D_loss: 0.219, G_loss: 3.213
Epoch: [99/100]: D_loss: 0.177, G_loss: 3.205
Epoch: [100/100]: D_loss: 0.184, G_loss: 3.258

4. 可视化

4.1.LOSS图

python 复制代码
G_loss_list = [i.item() for i in G_loss_plot]
D_loss_list = [i.item() for i in D_loss_plot]

import warnings

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100

plt.figure(figsize=(8,4))
plt.title("Generator and Descriminator Loss During Training")
plt.plot(G_loss_list,label = "G")
plt.plot(D_loss_list,label = "D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

输出图像为:

4.2.生成指定图像

python 复制代码
from numpy.random import randn

generator.load_state_dict(torch.load("./training_weights/generator_epoch_100.pth"), strict = False)
generator.eval()

interpolated = randn(100)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)

label = 0
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()

predictions = generator((interpolated, labels))
predictions = predictions.permute(0, 2, 3, 1).detach().cpu()

import warnings

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100


plt.figure(figsize=(8, 3))
pred = (predictions[0, :, :, :] + 1 ) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()

输出图像为:

四、理论基础

CGAN(条件生成对抗网络)的原理是在原始GAN的基础上,为生成器和判别器提供 额外的条件信息。

CGAN通过将条件信息(如类别标签或其他辅助信息)加入生成器和判别器的输入中,使得生成器能够根据这些条件信息生成特定类型的数据,而判别器则负责区分真实数据和生成数据是否符合这些条件。这种方式让生成器在生成数据时有了明确的方向,从而提高了生成数据的质量与相关性。

CGAN的特点包括有监督学习、联合隐层表征、可控性、使用卷积结构等,其具体内容为:

  1. **有监督学习:**CGAN通过额外信息的使用,将原本无监督的GAN转变为一种有监督的学习模式,这使得网络的训练更加目标明确,生成结果更加符合预期。
  2. **联合隐层表征:**在生成模型中,噪声输入和条件信息共同构成了联合隐层表征,这有助于生成更多样化且具有特定属性的数据。
  3. **可控性:**CGAN的一个关键特点是提高了生成过程的可控性,即可以通过调整条件信息来指导模型生成特定类型的数据。
  4. **使用卷积结构:**CGAN可以采用卷积神经网络作为其内部结构,这在图像相关的任务中尤其有效,因为它能够捕捉到局部特征,并提高模型对细节的处理能力。

相比于传统的GAN,CGAN的主要异同点包括条件信息的输入、训练稳定性、损失函数、网络结构等,其具体内容为:

  1. **条件信息的输入:**CGAN引入了条件变量,使得生成器和判别器都能接收到更多的信息来指导训练过程,这是传统GAN所不具备的。
  2. **训练稳定性:**传统GAN在训练过程中容易产生模式崩溃(mode collapse)的问题,而CGAN由于有了额外的条件信息,可以提高训练的稳定性和生成数据的多样性。
  3. **损失函数:**虽然CGAN的损失函数仍然保留了传统GAN的对抗损失函数的形式,但额外添加的条件信息使得损失计算更加复杂且有针对性。
  4. **网络结构:**在实现上,CGAN可以采用更深更复杂的网络结构,如卷积神经网络,这有助于处理更为复杂的数据类型,比如高分辨率图像。

CGAN网络结构如下图所示: 由上图的网络结构可知,条件信息y作为额外的输入被引入对抗网络中,与生成器中的噪声z合并作为隐含层表达;而在判别器D中,条件信息y则与原始数据x合并作为判别函数的输入。这种改进在以后的诸多方面研究中被证明是非常有效的,也为后续的相关工作提供了积极的指导作用

综上所述,CGAN的核心在于它通过引入条件信息来增强模型的生成能力和可控性,与传统GAN相比,它提供了更明确的训练目标和更好的生成效果。

相关推荐
幂律智能1 分钟前
从AI使用风险到合同智能审查重构企业风控能力
人工智能·重构
视***间9 分钟前
端侧大模型落地新标杆:视程空间将GPT-OSS边缘AI深度导入NVIDIA Jetson平台
人工智能·gpt·边缘计算·nvidia·ai算力·gpt-oss·视程空间
1892280486123 分钟前
NY379固态MT29F32T08GSLBHL8-36QA:B
大数据·服务器·人工智能·科技·缓存
Adair_z23 分钟前
[SEO艺术重读] 第9篇 熊猫算法、企鹅算法和惩罚机制
人工智能·熊猫算法·企鹅算法·谷歌算法恢复·网站seo诊断·高质量内容创作·e-e-a-t原则
ZZH_AI项目交付25 分钟前
我把 AI 最容易改坏真实 App 的地方,整理成了 skills
人工智能·ios·app
忆~遂愿26 分钟前
从文字应答到具象共情:Agent 交互的底层革新
人工智能·深度学习·目标检测·microsoft·机器学习·ar·交互
Ai.den27 分钟前
Windows 安装 MinerU 3.x 实现本地批量解析 PDF
人工智能·windows·ai
枫叶林FYL33 分钟前
【强化学习】长上下文可验证奖励强化学习:原理推导与系统架构
人工智能·系统架构
Teable任意门互动33 分钟前
深度解析:AI 赋能开源多维表格,实现企业全场景数据整合与高效应用
数据库·人工智能·低代码·信息可视化·开源·数据库开发
沪漂阿龙36 分钟前
Hermes Agent 安全边界全解析:让 AI Agent 敢执行、可控制、能回滚
人工智能·安全