Flow-GRPO vs Flow-Factory: SD3 GRPO 实现对比

1. 整体架构

维度 flow_grpo Flow-Factory
入口 scripts/train_sd3.py (967行单文件) ff-train examples/grpo/lora/sd3_5/default.yaml
架构 每个模型独立脚本,代码重复 三层抽象: Trainer → Adapter → Scheduler
核心循环 全在 train_sd3.py 内联 BaseTrainer.start() 统一调度,子类重写 sample/prepare_feedback/optimize

flow_grpotrain_sd3.py main() 包含所有逻辑:数据加载、采样、奖励、advantage、训练循环全部写在一个函数里。每个模型有独立的 train_*.py(如 train_flux.py, train_qwenimage.py),存在大量代码重复。

Flow-Factory 将训练流程拆分为:

复制代码
 Trainer (算法: grpo.py) → Adapter (模型: sd3_5.py) → Scheduler (动力学: flow_match_euler_discrete.py)

每层通过 filter_kwargs() 只传递各自需要的参数,模型和算法可以任意组合。


2. 配置系统

维度 flow_grpo Flow-Factory
配置格式 Python (ml_collections) YAML + dataclass
配置文件 config/base.py + config/grpo.py examples/grpo/lora/sd3_5/default.yaml
参数传递 全局 config 对象,. 访问 Arguments dataclass,强类型
覆盖方式 修改 Python 代码 YAML override / CLI

flow_grpo 配置示例 (config/grpo.py):

复制代码
 config.sample.num_steps = 10
 config.sample.noise_level = 0.7
 config.train.clip_range = 1e-4
 config.sample.global_std = False

Flow-Factory 配置示例 (default.yaml):

复制代码
 train:
   trainer_type: 'grpo'
   clip_range: 1.0e-4
   global_std: true
 scheduler:
   dynamics_type: "Flow-SDE"
   noise_level: 0.8
   num_sde_steps: 2

YAML 方式更易于版本管理和实验追踪,不需要修改代码即可切换配置。


3. 采样 / Rollout

flow_grpo 采样在 train_sd3.py 内联实现,Flow-Factory 通过 Adapter + Scheduler 分离。

3.1 推理 Pipeline

flow_grpo --- sd3_pipeline_with_logprob.py:

  • 直接修改 diffusers 的 SD3 pipeline,在其中加入 SDE noise 和 log_prob 计算

  • 返回 (images, all_latents, all_log_probs) --- 所有中间 latents 都存在显存里

  • CFG 在 pipeline 内部处理(拼接 negative+positive embedding 再做一次 forward)

复制代码
 # flow_grpo: 全量存储
 all_latents = [latents]  # 存储所有 timestep 的 latent
 all_log_probs = []
 for i, t in enumerate(timesteps):
     latent_model_input = torch.cat([latents] * 2)  # CFG 在内部
     noise_pred = transformer(...)
     latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
         self.scheduler, noise_pred, t, latents, noise_level=noise_level)
     all_latents.append(latents)
     all_log_probs.append(log_prob)

