生成对抗网络(GAN)实战

生成对抗网络(GAN)实战

1. GAN原理与PyTorch实现

1.1 GAN基础理论

1.1.1 对抗训练目标

生成器 G G G与判别器 D D D的极小极大博弈: min ⁡ G max ⁡ D E x ∼ p d a t a log ⁡ D ( x ) + E z ∼ p z log ⁡ ( 1 − D ( G ( z ) ) ) \min_G \max_D \mathbb{E}{x\sim p{data}}\\log D(x) + \mathbb{E}_{z\sim p_z}\\log(1-D(G(z))) minGmaxDEx∼pdatalogD(x)+Ez∼pzlog(1−D(G(z)))

1.1.2 网络结构示意图
graph LR Z[噪声z] --> G[生成器G] --> X_fake[假样本] X_real[真实样本] --> D[判别器D] X_fake --> D D --> L_real[真实概率] D --> L_fake[虚假概率] style Z fill:#9f9,stroke:#333 style X_real fill:#f99,stroke:#333

1.2 基础GAN实现

python 复制代码
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 784),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        return self.model(img_flat)

# 初始化
generator = Generator()
discriminator = Discriminator()

1.3 训练循环模板

python 复制代码
for epoch in range(epochs):
    for real_imgs, _ in dataloader:
        # 训练判别器
        z = torch.randn(batch_size, latent_dim)
        fake_imgs = generator(z)
        
        real_loss = F.binary_cross_entropy(discriminator(real_imgs), torch.ones_like)
        fake_loss = F.binary_cross_entropy(discriminator(fake_imgs.detach()), torch.zeros_like)
        d_loss = (real_loss + fake_loss) / 2
        
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()
        
        # 训练生成器
        g_loss = F.binary_cross_entropy(discriminator(fake_imgs), torch.ones_like)
        
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

2. DCGAN生成手写数字/人脸

2.1 DCGAN改进要点

  • 使用卷积层代替全连接
  • 添加批量归一化(BatchNorm)
  • 移除池化层,使用转置卷积上采样
  • LeakyReLU激活函数
2.1.1 生成器结构
graph TD Z[100维噪声] --> FC[全连接层] --> RS[Reshape 4x4x512] RS --> TConv1[转置卷积 5x5, stride=2] --> BN1 --> ReLU TConv1 --> TConv2[转置卷积 5x5, stride=2] --> BN2 --> ReLU TConv2 --> TConv3[转置卷积 5x5, stride=2] --> Tanh style Z fill:#9f9,stroke:#333 style Tanh fill:#f99,stroke:#333

2.2 改进的DCGAN实现

python 复制代码
class DCGAN_Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, z):
        z = z.view(z.size(0), -1, 1, 1)
        return self.model(z)

2.3 多数据集训练

python 复制代码
# 手写数字(MNIST)
transform_mnist = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# 人脸(CelebA)
transform_face = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# 可视化生成结果
def show_images(images, title=""):
    grid = torchvision.utils.make_grid(images, nrow=8, normalize=True)
    plt.imshow(grid.permute(1, 2, 0).cpu().detach())
    plt.title(title)
    plt.axis('off')

3. WGAN-GP稳定性优化

3.1 Wasserstein GAN改进

3.1.1 理论优势
  • 使用Wasserstein距离代替JS散度: W ( p r , p g ) = inf ⁡ γ ∈ Π ( p r , p g ) E ( x , y ) ∼ γ ∥ x − y ∥ W(p_r, p_g) = \inf_{\gamma \in \Pi(p_r, p_g)} \mathbb{E}_{(x,y)\sim\gamma}\\\|x-y\\\| W(pr,pg)=infγ∈Π(pr,pg)E(x,y)∼γ∥x−y∥
  • 增加梯度惩罚项(GP): λ E x ^ ∼ p x ^ ( ∣ ∣ ∇ x \^ D ( x \^ ) ∣ ∣ 2 − 1 ) 2 \lambda \mathbb{E}{\hat{x}\sim p{\hat{x}}}(\|\|\\nabla_{\\hat{x}}D(\\hat{x})\|\|_2 - 1)\^2 λEx^∼px^(∣∣∇x\^D(x\^)∣∣2−1)2
