DanceGRPO+FLUX:多模态生成强化学习模型的高效

一、背景介绍

Flux 模型小模型高效生成高质量图像的基础

Flux 虽是百亿级参数的大模型家族,但其中的轻量化变体(如 Flux.1 (schnell))以及核心技术,为小尺寸模型提供了高效生成的范式。其关键技术优势适配小模型的优化需求,具体体现在两点。一是采用 Rectified Flow(校正流)技术,拉直了传统扩散模型从噪声到图像的生成路径,将生成过程优化为近似直线的最短路径,大幅减少采样步数。像 Flux.1 (schnell) 仅需 4 步左右采样就能生成合理图像,这对小模型而言,意味着在降低计算成本的同时,还能避免多步迭代带来的精度损耗。二是创新的多模态融合架构,通过双文本编码器(CLIP+T5)精准解析文本语义,再结合双流转单流的 Transformer 注意力机制,实现文本与图像特征的深度交互。这种设计让小模型无需复杂结构,就能高效捕捉图文关联,提升生成图像的内容一致性。

DanceGRPO 框架:通过强化学习进一步提升小模型性能

DanceGRPO 是专门针对视觉生成领域 RLHF 方案不成熟的问题设计,能精准解决小模型训练中质量提升的核心痛点,具体优势有三。其一,兼容性强,适配 Flux 的核心范式。该框架创新性地将扩散模型和校正流模型(如 Flux)统一视为随机插值的特殊情况,二者的采样过程均可通过 SDE 实现,这让它能无缝对接 Flux 模型,针对性地开展强化学习优化,无需对 Flux 的基础架构做大幅修改,降低了小模型适配强化学习的成本。其二,显存压力低,适配小模型训练资源限制。此前 ReFL 等强化学习方案需对奖励模型和 VAE 解码特征反向传播,在视频生成等场景中显存压力极大,根本不适合小模型。而 DanceGRPO 通过采样部分时间步加速训练、去除作用不大的 KL 散度正则项等设计,大幅降低了计算和显存开销,同时还能让小模型在更多提示词样本上学习,提升泛化能力。其三,强化学习效果显著,精准优化核心指标。该框架通过多奖励模型叠加(图像美感、图文匹配等五类指标),让小模型能针对性提升薄弱项;同时通过固定初始化噪声、控制梯度更新频率等优化手段,避免训练中的奖励作弊和多样性下降问题。

强化学习框架对比

二、环境依赖

三、DanceGRPO+FLUX 整体流程

推理阶段(去噪生成图片,用于训练过程观察)

  1. 加载文本信息: 获取初始数据,将数据复制成 N 份作为输入。
  2. 去噪: 生成初始噪声,input 和当前噪音输入到 policy mode 预测噪声成分,去噪生成 latents。
  3. 图片生成保存: 基于推理阶段输出 latents,经过 vae mode 解码成 image,保存为文件用于观察过程。

关键点

  • 推理去噪生成图像:该模型中,默认一组生成 12 哥样本,即一个 prompt 会生成 12 哥大体相似而细节不同的图像,每个图像默认经过 16 步迭代去噪生成。
  • 去噪步长:步长随时间步长从大到小,因为初始噪声成分较多,相当于勾勒轮廓去噪步长可以大些,后面要收敛到正确终点,相当于描绘细节,需要慢慢去噪。
  • 图像多样性:去噪过程会加入随机扰动,局部优化,因此会有一组默认 12 张图片,每张整体相似而细节有差异的图片;一组内会进行对比,提升优势动作的概率。

Reward 阶段(jisaunq reward 值)

  1. 计算奖励值: image 和 prompt 输入到 reward model,计算得到 reward 值。
  2. 计算相对优势值: 计算 reward 的组内平均值,每个 reward 和平均值比较,得到 advantage(组内相对优势)。

reward 详细流程

  • 计算 reward 值:基于 prompt,图例阶段得到的完全去噪的 image 值,输入到 reward mode 中,经过一系列计算得到每个 image 的 reward 值。
  • 计算 advantage 值:reward 值经过组内平均得到平均值,再用每个 iamge 的 reward 值和平均值对比,得到 advantage (相对优势值)。