Flow-Factory --- SD3_5Adapter.inference():

  • 通过 trajectory_indices 选择性存储中间 latents(只存需要训练的 timestep)

  • CFG 在 forward() 内部统一处理

  • 返回 List[SD3_5Sample] --- 每个样本是独立的 dataclass 对象

  • 支持 extra_call_back_kwargs 按需存储额外信息(如 GRPO-Guard 需要的 next_latents_mean

复制代码
 # Flow-Factory: 选择性存储
 latent_collector = create_trajectory_collector(trajectory_indices, num_inference_steps)
 latent_collector.collect(latents, step_idx=0)
 for i, t in enumerate(timesteps):
     output = self.forward(t=t, t_next=t_next, ...)
     latents = output.next_latents
     latent_collector.collect(latents, i + 1)  # 只存 trajectory_indices 对应的

3.2 前向传播位置

flow_grpo : transformer forward 在 pipeline 内部(sd3_pipeline_with_logprob.py:147),采样和训练各自独立调用 transformer。

Flow-Factory : transformer forward 在 SD3_5Adapter.forward() 中统一实现,采样和训练共用同一段代码。


4. SDE 与 Log-Probability 计算

4.1 实现位置

维度 flow_grpo Flow-Factory
SDE 实现 独立函数 sde_step_with_logprob() Scheduler 类 FlowMatchEulerDiscreteSDEScheduler.step()
文件 flow_grpo/diffusers_patch/sd3_sde_with_logprob.py src/flow_factory/scheduler/flow_match_euler_discrete.py
与 diffusers 关系 猴子补丁(独立函数传入 scheduler) 继承 FlowMatchEulerDiscreteScheduler

4.2 动力学类型

flow_grpo: 仅支持 2 种

  • sde: Flow-GRPO 原始 SDE, std_dev_t = sqrt(sigma/(1-sigma)) * noise_level

  • cps: Coefficients-Preserving Sampling

Flow-Factory : 支持 4 种(通过 dynamics_type 参数切换)

  • Flow-SDE: 与 flow_grpo 的 sde 等效

  • Dance-SDE: Dance-GRPO 变体(新增)

  • CPS: 与 flow_grpo 的 cps 等效

  • ODE: 确定性采样(eval 时 noise_level=0)

4.3 精度处理

flow_grpo:

复制代码
 # 所有变量强制转 float32 避免 bf16 溢出
 model_output = model_output.float()
 sample = sample.float()

Flow-Factory : 同样的 float32 转换,但额外记录了输入 dtype,在采样后将 next_latents round-trip 回原始精度,确保 log_prob 在训练和采样阶段使用相同精度:

复制代码
 _input_dtype = latents.dtype  # 记录输入精度
 noise_pred = noise_pred.float()
 latents = latents.float()
 # ... SDE 计算 ...
 next_latents = next_latents.to(_input_dtype).float()  # round-trip

4.4 窗口 SDE(Fast 模式)

flow_grpo : 单独的文件 sd3_pipeline_with_logprob_fast.py + train_sd3_fast.py

Flow-Factory : 集成在同一个 Scheduler 中,通过 sde_steps + num_sde_steps + seed 参数配置:

复制代码
 # 从 sde_steps 中随机选择 num_sde_steps 个步骤注入噪声
 @property
 def current_sde_steps(self):
     generator = torch.Generator().manual_seed(self.seed)
     selected = torch.randperm(len(self.sde_steps), generator=generator)[:self.num_sde_steps]
     return self.sde_steps[selected]

这种方式避免了维护两套代码。

4.5 Log-Probability 公式

两者数学上等价。flow_grpo 的 CPS log_prob 去掉了常数项(只有分子),Flow-Factory 保留完整公式。Flow-SDE 的 log_prob 完全一致:

复制代码
log_prob = -((next_latents - prev_sample_mean)^2) / (2 * std_dev_t^2 * |dt|)
           - log(std_dev_t * sqrt(|dt|))
           - log(sqrt(2*pi))
# 沿非 batch 维度取 mean
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))

5. Advantage 计算

5.1 实现方式

flow_grpo --- PerPromptStatTracker (stat_tracking.py):

  • 维护 per-prompt 的历史 rewards 队列

  • 支持 grpo / rwr / sft / dpo 四种 advantage 类型

  • global_std=True 时用全局 std,False 时用组内 std

  • 需要手动 gather prompt_ids 和 rewards,跨 rank 通信由 train_sd3.py 内联处理

复制代码
# train_sd3.py 内联 gather + advantage 计算
prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy()
prompts = tokenizer.batch_decode(prompt_ids)
advantages = stat_tracker.update(prompts, gathered_rewards['avg'])
# ungather
advantages = advantages.reshape(num_processes, -1)[process_index]

Flow-Factory --- AdvantageProcessor:

  • 独立模块,支持两种聚合策略:

    • sum (GRPO): 加权求和 → 组内归一化(等效 flow_grpo)

    • gdpo (GDPO): 每个奖励独立组内归一化 → 加权求和 → 全局 batch norm

  • 通信感知: 自动区分 distributed_k_repeat(需 gather)和 group_contiguous(无需通信)

  • 返回详细的统计指标(per-reward mean/std,group std distribution,zero_std_ratio)

