【RL】Slime异步原理(单例设计模式)3

https://github.com/THUDM/slime/blob/2d4e625d/slime/rollout/sglang_rollout.py

好的,我们来用中文详细解释一下这个Python模块,并用一个具体的例子来说明它的执行过程。

这个文件是大型语言模型(LLM)训练框架(特别是像RLHF/RLAIF,即从人类/AI反馈中进行强化学习的框架,比如项目名SLIME)中的一个核心部分。它负责处理 "rollout"(采样/部署)阶段,也就是使用当前版本的模型来生成新数据的过程。

核心目标:一个高效的"数据生成引擎"

在基于强化学习的LLM微调(如PPO算法)中,训练循环包含几个关键步骤:

  1. Rollout(数据生成):拿一批提示(prompts),用当前的模型(称为"策略模型")生成回答。
  2. Reward Modeling(奖励建模):用一个独立的"奖励模型"(Reward Model, RM)来给这些生成的回答打分。分数越高,代表回答质量越好。
  3. Optimization(优化):使用提示、生成的回答以及它们的奖励分数来计算损失(例如PPO损失),然后更新策略模型的权重。

这个文件主要实现了第1步和第2步。它是一个异步、高性能的引擎,旨在:

  • SGLang 推理服务进行通信,以实现快速、高效的文本/多模态内容生成。
  • 通过并发管理,充分利用GPU资源。
  • 同时支持训练数据生成 (generate_rollout_async) 和模型评估 (eval_rollout)。
  • 集成奖励模型打分功能 (generate_and_rm)。
  • 支持部分生成(partial rollouts)、动态过滤和多模态输入等高级特性。

模块结构与函数详解

我们从基础组件到核心功能,逐一拆解。

1. GenerateState (单例类)

这是整个架构中最重要的部分。

  • 它是什么? 一个持有整个生成流程全局、持久状态的单例类。"单例"意味着在整个应用程序的生命周期中,这个类只会有一个实例。
  • 为什么是单例? 初始化像tokenizer这样的资源开销很大,我们不希望每次都重新加载。通过单例模式,代码中任何地方调用 GenerateState(args) 都会获取到同一个、已经加载好所有资源的对象。
  • 关键属性:
    • args: 命令行参数,提供了所有配置。
    • tokenizer: Hugging Face的tokenizer,用于文本编码。
    • semaphore: 一个 asyncio.Semaphore (信号量)。这是并发控制 的关键工具。如果sglang_server_concurrency(并发数)设为16,它就确保最多只有16个生成请求同时发送给SGLang服务。这可以防止服务器过载。
    • sampling_params: 一个字典,包含了默认的生成参数,如temperature(温度)、top_pmax_new_tokens(最大新生成token数)等。
    • pendings, aborted, remaining_batch_size: 用于管理当前正在运行的生成任务批次的状态变量。
2. generate (核心生成函数)

这个异步函数是与SGLang生成服务直接交互的接口。

  • 职责 :接收一个 Sample 对象,并为其生成模型的回答。
  • 工作流程 :
    1. 获取状态 : 访问全局的 GenerateState
    2. 处理多模态 : 检查 sample.prompt 是普通字符串还是包含文本和图片的字典列表。如果发现有图片,它会调用 _load_and_encode_image 函数从路径加载图片,将其编码为Base64格式的JPEG字符串,并在文本提示中插入一个特殊的 <image> 占位符。
    3. 处理多轮/部分生成 : 如果 sample.response 已经有内容(意味着这是对之前生成内容的续写),它会相应地减少 max_new_tokens,以确保总长度不超过限制。
    4. 构建请求体 (Payload) : 为SGLang服务的 /generate 接口构建一个JSON请求。这包括了采样参数、input_ids(编码后的提示)以及图片数据(如果有)。
    5. 发送HTTP请求 : 使用 await post(...) 异步地将请求发送到SGLang服务并等待响应。
    6. 处理响应 : 仔细解析SGLang返回的JSON数据。
      • 提取生成的文本、新生成的token ID以及它们的对数概率(logprobs)。
      • 用这些新数据更新 Sample 对象,包括 tokens 列表、拼接后的 response 字符串、response_lengthrollout_log_probs
    7. 设置状态 : 根据SGLang返回的 finish_reason (结束原因,如length表示达到最大长度, stop表示遇到停止符, abort表示被中断),更新 sample.statusTRUNCATED, COMPLETED, 或 ABORTED
3. generate_and_rm & generate_and_rm_group