训练阶段(计算 loss,更新梯度)

  1. 记录去噪过程: 前面步骤会记录每个样本的去噪过程状态,包括 reward 值,advantag 值,log_p 值(代表当时策略的对数)。
  2. 计算新策略对数: 此时 policy model 会生成新预测值,根据新预测值计算出 new_log_p 值(代表新策略的对数)。
  3. 计算旧策略比率: f(new_log, old_log) = ratio,代表某行为在新旧策略的概率比。
  4. 计算 loss 值: 基于 ratio 和 advantage 计算出 loss 值。

训练详细流程

  • loss 是基于 advantage 和 ratio 计算得出的,当 advantage 和 ratio 处于不同值时代表不同的含义
ADVANTAGE RATIO 含义
>0 >1 该动作为优势动作,且新策略该动作概率更大,新策略正确的提升了该动作的概率,新策略更优
>0 <1 该动作为优势动作,且新策略该动作概率更小,新策略错误的抑制了优势动作,后续需要提高 ratio
  • 第一次计算 loss 时,policy mode 还没有更新权重,此时 new_log_p 和 old_log_p 实际上是一样的,就是虽然定义上是新旧策略,但实际上新旧策略的权重一样。
  • loss 值会基于 advantage 和 ratio 一并计算,所以开始的 loss 值依赖于样本的 advantage 值,默认的梯度更新频率为 4 哥样品一次,当处理本组第四个样品后 ratio 就会开始变化了。

存在两个 loss 值,clipped_loss 和 unclipped_loss,都是基于 advantage 和 ratio 计算得到的,但是 clipped_loss 的计算中加入了 clip_range,约束了最终计算值的范围,防止局部过度优化。

四、模型部署流程

  1. 拉取代码:GitHub - XueZeyue/DanceGRPO: An official implementation of DanceGRPO: Unleashing GRPO on Visual Generation
bash 复制代码
git clone https://github.com/XueZeyue/DanceGRPO.git
  1. 下载权重

FLUX:huggingface.co/black-fores...

HPS:huggingface.co/xswu/HPSv2/...

open_clip:huggingface.co/laion/CLIP-...

  1. 其它依赖安装
  1. 仓库未实现懒加载,所以会导入许多用不到的三方库,可以直接注释,避免引入太多无用的依赖,耗费开发时间。
  2. 一些为调用的接口也可以进行规避,例如 flashatth 三方库接口等。
python 复制代码
# DanceGRPO/fastvideo/models/mochi_hf/modeling_mochi.py
# 注释掉以下
from liger_kernel.ops.swiglu import LigerSiLUMulFunction;flash_attn_no_pad.py
​
# flash_attn_no_pad.py
# 注释掉flash_attn的导包,flash_attn_no_pad注释掉中间逻辑,直接return;

执行安装脚本:

bash 复制代码
./env_setup.sh fastvideo
  1. 修改<font style="color:rgb(37, 43, 58);background-color:rgb(246, 247, 249);">preprocess_flux_embedding.py</font>
python 复制代码
# # 引入torch_npu
import torch_npu
from torch_npu.contrib import transfer_to_npu
​
# "./data/flux"写死的路径改成参数
# 原 : pipe = FluxPipeline.from_pretrained("./data/flux", torch_dtype=torch.bfloat16).to(device)
pipe = FluxPipeline.from_pretrained(args.model_path, torch_dtype=torch.bfloat16).to(device)
  1. 修改<font style="color:rgb(37, 43, 58);background-color:rgb(246, 247, 249);">train_grpo_flux.py</font>
python 复制代码
# # 引入torch_npu
import torch_npu
from torch_npu.contrib import transfer_to_npu
  1. 执行 Flux GRPO 脚本:
bash 复制代码
bash ./scripts/finetune/finetune_flux_grpo.sh

五、模型验证

验证流程将 GRPO 的推理、reward、训练三个阶段单独抽离对齐,再进行全流程验证,采用"分 - 合" 验证策略:

  • 单独阶段对齐能隔离不同模型和框架的差异,聚焦每个环节的前向计算准确性(比如推理阶段的动作生成、reward 阶段的评分计算、训练阶段的梯度更新),避免因单个阶段误差累积掩盖问题。
  • 全流程对齐则能验证阶段间数据传递的一致性,尤其要关注跨框架交互时的数据格式、精度损失等细节。