复制代码
# Flow-Factory: 一行调用
rewards = self.reward_buffer.finalize(store_to_samples=True, split='all')
self.compute_advantages(samples, rewards, store_to_samples=True)
# advantage 存入 sample.extra_kwargs['advantage']

5.2 关键差异

维度 flow_grpo Flow-Factory
历史追踪 维护 per-prompt 历史 stats(跨 epoch) 仅当前 batch 内统计
多奖励聚合 不支持(手动加权后传入) sum/gdpo 两种策略,自动处理
通信策略 硬编码 gather→compute→scatter 自动检测 sampler_type 选择路径
组大小校验 校验 group_size 一致性

flow_grpo 的 PerPromptStatTracker 会跨 epoch 累积历史 rewards,mean 基于历史而非仅当前 batch。Flow-Factory 默认只使用当前 batch。


6. PPO 训练循环与 Loss

6.1 训练循环结构

flow_grpo --- 内联在 main() 中:

复制代码
for inner_epoch in range(num_inner_epochs):
    perm = torch.randperm(total_batch_size)
    samples = {k: v[perm] for k, v in samples.items()}
    for j in train_timesteps:
        with accelerator.accumulate(transformer):
            # 1. 计算当前 log_prob
            prev_sample, log_prob, prev_sample_mean, std_dev_t = compute_log_prob(...)
            # 2. 参考模型 log_prob (KL)
            if beta > 0:
                with transformer.module.disable_adapter():
                    _, _, prev_sample_mean_ref, _ = compute_log_prob(...)
            # 3. PPO loss
            ratio = exp(log_prob - old_log_prob)
            policy_loss = mean(max(-A*ratio, -A*clip(ratio, 1-ε, 1+ε)))
            # 4. KL loss (x-based only)
            kl_loss = ((prev_sample_mean - prev_sample_mean_ref)^2).mean() / (2 * std_dev_t^2)
            loss = policy_loss + beta * kl_loss

Flow-Factory --- 封装在 GRPOTrainer.optimize():

复制代码
for inner_epoch in range(num_inner_epochs):
    shuffled_samples = [samples[i] for i in randperm]
    for batch_idx in range(num_batches):
        batch = BaseSample.stack(batch_samples)
        for timestep_index in train_timesteps:
            with accelerator.accumulate(*trainable_components):
                # 1. 当前模型 forward(复用 Adapter.forward)
                output = self.adapter.forward(t, t_next, latents, next_latents, ...)
                # 2. 参考模型 forward
                if enable_kl_loss:
                    with self.adapter.use_ref_parameters():
                        ref_output = self.adapter.forward(...)
                # 3. PPO loss (相同公式)
                ratio = exp(output.log_prob - old_log_prob)
                policy_loss = mean(max(-A*ratio, -A*clip(ratio, 1-ε, 1+ε)))
                # 4. KL loss (支持 v-based + x-based)
                if kl_type == 'v-based':
                    kl_div = mean((noise_pred - ref_noise_pred)^2)
                elif kl_type == 'x-based':
                    kl_div = mean((next_latents_mean - ref_next_latents_mean)^2)

6.2 PPO Loss 差异

维度 flow_grpo Flow-Factory
ratio 计算 exp(log_prob - old_log_prob)
clip_range 标量 1e-4 支持非对称 [low, high]
adv_clip 标量 adv_clip_max 支持非对称 [adv_low, adv_high]
KL 类型 仅 x-based v-based + x-based
KL 公式 (x-based) mean((mean - ref_mean)^2) / (2 * std^2) mean((mean - ref_mean)^2) (无 1/(2σ²))
梯度累积 手动设置 gradient_accumulation_steps * num_train_timesteps accelerator.accumulate(*trainable_components)
数据组织 dict of tensors (手动索引) BaseSample.stack() → 统一 batch 操作

6.3 训练组件管理

flow_grpo : 直接操作 pipeline.transformer,手动 accelerator.prepare(transformer, optimizer)

