flow_grpo vs Flow-Factory GRPO 实现对比

总览

Flow-Factory 是 flow_grpo 的工程化重写和扩展。核心 GRPO 数学公式相同,但架构上从「每个模型一个脚本」重构为「模型/算法/奖励/调度器完全解耦的统一框架」。


1. 架构与代码组织

维度 flow_grpo Flow-Factory
入口 每个模型独立 scripts/train_*.py 统一 CLI ff-train config.yaml
训练脚本 单文件 ~960 行,含所有逻辑 分层: BaseTrainer(abc.py)GRPOTrainer(grpo.py),~400 行
模型管理 直接操作 pipeline.transformer BaseAdapter(abc.py,~1800行) 统一管理所有组件
样本表示 dict of tensors BaseSample dataclass + extra_kwargs
算法注册 无,硬编码 Registry 模式,trainer_type 字符串查找类
模型注册 无,每个脚本硬编码 Registry 模式,model_type 字符串查找适配器
调度器 直接使用 FlowMatchEulerDiscreteScheduler SDESchedulerMixin 抽象 + 4 种动力学实现
奖励系统 rewards.py 闭包工厂函数 BaseRewardModel ABC + RewardProcessor + RewardBuffer
配置 Python ml_collections.ConfigDict YAML + dataclass 层级 (Arguments, TrainingArguments 等)

2. SDE Step 核心数学

2.1 Flow-SDE 公式

两者公式完全相同

复制代码
 # std_dev_t
 flow_grpo:  sqrt(sigma / (1 - where(sigma==1, sigma_max, sigma))) * noise_level
 Flow-Factory: sqrt(sigma / (1 - where(sigma==1.0, sigma_max, sigma))) * noise_level  # 相同
 ​
 # next_latents_mean
 # 两者公式一致:
 # latents*(1 + std²/(2*sigma)*dt) + noise_pred*(1 + std²*(1-sigma)/(2*sigma))*dt
 ​
 # log_prob
 # 两者公式一致:
 # -((prev_sample - mean)²) / (2*variance²) - log(variance) - log(sqrt(2*pi))
 # 最后 mean(dim=(1,...,ndim-1))

2.2 CPS 公式

完全相同

复制代码
 # 两者:
 std_dev_t = sigma_prev * sin(noise_level * pi / 2)
 x0 = latents - sigma * noise_pred
 x1 = latents + noise_pred * (1 - sigma)
 mean = x0*(1-sigma_prev) + x1*sqrt(sigma_prev² - std_dev_t²)
 next_latents = mean + std_dev_t * noise     # 注意: CPS无 sqrt(-dt) 因子
 log_prob = -(next_latents - mean)²           # 去掉了常数项

2.3 SDE Step 差异

项目 flow_grpo Flow-Factory
动力学类型 sde, cps (字符串) Flow-SDE, Dance-SDE, CPS, ODE (枚举风格)
Dance-SDE 不支持 新增 : pred_original_sample = latents - sigma * noise_pred
ODE 显式模式 noise_level=0 时隐式 新增 : dynamics_type='ODE' 明确路径
dtype 一致 无处理 next_latents.to(_input_dtype).float() round-trip 确保采样/训练精度一致
sigma_max 处理 硬编码 self.sigmas[1].item() 参数化 sigma_max=sigma_max or self.sigmas[1].item()
broadcast 方式 .view(-1, *([1]*(len(sample.shape)-1))) 手动 to_broadcast_tensor() 工具函数
输出 返回 tuple SDESchedulerOutput dataclass + return_kwargs 按需返回
噪声级别控制 noise_level 固定值 get_noise_level_for_sigma() 逐步动态查询,支持窗口 SDE

3. 训练循环 (optimize)

3.1 PPO Clipped Loss

核心公式完全相同

复制代码
 # 两者:
 ratio = exp(log_prob - old_log_prob)
 unclipped_loss = -advantages * ratio
 clipped_loss = -advantages * clip(ratio, 1-ε, 1+ε)   # flow_grpo: symmetric; FF: asymmetric
 policy_loss = mean(max(unclipped_loss, clipped_loss))

3.2 clip_range 语义

项目 flow_grpo Flow-Factory
clip_range 类型 标量 float (对称) 二元组 [low, high] (支持非对称)
示例 clip_range=1e-4[1-1e-4, 1+1e-4] clip_range=[-1e-4, 1e-4][1-1e-4, 1+1e-4]
adv_clip_range 标量 (对称) 二元组 (支持非对称)

3.3 KL 散度