记录关键节点的对齐数据(如中间特征、概率分布、loss 值、梯度等),既能作为阶段验证的基准,也能在全流程中快速定位误差来源。

随机性固定

load 版本(准确但麻烦)

通过torch.savetorch.load的方式将程序中涉及随机性的变量,在 NPU 和 GPU 上保持一致。

  1. 关闭shuffle,固定训练的数据顺序
ini 复制代码
# fastvideo/train_grpo_flux.py中,shuffle设为false
​
sampler = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=False, seed=args.sampler_seed)
  1. prev_sample 固定
markdown 复制代码
0.  GPU代码修改如下,在GPU上运行后保存下来
ini 复制代码
#  1. 添加全局变量COFF_STEP,控制coff生成的step数  
COFF_STEP = 0  
​
def flux_step():  
    global COFF_STEP  
    ......  
    if grpo and prev_sample is None:  
    coff = torch.randn_like(prev_sample_mean)  
    torch.save(coff, f"saves/coff_{COFF_STEP}_{torch.distributed.get_rank()}.pt")  
    prev_sample = prev_sample_mean + coff * std_dev_t  
    COFF_STEP += 1
markdown 复制代码
2. NPU 上加载
python 复制代码
coff = torch.load(f"saves/coff_{COFF_STEP}_{torch.distributed.get_rank()}.pt", map_location=f"cuda:{torch.cuda.current_device()}")
  1. input_latents 固定
markdown 复制代码
0.  GPU代码修改如下,在GPU上运行后保存下来
python 复制代码
def sample_reference_model(  
    args,  
    device,  
    transformer,  
    vae,  
    encoder_hidden_states,  
    pooled_prompt_embeds,  
    text_ids,  
    reward_model,  
    tokenizer,  
    caption,  
    preprocess_val,  
    step,  # # # 增加参数输入,用于序列文件记录,找到相关调用处,加上该入参  
)  

def train_one_step(  
    args,  
    device,  
    transformer,  
    vae,  
    reward_model,  
    tokenizer,  
    optimizer,  
    lr_scheduler,  
    loader,  
    noise_scheduler,  
    max_grad_norm,  
    preprocess_val,  
    step,  # # # 增加参数输入,用于序列文件记录,找到相关调用处,加上该入参  
) 

def sample_reference_model();  
     ......  
    if args.init_same_noise:  
        input_latents = torch.randn(  
                (1, IN_CHANNELS, latent_h, latent_w),  # (c,t,h,w)  
                device=device,  
                dtype=torch.bfloat16,  
            )  
        torch.save(input_latents, f"saves/input_latents_{step}_{torch.distributed.get_rank()}.pt")
markdown 复制代码
2. NPU上加载
python 复制代码
input_latents = torch.load(f"saves/input_latents_{step}_{torch.distributed.get_rank()}.pt", map_location=f'cuda:{device}')
  1. perms 固定
markdown 复制代码
0.  GPU代码修改如下,在GPU上运行后保存下来
scss 复制代码
def train_one_step():  
    ......  
    perms = torch.stack(  
            [  
                torch.randperm(len(samples["timesteps"][0]))  
                for _ in range(batch_size)  
            ]  
        ).to(device)  
    torch.save(perms, f"saves/perms_{step}_{torch.distributed.get_rank()}.pt")
css 复制代码
2. <font style="color:rgb(37, 43, 58);">NPU上加载</font>
python 复制代码
perms = torch.load(f"saves/perms_{step}_{torch.distributed.get_rank()}.pt", map_location=f'{device}')

使用 CPU 进行随机性固定

固定seed可用于模型训练复现,但是不同的设备如GPU和NPU在同样的seed下生成的值也是不一样的,但是不同设备上都有CPU,因此可以固定seed后使用CPU生成张量,以此让GPU和NPU上生成的张量输入保持相同

  1. fastvideo/train_grpo_flux.py:91修改为
css 复制代码
if grpo and prev_sample is None:  
	prev_sample = prev_sample_mean + torch.randn_like(prev_sample_mean.cpu()).to(  
		prev_sample_mean.device) * std_dev_t
  1. <font style="color:rgb(59, 62, 85);">fastvideo/train_grpo_flux.py:270</font>修改为