Flow-Factory : 通过 target_module_map 管理多个可训练组件(如 Wan 模型有多个 transformer),get_trainable_parameters() / trainable_components 统一接口。


7. KL 散度

flow_grpo: 只支持 x-based KL

复制代码
# 参考模型: disable_adapter (LoRA) 
with transformer.module.disable_adapter():
    _, _, prev_sample_mean_ref, _ = compute_log_prob(...)
kl_loss = ((prev_sample_mean - prev_sample_mean_ref) ** 2).mean(dim=(1,2,3), keepdim=True) / (2 * std_dev_t ** 2)

Flow-Factory: 支持两种 KL 类型

复制代码
# v-based: KL in velocity (noise_pred) space
kl_div = mean((output.noise_pred - ref_output.noise_pred) ** 2)

# x-based: KL in latent space  
kl_div = mean((output.next_latents_mean - ref_output.next_latents_mean) ** 2)

参考模型实现也不同:

  • flow_grpo: 使用 LoRA 的 disable_adapter() 关掉 LoRA 权重

  • Flow-Factory: 通过 use_ref_parameters() 上下文管理器,支持 EMA 包装器(全参数微调)和 disable_adapter()(LoRA)


8. 奖励系统

维度 flow_grpo Flow-Factory
实现 单文件函数 rewards.py:multi_score() 模块化: Pointwise/Groupwise/Globalwise 基类
调用方式 executor.submit(reward_fn, images, prompts, metadata) RewardProcessor + RewardBuffer 管理
异步支持 手动 ThreadPoolExecutor RewardBuffer 内建同步/异步两种模式
多奖励 score_dict = {"pickscore": 1.0, "aesthetic": 0.5} YAML 列表: rewards: [{name, reward_model, weight}]
额外处理 RewardBuffer.finalize(store_to_samples=True) 自动存储到样本

flow_grpo 的奖励通过远程服务器模式运行(避免依赖冲突,继承 ddpo-pytorch):

复制代码
rewards = executor.submit(reward_fn, images, prompts, prompt_metadata)
time.sleep(0)  # yield to start computation
rewards, reward_metadata = rewards.result()

Flow-FactoryRewardBuffer 封装了同步/异步逻辑:

复制代码
self.reward_buffer.add_samples(samples)          # 可能异步
rewards = self.reward_buffer.finalize(...)        # 等待完成,收集结果

9. 数据加载与分布式采样

flow_grpo --- DistributedKRepeatSampler:

  • 内联在 train_sd3.py 中定义

  • 手动实现跨 rank 同步的确定性采样

  • 无限循环 while True

  • 只支持 distributed_k_repeat 模式

复制代码
class DistributedKRepeatSampler(Sampler):
    def __iter__(self):
        while True:
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(dataset), generator=g)[:self.m]
            repeated = [idx for idx in indices for _ in range(self.k)]
            shuffled = torch.randperm(len(repeated), generator=g)
            # 分片到各 rank
            yield per_card_samples[self.rank]

Flow-Factory --- data_utils/ 模块:

  • 支持 3 种采样器: distributed_k_repeat / group_contiguous / group_distributed

  • group_contiguous: 所有 K 个副本在同一 rank,无需跨 rank 通信

  • 通过 sampler_type: "auto" 自动选择


10. EMA

flow_grpo --- EMAModuleWrapper:

复制代码
ema = EMAModuleWrapper(
    transformer_trainable_parameters, 
    decay=0.9, 
    update_step_interval=8, 
    device=accelerator.device
)
# 手动管理
ema.copy_ema_to(parameters, store_temp=True)  # eval 前
ema.copy_temp_to(parameters)                   # eval 后
ema.step(parameters, global_step)              # 训练后

Flow-Factory --- BaseAdapter.ema_step():

复制代码
self.adapter.ema_step(step=self.epoch)
# eval 时自动通过 use_ema_parameters() 上下文切换
with self.adapter.use_ema_parameters():
    ...

集成在 Adapter 中,支持多种 decay schedule,不需要手动 copy/copy_temp/step。


11. 日志系统

