PyTorch生成式人工智能——PatchGAN详解与实现

PyTorch生成式人工智能------PatchGAN详解与实现

    • [0. 前言](#0. 前言)
    • [1. PatchGAN 核心原理](#1. PatchGAN 核心原理)
    • [2. 网络架构](#2. 网络架构)
    • [3. 实现 PatchGAN](#3. 实现 PatchGAN)
      • [3.1 数据集加载](#3.1 数据集加载)
      • [3.2 模型构建](#3.2 模型构建)
      • [3.3 模型训练](#3.3 模型训练)
    • 相关链接

0. 前言

生成对抗网络 (Generative Adversarial Network, GAN)的发展历程中,如何提升生成图像的细节质量一直是研究者关注的核心问题。传统的全局判别器往往过于关注图像的整体结构,而忽略了局部细节的真实性,导致生成的面部图像缺乏细腻的纹理特征。正是在这样的技术背景下,PatchGAN 作为一种创新的判别器架构脱颖而出。它通过将图像分割为多个局部区域并进行独立判别,极大地增强了对高频细节的感知能力。本节将深入探讨 PatchGAN 相关原理,并使用 PyTorch 从零开始实现 PatchGAN,并使用 huggan/selfie2anime 数据集训练模型生成动漫人脸图像。

1. PatchGAN 核心原理

PatchGAN 的设计是为了解决图像生成任务中一个常见问题:传统的判别器 (Discriminator) 需要将一整张图片输入,然后输出单一标量值用于预测输入是真实/虚假。这可能导致生成图片在整体结构上看起来合理,但局部细节模糊或缺乏高频信息。
PatchGAN 判别器不是将整张图像分类为真或假,而是将图像分割成重叠的 N x N 图像块 (Patche),并尝试对每个图像块的真假进行分类,如下图所示。最终,判别器的输出不是一个值,而是一个矩阵(或特征图),其中的每个元素代表对应图像块为真实图像的概率。

PatchGAN 提供了全局图像块级约束,为了进一步强化关键细节可以采用局部判别 (Local Discriminator)。全局判别容易忽略微小纹理(如眼睛高光、睫毛、皮肤纹理),局部判别补强这些高频信息。选取图像中的局部裁剪(例如 128×128 )(随机或其它感兴趣区域),专门训练一个判别器区分该局部图像的真实/虚假。

2. 网络架构

模型采用一个生成器与两个判别器,两个分别提供全局图像块级约束(判别器 D_patch )与局部关键细节约束(判别器 D_local),判别器 D_patch 接受生成器产生的图像与真实图像作为输入,并输出 一个 M x M 的矩阵(例如 30x30, 70x70 等,取决于架构深度和输入大小),矩阵中的每个元素 (i, j) 对应于输入图像中的一个感受野 (Receptive Field),即一个图像块 Patch。计算这个 M x M 输出矩阵与一个全为 1 (真实)或 0 (虚假)的同样大小矩阵之间的损失,使用 BCEWithLogitsLoss (或 hinge loss)。

判别器 D_local 与原始生成对抗网络 (Generative Adversarial Network, GAN)类似,不同之处在于接收图像中的局部裁剪区域作为输入,用于强化细节纹理与局部结构的真实感。

3. 实现 PatchGAN

接下来,我们使用 Hugging Facehuggan/selfie2anime 数据集(仅动漫脸部分 imageB )训练 PatchGAN,并采用 PatchGAN + 局部判别器的复合判别策略提升局部细节质量。

3.1 数据集加载

(1) 首先,导入所需库,并定义超参数与设备:

python 复制代码
import os
import math
import random
from pathlib import Path
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
from PIL import Image
from datasets import load_dataset

IMG_SIZE = 256
BATCH_SIZE = 16
Z_DIM = 256
LR = 2e-4
BETA1, BETA2 = 0.5, 0.999
EPOCHS = 200
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAMPLES_DIR = Path("samples")
CHECKPOINT_DIR = Path("checkpoints")
SAMPLES_DIR.mkdir(exist_ok=True, parents=True)
CHECKPOINT_DIR.mkdir(exist_ok=True, parents=True)

(2) 下载huggan/selfie2anime 数据集并定义数据集类 HFImageDataset()

python 复制代码
def prepare_hf_dataset(image_field="imageB", split="train"):
    print("Loading dataset from Hugging Face: huggan/selfie2anime ...")
    hf = load_dataset("huggan/selfie2anime", split=split)  # split='train'
    print("Dataset loaded. Rows:", len(hf))
    return hf

class HFImageDataset(Dataset):
    def __init__(self, hf_dataset, image_field="imageB", transform=None):
        self.ds = hf_dataset
        self.image_field = image_field
        self.transform = transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]
        img = item[self.image_field]
        if isinstance(img, (list, tuple)):
            img = img[0]
        if not isinstance(img, Image.Image):
            img = Image.fromarray(img)
        if self.transform:
            img = self.transform(img)
        return img

(3) 预处理使用 RandomCrop 缩放后裁剪,并归一化到 [-1,1],便于在生成器中使用 tanh 输出:

python 复制代码
train_transform = transforms.Compose([
    transforms.Resize(int(IMG_SIZE * 1.12), interpolation=Image.BICUBIC),
    transforms.RandomCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])

(4) 加载数据集,且仅使用动漫脸图像 (imageB):

python 复制代码
hf = prepare_hf_dataset(image_field="imageB", split="train")
dataset = HFImageDataset(hf, image_field="imageB", transform=train_transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)

3.2 模型构建

(1) 定义生成器,先将长度为 Z_DIM 的噪声通过全连接映射到 ngf x 4 x 4,再通过一系列 ConvTranspose2d 逐步放大到 256×256,输出用 tanh 映射到 [-1, 1],与数据归一化一致:

python 复制代码
class Generator(nn.Module):
    def __init__(self, z_dim=Z_DIM, ngf=1024, out_channels=3):
        super().__init__()
        # 从 z 映射到 4x4 特征图
        self.fc = nn.Linear(z_dim, ngf * 4 * 4)
        self.net = nn.Sequential(
            # 4x4 -> 8x8
            nn.ConvTranspose2d(ngf, ngf//2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf//2),
            nn.ReLU(True),
            # 8x8 -> 16x16
            nn.ConvTranspose2d(ngf//2, ngf//4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf//4),
            nn.ReLU(True),
            # 16x16 -> 32x32
            nn.ConvTranspose2d(ngf//4, ngf//8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf//8),
            nn.ReLU(True),
            # 32x32 -> 64x64
            nn.ConvTranspose2d(ngf//8, ngf//16, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf//16),
            nn.ReLU(True),
            # 64x64 -> 128x128
            nn.ConvTranspose2d(ngf//16, ngf//32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf//32),
            nn.ReLU(True),
            # 128x128 -> 256x256
            nn.ConvTranspose2d(ngf//32, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # 最后一层输出 3 通道
            nn.Conv2d(64, out_channels, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        x = self.fc(z).view(z.size(0), -1, 4, 4)  # B x ngf x 4 x 4
        return self.net(x)

(2) 定义 PatchGAN 判别器,输出为特征图(每个位置对应一个 patch 的判别),在训练过程中,我们可以对这个输出使用 BCEWithLogitsLoss 损失,标签 1 对应真实 patch0 对应虚假 patch

python 复制代码
class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=3, ndf=64):
        super().__init__()
        # 多层卷积下采样,最后输出单通道特征图(判别 map)
        def conv_block(in_c, out_c, stride, use_bn=True):
            layers = [nn.Conv2d(in_c, out_c, kernel_size=4, stride=stride, padding=1, bias=not use_bn)]
            if use_bn:
                layers.append(nn.BatchNorm2d(out_c))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        layers += [nn.Conv2d(in_channels, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True)]
        layers += conv_block(ndf, ndf*2, stride=2)
        layers += conv_block(ndf*2, ndf*4, stride=2)
        layers += conv_block(ndf*4, ndf*8, stride=1)  # keep stride=1 to keep patch granularity
        layers += [nn.Conv2d(ndf*8, 1, kernel_size=4, stride=1, padding=1)]  # 单通道 logits map
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)  # 输出 shape: (B,1,H_p,W_p)

(3) 定义局部判别器,最终输出一个标量 logit (每个局部 patch 一个判别),训练过程中我们从真实图像与生成图像中随机裁剪局部 patch (例如 128×128) 传递给局部判别器:

python 复制代码
class LocalDiscriminator(nn.Module):
    def __init__(self, in_channels=3, ndf=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, ndf, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1), nn.BatchNorm2d(ndf*2), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1), nn.BatchNorm2d(ndf*4), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1), nn.BatchNorm2d(ndf*8), nn.LeakyReLU(0.2, True),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(ndf*8, 1, kernel_size=1)  # 输出 1x1 logit
        )

    def forward(self, x):
        # x 形状 (B,3,H_loc,W_loc)(例如 128x128)
        out = self.net(x)  # B x 1 x 1 x 1
        return out.view(x.size(0), 1)  # (B,1)

3.3 模型训练

(1) 初始化模型、优化器以及损失:

python 复制代码
G = Generator(z_dim=Z_DIM).to(DEVICE)
D_patch = PatchDiscriminator().to(DEVICE)
D_local = LocalDiscriminator().to(DEVICE)

# 初始化权重
def init_weights(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            nn.init.constant_(m.bias.data, 0)

G.apply(init_weights)
D_patch.apply(init_weights)
D_local.apply(init_weights)

# 优化器
opt_G = optim.Adam(G.parameters(), lr=LR, betas=(BETA1, BETA2))
opt_D = optim.Adam(list(D_patch.parameters()) + list(D_local.parameters()), lr=LR, betas=(BETA1, BETA2))

# 损失
bce_loss = nn.BCEWithLogitsLoss().to(DEVICE)

(2) 定义辅助函数,random_local_patches() 用于在每张图像上随机裁剪同一大小的局部 patch,训练局部判别器,save_samples() 把生成图像从 [-1,1] 映射回 [0,1] 并保存:

python 复制代码
def random_local_patches(imgs, loc_size=128):
    B, C, H, W = imgs.shape
    patches = []
    for i in range(B):
        top = random.randint(0, H - loc_size)
        left = random.randint(0, W - loc_size)
        patches.append(imgs[i:i+1, :, top:top+loc_size, left:left+loc_size])
    return torch.cat(patches, dim=0).to(imgs.device)  # (B,3,loc_size,loc_size)

def save_samples(fake_imgs, step, nrow=4):
    grid = utils.make_grid((fake_imgs + 1) / 2.0, nrow=nrow, padding=2, normalize=False)
    save_path = SAMPLES_DIR / f"sample_{step:06d}.png"
    utils.save_image(grid, save_path)
    print("Saved samples to", save_path)

(3) 定义训练流程,判别器训练阶段同时训练 Patch 判别器和局部判别器;Patch 判别器的损失与局部判别器按比例合并,生成器阶段以欺骗两个判别器为目标:

python 复制代码
global_step = 0
sample_z = torch.randn(16, Z_DIM, device=DEVICE)  # 固定噪声用于可视化

for epoch in range(EPOCHS):
    loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for real in loop:
        real = real.to(DEVICE)  # (B,3,256,256)
        B = real.size(0)
        
        # Train Discriminators (patch + local)
        opt_D.zero_grad()

        # Patch discriminator on real
        logits_real_patch = D_patch(real)
        # create labels (same shape as logits)
        labels_real_patch = torch.ones_like(logits_real_patch, device=DEVICE) * 0.9  # label smoothing
        loss_D_patch_real = bce_loss(logits_real_patch, labels_real_patch)

        # Generate fake images
        z = torch.randn(B, Z_DIM, device=DEVICE)
        fake = G(z)
        logits_fake_patch = D_patch(fake.detach())
        labels_fake_patch = torch.zeros_like(logits_fake_patch, device=DEVICE)
        loss_D_patch_fake = bce_loss(logits_fake_patch, labels_fake_patch)

        loss_D_patch = (loss_D_patch_real + loss_D_patch_fake) * 0.5

        # Local discriminator on random patches
        real_local = random_local_patches(real, loc_size=128)  # (B,3,128,128)
        fake_local = random_local_patches(fake.detach(), loc_size=128)

        logits_real_local = D_local(real_local)  # (B,1)
        logits_fake_local = D_local(fake_local)  # (B,1)
        labels_real_local = torch.ones_like(logits_real_local, device=DEVICE) * 0.9
        labels_fake_local = torch.zeros_like(logits_fake_local, device=DEVICE)

        loss_D_local = 0.5 * (bce_loss(logits_real_local, labels_real_local) +
                              bce_loss(logits_fake_local, labels_fake_local))

        # Total discriminator loss
        loss_D = loss_D_patch + 0.7 * loss_D_local
        loss_D.backward()
        opt_D.step()

        # Train Generator
        opt_G.zero_grad()
        z2 = torch.randn(B, Z_DIM, device=DEVICE)
        fake2 = G(z2)
        # Try to fool both discriminators

        # Patch GAN objective (we want D_patch(fake) -> 1)
        logits_fake_patch_forG = D_patch(fake2)
        labels_G_patch = torch.ones_like(logits_fake_patch_forG, device=DEVICE)
        loss_G_patch = bce_loss(logits_fake_patch_forG, labels_G_patch)

        # Local objective
        fake_local2 = random_local_patches(fake2, loc_size=128)
        logits_fake_local_forG = D_local(fake_local2)
        labels_G_local = torch.ones_like(logits_fake_local_forG, device=DEVICE)
        loss_G_local = bce_loss(logits_fake_local_forG, labels_G_local)

        loss_G = loss_G_patch + 0.7 * loss_G_local
        loss_G.backward()
        opt_G.step()
        global_step += 1
        if global_step % 10 == 0:
            loop.set_postfix({
                "loss_D": f"{loss_D.item():.4f}",
                "loss_G": f"{loss_G.item():.4f}"
            })
        if global_step % 500 == 0:
            with torch.no_grad():
                samples = G(sample_z).cpu()
                save_samples(samples, global_step, nrow=4)
    torch.save({
        'epoch': epoch,
        'G_state_dict': G.state_dict(),
        'D_patch_state_dict': D_patch.state_dict(),
        'D_local_state_dict': D_local.state_dict(),
        'opt_G': opt_G.state_dict(),
        'opt_D': opt_D.state_dict()
    }, CHECKPOINT_DIR / f"ckpt_epoch_{epoch:03d}.pth")
    print(f"Saved checkpoint for epoch {epoch}")

生成结果如下所示,可以看到随着训练生成的图像越来越逼真:

相关链接

PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch生成式人工智能(1)------神经网络与模型训练过程详解
PyTorch生成式人工智能(2)------PyTorch基础
PyTorch生成式人工智能(3)------使用PyTorch构建神经网络
PyTorch生成式人工智能(4)------卷积神经网络详解
PyTorch生成式人工智能(5)------分类任务详解
PyTorch生成式人工智能(6)------生成模型(Generative Model)详解
PyTorch生成式人工智能(7)------生成对抗网络实践详解
PyTorch生成式人工智能(8)------深度卷积生成对抗网络
PyTorch生成式人工智能(9)------Pix2Pix详解与实现
PyTorch生成式人工智能(10)------CyclelGAN详解与实现
PyTorch生成式人工智能(11)------神经风格迁移
PyTorch生成式人工智能(12)------StyleGAN详解与实现
PyTorch生成式人工智能(13)------WGAN详解与实现
PyTorch生成式人工智能(14)------条件生成对抗网络(conditional GAN,cGAN)
PyTorch生成式人工智能(15)------自注意力生成对抗网络(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)------自编码器(AutoEncoder)详解
PyTorch生成式人工智能(17)------变分自编码器详解与实现
PyTorch生成式人工智能(18)------循环神经网络详解与实现
PyTorch生成式人工智能(19)------自回归模型详解与实现
PyTorch生成式人工智能(20)------像素卷积神经网络(PixelCNN)
PyTorch生成式人工智能(24)------使用PyTorch构建Transformer模型
PyTorch生成式人工智能(25)------基于Transformer实现机器翻译
PyTorch生成式人工智能(26)------使用PyTorch构建GPT模型
PyTorch生成式人工智能(27)------从零开始训练GPT模型
PyTorch生成式人工智能(28)------MuseGAN详解与实现

相关推荐
LLM精进之路6 分钟前
上海AI实验室突破扩散模型!GetMesh融合点云与三平面,重塑3D内容创作
人工智能·深度学习·机器学习·语言模型·transformer
肥仔哥哥193030 分钟前
基于OpenCv做照片分析(Java)
java·人工智能·opencv·图像原理
2501_924879261 小时前
客流特征识别误报率↓76%!陌讯多模态时序融合算法在智慧零售的实战解析
大数据·人工智能·算法·目标检测·计算机视觉·视觉检测·零售
说私域1 小时前
消费、渠道与技术变革下新零售的崛起与开源AI大模型AI智能名片S2B2C商城小程序的融合发展
人工智能·开源·零售
北京地铁1号线1 小时前
广告推荐模型2:因子分解机(Factorization Machines, FM)
人工智能·算法·推荐算法
做一个快乐的小傻瓜1 小时前
机器学习笔记
人工智能·决策树·机器学习
居然JuRan1 小时前
MCP:基础概念、快速应用和背后原理
人工智能
1ucency2 小时前
Dify插件“Database”安装及配置
人工智能
eqwaak02 小时前
科技信息差(8.26)
大数据·开发语言·人工智能·编辑器
念夏沫2 小时前
“华生科技杯”2025年全国青少年龙舟锦标赛在海宁举行
大数据·人工智能·科技