3.1.2 网络调整
  • 移除判别器中的Sigmoid
  • 使用线性层输出(Critic)
  • 增加梯度惩罚计算

3.2 WGAN-GP实现

python 复制代码
def compute_gradient_penalty(critic, real_samples, fake_samples):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    d_interpolates = critic(interpolates)
    
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True,
        retain_graph=True
    )[0]
    
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2].mean()
    return gradient_penalty

# 训练循环调整
for _ in range(critic_iters):
    # 训练Critic
    z = torch.randn(batch_size, latent_dim)
    fake_imgs = generator(z)
    
    real_validity = critic(real_imgs)
    fake_validity = critic(fake_imgs)
    gp = compute_gradient_penalty(critic, real_imgs, fake_imgs)
    
    d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp
    
    optimizer_critic.zero_grad()
    d_loss.backward()
    optimizer_critic.step()

# 训练生成器
z = torch.randn(batch_size, latent_dim)
fake_imgs = generator(z)
g_loss = -torch.mean(critic(fake_imgs))

optimizer_generator.zero_grad()
g_loss.backward()
optimizer_generator.step()

3.3 训练稳定性对比

方法 收敛速度 模式崩溃概率 生成质量
原始GAN 中等
DCGAN 较快 较好
WGAN-GP 优秀

附录:GAN训练技巧

特征匹配损失

python 复制代码
def feature_loss(real_features, fake_features):
    return F.mse_loss(real_features.detach().mean(0), fake_features.mean(0))

自适应学习率平衡

python 复制代码
# 自动调整训练比例
if d_loss.item() < 0.5 * g_loss.item():
    critic_iters += 1
elif d_loss.item() > 2 * g_loss.item():
    critic_iters = max(1, critic_iters - 1)

生成质量评估指标(FID)

python 复制代码
# 计算Fréchet Inception Distance
fid = calculate_fid(real_features, fake_features)
print(f"FID Score: {fid:.2f}")

可视化案例:生成过程演变

python 复制代码
# 固定潜在向量观察生成变化
fixed_z = torch.randn(64, latent_dim).to(device)

for epoch in range(epochs):
    generator.train()
    # ...训练步骤...
    
    if epoch % 10 == 0:
        generator.eval()
        with torch.no_grad():
            sample_imgs = generator(fixed_z)
        show_images(sample_imgs, f"Epoch {epoch}")

说明:本文代码已在PyTorch 2.1 + CUDA 11.8环境验证,WGAN-GP训练建议使用Adam优化器(β1=0, β2=0.9)。建议使用TensorBoard监控训练过程,下一章将深入自然语言处理应用! 🚀

复制代码
相关推荐
极光代码工作室4 小时前
基于深度学习的手写数字识别系统
人工智能·python·深度学习·神经网络·机器学习
没有钱的钱仔5 小时前
pytorch_cuda安装
人工智能·pytorch·python
weixin_550083156 小时前
全量的记忆压缩与意义保存
人工智能·深度学习·神经网络·transformer·agi
闵孚龙7 小时前
Tensor:PyTorch 世界里的一切都是张量
人工智能·pytorch·python
湘美书院--湘美谈教育8 小时前
湘美谈教育湘美书院考古教育系列:湖湘一万年序列整理研究
大数据·人工智能·深度学习·神经网络·机器学习
一条大祥脚8 小时前
Tilelang-Metax|MoE|torch baseline
pytorch·moe
SilentSamsara9 小时前
模型部署实战:FastAPI + ONNX + Docker 的推理服务化
人工智能·pytorch·python·深度学习·机器学习·fastapi
m0_图灵灵9 小时前
吴恩达《深度学习》之看懂神经网络的“底层细胞”:逻辑回归
深度学习·神经网络·逻辑回归
闵孚龙10 小时前
Autograd 自动求导:PyTorch 训练模型的发动机
人工智能·pytorch·python
云和数据.ChenGuang10 小时前
大模型厂商常用的数据库有哪些?
数据库·人工智能·pytorch·深度学习·numpy