ini 复制代码
if args.init_same_noise:  
	input_latents = torch.randn(  
		(1, IN_CHANNELS, latent_h, latent_w),  #  (c,t,h,w)  
		dtype=torch.bfloat16,  
	).to(device)
  1. fastvideo/train_grpo_flux.py:657修改为
ini 复制代码
sampler = DistributedSampler(  
		train_dataset, rank=rank, num_replicas=world_size, shuffle=False, seed=args.sampler_seed  
	)
  1. fastvideo/train_grpo_flux.py:1061增加
python 复制代码
import random  
def seed_all_own(seed=1234, mode=True, is_gpu=True):  
	random.seed(seed)  
	os.environ['PYTHONHASHSEED'] = str(seed)  
	os.environ['GLOBAL_SEED'] = str(seed)  
	np.random.seed(seed)  
	torch.manual_seed(seed)  
	torch.use_deterministic_algorithms(mode)  
	if is_gpu:  
		os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  
		os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  
		torch.cuda.manual_seed_all(seed)  
		torch.cuda.manual_seed(seed)  
		torch.backends.cudnn.deterministic = True  
		torch.backends.cudnn.enable = False  
		torch.backends.cudnn.benchmark = False  
	else:  
		import torch_npu  
		os.environ['HCCL_DETERMINISTIC'] = 'true'  
		os.environ['CLOSE_MATMUL_K_SHIFT'] = '1'  
		torch_npu.npu.manual_seed_all(seed)  
		torch_npu.npu.manual_seed(seed)  
	print("====== seed all ========")  
seed_all_own(is_gpu=False)  
from msprobe.pytorch import seed_all  
seed_all(mode=True)

推理流程对齐

推理流程对齐的内容主要是 GRPO 去噪后生成的 latents,latents 解码成图片后对比:固定随机性,将GPU、NPU上使用相同noise的latents使用vae解码,再保存,此时只需要对比生成图片的差异。 关键代码:decoded_image[0].save(img_path),这里会保存训练过程中,模型在每个step,每次generation中生成的图片,可以直观的看到训练过程中的的变化。

ini 复制代码
# # # # sample_reference_model函数  
def sample_reference_model():  
    with torch.inference_mode():  
        with torch.autocast("cuda", dtype=torch.bfloat16):  
            latents = unpack_latents(latents, h, w, 8)  
                latents = (latents / 0.3611) + 0.1159  
                image = vae.decode(latents, return_dict=False)[0]  
                decoded_image = image_processor.postprocess(  
                image)  
        decoded_image[0].save(f"./images/flux_{step}_{rank}_{index}.png")

Reward Model 对齐

DanceGRPO 模型涉及多个 model,强化学习中需要对齐的主要是loss和reward值,这里讲的是如何对齐reward。

此处采取的方法是把reward model单独拿出来,for循环多步,对比GPU和NPU的值reward值,代码修改如下:

css 复制代码
for step in range(1, 1001):
             #  text = tokenizer([batch_caption[0]]).to(device=device, non_blocking=True)
         image = torch.load(f"/home/grpo/DanceGRPO/save/images-1/image_{step}_{rank}.pt")
         text = torch.load(f"/home/grpo/DanceGRPO/save/texts-1/text_{step}_{rank}.pt")
         
         #  torch.save(image, f"/home/GRPO/DanceGRPO/save/images-1/image_{step}_{rank}.pt")
         #  torch.save(text, f"/home/GRPO/DanceGRPO/save/texts-1/text_{step}_{rank}.pt")
         if rank == 0:
             print(f"image_{rank}_{step}: ", image, "\n\n")
             print(f"text_{rank}_{step}: ", text, "\n\n")
         with torch.no_grad():
             with torch.amp.autocast("cuda"):
                 outputs = reward_model(image, text)
         if rank == 0:
             print(f"output_{rank}_{step}: ", outputs, "\n\n")
         image_features, text_features = outputs["image_features"], outputs["text_features"]
         logits_per_image = image_features @ text_features.T
         hps_score = torch.diagonal(logits_per_image)
         all_rewards = []
         all_rewards.append(hps_score)
         all_rewards = torch.cat(all_rewards, dim=0)
         samples = {
             "rewards": all_rewards.to(torch.float32)
         }
         if rank == 0:
             print(f"samples_{rank}_{step}: ", samples, "\n\n")
         gathered_reward = gather_tensor(samples["rewards"])
         if rank == 0:
             print(f"gather_reward_{rank}_{step}: ", gathered_reward, "\n\n")
         if dist.get_rank() == 0:
             print("gathered_hps_reward", gathered_reward)
             with open('./hps_reward.txt', 'a') as f:
                 f.write(f"{gathered_reward.mean().item()}\n")
         samples_batched = {
             k: v.unsqueeze(1)
             for k, v in samples.items()
         }
         samples_batched_list = [             dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())         ]
         for i, sample in list(enumerate(samples_batched_list)):
             if rank == 0:
                 print(f"sample_{rank}_{step}: ", sample["rewards"], "\n\n")
             if dist.get_rank() % 8 == 0:
                 print("hps reward", sample["rewards"].item(), "\n\n\n\n\n")
             #  print("ratio", ratio)
             #  print("advantage", sample["advantages"].item())
             #  print("final loss", loss.item())

