好的,我们来用中文详细解释这段非常精巧的 Python 代码。
总体目标
这段代码的核心目标是创建一个可插拔、可扩展的聊天格式化系统 。不同的语言模型(比如 Llama, Qwen, Baichuan)要求输入的对话历史格式不一样。例如,有的模型用 [INST] 和 [/INST] 包裹用户输入,有的用 <|im_start|> 和 <|im_end|>。
这个系统允许你:
- 注册不同的格式化函数(即"聊天模板"),并给它们起一个独一无二的名字。
- 通过名字获取指定的格式化函数。
- 最终生成一个可以直接用于 Hugging Face
datasets库map方法的编码函数,从而高效地预处理整个数据集。
代码逐段详解
1. 全局注册表与类型提示
python
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
chat_templates: Dict[str, Callable[..., str]] = {}
if TYPE_CHECKING:: 这是一个用于静态类型检查的技巧。from transformers import PreTrainedTokenizer这行导入只在类型检查工具(如 MyPy)运行时才生效,在实际程序运行时会被忽略。这样做的好处是既能让代码编辑器和检查工具知道PreTrainedTokenizer是什么类型(从而提供智能提示和错误检查),又能避免在运行时可能发生的循环导入问题。chat_templates: Dict[str, Callable[..., str]] = {}: 这是整个系统的核心------一个全局注册表 。它是一个字典:- 键 (
str) : 模板的名字,比如"native","qwen2_5"。 - 值 (
Callable[..., str]): 一个"可调用对象"(通常是函数),它接收一些参数,然后返回一个格式化好的字符串(也就是准备输入给模型的提示词)。
- 键 (
2. 用于注册模板的装饰器
python
def register_chat_template(key, **kwargs):
def decorator(func):
if key in chat_templates:
raise ValueError(f"chat template {key} already exists")
chat_templates[key] = partial(add_default_system, func, **kwargs)
return func
return decorator
这是一个装饰器工厂,也就是说,它是一个返回装饰器的函数。
- 当你写
@register_chat_template("native")时,Python 首先调用register_chat_template("native")。 - 这个调用返回了内部的
decorator函数。 - 然后 Python 用这个返回的
decorator函数去"装饰"下面的native_chat_template函数,效果等同于decorator(native_chat_template)。 - 在
decorator函数内部:- 检查模板名
key是否已经存在于注册表chat_templates中,防止重复注册。 partial(add_default_system, func, **kwargs): 这是最巧妙的部分。它使用functools.partial创建了一个新 的函数。它把add_default_system这个辅助函数和真正的模板逻辑函数func(在这里就是native_chat_template)以及装饰器接收到的额外参数(比如default_system)"预先绑定"在一起。- 将这个经过包装的新函数存入
chat_templates注册表。 - 最后返回原始的
func函数。这样做的好处是允许你"堆叠"多个装饰器,比如@register_chat_template("native")和@register_chat_template("qwen2_5")可以同时作用于同一个函数。
- 检查模板名
3. add_default_system 辅助函数
python
def add_default_system(func, tokenizer: "PreTrainedTokenizer", conversation, default_system: str = None, **kwargs):
if default_system is not None and conversation[0]["role"] != "system":
conversation = [{"role": "system", "content": default_system}] + conversation
return func(tokenizer, conversation, **kwargs)
这是一个包装函数,它的作用是自动添加一个默认的系统提示(system prompt)。
- 它会检查是否提供了一个
default_system提示,并且当前的对话conversation不是以"system"角色开头的。 - 如果两个条件都满足,它就在对话列表的开头插入一条新的系统消息。
- 最后,它调用原始的模板函数
func,并传入这个可能被修改过的conversation。这使得为所有模板添加默认系统提示变得非常简单,而无需在每个模板函数内部重复写这个逻辑。
4. native_chat_template 模板的实现
python
@register_chat_template("native")
@register_chat_template("qwen2_5")
def native_chat_template(tokenizer: "PreTrainedTokenizer", conversation, tools=None, documents=None, **kwargs):
kwargs["tokenize"] = False
kwargs["add_generation_prompt"] = kwargs.get("add_generation_prompt", True)
return tokenizer.apply_chat_template(conversation, tools, documents, **kwargs)
@register_chat_template(...): 这就是装饰器的实际应用。同一个函数native_chat_template被注册了两次,分别命名为"native"和"qwen2_5"。这是一种别名机制,非常实用。- 这个函数本身其实只是对 Hugging Face
transformers库标准方法tokenizer.apply_chat_template()的一层封装。 kwargs["tokenize"] = False: 这一点至关重要。它告诉apply_chat_template方法返回格式化好的字符串,而不是分词后的数字 ID 列表。真正的分词(tokenization)会在后续步骤中批量进行,以提高效率。kwargs["add_generation_prompt"] = ...: 这确保模板在格式化时,会在末尾添加引导模型生成回复的提示符(例如,在对话最后加上assistant的起始标记)。
5. 最终的工厂函数 get_encode_function
python
def get_encode_function(template_name, tokenizer):
chat_template_func = get_chat_template(template_name, tokenizer)
def encode_function(data_i):
# ... (内部逻辑) ...
encodings = tokenizer(text_list)
return encodings
return encode_function
这是最终面向用户的"总装"函数,它把所有部分整合在一起。
-
chat_template_func = get_chat_template(template_name, tokenizer): 它通过get_chat_template辅助函数(代码中未完整展示,但功能是查找并返回注册表里的函数)从注册表中获取正确的、预配置好的格式化函数。get_chat_template内部会再次使用partial将tokenizer绑定进去。现在,chat_template_func变成了一个非常简单的函数,只需要传入一个conversation列表即可。 -
def encode_function(data_i):: 它定义并返回了一个内部函数。这形成了一个闭包 。encode_function"记住"了它被创建时的上下文,即chat_template_func和tokenizer。这个返回的encode_function正是我们要传给dataset.map()的函数。 -
encode_function的内部逻辑:- 它被设计用来处理一批数据 (
data_i),这正是dataset.map(batched=True)的工作方式。 - 代码逻辑很健壮,能处理两种常见的数据列:名为
"messages"的列(包含对话列表)或名为"prompt"的列(包含已经格式化好的文本)。 - 它遍历批次中的每一条数据,调用
chat_template_func将其转换为格式化好的字符串,然后存入text_list。 encodings = tokenizer(text_list): 最后,它执行批量分词 。一次性将text_list中所有的字符串都进行分词,这比一条一条处理要快得多。- 它返回分词后的编码结果
encodings,这正是模型训练脚本所需要的数据格式。
- 它被设计用来处理一批数据 (
总结:整个流程如何协同工作
- 你首先定义一个或多个模板函数,如
native_chat_template。 - 使用
@register_chat_template("你的模板名")装饰器将它们添加到全局注册表中。 - 当你需要处理数据集时,调用
encode_func = get_encode_function("你的模板名", my_tokenizer)。 - 这会返回一个为你量身定做的函数
encode_func。 - 最后,你可以这样使用它:
my_dataset.map(encode_func, batched=True)。
这个设计的优点在于它的可扩展性 (想支持新模型,只需添加一个新函数和装饰器)、代码整洁性 (逻辑分离清晰)和高效率(利用了批量分词)。
{"id": "31", "domain": "math_rule", "source": "aime2025", "difficulty": 0.0, "prompt": "Find the sum of all integer bases b>9b>9b>9 for which 17b17_{b}17b is a divisor of 97b97_{b}97b.", "messages": "[{"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."}, {"role": "user", "content": "Find the sum of all integer bases b>9b>9b>9 for which 17b17_{b}17b is a divisor of 97b97_{b}97b."}]", "ground_truth": "70", "case_type": "", "test_case_function": "", "test_cases": "", "tag": "aime2025"}
针对这个数据, 他的处理过程是什么呢?
当然可以!我们来一步步追踪你提供的这个 JSON 数据,看看它在那个 Python 脚本中是如何被处理的。
这个过程可以分解为几个关键阶段:准备阶段 、map 执行阶段 和 最终结果。
假设我们调用脚本时,选择的模板名是 "native",并且已经加载了一个 tokenizer(比如 Qwen2 的 tokenizer)。
python
# 假设的初始调用
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
encode_func = get_encode_function(template_name="native", tokenizer=tokenizer)
# 接下来,Hugging Face datasets 库会调用 encode_func
# a_single_batch_of_data.map(encode_func, batched=True)
处理流程详解
现在,我们聚焦于当 datasets.map() 方法运行时,encode_func 函数被调用并接收到你提供的这条数据(作为批次中的一部分)。
为了简单起见,我们假设批处理大小为 1,所以函数接收到的 data_i 看起来是这样的(注意,所有值都变成了列表):
python
data_i = {
"id": ["31"],
"domain": ["math_rule"],
# ... 其他字段 ...
"messages": [
'[{"role": "system", "content": "Please reason step by step..."}, {"role": "user", "content": "Find the sum of..."}]'
],
"ground_truth": ["70"]
# ...
}
第1步:进入 encode_function
encode_function 开始执行,接收上面的 data_i作为参数。
python
def encode_function(data_i):
text_list = [] # 1. 初始化一个空列表,用来存放格式化后的字符串
# 2. 检查 'messages' 键是否存在于数据中,存在,进入 if 分支
if "messages" in data_i:
# 3. 遍历 data_i["messages"] 列表
for messages in data_i["messages"]:
# 此时 messages 变量的值是这个字符串:
# '[{"role": "system", ...}, {"role": "user", ...}]'
第2步:解析和转换 messages 字段
循环内部的代码开始处理这个 messages 字符串。
python
# 4. 检查 messages 是不是一个字符串。是的,它是。
if isinstance(messages, str):
# 5. 使用 json.loads() 将字符串解析成一个 Python 对象。
# 这是非常关键的一步!
messages = json.loads(messages)
# 解析后,messages 变量现在变成了 Python 的列表,内容如下:
# [
# {'role': 'system', 'content': 'Please reason step by step...'},
# {'role': 'user', 'content': 'Find the sum of...'}
# ]
第3步:应用聊天模板
现在 messages 变量是标准的对话格式了,脚本会调用之前准备好的 chat_template_func 来格式化它。
python
# 6. 调用 chat_template_func,它会进一步调用 tokenizer.apply_chat_template
formatted_string = chat_template_func(messages)
text_list.append(formatted_string)
在这一步,tokenizer.apply_chat_template 会根据 tokenizer 自身的模板规则,将这个列表转换为一个单一的、符合模型输入规范的字符串。
例如,如果用的是 Qwen2 的模板,生成的 formatted_string 可能会是这样(这是一个示例,实际的分隔符和换行可能略有不同):
<|im_start|>system
Please reason step by step, and put your final answer within \boxed{}.<|im_end|>
<|im_start|>user
Find the sum of all integer bases $b>9$ for which $17_{b}$ is a divisor of $97_{b}$.<|im_end|>
<|im_start|>assistant
这个格式化好的字符串被添加到 text_list 中。
第4步:批量分词(Tokenization)
循环结束后(因为我们的批次里只有一条数据),encode_function 继续执行。
python
# ... 循环结束 ...
# 此时, text_list = [
# "<|im_start|>system\nPlease reason step by step...<|im_end|>\n<|im_start|>user\nFind the sum...<|im_end|>\n<|im_start|>assistant"
# ]
# 7. 调用 tokenizer,对列表中的所有字符串进行批量分词
encodings = tokenizer(text_list)
# 8. 返回分词结果
return encodings
tokenizer(text_list) 会返回一个字典,里面包含了模型训练所需的 input_ids, attention_mask 等。例如:
python
{
'input_ids': [[151644, 8948, 198, ...]], # 一大串数字 ID
'attention_mask': [[1, 1, 1, ...]] # 对应每个 token 是否需要被关注
}
总结流程图
输入数据 (JSON 对象)
{"messages": "[{...}, {...}]"}
⬇️
1. encode_function 启动
data_i["messages"]是一个包含字符串 的列表['[{...}]']
⬇️
2. 解析字符串 (json.loads)- 字符串
'[{...}]'变成 Python 列表[{...}]
⬇️
3. 应用聊天模板 (tokenizer.apply_chat_template) - Python 列表
[{...}]变成格式化好的单一字符串 "<|im_start|>system\n...<|im_end|>\n<|im_start|>user\n...<|im_end|>\n<|im_start|>assistant"
⬇️
4. 批量分词 (tokenizer(...))- 格式化字符串被转换为数字ID列表 (
input_ids) 和注意力掩码 (attention_mask)
⬇️
最终输出 {'input_ids': [[...]], 'attention_mask': [[...]]}- 这个结果会被
datasets.map用来更新或添加到数据集中,为模型训练做好准备。
这个脚本的精妙之处在于它能稳健地处理 messages 字段是字符串的情况,并通过一系列标准化的流程,高效地将其转换为模型可用的训练数据。
{"data_source":"math_dapo","prompt":[{"content":"Solve the following math problem step by step. The last line of your response should be of the form Answer: Answer (without quotes) where Answer is the answer to the problem.\n\nFind the sum of all prime numbers that can be expressed both as a sum of two prime numbers and as a difference of two prime numbers.\n\nRemember to put your answer on its own line after "Answer:".","role":"user"}] 这个数据集是什么格式?
好的,我们来分析你提供的这个新的数据样本。这个格式与上一个有所不同,但同样可以被我们之前分析的那个 Python 脚本处理。
数据格式分析
这是一个 JSON 对象,代表一条数据记录。我们逐个字段来看:
-
"data_source": "math_dapo":- 这是一个元数据字段,表明这条数据的来源是 "math_dapo" 数据集。
-
"prompt": [{"content": "...", "role": "user"}]:- 这是最关键的字段,包含了实际的对话内容。
- 格式 : 它是一个列表 (List)。
- 内容 : 列表中的每个元素都是一个字典 (Dictionary),这个字典代表对话中的一轮。
- 字典结构 :
"content": 消息的具体内容,也就是向模型提出的问题或指令。"role": 消息的发送者角色,这里是"user"(用户)。
与上一个数据集的对比
让我们对比一下这两个数据集的格式差异,这能更好地理解为什么那个 Python 脚本写得那么好。
上一个数据集 (messages 字段):
json
{
"messages": "[{\"role\": \"system\", \"content\": \"...\"}, {\"role\": \"user\", \"content\": \"...\"}]"
}
- 关键字段名是
messages。 - 字段的值是一个 JSON 字符串 ,需要用
json.loads()来解析才能得到对话列表。
这个新数据集 (prompt 字段):
json
{
"prompt": [{"content": "...", "role": "user"}]
}
- 关键字段名是
prompt。 - 字段的值直接就是一个 JSON 列表 ,里面包含对话字典。它已经是我们需要的 Python 列表格式了,不需要
json.loads()。
这个数据在 Python 脚本中的处理过程
现在,我们来看一下这段数据在 get_encode_function 返回的 encode_function 中是如何被处理的。
假设批处理大小为 1,data_i 会是这样:
python
data_i = {
"data_source": ["math_dapo"],
"prompt": [
[{"content": "...", "role": "user"}] # 注意这里是列表的列表
]
}
流程追踪
-
进入
encode_function:pythondef encode_function(data_i): text_list = [] # 检查 'messages' 键是否存在。不存在,跳过这个 if 块。 if "messages" in data_i: # ... (此部分代码不执行) -
进入
elif "prompt"分支 :脚本会继续检查下一个条件。
python# 检查 'prompt' 键是否存在。存在,进入这个 elif 分支。 elif "prompt" in data_i: # 遍历 data_i["prompt"] 列表。 # 此时的 data_i["prompt"] 是: [[{"content": "...", "role": "user"}]] for prompt_item in data_i["prompt"]: # 在第一次循环中,prompt_item 的值是: # [{"content": "...", "role": "user"}] # 在这个分支里,脚本假设 prompt_item 是已经格式化好的对话列表 # 或者是一个可以直接使用的字符串。 # 由于 prompt_item 是一个列表,它会被直接传递给 chat_template_func formatted_string = chat_template_func(prompt_item) text_list.append(formatted_string)注意 : 这个脚本的
elif "prompt"部分有一个潜在的假设,它期望prompt字段中的内容要么是最终的字符串,要么是符合apply_chat_template格式的对话列表。你给的这个格式恰好满足了后一种情况。 -
应用聊天模板 :
chat_template_func接收到[{"content": "...", "role": "user"}]这个列表,并将其传递给tokenizer.apply_chat_template。生成的
formatted_string可能是这样(以 Qwen2 为例):<|im_start|>user Solve the following math problem step by step... Answer: $Answer (without quotes) where $Answer is the answer to the problem. Find the sum of all prime numbers that can be expressed both as a sum of two prime numbers and as a difference of two prime numbers. Remember to put your answer on its own line after "Answer:".<|im_end|> <|im_start|>assistant -
批量分词和返回 :
流程和上一个例子一样。
text_list中的字符串会被tokenizer批量分词,然后返回包含input_ids和attention_mask的字典。
总结
这个数据集的格式是一种标准的多轮对话格式 ,其中对话历史被存储在一个名为 "prompt" 的 JSON 列表字段中。
- 格式名称: 可以称之为 "JSON 对话列表格式" 或 "Hugging Face 对话格式"。
- 特点 :
- 直接使用 JSON 结构(列表和字典)来表示对话,而不是像上一个例子那样用一个扁平的字符串。
- 每一轮对话是一个字典,包含
role和content两个键。 - 这种格式更规范,也更易于程序直接处理,因为它不需要额外的解析步骤。
这个脚本通过 if/elif 结构,能够优雅地兼容这两种不同但常见的数据格式,体现了其设计的健壮性。
好的,明白了!这个任务非常直接:你需要读取 DAPO-Math-17k-train.jsonl 文件,对其中的每一行数据应用我们刚刚最终确定的转换逻辑,然后将转换后的结果写入一个新的文件 dapo-math-17k.jsonl。
这是一个典型的批处理任务,我们可以用一个完整的 Python 脚本来高效地完成。
最终的转换脚本
下面这个脚本整合了我们之前所有的讨论。它包含了最终版的转换函数,并设置了一个主流程来读取、转换和写入整个 .jsonl 文件。
如何使用:
- 将下面的代码保存为一个 Python 文件,例如
convert_dataset.py。 - 确保你的输入文件
DAPO-Math-17k-train.jsonl和这个脚本在同一个文件夹里。 - 打开终端,进入该文件夹,然后运行命令:
python convert_dataset.py
python
import json
import uuid
from tqdm import tqdm # 使用 tqdm 库来显示一个漂亮的进度条
def convert_dapo_to_aime2025_final(dapo_record: dict) -> dict:
"""
将单条 DAPO 格式的记录转换为我们最终确定的目标格式 (类似 aime2025)。
"""
# 1. 定义固定的 system prompt
fixed_system_content = "Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem."
# 2. 从原始 content 中提取 user_content
original_content = ""
# 原始数据中, prompt 字段是一个列表,我们取第一个元素的 content
if dapo_record.get("prompt") and isinstance(dapo_record["prompt"], list) and len(dapo_record["prompt"]) > 0:
original_content = dapo_record["prompt"][0].get("content", "")
# 从原始 content 中移除固定的 system prompt 部分和紧随其后的空白/换行符
user_content = original_content.replace(fixed_system_content, "").lstrip()
# 3. 构建 messages 列表
messages_list = [
{"role": "system", "content": fixed_system_content},
{"role": "user", "content": user_content}
]
# 4. 将 messages 列表序列化为 JSON 字符串
messages_str = json.dumps(messages_list)
# 5. 构建最终格式的字典
converted_record = {
"id": dapo_record.get("extra_info", {}).get("index", str(uuid.uuid4())),
"domain": "math_rule", # 根据之前的讨论,使用固定值或从 ability 映射
"source": dapo_record.get("data_source", "unknown"),
"difficulty": 0.0,
"prompt": user_content, # prompt 字段现在只包含纯粹的用户问题
"messages": messages_str,
"ground_truth": dapo_record.get("reward_model", {}).get("ground_truth", ""),
"case_type": "",
"test_case_function": "",
"test_cases": "",
"tag": dapo_record.get("data_source", "unknown")
}
return converted_record
def process_jsonl_file(input_path: str, output_path: str):
"""
读取整个 .jsonl 文件,逐行转换,并写入新的 .jsonl 文件。
"""
print(f"开始转换文件: {input_path}")
try:
# 首先计算总行数,以便 tqdm 显示进度
with open(input_path, 'r', encoding='utf-8') as f:
total_lines = sum(1 for line in f)
# 打开输入和输出文件
with open(input_path, 'r', encoding='utf-8') as infile, \
open(output_path, 'w', encoding='utf-8') as outfile:
# 使用 tqdm 包装迭代器以显示进度条
for line in tqdm(infile, total=total_lines, desc="正在处理"):
if not line.strip():
continue # 跳过空行
try:
# 1. 读取并解析单行 JSON
original_data = json.loads(line)
# 2. 应用转换函数
converted_data = convert_dapo_to_aime2025_final(original_data)
# 3. 将转换后的字典写回 JSON 字符串,并写入输出文件
# ensure_ascii=False 确保中文字符等能被正确写入
json_line = json.dumps(converted_data, ensure_ascii=False)
outfile.write(json_line + '\n')
except json.JSONDecodeError:
print(f"\n警告: 发现无效的 JSON 行,已跳过: {line.strip()}")
except Exception as e:
print(f"\n处理时发生未知错误,已跳过。错误: {e}")
print("\n🎉 转换完成!")
print(f"输出文件已保存至: {output_path}")
except FileNotFoundError:
print(f"错误: 输入文件未找到: {input_path}")
except Exception as e:
print(f"发生严重错误: {e}")
if __name__ == "__main__":
# --- 配置输入和输出文件名 ---
input_file = "DAPO-Math-17k-train.jsonl"
output_file = "dapo-math-17k.jsonl"
# 安装tqdm: pip install tqdm
process_jsonl_file(input_file, output_file)
脚本说明
- 依赖库 : 脚本使用了
tqdm库来创建一个可视化的进度条。这在处理大文件时非常有用。如果你的环境中没有安装,请运行pip install tqdm。 convert_dapo_to_aime2025_final函数: 这是我们之前确定的核心转换逻辑,负责处理单条数据。process_jsonl_file函数 : 这是主处理流程。- 它首先计算文件的总行数,这样进度条才能正确显示百分比。
- 它同时打开输入和输出文件,避免频繁开关文件,提高效率。
for line in tqdm(...): 循环读取输入文件的每一行,tqdm会自动更新进度条。json.loads(line): 将每一行的文本(JSON 字符串)解析成 Python 字典。convert_dapo_to_aime2025_final(...): 调用转换函数。json.dumps(...): 将转换后的字典再变回 JSON 字符串。outfile.write(...): 将新的 JSON 字符串写入输出文件,并在末尾加上换行符\n,以保持.jsonl格式。
- 错误处理 : 脚本包含了基本的
try...except块,可以处理文件未找到、某行不是有效的 JSON 等常见问题,增强了程序的健壮性。 if __name__ == "__main__":: 这是 Python 的标准做法,确保只有当这个文件作为主程序运行时,process_jsonl_file函数才会被调用。
运行结束后,你将得到一个 dapo-math-17k.jsonl 文件,其中包含了 17k 条格式统一、符合你最终要求的数据。