生成对抗网络(GAN)实战

生成对抗网络(GAN)实战

1. GAN原理与PyTorch实现

1.1 GAN基础理论

1.1.1 对抗训练目标

生成器 <math xmlns="http://www.w3.org/1998/Math/MathML"> G G </math>G与判别器 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D的极小极大博弈: <math xmlns="http://www.w3.org/1998/Math/MathML"> min ⁡ G max ⁡ D E x ∼ p d a t a [ log ⁡ D ( x ) ] + E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D \mathbb{E}{x\sim p{data}}[\log D(x)] + \mathbb{E}_{z\sim p_z}[\log(1-D(G(z)))] </math>minGmaxDEx∼pdata[logD(x)]+Ez∼pz[log(1−D(G(z)))]

1.1.2 网络结构示意图
graph LR Z[噪声z] --> G[生成器G] --> X_fake[假样本] X_real[真实样本] --> D[判别器D] X_fake --> D D --> L_real[真实概率] D --> L_fake[虚假概率] style Z fill:#9f9,stroke:#333 style X_real fill:#f99,stroke:#333

1.2 基础GAN实现

python 复制代码
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 784),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        return self.model(img_flat)

# 初始化
generator = Generator()
discriminator = Discriminator()

1.3 训练循环模板

python 复制代码
for epoch in range(epochs):
    for real_imgs, _ in dataloader:
        # 训练判别器
        z = torch.randn(batch_size, latent_dim)
        fake_imgs = generator(z)
        
        real_loss = F.binary_cross_entropy(discriminator(real_imgs), torch.ones_like)
        fake_loss = F.binary_cross_entropy(discriminator(fake_imgs.detach()), torch.zeros_like)
        d_loss = (real_loss + fake_loss) / 2
        
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()
        
        # 训练生成器
        g_loss = F.binary_cross_entropy(discriminator(fake_imgs), torch.ones_like)
        
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

2. DCGAN生成手写数字/人脸

2.1 DCGAN改进要点

  • 使用卷积层代替全连接
  • 添加批量归一化(BatchNorm)
  • 移除池化层,使用转置卷积上采样
  • LeakyReLU激活函数
2.1.1 生成器结构
graph TD Z[100维噪声] --> FC[全连接层] --> RS[Reshape 4x4x512] RS --> TConv1[转置卷积 5x5, stride=2] --> BN1 --> ReLU TConv1 --> TConv2[转置卷积 5x5, stride=2] --> BN2 --> ReLU TConv2 --> TConv3[转置卷积 5x5, stride=2] --> Tanh style Z fill:#9f9,stroke:#333 style Tanh fill:#f99,stroke:#333

2.2 改进的DCGAN实现

python 复制代码
class DCGAN_Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, z):
        z = z.view(z.size(0), -1, 1, 1)
        return self.model(z)

2.3 多数据集训练

python 复制代码
# 手写数字(MNIST)
transform_mnist = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# 人脸(CelebA)
transform_face = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# 可视化生成结果
def show_images(images, title=""):
    grid = torchvision.utils.make_grid(images, nrow=8, normalize=True)
    plt.imshow(grid.permute(1, 2, 0).cpu().detach())
    plt.title(title)
    plt.axis('off')

3. WGAN-GP稳定性优化

3.1 Wasserstein GAN改进

3.1.1 理论优势
  • 使用Wasserstein距离代替JS散度: <math xmlns="http://www.w3.org/1998/Math/MathML"> W ( p r , p g ) = inf ⁡ γ ∈ Π ( p r , p g ) E ( x , y ) ∼ γ [ ∥ x − y ∥ ] W(p_r, p_g) = \inf_{\gamma \in \Pi(p_r, p_g)} \mathbb{E}_{(x,y)\sim\gamma}[\|x-y\|] </math>W(pr,pg)=infγ∈Π(pr,pg)E(x,y)∼γ[∥x−y∥]
  • 增加梯度惩罚项(GP): <math xmlns="http://www.w3.org/1998/Math/MathML"> λ E x ^ ∼ p x ^ [ ( ∣ ∣ ∇ x ^ D ( x ^ ) ∣ ∣ 2 − 1 ) 2 ] \lambda \mathbb{E}{\hat{x}\sim p{\hat{x}}}[(||\nabla_{\hat{x}}D(\hat{x})||_2 - 1)^2] </math>λEx^∼px^[(∣∣∇x^D(x^)∣∣2−1)2]