这两个函数包装了 generate,为其增加了奖励模型(RM)打分的步骤。

  • generate_and_rm:

    1. 首先,它调用 await generate(...) 来获取模型的回答。它使用 state.semaphore 来确保并发请求不会超限。
    2. 生成完成后(且样本未被中止),它调用 await async_rm(args, sample) 来为生成的回答获取一个奖励分数。
    3. 它将这个 reward 附加到 Sample 对象上。
    • 这个函数封装了"先生成,后打分"的核心模式。
  • generate_and_rm_group:

    1. 这个函数用于处理一个提示需要生成多个回答的场景 (n_samples_per_prompt)。
    2. 它使用 asyncio.gather 并发地创建并运行多个 generate_and_rm 任务。
    3. 它有针对 args.group_rm 的特殊逻辑。如果 group_rm 为真,意味着奖励模型需要看到一组中所有 的回答才能打分(例如,选出最好一个)。在这种情况下,它会在组内所有生成任务都完成后,调用 batched_async_rm 进行批量打分。
4. generate_rollout_async (训练数据生成管理器)

这是生成一批训练数据的主函数。

  • 目标 : 准确地生成 args.rollout_batch_size 组高质量的样本。
  • 工作流程 :
    1. 初始化 : 获取全局状态,实例化动态过滤器,并设置一个进度条 (tqdm)。
    2. 主循环 (while len(data) < target_data_size): 这个循环会一直运行,直到收集到足够多的有效数据。
    3. 提交任务 : 在循环内部,它首先从 data_source (数据源) 获取新的提示,并使用 state.submit_generate_tasks 将它们作为生成任务提交。这会将 asyncio.Task 对象填充到 state.pendings 集合中。
    4. 等待与处理 : 它使用 await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED)。这是一个非常高效的模式:它不等待所有任务完成,而是只要有任何一个任务完成,就立即处理其结果,然后继续循环。这使得任务管道始终保持满负荷运行。
    5. 动态过滤 : 当一组样本生成完成后,它会被传递给一个 dynamic_filter。这个过滤器可以根据某些标准(例如,所有回答都太短、奖励分数太低)决定丢弃这组样本。这可以防止低质量数据进入训练缓冲区。
    6. 收集数据 : 如果过滤器决定保留 (keep=True),这组样本就被添加到 data 列表中。
    7. 中止与清理 : 一旦收集到 target_data_size 的数据,循环就退出。但此时可能还有很多正在进行中的生成请求。await abort(args, rollout_id) 函数会被调用,来取消所有这些在SGLang worker上挂起的请求。这节省了大量的计算资源。如果 partial_rollout(部分生成)被启用,它甚至会收集这些被中断的、未完成的样本,以便后续使用。
    8. 返回: 最后,它返回收集到的数据和任何被中止的样本。
5. eval_rollout & eval_rollout_single_dataset

这两个函数用于评估,而非训练。

  • 目的: 在一个固定的评估数据集上运行模型,并衡量其性能(主要通过奖励分数)。
  • 与训练Rollout的关键区别 :
    • 它们使用特定的评估数据集 (EvalDatasetConfig)。
    • 它们可能使用不同的采样参数(例如,eval_temperature 通常更低或为0,以获得更确定的输出)。
    • 它们为每个提示生成固定数量的样本 (n_samples_per_eval_prompt)。
    • 最终输出的不是样本本身,而是一个包含指标(如rewardstruncated状态)的字典,按数据集名称组织。这些指标会被记录到像W&B这样的工具中,用于追踪模型性能随时间的变化。
6. generate_rollout & generate_abortable_samples (同步封装)
  • generate_rollout : 这就是你提问的那个函数。它是一个简单的、同步的入口点。训练框架的其他部分(很可能是同步代码)就调用这个函数。
  • generate_abortable_samples : 这个函数是同步世界和异步世界之间的桥梁。
    • 它检查 evaluation 标志。
    • 如果为真,它调用 run(eval_rollout(...))
    • 如果为假,它调用 run(generate_rollout_async(...))
    • 这里的 run 函数(很可能是 asyncio.run)会启动asyncio事件循环,运行给定的异步函数直到它完成,然后关闭事件循环并返回最终结果。

执行过程举例(以训练为例)

假设我们有以下配置:

  • rollout_batch_size = 64 (目标是生成64组样本)
  • n_samples_per_prompt = 4 (每个提示生成4个回答)
  • sglang_server_concurrency = 16 (最多16个并发请求)