flow_grpo : 直接使用 wandb.log(),手动构造 log dict

复制代码
wandb.log({
    "epoch": epoch,
    "reward_pickscore": value.mean(),
    "approx_kl": ...,
    "clipfrac": ...,
}, step=global_step)

Flow-Factory : 通过 log_data() 统一接口,支持 wandb / swanlab / tensorboard

复制代码
self.log_data({'train/policy_loss': loss, 'train/adv_abs_mean': ...}, step=self.step)

AdvantageProcessor.pop_advantage_metrics() 自动生成详细统计指标。


12. 其他差异

12.1 GRPO-Guard 实现

flow_grpo : 单独的 train_sd3_GRPO_Guard.py 文件,大量代码重复。

Flow-Factory : GRPOGuardTrainer 继承 GRPOTrainer,只重写 sample()(增加 extra_call_back_kwargs)和 optimize()(使用重加权 ratio),复用其他所有逻辑。

12.2 依赖版本

flow_grpo Flow-Factory
diffusers 0.33.1 >=0.36.0
transformers 4.40.0 >=4.57.1
accelerate 1.4.0 >=1.11.0
peft 0.10.0 >=0.17.0

12.3 样本后处理

flow_grpo : 在 compute_log_prob() 中每次训练重算当前 log_prob 时也调 sde_step_with_logprob(),需要传入 prev_sample(即 sample["next_latents"])来计算 log_prob。

Flow-Factory : 采样时已通过 trajectory_indices 选择性存储了中间 latents 和 log_probs,训练时从 batch['all_latents']batch['log_probs'] 中按 index 取出,调用 adapter.forward() 重算当前 log_prob。

12.4 zero advantage 过滤

flow_grpo 有一个特殊处理:过滤掉 advantage 全为零的样本,并保证过滤后的 batch 能被 num_batches_per_epoch 整除:

复制代码
mask = (samples["advantages"].abs().sum(dim=1) != 0)
# 如果不整除,随机补一些样本
if true_count % num_batches != 0:
    random_indices = torch.randperm(len(false_indices))[:num_to_change]
    mask[false_indices[random_indices]] = True
samples = {k: v[mask] for k, v in samples.items()}

Flow-Factory 没有这个逻辑。


13. 总结

Flow-Factory 是 flow_grpo 的全面工程化重构,核心算法等效但架构完全不同:

维度 flow_grpo Flow-Factory
代码组织 单文件内联 多层抽象解耦
模型扩展 每模型复制脚本 实现 Adapter 即可
算法扩展 每算法复制脚本 继承 Trainer 即可
动力学扩展 修改 SDE 函数 Scheduler 多态
配置管理 Python config YAML + dataclass
分布式通信 手动 gather/scatter 通信感知 AdvantageProcessor
内存效率 存储所有 timestep trajectory_indices 选择性存储
多奖励聚合 不支持 sum / gdpo
KL 类型 x-based v-based + x-based
日志后端 wandb wandb / swanlab / tensorboard
代码量 (SD3 GRPO) ~967行 train_sd3.py + patch文件 ~450行 grpo.py + ~350行 adapter + ~440行 scheduler
相关推荐
rundreamsFly1 小时前
Dify 1.14.0 发布:从“单机玩具”到“工业级协作”的硬核进化
人工智能·dify
平行侠1 小时前
40希尔排序 - 以递减间距进行插入排序
java·算法·排序算法
__土块__1 小时前
RAG技术详解与应用实践
人工智能·技术分享·rag·ai技术·检索增强生成
多年小白1 小时前
A股算力租赁板块 深度分析
大数据·人工智能·ai·金融·区块链
林熙蕾LXL1 小时前
进程处理操作
开发语言·c++·算法
IT_陈寒1 小时前
Redis突然吃掉所有内存,我的服务差点挂了
前端·人工智能·后端
代码无bug抓狂人1 小时前
用回溯算法解决01背包
数据结构·算法
Risk Actuary1 小时前
快速傅里叶变换与聚合风险精算模型(二)
人工智能
Shan12051 小时前
二叉树的遍历算法之中序遍历
算法