1. 整体架构
| 维度 | flow_grpo | Flow-Factory |
|---|---|---|
| 入口 | scripts/train_sd3.py (967行单文件) |
ff-train examples/grpo/lora/sd3_5/default.yaml |
| 架构 | 每个模型独立脚本,代码重复 | 三层抽象: Trainer → Adapter → Scheduler |
| 核心循环 | 全在 train_sd3.py 内联 | BaseTrainer.start() 统一调度,子类重写 sample/prepare_feedback/optimize |
flow_grpo 的 train_sd3.py main() 包含所有逻辑:数据加载、采样、奖励、advantage、训练循环全部写在一个函数里。每个模型有独立的 train_*.py(如 train_flux.py, train_qwenimage.py),存在大量代码重复。
Flow-Factory 将训练流程拆分为:
Trainer (算法: grpo.py) → Adapter (模型: sd3_5.py) → Scheduler (动力学: flow_match_euler_discrete.py)
每层通过 filter_kwargs() 只传递各自需要的参数,模型和算法可以任意组合。
2. 配置系统
| 维度 | flow_grpo | Flow-Factory |
|---|---|---|
| 配置格式 | Python (ml_collections) | YAML + dataclass |
| 配置文件 | config/base.py + config/grpo.py |
examples/grpo/lora/sd3_5/default.yaml |
| 参数传递 | 全局 config 对象,. 访问 |
Arguments dataclass,强类型 |
| 覆盖方式 | 修改 Python 代码 | YAML override / CLI |
flow_grpo 配置示例 (config/grpo.py):
config.sample.num_steps = 10
config.sample.noise_level = 0.7
config.train.clip_range = 1e-4
config.sample.global_std = False
Flow-Factory 配置示例 (default.yaml):
train:
trainer_type: 'grpo'
clip_range: 1.0e-4
global_std: true
scheduler:
dynamics_type: "Flow-SDE"
noise_level: 0.8
num_sde_steps: 2
YAML 方式更易于版本管理和实验追踪,不需要修改代码即可切换配置。
3. 采样 / Rollout
flow_grpo 采样在 train_sd3.py 内联实现,Flow-Factory 通过 Adapter + Scheduler 分离。
3.1 推理 Pipeline
flow_grpo --- sd3_pipeline_with_logprob.py:
-
直接修改 diffusers 的 SD3 pipeline,在其中加入 SDE noise 和 log_prob 计算
-
返回
(images, all_latents, all_log_probs)--- 所有中间 latents 都存在显存里 -
CFG 在 pipeline 内部处理(拼接 negative+positive embedding 再做一次 forward)
# flow_grpo: 全量存储
all_latents = [latents] # 存储所有 timestep 的 latent
all_log_probs = []
for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latents] * 2) # CFG 在内部
noise_pred = transformer(...)
latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
self.scheduler, noise_pred, t, latents, noise_level=noise_level)
all_latents.append(latents)
all_log_probs.append(log_prob)
Flow-Factory --- SD3_5Adapter.inference():
-
通过
trajectory_indices选择性存储中间 latents(只存需要训练的 timestep) -
CFG 在
forward()内部统一处理 -
返回
List[SD3_5Sample]--- 每个样本是独立的 dataclass 对象 -
支持
extra_call_back_kwargs按需存储额外信息(如 GRPO-Guard 需要的next_latents_mean)
# Flow-Factory: 选择性存储
latent_collector = create_trajectory_collector(trajectory_indices, num_inference_steps)
latent_collector.collect(latents, step_idx=0)
for i, t in enumerate(timesteps):
output = self.forward(t=t, t_next=t_next, ...)
latents = output.next_latents
latent_collector.collect(latents, i + 1) # 只存 trajectory_indices 对应的
3.2 前向传播位置
flow_grpo : transformer forward 在 pipeline 内部(sd3_pipeline_with_logprob.py:147),采样和训练各自独立调用 transformer。
Flow-Factory : transformer forward 在 SD3_5Adapter.forward() 中统一实现,采样和训练共用同一段代码。
4. SDE 与 Log-Probability 计算
4.1 实现位置
| 维度 | flow_grpo | Flow-Factory |
|---|---|---|
| SDE 实现 | 独立函数 sde_step_with_logprob() |
Scheduler 类 FlowMatchEulerDiscreteSDEScheduler.step() |
| 文件 | flow_grpo/diffusers_patch/sd3_sde_with_logprob.py |
src/flow_factory/scheduler/flow_match_euler_discrete.py |
| 与 diffusers 关系 | 猴子补丁(独立函数传入 scheduler) | 继承 FlowMatchEulerDiscreteScheduler |
4.2 动力学类型
flow_grpo: 仅支持 2 种
-
sde: Flow-GRPO 原始 SDE,std_dev_t = sqrt(sigma/(1-sigma)) * noise_level -
cps: Coefficients-Preserving Sampling
Flow-Factory : 支持 4 种(通过 dynamics_type 参数切换)
-
Flow-SDE: 与 flow_grpo 的sde等效 -
Dance-SDE: Dance-GRPO 变体(新增) -
CPS: 与 flow_grpo 的cps等效 -
ODE: 确定性采样(eval 时 noise_level=0)
4.3 精度处理
flow_grpo:
# 所有变量强制转 float32 避免 bf16 溢出
model_output = model_output.float()
sample = sample.float()
Flow-Factory : 同样的 float32 转换,但额外记录了输入 dtype,在采样后将 next_latents round-trip 回原始精度,确保 log_prob 在训练和采样阶段使用相同精度:
_input_dtype = latents.dtype # 记录输入精度
noise_pred = noise_pred.float()
latents = latents.float()
# ... SDE 计算 ...
next_latents = next_latents.to(_input_dtype).float() # round-trip
4.4 窗口 SDE(Fast 模式)
flow_grpo : 单独的文件 sd3_pipeline_with_logprob_fast.py + train_sd3_fast.py
Flow-Factory : 集成在同一个 Scheduler 中,通过 sde_steps + num_sde_steps + seed 参数配置:
# 从 sde_steps 中随机选择 num_sde_steps 个步骤注入噪声
@property
def current_sde_steps(self):
generator = torch.Generator().manual_seed(self.seed)
selected = torch.randperm(len(self.sde_steps), generator=generator)[:self.num_sde_steps]
return self.sde_steps[selected]
这种方式避免了维护两套代码。
4.5 Log-Probability 公式
两者数学上等价。flow_grpo 的 CPS log_prob 去掉了常数项(只有分子),Flow-Factory 保留完整公式。Flow-SDE 的 log_prob 完全一致:
log_prob = -((next_latents - prev_sample_mean)^2) / (2 * std_dev_t^2 * |dt|)
- log(std_dev_t * sqrt(|dt|))
- log(sqrt(2*pi))
# 沿非 batch 维度取 mean
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
5. Advantage 计算
5.1 实现方式
flow_grpo --- PerPromptStatTracker (stat_tracking.py):
-
维护 per-prompt 的历史 rewards 队列
-
支持 grpo / rwr / sft / dpo 四种 advantage 类型
-
global_std=True时用全局 std,False时用组内 std -
需要手动 gather prompt_ids 和 rewards,跨 rank 通信由 train_sd3.py 内联处理
# train_sd3.py 内联 gather + advantage 计算
prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy()
prompts = tokenizer.batch_decode(prompt_ids)
advantages = stat_tracker.update(prompts, gathered_rewards['avg'])
# ungather
advantages = advantages.reshape(num_processes, -1)[process_index]
Flow-Factory --- AdvantageProcessor:
-
独立模块,支持两种聚合策略:
-
sum(GRPO): 加权求和 → 组内归一化(等效 flow_grpo) -
gdpo(GDPO): 每个奖励独立组内归一化 → 加权求和 → 全局 batch norm
-
-
通信感知: 自动区分
distributed_k_repeat(需 gather)和group_contiguous(无需通信) -
返回详细的统计指标(per-reward mean/std,group std distribution,zero_std_ratio)
# Flow-Factory: 一行调用
rewards = self.reward_buffer.finalize(store_to_samples=True, split='all')
self.compute_advantages(samples, rewards, store_to_samples=True)
# advantage 存入 sample.extra_kwargs['advantage']
5.2 关键差异
| 维度 | flow_grpo | Flow-Factory |
|---|---|---|
| 历史追踪 | 维护 per-prompt 历史 stats(跨 epoch) | 仅当前 batch 内统计 |
| 多奖励聚合 | 不支持(手动加权后传入) | sum/gdpo 两种策略,自动处理 |
| 通信策略 | 硬编码 gather→compute→scatter | 自动检测 sampler_type 选择路径 |
| 组大小校验 | 无 | 校验 group_size 一致性 |
flow_grpo 的
PerPromptStatTracker会跨 epoch 累积历史 rewards,mean 基于历史而非仅当前 batch。Flow-Factory 默认只使用当前 batch。
6. PPO 训练循环与 Loss
6.1 训练循环结构
flow_grpo --- 内联在 main() 中:
for inner_epoch in range(num_inner_epochs):
perm = torch.randperm(total_batch_size)
samples = {k: v[perm] for k, v in samples.items()}
for j in train_timesteps:
with accelerator.accumulate(transformer):
# 1. 计算当前 log_prob
prev_sample, log_prob, prev_sample_mean, std_dev_t = compute_log_prob(...)
# 2. 参考模型 log_prob (KL)
if beta > 0:
with transformer.module.disable_adapter():
_, _, prev_sample_mean_ref, _ = compute_log_prob(...)
# 3. PPO loss
ratio = exp(log_prob - old_log_prob)
policy_loss = mean(max(-A*ratio, -A*clip(ratio, 1-ε, 1+ε)))
# 4. KL loss (x-based only)
kl_loss = ((prev_sample_mean - prev_sample_mean_ref)^2).mean() / (2 * std_dev_t^2)
loss = policy_loss + beta * kl_loss
Flow-Factory --- 封装在 GRPOTrainer.optimize():
for inner_epoch in range(num_inner_epochs):
shuffled_samples = [samples[i] for i in randperm]
for batch_idx in range(num_batches):
batch = BaseSample.stack(batch_samples)
for timestep_index in train_timesteps:
with accelerator.accumulate(*trainable_components):
# 1. 当前模型 forward(复用 Adapter.forward)
output = self.adapter.forward(t, t_next, latents, next_latents, ...)
# 2. 参考模型 forward
if enable_kl_loss:
with self.adapter.use_ref_parameters():
ref_output = self.adapter.forward(...)
# 3. PPO loss (相同公式)
ratio = exp(output.log_prob - old_log_prob)
policy_loss = mean(max(-A*ratio, -A*clip(ratio, 1-ε, 1+ε)))
# 4. KL loss (支持 v-based + x-based)
if kl_type == 'v-based':
kl_div = mean((noise_pred - ref_noise_pred)^2)
elif kl_type == 'x-based':
kl_div = mean((next_latents_mean - ref_next_latents_mean)^2)
6.2 PPO Loss 差异
| 维度 | flow_grpo | Flow-Factory |
|---|---|---|
| ratio 计算 | exp(log_prob - old_log_prob) |
同 |
| clip_range | 标量 1e-4 |
支持非对称 [low, high] |
| adv_clip | 标量 adv_clip_max |
支持非对称 [adv_low, adv_high] |
| KL 类型 | 仅 x-based | v-based + x-based |
| KL 公式 (x-based) | mean((mean - ref_mean)^2) / (2 * std^2) |
mean((mean - ref_mean)^2) (无 1/(2σ²)) |
| 梯度累积 | 手动设置 gradient_accumulation_steps * num_train_timesteps |
accelerator.accumulate(*trainable_components) |
| 数据组织 | dict of tensors (手动索引) | BaseSample.stack() → 统一 batch 操作 |
6.3 训练组件管理
flow_grpo : 直接操作 pipeline.transformer,手动 accelerator.prepare(transformer, optimizer)
Flow-Factory : 通过 target_module_map 管理多个可训练组件(如 Wan 模型有多个 transformer),get_trainable_parameters() / trainable_components 统一接口。
7. KL 散度
flow_grpo: 只支持 x-based KL
# 参考模型: disable_adapter (LoRA)
with transformer.module.disable_adapter():
_, _, prev_sample_mean_ref, _ = compute_log_prob(...)
kl_loss = ((prev_sample_mean - prev_sample_mean_ref) ** 2).mean(dim=(1,2,3), keepdim=True) / (2 * std_dev_t ** 2)
Flow-Factory: 支持两种 KL 类型
# v-based: KL in velocity (noise_pred) space
kl_div = mean((output.noise_pred - ref_output.noise_pred) ** 2)
# x-based: KL in latent space
kl_div = mean((output.next_latents_mean - ref_output.next_latents_mean) ** 2)
参考模型实现也不同:
-
flow_grpo: 使用 LoRA 的
disable_adapter()关掉 LoRA 权重 -
Flow-Factory: 通过
use_ref_parameters()上下文管理器,支持 EMA 包装器(全参数微调)和disable_adapter()(LoRA)
8. 奖励系统
| 维度 | flow_grpo | Flow-Factory |
|---|---|---|
| 实现 | 单文件函数 rewards.py:multi_score() |
模块化: Pointwise/Groupwise/Globalwise 基类 |
| 调用方式 | executor.submit(reward_fn, images, prompts, metadata) |
RewardProcessor + RewardBuffer 管理 |
| 异步支持 | 手动 ThreadPoolExecutor |
RewardBuffer 内建同步/异步两种模式 |
| 多奖励 | score_dict = {"pickscore": 1.0, "aesthetic": 0.5} |
YAML 列表: rewards: [{name, reward_model, weight}] |
| 额外处理 | 无 | RewardBuffer.finalize(store_to_samples=True) 自动存储到样本 |
flow_grpo 的奖励通过远程服务器模式运行(避免依赖冲突,继承 ddpo-pytorch):
rewards = executor.submit(reward_fn, images, prompts, prompt_metadata)
time.sleep(0) # yield to start computation
rewards, reward_metadata = rewards.result()
Flow-Factory 的 RewardBuffer 封装了同步/异步逻辑:
self.reward_buffer.add_samples(samples) # 可能异步
rewards = self.reward_buffer.finalize(...) # 等待完成,收集结果
9. 数据加载与分布式采样
flow_grpo --- DistributedKRepeatSampler:
-
内联在 train_sd3.py 中定义
-
手动实现跨 rank 同步的确定性采样
-
无限循环
while True -
只支持
distributed_k_repeat模式
class DistributedKRepeatSampler(Sampler):
def __iter__(self):
while True:
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(dataset), generator=g)[:self.m]
repeated = [idx for idx in indices for _ in range(self.k)]
shuffled = torch.randperm(len(repeated), generator=g)
# 分片到各 rank
yield per_card_samples[self.rank]
Flow-Factory --- data_utils/ 模块:
-
支持 3 种采样器:
distributed_k_repeat/group_contiguous/group_distributed -
group_contiguous: 所有 K 个副本在同一 rank,无需跨 rank 通信 -
通过
sampler_type: "auto"自动选择
10. EMA
flow_grpo --- EMAModuleWrapper:
ema = EMAModuleWrapper(
transformer_trainable_parameters,
decay=0.9,
update_step_interval=8,
device=accelerator.device
)
# 手动管理
ema.copy_ema_to(parameters, store_temp=True) # eval 前
ema.copy_temp_to(parameters) # eval 后
ema.step(parameters, global_step) # 训练后
Flow-Factory --- BaseAdapter.ema_step():
self.adapter.ema_step(step=self.epoch)
# eval 时自动通过 use_ema_parameters() 上下文切换
with self.adapter.use_ema_parameters():
...
集成在 Adapter 中,支持多种 decay schedule,不需要手动 copy/copy_temp/step。
11. 日志系统
flow_grpo : 直接使用 wandb.log(),手动构造 log dict
wandb.log({
"epoch": epoch,
"reward_pickscore": value.mean(),
"approx_kl": ...,
"clipfrac": ...,
}, step=global_step)
Flow-Factory : 通过 log_data() 统一接口,支持 wandb / swanlab / tensorboard
self.log_data({'train/policy_loss': loss, 'train/adv_abs_mean': ...}, step=self.step)
AdvantageProcessor.pop_advantage_metrics() 自动生成详细统计指标。
12. 其他差异
12.1 GRPO-Guard 实现
flow_grpo : 单独的 train_sd3_GRPO_Guard.py 文件,大量代码重复。
Flow-Factory : GRPOGuardTrainer 继承 GRPOTrainer,只重写 sample()(增加 extra_call_back_kwargs)和 optimize()(使用重加权 ratio),复用其他所有逻辑。
12.2 依赖版本
| 包 | flow_grpo | Flow-Factory |
|---|---|---|
| diffusers | 0.33.1 | >=0.36.0 |
| transformers | 4.40.0 | >=4.57.1 |
| accelerate | 1.4.0 | >=1.11.0 |
| peft | 0.10.0 | >=0.17.0 |
12.3 样本后处理
flow_grpo : 在 compute_log_prob() 中每次训练重算当前 log_prob 时也调 sde_step_with_logprob(),需要传入 prev_sample(即 sample["next_latents"])来计算 log_prob。
Flow-Factory : 采样时已通过 trajectory_indices 选择性存储了中间 latents 和 log_probs,训练时从 batch['all_latents'] 和 batch['log_probs'] 中按 index 取出,调用 adapter.forward() 重算当前 log_prob。
12.4 zero advantage 过滤
flow_grpo 有一个特殊处理:过滤掉 advantage 全为零的样本,并保证过滤后的 batch 能被 num_batches_per_epoch 整除:
mask = (samples["advantages"].abs().sum(dim=1) != 0)
# 如果不整除,随机补一些样本
if true_count % num_batches != 0:
random_indices = torch.randperm(len(false_indices))[:num_to_change]
mask[false_indices[random_indices]] = True
samples = {k: v[mask] for k, v in samples.items()}
Flow-Factory 没有这个逻辑。
13. 总结
Flow-Factory 是 flow_grpo 的全面工程化重构,核心算法等效但架构完全不同:
| 维度 | flow_grpo | Flow-Factory |
|---|---|---|
| 代码组织 | 单文件内联 | 多层抽象解耦 |
| 模型扩展 | 每模型复制脚本 | 实现 Adapter 即可 |
| 算法扩展 | 每算法复制脚本 | 继承 Trainer 即可 |
| 动力学扩展 | 修改 SDE 函数 | Scheduler 多态 |
| 配置管理 | Python config | YAML + dataclass |
| 分布式通信 | 手动 gather/scatter | 通信感知 AdvantageProcessor |
| 内存效率 | 存储所有 timestep | trajectory_indices 选择性存储 |
| 多奖励聚合 | 不支持 | sum / gdpo |
| KL 类型 | x-based | v-based + x-based |
| 日志后端 | wandb | wandb / swanlab / tensorboard |
| 代码量 (SD3 GRPO) | ~967行 train_sd3.py + patch文件 | ~450行 grpo.py + ~350行 adapter + ~440行 scheduler |