项目 flow_grpo Flow-Factory
KL 类型 仅有 x-based x-based + v-based (可配置)
x-based 公式 (mean - mean_ref)² / (2*std²) → mean (mean - mean_ref)² → mean(dim=(1..ndim)) → mean
v-based 公式 不支持 (noise_pred - noise_pred_ref)² → mean
参考模型 transformer.module.disable_adapter() (仅 LoRA) LoRA: disable_adapter() / Full: use_ref_parameters() (EMA 包装器)
KL 计算位置 torch.no_grad() 修复 : 在 no_grad() 外计算,确保正确梯度 (issue #122)

3.4 训练循环结构

项目 flow_grpo Flow-Factory
样本打乱 torch.randperm 无种子控制 create_generator(seed, epoch, inner_epoch) 确定性种子
批次构建 手动 reshape dict BaseSample.stack() + sample.to(device)
CPU offload 不支持 offload_samples_to_cpu 选项,采样后移到 CPU,训练时懒加载回 GPU
轨迹存储 存储所有 timestep 的 latents trajectory_indices 选择性存储 + latent_index_map / log_prob_index_map
梯度累积 accelerator.accumulate(transformer) 传单个模型 accelerator.accumulate(*trainable_components) 支持多组件
指标日志 wandb.log 直接调用 self.log_data() 抽象,支持 wandb/swanlab/tensorboard
指标前缀 无前缀 train/ 前缀 (如 train/policy_loss)

4. Advantage 计算

4.1 核心算法

两者默认都使用 GRPO 的 group-normalized advantage:

复制代码
 advantage = (reward - group_mean) / std

其中 std 可以是 global_std(全局)或逐组 std。

4.2 关键差异

项目 flow_grpo Flow-Factory
实现位置 PerPromptStatTracker (stat_tracking.py) AdvantageProcessor (advantage_processor.py)
历史追踪 有状态: 累积历史 rewards,跨 epoch 计算 mean/std 无状态: 每个 epoch 独立计算
多奖励聚合 multi_score() 中加权求和后在 rewards.py 层面聚合 compute_weighted_sum 策略,per-reward 分别记录统计
GDPO 策略 不支持 新增 : compute_gdpo --- 每个奖励独立组内归一化后再加权求和
通信感知 仅 all-gather 两种路径: group_contiguous (零通信) / distributed_k_repeat (all-gather)
group_size 校验 严格校验每组大小必须等于 group_size
advantage 类型扩展 grpo/rwr/sft/dpo 四种硬编码 aggregation_func 参数,支持自定义 callable
零 std 比例 calculate_zero_std_ratio 独立函数 内置于 AdvantageProcessor,支持分布式全局计算
统计日志 手动 wandb.log 自动构建 _pending_advantage_metrics,统一 log_data

5. 采样 (Sampling)

5.1 采样流程

相同点 : 都在推理时注入 SDE 噪声,记录 latentsnext_latentslog_probstimesteps

差异:

项目 flow_grpo Flow-Factory
CFG 嵌入 采样前预先 concat [neg, pos],传 prompt_embeds 时已拼接 通过 guidance_scale + do_classifier_free_guidance 参数传递,pipeline 内部处理
generator create_generator(prompts, seed) SHA256 hash create_generator_by_prompt(batch['prompt'], seed)
same_latent 支持 config.sample.same_latent 通过 scheduler seed 控制
pipeline 函数 模型特定的 pipeline_with_logprob() 统一的 adapter.inference() → 各模型 forward()
返回值 (images, latents, log_probs) List[BaseSample] 对象,包含所有字段
窗口 SDE (Fast) 单独的 pipeline_with_logprob_fast.py 统一在 scheduler 中通过 sde_steps/num_sde_steps 控制

5.2 数据加载

项目 flow_grpo Flow-Factory
采样器 DistributedKRepeatSampler (手动实现) 3 种: DistributedKRepeatSampler / GroupContiguousSampler / GroupDistributedSampler
数据集 TextPromptDataset / GenevalPromptDataset 硬编码 统一的 get_dataloader() + preprocess_func
预处理 运行时逐 batch 编码 prompt 数据集预处理阶段离线编码 + 缓存

6. 奖励系统

项目 flow_grpo Flow-Factory
架构 rewards.py 工厂函数,闭包返回 _fn BaseRewardModel ABC → PointwiseRewardModel / GroupwiseRewardModel
多奖励组合 multi_score() 在函数内加权求和 RewardProcessor.compute_rewards()AdvantageProcessor 策略聚合
远程奖励 手动 requests.Session async_reward 标志 + ThreadPoolExecutor 自动管理
奖励存储 sample["rewards"] dict sample.extra_kwargs['rewards']
评估奖励 与训练共用同一个 reward_fn 独立的 eval_reward_args + eval_reward_processor
batch_wise/group_wise 通过 only_strictref_images 参数区分 通过模型类型 Pointwise/Groupwise 自动分发
支持的奖励模型 PickScore, CLIPScore, Aesthetic, ImageReward, GenEval, OCR, DeQA, UnifiedReward, QwenVL, JPEG PickScore, PickScore_Rank, CLIP, OCR, vllm_evaluate, rational_rewards_t2i, rational_rewards_edit, CLAP, ImageBind

7. GRPO-Guard 实现

项目 flow_grpo Flow-Factory
文件 独立 train_sd3_GRPO_Guard.py (完整复制所有代码) GRPOGuardTrainer(GRPOTrainer) 继承,仅覆写 sample()optimize()
RatioNorm optimize() 中手写 同公式,封装在 optimize() 覆写中
额外存储 无显式说明 extra_call_back_kwargs=['next_latents_mean'] 采样时存储
scale_factor sqrt(-dt) * std_dev_t 相同
ratio 公式 exp((log_prob - old_log_prob) * scale_factor + mse / (2 * scale_factor)) 相同

8. EMA (指数移动平均)

项目 flow_grpo Flow-Factory
实现 EMAModuleWrapper (基础版) EMAModuleWrapper (增强版)
decay schedule 固定: min((1+step)/(10+step), decay) 多种: constant, power, linear, piecewise_linear, cosine, warmup_cosine
设备 自动 可配置 ema_device: cuda/cpu
上下文管理器 copy_ema_to(temp=True) / copy_temp_to() use_ema_parameters() 上下文管理器
参考模型 不适用 use_ref_parameters() --- 用 decay=0 的 EMA 存储原始权重

9. 配置系统

flow_grpo (config/grpo.py)

复制代码
 def pickscore_sd3_fast():
     config = compressibility()  # 继承基础配置
     config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
     config.sample.num_steps = 10
     config.sample.train_batch_size = 9
     # ... 约 30 行配置
     return config

Flow-Factory (examples/grpo/lora/sd3_5/default.yaml)

复制代码
 model:
   finetune_type: 'lora'
   model_name_or_path: "stabilityai/stable-diffusion-3.5-medium"
   model_type: "sd3-5"
 train:
   trainer_type: 'grpo'
   clip_range: 1.0e-4
 scheduler:
   dynamics_type: "Flow-SDE"
   noise_level: 0.8
   num_sde_steps: 2
 rewards:
   - name: "pick_score"
     reward_model: "PickScore"

差异 : flow_grpo 的 Python 配置是命令式编程(config.sample.train_batch_size = 9),功能强大但难以 diff;Flow-Factory 的 YAML 配置声明式,易于版本控制和实验管理。


10. 独有特性

flow_grpo 有而 Flow-Factory 没有

  • PerPromptStatTracker 的历史累积 advantage(对某些任务可能更稳定)

  • same_latent 选项(同 prompt 用相同噪声)

  • calculate_zero_std_ratio 在 stat_tracker 外独立调用

  • DistributedKRepeatSampler 中对 mask 的可整除性补齐逻辑(避免空 batch)

Flow-Factory 有而 flow_grpo 没有

  • Dance-SDE 动力学

  • v-based KL 散度

  • GDPO advantage 聚合策略

  • 非对称 clip_range

  • CPU offload 样本管道(节省 GPU 内存)

  • 异步奖励计算(API-based 奖励不阻塞采样)

  • 多组件 LoRA(同时训练多个 transformer)

  • Named Parameters 快照(用于算法如 DPO/GARDO 更新参考模型)

  • 3 种分布式采样器group_contiguous 零通信优势计算)

  • 多日志后端(wandb/swanlab/tensorboard)

  • 注意力后端 支持(flash_hub, xformers 等)

  • latent_storage_dtype 精度控制

  • FSDP2 支持

  • 视频/音频 模型支持(Wan, LTX-2)


