R3GAN训练自己的数据集

简介

简介:这篇论文挑战了"GANs难以训练"的广泛观点,通过提出一个更稳定的损失函数和现代化的网络架构,构建了一个简洁而高效的GAN基线模型R3GAN。作者证明了通过合适的理论基础和架构设计,GANs可以稳定训练并达到优异性能。

论文题目:The GAN is dead; long live the GAN! A Modern Baseline GAN

会议:NeurIPS 2024

源码地址:https://www.github.com/brownvc/R3GAN

本文在调试代码的时候对代码做了一些修改,如果有遇到报错的问题可以直接复制我这篇博客修改后的代码:R3GAN利用配置好的Pytorch训练自己的数据集-CSDN博客这篇论文挑战了"GANs难以训练"的广泛观点,通过提出一个更稳定的损失函数和现代化的网络架构,构建了一个简洁而高效的GAN基线模型R3GAN。作者证明了通过合适的理论基础和架构设计,GANs可以稳定训练并达到优异性能。https://blog.csdn.net/LJ1147517021/article/details/148315781?fromshare=blogdetail&sharetype=blogdetail&sharerId=148315781&sharerefer=PC&sharesource=LJ1147517021&sharefrom=from_link

摘要:论文反驳了GANs难以训练的普遍观点,提出了一个理论有保障的现代GAN基线。首先,推导出一个良好行为的正则化相对论GAN损失函数,解决了模式丢弃和不收敛问题,并数学证明了其局部收敛性。其次,该损失函数允许丢弃所有经验性技巧,用现代架构替换常见GANs中的过时骨干网络。以StyleGAN2为例,展示了简化和现代化的路线图,产生了新的极简基线R3GAN。尽管简单,该方法在FFHQ、ImageNet、CIFAR和Stacked MNIST数据集上超越了StyleGAN2,与最先进的GANs和扩散模型相比表现优异。

模型结构

生成器架构

核心设计原则:

  • 基于现代化ResNet架构,摒弃VGG-like设计
  • 每个分辨率阶段包含一个过渡层和两个残差块
  • 采用分组卷积和倒置瓶颈设计

关键特性:

  • 无归一化层:避免批量归一化等数据相关的归一化
  • Fix-up初始化:零初始化每个残差块的最后一层卷积
  • 双线性插值:用于上采样,避免棋盘效应

鉴别器架构

设计特点:

  • 与生成器完全对称的架构
  • 相同的残差块结构和过渡层设计
  • 分类器头:全局4×4深度卷积 + 线性层

损失函数

相对论配对GAN损失 (RpGAN):

L(θ,ψ) = E[f(D_ψ(G_θ(z)) - D_ψ(x))]

R1正则化:

R1(ψ) = (γ/2) * E[||∇_x D_ψ(x)||²] (x~p_D)

R2正则化:

R2(θ,ψ) = (γ/2) * E[||∇_x D_ψ(x)||²] (x~p_θ)

训练自己的数据集

1. 准备数据集

首先使用 dataset_tool.py 将您的图像数据转换为适合训练的格式:

复制代码
# 从文件夹创建数据集
python dataset_tool.py --source=path/to/your/images --dest=path/to/output.zip

# 如果需要调整分辨率和裁剪
python dataset_tool.py --source=path/to/your/images --dest=path/to/output.zip \
    --resolution=256x256 --transform=center-crop

数据集要求:

  • 图像必须是正方形(如256x256, 512x512)
  • 分辨率必须是2的幂次(64, 128, 256, 512, 1024等)
  • 支持RGB或灰度图像
  • 可以是文件夹或ZIP格式

2. 创建自定义训练配置

train.py 中添加您自己的预设配置。参考现有预设,在 main() 函数中添加:

复制代码
if opts.preset == 'YOUR_DATASET':
    # 网络架构参数
    WidthPerStage = [768, 768, 768, 512, 256]  # 每阶段宽度
    BlocksPerStage = [2, 2, 2, 2, 2]           # 每阶段块数
    CardinalityPerStage = [96, 96, 96, 48, 24] # 每阶段基数
    FP16Stages = [-1, -2, -3, -4]              # FP16优化的阶段
    NoiseDimension = 64                         # 噪声维度
    
    # 如果是条件生成(有类别标签)
    if opts.cond:
        c.G_kwargs.ConditionEmbeddingDimension = NoiseDimension
        c.D_kwargs.ConditionEmbeddingDimension = WidthPerStage[0]
    
    # 训练调度参数
    ema_nimg = 500 * 1000      # EMA开始的图像数
    decay_nimg = 2e7           # 总衰减图像数
    
    # 各种调度器
    c.ema_scheduler = { 'base_value': 0, 'final_value': ema_nimg, 'total_nimg': decay_nimg }
    c.aug_scheduler = { 'base_value': 0, 'final_value': 0.3, 'total_nimg': decay_nimg }
    c.lr_scheduler = { 'base_value': 2e-4, 'final_value': 5e-5, 'total_nimg': decay_nimg }
    c.gamma_scheduler = { 'base_value': 2, 'final_value': 0.2, 'total_nimg': decay_nimg }
    c.beta2_scheduler = { 'base_value': 0.9, 'final_value': 0.99, 'total_nimg': decay_nimg }