执行流程如下:

  1. 外部调用 : 训练框架的主循环调用 generate_rollout(args, rollout_id=1, data_buffer=...)

  2. 进入异步世界 : generate_rollout 调用 generate_abortable_samples,后者又调用 asyncio.run(generate_rollout_async(...))。此时,程序进入了asyncio的事件循环。

  3. generate_rollout_async 开始工作:

    • 第一次填充任务 : while len(data) < 64 循环开始。data 列表当前为空。
    • while state.remaining_batch_size < 64 循环也开始。
    • 它调用 data_source 从数据缓冲区获取一批提示,比如拿了20个提示。
    • state.submit_generate_tasks 被调用。它为这20个提示创建了20个 generate_and_rm_group 任务。每个任务都包含了对同一个提示的4个 Sample 对象的拷贝。
    • 这20个任务被添加到 state.pendings 集合中。state.remaining_batch_size 变为20。
  4. 并发执行与处理:

    • 程序执行到 await asyncio.wait(...)asyncio事件循环开始调度这20个任务。
    • 由于并发限制是16,事件循环会立即启动前16个任务。每个任务开始向SGLang发送生成请求。
    • 假设任务#3最先完成。它生成了4个回答,并为它们打好了分。
    • asyncio.wait 返回,done 集合里包含了完成的任务#3。
    • for task in done: 循环开始处理任务#3的结果(一个包含4个 Sample 对象的列表)。
    • 这组样本被送入 dynamic_filter。假设过滤器认为这组样本质量合格。
    • 这组样本被添加到 data 列表中。现在 len(data) 是1。进度条更新。state.remaining_batch_size 减1,变为19。
    • while len(data) < 64 循环继续。state.remaining_batch_size (19) 小于64,所以再次调用 data_source 获取更多提示,提交新任务,直到 remaining_batch_size 达到或超过64。
    • 这个"提交-等待任一完成-处理"的过程不断重复。由于请求和处理是并行的,GPU的利用率会非常高。
  5. 任务完成与中止:

    • len(data) 终于达到64时,外层 while 循环结束。
    • 此时,state.pendings 中可能还有很多正在运行或排队中的任务(因为我们总是超量提交任务以保持流水线满载)。
    • await abort(...) 被调用。它向所有SGLang worker发送一个 abort_request HTTP请求,强制取消所有挂起的生成任务。
    • abort 函数还会等待所有被取消的任务真正结束,并收集它们已经生成的部分结果(如果 partial_rollout 开启)。
  6. 返回结果:

    • 收集到的64组完整样本被排序后返回。
    • generate_abortable_samples 将这些样本和被中止的部分样本返回给 generate_rollout
    • generate_rollout 将部分样本存入 data_buffer 以备后用,并将完整的64组样本返回给外部的训练循环。
  7. 训练: 外部训练循环拿到这批新鲜出炉、带有奖励分数的数据,进行PPO等算法的优化步骤。

这个设计通过异步IO和并发控制,将数据生成和打分这两个耗时的步骤高效地流水线化,是现代LLM训练框架中实现高性能数据吞吐的关键。

当然可以,这是一个非常好的问题,它触及了软件设计中的一个核心权衡。

如果 GenerateState 设置成单例(Singleton),会带来以下几个主要问题,这些问题主要关乎性能、资源管理和状态一致性

1. 性能开销:昂贵的重复初始化

GenerateState__init__ 方法中做了两件比较耗时的事情:

  • self.tokenizer = AutoTokenizer.from_pretrained(...): 加载一个分词器(tokenizer)需要从磁盘读取多个文件(词汇表、配置文件等),解析它们,并构建相应的对象。虽然这通常比加载整个模型快得多,但对于一个频繁调用的函数来说,重复执行这个操作仍然会带来不可忽视的性能开销。
  • self.semaphore = asyncio.Semaphore(...) : 创建信号量本身很快,但关键在于它的作用域。信号量的目的是在整个应用程序范围内限制对某一资源的并发访问。如果每次调用都创建一个新的信号量,它就失去了全局并发控制的能力。

不使用单例的场景下:

假设 generate_rollout 在一个训练循环中被反复调用。每次调用 generate_rollout 都会启动一个新的异步事件循环,并多次间接调用 GenerateState(args)

python 复制代码
# 伪代码,模拟非单例的情况
class GenerateState:
    def __init__(self, args):
        print("正在加载 Tokenizer...") # 这会打印很多次
        self.tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint)
        self.semaphore = asyncio.Semaphore(args.sglang_server_concurrency)
        # ... 其他初始化 ...