生成1000个reward值,其精度对比效果如下(绝对误差≈0.015%):

数据、图片来自昇腾官方数据。

端到端对齐

对齐标准

固定随机性后,需要按照如下标准关注对齐结果:

  • 关注推理阶段生成的图片,主观对齐
  • 关注训练过程中的loss(生成模型loss较小,参考价值有限)
  • 关注reward scores,200步误差5%以内

对齐步骤

端到端对齐流程主要关注两方面,一方面是综合度量模型训练的指标:推理阶段图片+loss+rward scores,另一方面是下游任务推理效果。

全流程对齐具体步骤:

  • 两边加载相同的预训练权重。
  • 固定随机性:整体随机性与确定性计算固定(seed_all,mode=True),noise在cpu侧生成。
  • 保存关键信息:推理阶段的图片、reward阶段的rewardvalues、训练阶段模型loss,同时保存权重,用于对齐推理效果,此处注意需要持续关注推理阶段生成图片的效果,具体例子为在替换rope融合算子时,loss结果与reward差异不大,但推理阶段出现了花图。

端到端流程结构

六、常见问题

如遇到ROPE部分不支持complex128计算问题,NPU场景需要适配修改___CODE_BLOCK_PLACEHOLDER___211250

ini 复制代码
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"  #增加改行

##下面增加is_npu判断
freqs_dtype = torch.float32 if is_mps or is_npu else torch.float64

七、总结

DanceGRPO+FLUX 模型在 AI 生图领域,解决 FLUX 在生成过程中与人类审美、语义对齐等方面的适配问题,大幅提升其生图质量与稳定性。展望未来,多模态生成强化学习模型有望在更多领域开花结果,如影视特效制作中实现更逼真的虚拟场景与角色创建,教育领域中打造沉浸式的学习环境,医疗领域辅助医生进行手术模拟与病情可视化分析等 。同时,随着技术发展,模型将不断优化,生成效率与质量进一步提升,在处理复杂任务、理解模糊指令等方面取得更大突破,为各行业数字化转型与创新发展注入强大动力 。

注明:昇腾PAE案例库对本文写作亦有帮助。

相关推荐
明洞日记2 小时前
【设计模式手册022】抽象工厂模式 - 创建产品家族
java·设计模式·抽象工厂模式
阿拉斯攀登3 小时前
设计模式:命令模式(Spring MVC中的实践)
设计模式·springmvc·命令模式
明洞日记3 小时前
【设计模式手册021】代理模式 - 如何控制对象访问
设计模式·系统安全·代理模式
山沐与山4 小时前
【设计模式】Python策略模式:从入门到实战
python·设计模式·策略模式
阿拉斯攀登4 小时前
设计模式:责任链模式(mybatis数据权限实现)
设计模式·mybatis·责任链模式
syt_10134 小时前
设计模式之-模板模式
设计模式
阿拉斯攀登4 小时前
设计模式:责任链模式(MyBatis)
设计模式·mybatis·责任链模式
崎岖Qiu5 小时前
【设计模式笔记19】:建造者模式
java·笔记·设计模式·建造者模式
syt_10135 小时前
设计模式之-享元模式
javascript·设计模式·享元模式