3. 开始训练

复制代码
# 无条件生成(如人脸、风景等)
python train.py \
    --outdir=./training-runs \
    --data=./datasets/your_dataset.zip \
    --gpus=4 \
    --batch=256 \
    --mirror=1 \
    --aug=1 \
    --preset=YOUR_DATASET \
    --tick=1 \
    --snap=200

# 条件生成(有类别标签)
python train.py \
    --outdir=./training-runs \
    --data=./datasets/your_dataset.zip \
    --gpus=4 \
    --batch=256 \
    --mirror=1 \
    --aug=1 \
    --cond=1 \
    --preset=YOUR_DATASET \
    --tick=1 \
    --snap=200

4. 参数说明

  • --gpus: GPU数量
  • --batch: 总批次大小
  • --mirror: 是否启用水平翻转增强
  • --aug: 是否启用数据增强
  • --cond: 是否训练条件模型(需要标签)
  • --tick: 多少kimg输出一次进度
  • --snap: 多少tick保存一次模型

5. 生成图像

训练完成后,使用保存的模型生成图像:

复制代码
# 生成8张图像
python gen_images.py \
    --seeds=0-7 \
    --outdir=generated_images \
    --network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl

# 条件生成(指定类别)
python gen_images.py \
    --seeds=0-7 \
    --outdir=generated_images \
    --class=5 \
    --network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl

6. 评估指标

复制代码
python calc_metrics.py \
    --metrics=fid50k_full,kid50k_full \
    --data=./datasets/your_dataset.zip \
    --network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl

7.报错指南

1.UnboundLocalError: local variable 'NoiseDimension' referenced before assignment

解决办法:在 train.py 中,NoiseDimension 只在特定的预设配置块中定义(如 CIFAR10、FFHQ-64 等)。如果您使用的 --preset 参数不匹配任何现有预设,这个变量就不会被定义,导致使用时出错。可以使用作者定义好的预先设置。

复制代码
--preset=CIFAR10
--preset=FFHQ-64  
--preset=FFHQ-256
--preset=ImageNet-32
--preset=ImageNet-64

2.RuntimeError: Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "R3GAN\torch_utils\custom_ops.py".

解决办法:这个错误是因为R3GAN使用了自定义的CUDA操作符,需要C++编译器来编译。在Windows系统上缺少MSVC/GCC/CLANG编译器。

修改 torch_utils/custom_ops.py :找到 get_plugin 函数(大约第84行),在函数开头添加:

复制代码
def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
    # 禁用所有自定义插件
    return None


def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'):
    # 强制使用 'ref' 实现
    impl = 'ref'
相关推荐
星夜98220 分钟前
C++回顾 Day6
开发语言·数据结构·c++·算法
LaughingZhu1 小时前
PH热榜 | 2025-05-29
前端·人工智能·经验分享·搜索引擎·产品运营
视觉语言导航2 小时前
俄罗斯无人机自主任务规划!UAV-CodeAgents:基于多智能体ReAct和视觉语言推理的可扩展无人机任务规划
人工智能·深度学习·无人机·具身智能
世润2 小时前
深度学习-梯度消失和梯度爆炸
人工智能·深度学习
pen-ai2 小时前
【深度学习】11. Transformer解析: Self-Attention、ELMo、Bert、GPT
深度学习·bert·transformer
小彭律师2 小时前
LSTM+Transformer混合模型架构文档
人工智能·lstm·transformer
-曾牛2 小时前
使用Spring AI集成Perplexity AI实现智能对话(详细配置指南)
java·人工智能·后端·spring·llm·大模型应用·springai
归去_来兮4 小时前
长短期记忆(LSTM)网络模型
人工智能·深度学习·lstm·时序模型
asom224 小时前
LeetCode Hot100(矩阵)
算法·leetcode·矩阵
蒟蒻小袁4 小时前
力扣面试150题--二叉树的右视图
算法·leetcode·面试