# 在训练循环中
for epoch in range(num_epochs):
    # 每次 rollout 都会创建一个新的 GenerateState 实例
    train_data = generate_rollout(...) 
    # ... 进行训练 ...

每次调用 generate_rollout,内部的 generate, generate_and_rm 等函数都会创建一个全新的 GenerateState 实例。这意味着:

  • 分词器会被反复加载,导致每次 rollout 开始时都有一个明显的延迟。
  • 每个 GenerateState 实例都有自己的 semaphore,无法实现跨函数调用、跨任务的全局并发控制。

2. 资源浪费:内存占用增加

虽然分词器占用的内存远小于模型本身,但它也不是微不足道的。在内存敏感的环境中,如果代码的不同部分都独立地创建和持有一个 GenerateState 实例,就会导致内存中有多个分词器对象的副本,造成不必要的浪费。

3. 状态不一致与逻辑错误

这是最严重的问题。GenerateState 不仅仅是资源的容器,它还管理着动态状态

  • self.semaphore : 它的核心价值在于全局性。我们希望控制的是"整个程序对SGLang服务的总并发请求数"。如果每个任务都有自己的信号量,那么并发控制就形同虚设。比如,并发限制是16,但如果有100个任务,每个任务都创建自己的信号量,它们就可以同时向服务器发送100个请求,导致服务器过载崩溃。

  • self.pendings : 这个集合记录了所有"正在进行中"的生成任务。这对于 abort 函数至关重要。abort 函数需要知道当前有哪些任务在运行,以便能够正确地等待它们结束(即使是被取消后)。如果每个任务都访问不同的 GenerateState 实例,abort 函数将无法找到并管理其他任务,导致清理工作失败。

  • self.aborted : 这是一个全局"中止"标志。一旦 abort 函数被调用,它会将 state.aborted 设为 True。所有新创建的或正在等待信号量的任务在开始执行前,都会检查这个标志。如果为 True,它们会立即放弃执行,从而快速停止整个 rollout 过程。如果 aborted 标志不是全局共享的,那么一个地方设置的中止信号无法被其他地方感知,abort 的效果将大打折扣。

总结:为什么单例在这里是正确的选择

问题 非单例的后果 使用单例的好处
性能 每次需要状态时都重新加载分词器等资源,导致显著的性能开销。 "一次加载,到处使用"。资源在程序启动时加载一次,之后所有调用都直接复用,性能高。
资源 内存中存在多个不必要的资源副本(如多个分词器对象)。 节省内存。整个应用程序只维护一份资源。
并发控制 每个任务有自己的信号量,无法实现全局并发限制,可能导致服务过载。 全局并发控制。所有任务共享同一个信号量,确保对SGLang的请求总数不会超过预设值。
状态一致性 abort 信号无法全局传播,挂起任务列表 (pendings) 分散且不完整,导致中止和清理逻辑失效。 统一的状态管理 。所有任务都操作同一个 pendings 集合和 aborted 标志,使得中止和恢复等复杂逻辑能够正确、可靠地工作。

因此,在这种场景下,将 GenerateState 设计成单例模式,不仅仅是一种"优化",更是确保程序正确性健壮性的关键设计决策。它保证了所有与生成相关的操作都在一个统一、共享的上下文中进行,避免了混乱和不可预测的行为。

这是一个非常棒的问题,它直击强化学习(RL)训练循环的核心机制。之所以要反复、周期性地调用 generate_rollout,是因为RL训练方法(如PPO)的本质所决定的。

简单来说,这是一种 "边玩边学、迭代优化" 的过程。

让我们把LLM的训练想象成教一个学生写作文。

  1. 学生(当前模型):你有一个版本的语言模型,就像一个特定水平的学生。
  2. 作文题目(Prompts) :你有一批作文题目,也就是 prompts
  3. 写作文(Rollout) :你让学生根据这些题目写一批作文。这个过程就是 generate_rollout。学生会用他当前的知识和风格来写。
  4. 老师打分(Reward Model):你(或者一个奖励模型老师)来阅读这些作文,并给每一篇打分。分数高的作文就是好作文。
  5. 复盘和学习(Optimization):你拿着这些打了分的作文,和学生一起复盘。告诉他:"你看,这几篇得了高分,是因为你用了这个句式、那个论点... 而那几篇分数低,是因为..." 通过这次复盘,学生学到了新的写作技巧,水平得到了提升。这个过程就是模型的优化步骤(梯度更新)。
  6. 进入下一轮 :现在你有了一个更强 的学生(更新后 的模型)。为了让他继续进步,你不能老是让他用旧的、已经复盘过的作文来学习。你需要给他新的题目 ,或者让他用新的水平 去写旧的题目,然后重复第3、4、5步。