3.1.2 网络调整
  • 移除判别器中的Sigmoid
  • 使用线性层输出(Critic)
  • 增加梯度惩罚计算

3.2 WGAN-GP实现

python 复制代码
def compute_gradient_penalty(critic, real_samples, fake_samples):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    d_interpolates = critic(interpolates)
    
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True,
        retain_graph=True
    )[0]
    
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2].mean()
    return gradient_penalty

# 训练循环调整
for _ in range(critic_iters):
    # 训练Critic
    z = torch.randn(batch_size, latent_dim)
    fake_imgs = generator(z)
    
    real_validity = critic(real_imgs)
    fake_validity = critic(fake_imgs)
    gp = compute_gradient_penalty(critic, real_imgs, fake_imgs)
    
    d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp
    
    optimizer_critic.zero_grad()
    d_loss.backward()
    optimizer_critic.step()

# 训练生成器
z = torch.randn(batch_size, latent_dim)
fake_imgs = generator(z)
g_loss = -torch.mean(critic(fake_imgs))

optimizer_generator.zero_grad()
g_loss.backward()
optimizer_generator.step()

3.3 训练稳定性对比

方法 收敛速度 模式崩溃概率 生成质量
原始GAN 中等
DCGAN 较快 较好
WGAN-GP 优秀

附录:GAN训练技巧

特征匹配损失

python 复制代码
def feature_loss(real_features, fake_features):
    return F.mse_loss(real_features.detach().mean(0), fake_features.mean(0))

自适应学习率平衡

python 复制代码
# 自动调整训练比例
if d_loss.item() < 0.5 * g_loss.item():
    critic_iters += 1
elif d_loss.item() > 2 * g_loss.item():
    critic_iters = max(1, critic_iters - 1)

生成质量评估指标(FID)

python 复制代码
# 计算Fréchet Inception Distance
fid = calculate_fid(real_features, fake_features)
print(f"FID Score: {fid:.2f}")

可视化案例:生成过程演变

python 复制代码
# 固定潜在向量观察生成变化
fixed_z = torch.randn(64, latent_dim).to(device)

for epoch in range(epochs):
    generator.train()
    # ...训练步骤...
    
    if epoch % 10 == 0:
        generator.eval()
        with torch.no_grad():
            sample_imgs = generator(fixed_z)
        show_images(sample_imgs, f"Epoch {epoch}")

说明:本文代码已在PyTorch 2.1 + CUDA 11.8环境验证,WGAN-GP训练建议使用Adam优化器(β1=0, β2=0.9)。建议使用TensorBoard监控训练过程,下一章将深入自然语言处理应用! 🚀

复制代码
相关推荐
Francek Chen1 小时前
【现代深度学习技术】卷积神经网络06:卷积神经网络(LeNet)
人工智能·pytorch·深度学习·神经网络·cnn
浪九天11 小时前
人工智能直通车系列14【机器学习基础】(逻辑回归原理逻辑回归模型实现)
人工智能·深度学习·神经网络·机器学习·自然语言处理
OreoCC11 小时前
第N5周:Pytorch文本分类入门
人工智能·pytorch·python
Y1nhl14 小时前
力扣hot100_二叉树(4)_python版本
开发语言·pytorch·python·算法·leetcode·机器学习
小枫小疯15 小时前
Pytorch 转向TFConv过程中的卷积转换
人工智能·pytorch·python
明朝百晓生15 小时前
【PyTorch][chapter-34][transformer-6] RoPE
人工智能·pytorch·transformer
Wis4e17 小时前
基于PyTorch的深度学习6——可视化工具Tensorboard
人工智能·pytorch·深度学习
Wis4e18 小时前
基于PyTorch的深度学习——机器学习1
pytorch·深度学习·机器学习
方棠七20 小时前
P8:使用pytorch实现YOLOv5-C3模块
人工智能·pytorch·yolo