11. 总结

Flow-Factory 保留了 flow_grpo 的所有核心数学公式(Flow-SDE、CPS、PPO clipped loss、GRPO advantage normalization、GRPO-Guard ratio reweighting),但在软件工程层面做了彻底的重构:

  • 解耦: 模型/算法/调度器/奖励各自独立,通过抽象基类和 registry 模式组合

  • 扩展性 : 新增算法只需实现 BaseTrainer,新增模型只需实现 BaseAdapter

  • 配置化: YAML 驱动,实验可复现、可 diff

  • 性能优化: CPU offload、选择性轨迹存储、异步奖励、零通信 advantage

  • 功能增强: v-based KL、GDPO 多奖励聚合、Dance-SDE、非对称 clipping、多日志后端

两者关系可类比为「研究原型代码」与「生产级框架」。

相关推荐
石逸凡1 小时前
新时代的信息茧房
大数据·人工智能
Jay-r1 小时前
积极的断舍离:化解时代性焦虑的生活哲学
人工智能·科技·生活·感悟·哲学
闵孚龙1 小时前
Claude Code 沙箱系统全解析:Seatbelt、Bubblewrap、AI Agent 安全隔离、权限治理与企业级防护
人工智能·安全
:mnong1 小时前
MIT OpenCourseWare 25周年庆典与学习者故事
人工智能·mitocw
带娃的IT创业者1 小时前
Claude Code 源码泄露事件深度剖析:当 AI 编程工具不再“透明”
人工智能·ai编程·ai安全·源码泄露·claude code·工程伦理
ʜᴇɴʀʏ2 小时前
TPAMI 2026 | Semi-DETR++:基于检测 Transformer 的高效半监督目标检测
深度学习·目标检测·transformer
zxsz_com_cn2 小时前
设备预测性维护系统集成的关键技术与实践
人工智能·物联网
TheRouter2 小时前
AI Agent 工具数量超过 12 个后,选择准确率从 95% 拦腰跌到53%
人工智能
啦啦啦_99992 小时前
神经网络基础
人工智能·深度学习·神经网络