所以,generate_rollout 的反复调用对应了这个过程中的第3步,它是整个迭代学习循环中不可或缺的一环。


从技术角度深入解释

在基于策略梯度(Policy Gradient)的强化学习算法(如PPO)中,这个循环有更精确的术语和原因:

1. On-Policy vs. Off-Policy(在策略 vs. 离策略)
  • PPO 是一种典型的 On-Policy(在策略)算法
  • "On-Policy" 的核心思想是:用来学习和更新模型的数据,必须是由当前最新版本的模型(策略)自己生成(采样)的。
  • 为什么?因为PPO的优化目标是最大化一个基于当前策略的期望奖励。如果你用一个很老的模型生成的数据来训练一个很新的模型,数据和模型之间就存在"分布不匹配"(distribution mismatch)的问题。这就好比用小学生写的范文去指导一个大学生,效果会很差,甚至可能让大学生的写作水平倒退。

所以,训练流程必须是:

  1. generate_rollout :用当前策略π_k 生成一批数据(经验)。
  2. optimize :用这批数据来计算梯度,更新模型,得到一个新的策略π_{k+1}
  3. 丢弃旧数据:从步骤1生成的旧数据通常会被丢弃(或者在经验回放池中被标记为旧的)。
  4. 重复 :用新的策略π_{k+1} 再次调用 generate_rollout 来生成新的一批数据,然后进行下一次优化。

generate_rollout 的每一次调用,都是为了获取与当前模型能力相匹配的、新鲜的、高质量的训练数据。

2. 探索与利用(Exploration and Exploitation)
  • 探索 :模型需要尝试生成各种不同风格、不同内容的回答,去发现哪些类型的回答能够获得高分。generate_rollout 通过设置一定的 temperature(温度)或 top_p 采样,鼓励模型进行探索。
  • 利用:当模型发现某种回答模式能稳定获得高分时,它会倾向于更多地生成这种回答。

随着模型不断被优化,它对"好回答"的定义会越来越精确。因此,它需要不断地在新学到的"好"的方向上进行更深入的探索。反复调用 generate_rollout 允许模型在每一次迭代后,都在一个新的、更有潜力的区域进行探索。

3. 避免模型"自娱自乐"

如果只用初始模型生成一次数据,然后反复用这批数据进行训练,模型很快就会在这批固定的数据上"过拟合"。它会学会如何完美地应对这批数据,但在遇到新的、从未见过的问题时表现会很差。

通过周期性地生成新数据,我们可以确保模型总是在面对新的挑战,迫使它学习到更具泛化能力的知识,而不是仅仅记住特定样本的答案。

总结

为什么要每次调用 generate_rollout

  • 遵从On-Policy算法的要求:确保用于训练的数据是由当前最新模型生成的,避免数据分布不匹配导致训练不稳定。
  • 实现"探索-利用"的循环:让模型在更新后,能基于新的认知水平去探索更高奖励的回答空间。
  • 保证训练数据的"新鲜度":防止模型在固定的旧数据上过拟合,提高模型的泛化能力。
  • 迭代提升 :这是强化学习"试错-学习-改进"这一核心哲学的直接体现。每一次 rollout + optimize 的循环,都是模型能力的一次小幅但确切的提升。

所以,这个反复调用的过程,正是驱动模型在与奖励信号的互动中不断进化、变得越来越强大的引擎。

相关推荐
老鼠只爱大米1 小时前
Java设计模式之装饰器模式详解
java·设计模式·装饰器模式·decorator·java设计模式
9***Y481 小时前
Web3预言机设计模式
设计模式·web3
wudl55661 小时前
Agent 设计模式全面分析
设计模式
__万波__1 小时前
二十三种设计模式(四)--原型模式
java·设计模式·原型模式
4***g8941 小时前
Java进阶-SpringCloud设计模式-工厂模式的设计与详解
java·spring cloud·设计模式
__万波__1 小时前
二十三种设计模式(五)--建造者模式
java·设计模式·建造者模式
北郭guo1 小时前
Java设计模式 【理论+代码实现】 让你从小白到大佬的蜕变
java·开发语言·设计模式
执笔论英雄8 小时前
Slime异步原理(单例设计模式)4
开发语言·python·设计模式
执笔论英雄10 小时前
Slime异步原理(单例设计模式)5
设计模式