https://github.com/THUDM/slime/blob/2d4e625d/slime/rollout/sglang_rollout.py
好的,我们来用中文详细解释一下这个Python模块,并用一个具体的例子来说明它的执行过程。
这个文件是大型语言模型(LLM)训练框架(特别是像RLHF/RLAIF,即从人类/AI反馈中进行强化学习的框架,比如项目名SLIME)中的一个核心部分。它负责处理 "rollout"(采样/部署)阶段,也就是使用当前版本的模型来生成新数据的过程。
核心目标:一个高效的"数据生成引擎"
在基于强化学习的LLM微调(如PPO算法)中,训练循环包含几个关键步骤:
- Rollout(数据生成):拿一批提示(prompts),用当前的模型(称为"策略模型")生成回答。
- Reward Modeling(奖励建模):用一个独立的"奖励模型"(Reward Model, RM)来给这些生成的回答打分。分数越高,代表回答质量越好。
- Optimization(优化):使用提示、生成的回答以及它们的奖励分数来计算损失(例如PPO损失),然后更新策略模型的权重。
这个文件主要实现了第1步和第2步。它是一个异步、高性能的引擎,旨在:
- 与
SGLang推理服务进行通信,以实现快速、高效的文本/多模态内容生成。 - 通过并发管理,充分利用GPU资源。
- 同时支持训练数据生成 (
generate_rollout_async) 和模型评估 (eval_rollout)。 - 集成奖励模型打分功能 (
generate_and_rm)。 - 支持部分生成(partial rollouts)、动态过滤和多模态输入等高级特性。
模块结构与函数详解
我们从基础组件到核心功能,逐一拆解。
1. GenerateState (单例类)
这是整个架构中最重要的部分。
- 它是什么? 一个持有整个生成流程全局、持久状态的单例类。"单例"意味着在整个应用程序的生命周期中,这个类只会有一个实例。
- 为什么是单例? 初始化像
tokenizer这样的资源开销很大,我们不希望每次都重新加载。通过单例模式,代码中任何地方调用GenerateState(args)都会获取到同一个、已经加载好所有资源的对象。 - 关键属性:
args: 命令行参数,提供了所有配置。tokenizer: Hugging Face的tokenizer,用于文本编码。semaphore: 一个asyncio.Semaphore(信号量)。这是并发控制 的关键工具。如果sglang_server_concurrency(并发数)设为16,它就确保最多只有16个生成请求同时发送给SGLang服务。这可以防止服务器过载。sampling_params: 一个字典,包含了默认的生成参数,如temperature(温度)、top_p、max_new_tokens(最大新生成token数)等。pendings,aborted,remaining_batch_size: 用于管理当前正在运行的生成任务批次的状态变量。
2. generate (核心生成函数)
这个异步函数是与SGLang生成服务直接交互的接口。
- 职责 :接收一个
Sample对象,并为其生成模型的回答。 - 工作流程 :
- 获取状态 : 访问全局的
GenerateState。 - 处理多模态 : 检查
sample.prompt是普通字符串还是包含文本和图片的字典列表。如果发现有图片,它会调用_load_and_encode_image函数从路径加载图片,将其编码为Base64格式的JPEG字符串,并在文本提示中插入一个特殊的<image>占位符。 - 处理多轮/部分生成 : 如果
sample.response已经有内容(意味着这是对之前生成内容的续写),它会相应地减少max_new_tokens,以确保总长度不超过限制。 - 构建请求体 (Payload) : 为SGLang服务的
/generate接口构建一个JSON请求。这包括了采样参数、input_ids(编码后的提示)以及图片数据(如果有)。 - 发送HTTP请求 : 使用
await post(...)异步地将请求发送到SGLang服务并等待响应。 - 处理响应 : 仔细解析SGLang返回的JSON数据。
- 提取生成的文本、新生成的token ID以及它们的对数概率(
logprobs)。 - 用这些新数据更新
Sample对象,包括tokens列表、拼接后的response字符串、response_length和rollout_log_probs。
- 提取生成的文本、新生成的token ID以及它们的对数概率(
- 设置状态 : 根据SGLang返回的
finish_reason(结束原因,如length表示达到最大长度,stop表示遇到停止符,abort表示被中断),更新sample.status为TRUNCATED,COMPLETED, 或ABORTED。
- 获取状态 : 访问全局的
3. generate_and_rm & generate_and_rm_group
这两个函数包装了 generate,为其增加了奖励模型(RM)打分的步骤。
-
generate_and_rm:- 首先,它调用
await generate(...)来获取模型的回答。它使用state.semaphore来确保并发请求不会超限。 - 生成完成后(且样本未被中止),它调用
await async_rm(args, sample)来为生成的回答获取一个奖励分数。 - 它将这个
reward附加到Sample对象上。
- 这个函数封装了"先生成,后打分"的核心模式。
- 首先,它调用
-
generate_and_rm_group:- 这个函数用于处理一个提示需要生成多个回答的场景 (
n_samples_per_prompt)。 - 它使用
asyncio.gather并发地创建并运行多个generate_and_rm任务。 - 它有针对
args.group_rm的特殊逻辑。如果group_rm为真,意味着奖励模型需要看到一组中所有 的回答才能打分(例如,选出最好一个)。在这种情况下,它会在组内所有生成任务都完成后,调用batched_async_rm进行批量打分。
- 这个函数用于处理一个提示需要生成多个回答的场景 (
4. generate_rollout_async (训练数据生成管理器)
这是生成一批训练数据的主函数。
- 目标 : 准确地生成
args.rollout_batch_size组高质量的样本。 - 工作流程 :
- 初始化 : 获取全局状态,实例化动态过滤器,并设置一个进度条 (
tqdm)。 - 主循环 (
while len(data) < target_data_size): 这个循环会一直运行,直到收集到足够多的有效数据。 - 提交任务 : 在循环内部,它首先从
data_source(数据源) 获取新的提示,并使用state.submit_generate_tasks将它们作为生成任务提交。这会将asyncio.Task对象填充到state.pendings集合中。 - 等待与处理 : 它使用
await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED)。这是一个非常高效的模式:它不等待所有任务完成,而是只要有任何一个任务完成,就立即处理其结果,然后继续循环。这使得任务管道始终保持满负荷运行。 - 动态过滤 : 当一组样本生成完成后,它会被传递给一个
dynamic_filter。这个过滤器可以根据某些标准(例如,所有回答都太短、奖励分数太低)决定丢弃这组样本。这可以防止低质量数据进入训练缓冲区。 - 收集数据 : 如果过滤器决定保留 (
keep=True),这组样本就被添加到data列表中。 - 中止与清理 : 一旦收集到
target_data_size的数据,循环就退出。但此时可能还有很多正在进行中的生成请求。await abort(args, rollout_id)函数会被调用,来取消所有这些在SGLang worker上挂起的请求。这节省了大量的计算资源。如果partial_rollout(部分生成)被启用,它甚至会收集这些被中断的、未完成的样本,以便后续使用。 - 返回: 最后,它返回收集到的数据和任何被中止的样本。
- 初始化 : 获取全局状态,实例化动态过滤器,并设置一个进度条 (
5. eval_rollout & eval_rollout_single_dataset
这两个函数用于评估,而非训练。
- 目的: 在一个固定的评估数据集上运行模型,并衡量其性能(主要通过奖励分数)。
- 与训练Rollout的关键区别 :
- 它们使用特定的评估数据集 (
EvalDatasetConfig)。 - 它们可能使用不同的采样参数(例如,
eval_temperature通常更低或为0,以获得更确定的输出)。 - 它们为每个提示生成固定数量的样本 (
n_samples_per_eval_prompt)。 - 最终输出的不是样本本身,而是一个包含指标(如
rewards、truncated状态)的字典,按数据集名称组织。这些指标会被记录到像W&B这样的工具中,用于追踪模型性能随时间的变化。
- 它们使用特定的评估数据集 (
6. generate_rollout & generate_abortable_samples (同步封装)
generate_rollout: 这就是你提问的那个函数。它是一个简单的、同步的入口点。训练框架的其他部分(很可能是同步代码)就调用这个函数。generate_abortable_samples: 这个函数是同步世界和异步世界之间的桥梁。- 它检查
evaluation标志。 - 如果为真,它调用
run(eval_rollout(...))。 - 如果为假,它调用
run(generate_rollout_async(...))。 - 这里的
run函数(很可能是asyncio.run)会启动asyncio事件循环,运行给定的异步函数直到它完成,然后关闭事件循环并返回最终结果。
- 它检查
执行过程举例(以训练为例)
假设我们有以下配置:
rollout_batch_size= 64 (目标是生成64组样本)n_samples_per_prompt= 4 (每个提示生成4个回答)sglang_server_concurrency= 16 (最多16个并发请求)
执行流程如下:
-
外部调用 : 训练框架的主循环调用
generate_rollout(args, rollout_id=1, data_buffer=...)。 -
进入异步世界 :
generate_rollout调用generate_abortable_samples,后者又调用asyncio.run(generate_rollout_async(...))。此时,程序进入了asyncio的事件循环。 -
generate_rollout_async开始工作:- 第一次填充任务 :
while len(data) < 64循环开始。data列表当前为空。 while state.remaining_batch_size < 64循环也开始。- 它调用
data_source从数据缓冲区获取一批提示,比如拿了20个提示。 state.submit_generate_tasks被调用。它为这20个提示创建了20个generate_and_rm_group任务。每个任务都包含了对同一个提示的4个Sample对象的拷贝。- 这20个任务被添加到
state.pendings集合中。state.remaining_batch_size变为20。
- 第一次填充任务 :
-
并发执行与处理:
- 程序执行到
await asyncio.wait(...)。asyncio事件循环开始调度这20个任务。 - 由于并发限制是16,事件循环会立即启动前16个任务。每个任务开始向SGLang发送生成请求。
- 假设任务#3最先完成。它生成了4个回答,并为它们打好了分。
asyncio.wait返回,done集合里包含了完成的任务#3。for task in done:循环开始处理任务#3的结果(一个包含4个Sample对象的列表)。- 这组样本被送入
dynamic_filter。假设过滤器认为这组样本质量合格。 - 这组样本被添加到
data列表中。现在len(data)是1。进度条更新。state.remaining_batch_size减1,变为19。 while len(data) < 64循环继续。state.remaining_batch_size(19) 小于64,所以再次调用data_source获取更多提示,提交新任务,直到remaining_batch_size达到或超过64。- 这个"提交-等待任一完成-处理"的过程不断重复。由于请求和处理是并行的,GPU的利用率会非常高。
- 程序执行到
-
任务完成与中止:
- 当
len(data)终于达到64时,外层while循环结束。 - 此时,
state.pendings中可能还有很多正在运行或排队中的任务(因为我们总是超量提交任务以保持流水线满载)。 await abort(...)被调用。它向所有SGLang worker发送一个abort_requestHTTP请求,强制取消所有挂起的生成任务。abort函数还会等待所有被取消的任务真正结束,并收集它们已经生成的部分结果(如果partial_rollout开启)。
- 当
-
返回结果:
- 收集到的64组完整样本被排序后返回。
generate_abortable_samples将这些样本和被中止的部分样本返回给generate_rollout。generate_rollout将部分样本存入data_buffer以备后用,并将完整的64组样本返回给外部的训练循环。
-
训练: 外部训练循环拿到这批新鲜出炉、带有奖励分数的数据,进行PPO等算法的优化步骤。
这个设计通过异步IO和并发控制,将数据生成和打分这两个耗时的步骤高效地流水线化,是现代LLM训练框架中实现高性能数据吞吐的关键。
当然可以,这是一个非常好的问题,它触及了软件设计中的一个核心权衡。
如果 GenerateState 不 设置成单例(Singleton),会带来以下几个主要问题,这些问题主要关乎性能、资源管理和状态一致性。
1. 性能开销:昂贵的重复初始化
GenerateState 的 __init__ 方法中做了两件比较耗时的事情:
self.tokenizer = AutoTokenizer.from_pretrained(...): 加载一个分词器(tokenizer)需要从磁盘读取多个文件(词汇表、配置文件等),解析它们,并构建相应的对象。虽然这通常比加载整个模型快得多,但对于一个频繁调用的函数来说,重复执行这个操作仍然会带来不可忽视的性能开销。self.semaphore = asyncio.Semaphore(...): 创建信号量本身很快,但关键在于它的作用域。信号量的目的是在整个应用程序范围内限制对某一资源的并发访问。如果每次调用都创建一个新的信号量,它就失去了全局并发控制的能力。
不使用单例的场景下:
假设 generate_rollout 在一个训练循环中被反复调用。每次调用 generate_rollout 都会启动一个新的异步事件循环,并多次间接调用 GenerateState(args)。
python
# 伪代码,模拟非单例的情况
class GenerateState:
def __init__(self, args):
print("正在加载 Tokenizer...") # 这会打印很多次
self.tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint)
self.semaphore = asyncio.Semaphore(args.sglang_server_concurrency)
# ... 其他初始化 ...
# 在训练循环中
for epoch in range(num_epochs):
# 每次 rollout 都会创建一个新的 GenerateState 实例
train_data = generate_rollout(...)
# ... 进行训练 ...
每次调用 generate_rollout,内部的 generate, generate_and_rm 等函数都会创建一个全新的 GenerateState 实例。这意味着:
- 分词器会被反复加载,导致每次 rollout 开始时都有一个明显的延迟。
- 每个
GenerateState实例都有自己的semaphore,无法实现跨函数调用、跨任务的全局并发控制。
2. 资源浪费:内存占用增加
虽然分词器占用的内存远小于模型本身,但它也不是微不足道的。在内存敏感的环境中,如果代码的不同部分都独立地创建和持有一个 GenerateState 实例,就会导致内存中有多个分词器对象的副本,造成不必要的浪费。
3. 状态不一致与逻辑错误
这是最严重的问题。GenerateState 不仅仅是资源的容器,它还管理着动态状态。
-
self.semaphore: 它的核心价值在于全局性。我们希望控制的是"整个程序对SGLang服务的总并发请求数"。如果每个任务都有自己的信号量,那么并发控制就形同虚设。比如,并发限制是16,但如果有100个任务,每个任务都创建自己的信号量,它们就可以同时向服务器发送100个请求,导致服务器过载崩溃。 -
self.pendings: 这个集合记录了所有"正在进行中"的生成任务。这对于abort函数至关重要。abort函数需要知道当前有哪些任务在运行,以便能够正确地等待它们结束(即使是被取消后)。如果每个任务都访问不同的GenerateState实例,abort函数将无法找到并管理其他任务,导致清理工作失败。 -
self.aborted: 这是一个全局"中止"标志。一旦abort函数被调用,它会将state.aborted设为True。所有新创建的或正在等待信号量的任务在开始执行前,都会检查这个标志。如果为True,它们会立即放弃执行,从而快速停止整个 rollout 过程。如果aborted标志不是全局共享的,那么一个地方设置的中止信号无法被其他地方感知,abort的效果将大打折扣。
总结:为什么单例在这里是正确的选择
| 问题 | 非单例的后果 | 使用单例的好处 |
|---|---|---|
| 性能 | 每次需要状态时都重新加载分词器等资源,导致显著的性能开销。 | "一次加载,到处使用"。资源在程序启动时加载一次,之后所有调用都直接复用,性能高。 |
| 资源 | 内存中存在多个不必要的资源副本(如多个分词器对象)。 | 节省内存。整个应用程序只维护一份资源。 |
| 并发控制 | 每个任务有自己的信号量,无法实现全局并发限制,可能导致服务过载。 | 全局并发控制。所有任务共享同一个信号量,确保对SGLang的请求总数不会超过预设值。 |
| 状态一致性 | abort 信号无法全局传播,挂起任务列表 (pendings) 分散且不完整,导致中止和清理逻辑失效。 |
统一的状态管理 。所有任务都操作同一个 pendings 集合和 aborted 标志,使得中止和恢复等复杂逻辑能够正确、可靠地工作。 |
因此,在这种场景下,将 GenerateState 设计成单例模式,不仅仅是一种"优化",更是确保程序正确性 和健壮性的关键设计决策。它保证了所有与生成相关的操作都在一个统一、共享的上下文中进行,避免了混乱和不可预测的行为。
这是一个非常棒的问题,它直击强化学习(RL)训练循环的核心机制。之所以要反复、周期性地调用 generate_rollout,是因为RL训练方法(如PPO)的本质所决定的。
简单来说,这是一种 "边玩边学、迭代优化" 的过程。
让我们把LLM的训练想象成教一个学生写作文。
- 学生(当前模型):你有一个版本的语言模型,就像一个特定水平的学生。
- 作文题目(Prompts) :你有一批作文题目,也就是
prompts。 - 写作文(Rollout) :你让学生根据这些题目写一批作文。这个过程就是
generate_rollout。学生会用他当前的知识和风格来写。 - 老师打分(Reward Model):你(或者一个奖励模型老师)来阅读这些作文,并给每一篇打分。分数高的作文就是好作文。
- 复盘和学习(Optimization):你拿着这些打了分的作文,和学生一起复盘。告诉他:"你看,这几篇得了高分,是因为你用了这个句式、那个论点... 而那几篇分数低,是因为..." 通过这次复盘,学生学到了新的写作技巧,水平得到了提升。这个过程就是模型的优化步骤(梯度更新)。
- 进入下一轮 :现在你有了一个更强 的学生(更新后 的模型)。为了让他继续进步,你不能老是让他用旧的、已经复盘过的作文来学习。你需要给他新的题目 ,或者让他用新的水平 去写旧的题目,然后重复第3、4、5步。
所以,generate_rollout 的反复调用对应了这个过程中的第3步,它是整个迭代学习循环中不可或缺的一环。
从技术角度深入解释
在基于策略梯度(Policy Gradient)的强化学习算法(如PPO)中,这个循环有更精确的术语和原因:
1. On-Policy vs. Off-Policy(在策略 vs. 离策略)
- PPO 是一种典型的 On-Policy(在策略)算法。
- "On-Policy" 的核心思想是:用来学习和更新模型的数据,必须是由当前最新版本的模型(策略)自己生成(采样)的。
- 为什么?因为PPO的优化目标是最大化一个基于当前策略的期望奖励。如果你用一个很老的模型生成的数据来训练一个很新的模型,数据和模型之间就存在"分布不匹配"(distribution mismatch)的问题。这就好比用小学生写的范文去指导一个大学生,效果会很差,甚至可能让大学生的写作水平倒退。
所以,训练流程必须是:
generate_rollout:用当前策略π_k 生成一批数据(经验)。optimize:用这批数据来计算梯度,更新模型,得到一个新的策略π_{k+1}。- 丢弃旧数据:从步骤1生成的旧数据通常会被丢弃(或者在经验回放池中被标记为旧的)。
- 重复 :用新的策略π_{k+1} 再次调用
generate_rollout来生成新的一批数据,然后进行下一次优化。
generate_rollout 的每一次调用,都是为了获取与当前模型能力相匹配的、新鲜的、高质量的训练数据。
2. 探索与利用(Exploration and Exploitation)
- 探索 :模型需要尝试生成各种不同风格、不同内容的回答,去发现哪些类型的回答能够获得高分。
generate_rollout通过设置一定的temperature(温度)或top_p采样,鼓励模型进行探索。 - 利用:当模型发现某种回答模式能稳定获得高分时,它会倾向于更多地生成这种回答。
随着模型不断被优化,它对"好回答"的定义会越来越精确。因此,它需要不断地在新学到的"好"的方向上进行更深入的探索。反复调用 generate_rollout 允许模型在每一次迭代后,都在一个新的、更有潜力的区域进行探索。
3. 避免模型"自娱自乐"
如果只用初始模型生成一次数据,然后反复用这批数据进行训练,模型很快就会在这批固定的数据上"过拟合"。它会学会如何完美地应对这批数据,但在遇到新的、从未见过的问题时表现会很差。
通过周期性地生成新数据,我们可以确保模型总是在面对新的挑战,迫使它学习到更具泛化能力的知识,而不是仅仅记住特定样本的答案。
总结
为什么要每次调用 generate_rollout?
- 遵从On-Policy算法的要求:确保用于训练的数据是由当前最新模型生成的,避免数据分布不匹配导致训练不稳定。
- 实现"探索-利用"的循环:让模型在更新后,能基于新的认知水平去探索更高奖励的回答空间。
- 保证训练数据的"新鲜度":防止模型在固定的旧数据上过拟合,提高模型的泛化能力。
- 迭代提升 :这是强化学习"试错-学习-改进"这一核心哲学的直接体现。每一次
rollout+optimize的循环,都是模型能力的一次小幅但确切的提升。
所以,这个反复调用的过程,正是驱动模型在与奖励信号的互动中不断进化、变得越来越强大的引擎。