1. 训练逻辑代码
python
# 假设以下是外部定义的依赖(方便理解)
# args: 训练参数配置(如设备、学习率、梯度累积步数等)
# optimizer: 优化器(如AdamW)
# model: 待训练的模型
# autocast_ctx: AMP自动混合精度上下文管理器
# scaler: GradScaler梯度缩放器(配合AMP使用)
# Logger: 自定义日志打印函数
# is_main_process: 判断是否为主进程(分布式训练用)
# lm_config: 语言模型配置(如hidden_size、是否用MoE等)
# lm_checkpoint: 自定义的检查点保存函数
# get_lr: 学习率调度函数(根据步数计算当前学习率)
# DistributedDataParallel: 分布式数据并行训练类
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
"""
执行单个Epoch的训练过程,包含完整的前向/反向传播、梯度处理、日志记录和模型保存逻辑
Args:
epoch (int): 当前训练的Epoch编号(从0开始)
loader (DataLoader): 训练数据加载器,输出(input_ids, labels)
iters (int): 当前Epoch的总迭代步数
start_step (int, optional): 起始迭代步数(用于断点续训),默认0
wandb (optional): Weights & Biases日志对象,用于可视化训练指标,默认None
"""
# 记录当前Epoch的训练开始时间
start_time = time.time()
# 遍历数据加载器,step从start_step+1开始计数(断点续训时保持步数连续),一般start_step几百几百的算
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1)
# -------------------------- 1. 数据预处理 --------------------------
# 将输入和标签移到指定设备(GPU/CPU)
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
# -------------------------- 2. 学习率调度 --------------------------
# 计算当前全局步数(所有Epoch累计),并获取对应的学习率
global_step = epoch * iters + step
total_steps = args.epochs * iters # 训练总步数
lr = get_lr(global_step, total_steps, args.learning_rate) # 余弦退火算法,更新学习率
# 更新优化器的学习率
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# -------------------------- 3. 前向传播(混合精度) --------------------------
# 启用AMP自动混合精度上下文,自动切换FP16/FP32精度
with autocast_ctx:
# 模型前向传播,输入input_ids并传入labels计算损失
res = model(input_ids, labels=labels)
# 总损失 = 主损失 + 辅助损失(如MoE模型的负载均衡损失、对齐损失等)
loss = res.loss + res.aux_loss
# 梯度累积:将损失除以累积步数,保证最终梯度等价于大批次训练
loss = loss / args.accumulation_steps
# -------------------------- 4. 反向传播(梯度缩放) --------------------------
# AMP梯度缩放:防止FP16梯度下溢,先缩放损失再反向传播
scaler.scale(loss).backward()
# -------------------------- 5. 梯度裁剪 & 优化器更新(梯度累积到位后) --------------------------
# 当累积步数达到设定值时,执行参数更新
if (step + 1) % args.accumulation_steps == 0:
# 取消梯度缩放:将优化器的梯度值恢复到原始尺度(用于梯度裁剪)
scaler.unscale_(optimizer)
# 梯度裁剪:防止梯度爆炸,限制参数梯度的L2范数不超过grad_clip
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
# 执行优化器步进(参数更新),scaler会自动处理梯度缩放的恢复
scaler.step(optimizer)
# 更新scaler的缩放因子(AMP自适应调整)
scaler.update()
# 清空梯度(set_to_none=True比zero()更节省显存)
optimizer.zero_grad(set_to_none=True)
# -------------------------- 6. 日志记录 --------------------------
# 达到日志打印间隔,或当前Epoch最后一步时,打印并记录训练指标
if step % args.log_interval == 0 or step == iters - 1:
# 计算已消耗的训练时间
spend_time = time.time() - start_time
# 恢复真实损失值(反向传播时除以了累积步数,这里乘回来)
current_loss = loss.item() * args.accumulation_steps
# 提取辅助损失值(无辅助损失则为0)
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
# 主损失 = 总损失 - 辅助损失
current_logits_loss = current_loss - current_aux_loss
# 获取当前优化器的学习率(取最后一个param_group的lr,适配多参数组场景)
current_lr = optimizer.param_groups[-1]['lr']
# 计算当前Epoch剩余时间(分钟)
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
# 打印训练日志
Logger(
f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), '
f'loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, '
f'aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, '
f'epoch_time: {eta_min:.1f}min'
)
# 如果启用wandb,将指标写入wandb可视化
if wandb:
wandb.log({
"loss": current_loss,
"logits_loss": current_logits_loss,
"aux_loss": current_aux_loss,
"learning_rate": current_lr,
"epoch_time": eta_min
})
# -------------------------- 7. 模型保存 --------------------------
# 达到保存间隔,或当前Epoch最后一步,且为主进程时(分布式训练避免重复保存)
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
# 切换模型到评估模式(防止BN/Dropout等层影响权重)
model.eval()
# 生成MoE模型的后缀(区分普通模型和MoE模型)
moe_suffix = '_moe' if lm_config.use_moe else ''
# 构建模型权重保存路径
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
# 处理分布式模型:如果是DDP包装的模型,取原始模型(module)
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
# 处理可能的模型包装(如torch.compile后的_orig_mod)
raw_model = getattr(raw_model, '_orig_mod', raw_model)
# 获取模型状态字典
state_dict = raw_model.state_dict()
# 保存权重:转为FP16并移到CPU,节省磁盘空间
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
# 调用自定义检查点函数,保存更完整的训练状态(模型、优化器、epoch/step、scaler等)
lm_checkpoint(
lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler
)
# 切换回训练模式
model.train()
# 删除状态字典,释放显存
del state_dict
# -------------------------- 8. 清理显存 --------------------------
# 删除当前迭代的张量,避免显存泄漏
del input_ids, labels, res, loss
1.1 数据预处理:
python
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1)
# -------------------------- 1. 数据预处理 --------------------------
# 将输入和标签移到指定设备(GPU/CPU)
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
1.2 学习率调度:
python
# -------------------------- 2. 学习率调度 --------------------------
# 计算当前全局步数(所有Epoch累计),并获取对应的学习率
global_step = epoch * iters + step
total_steps = args.epochs * iters # 训练总步数
lr = get_lr(global_step, total_steps, args.learning_rate) # 余弦退火算法,更新学习率
# 更新优化器的学习率
for param_group in optimizer.param_groups:
param_group['lr'] = lr
其中核心就在于 get_lr() 函数,它根据 当前训练步数 以及 总步数 动态调整学习率大小,其实现过程如下:
python
import math
def get_lr(current_step, total_steps, lr):
"""
基于余弦退火的学习率调度函数
:param current_step: 当前训练步数(从0开始,到total_steps结束)
:param total_steps: 训练的总步数
:param lr: 初始学习率(基准学习率)
:return: 当前步数对应的动态学习率
"""
# 核心计算逻辑:余弦退火公式 + 系数调整
return lr * (0.1 + 0.45 * (1 + math.cos(math.pi * current_step / total_steps)))
- 余弦项核心 math.cos(math.pi * current_step / total_steps): 将当前步数归一化到 0 ~ 1 之间,然后映射到 0 ~ π 弧度(余弦函数在 0~π 区间是从 1 平滑下降到 -1)
目的:通过余弦函数让学习率随步数平滑下降,防止学习率下降过快或过低。
1.3 AMP 自动混合精度前向传播
核心内容在于通过 with autocast_ctx: 实现 AMP 自动混合精度上下文,管理前向传播的精度切换:
python
# -------------------------- 3. 前向传播(混合精度) --------------------------
# 启用AMP自动混合精度上下文,自动切换FP16/FP32精度
with autocast_ctx:
# 模型前向传播,输入input_ids并传入labels计算损失
res = model(input_ids, labels=labels)
# 总损失 = 主损失 + 辅助损失(如MoE模型的负载均衡损失、对齐损失等)
loss = res.loss + res.aux_loss
# 梯度累积:将损失除以累积步数,保证最终梯度等价于大批次训练
loss = loss / args.accumulation_steps
AMP 会通过智能判断的方式,对 精度敏感的运算(如梯度累加、BN 层)保留 FP32(单精度),对 精度不敏感的运算(如卷积、矩阵乘法)自动转成 FP16/BF16(半精度)。
- 框架如何 "智能判断" 精度? PyTorch 框架给所有内置算子标注 "精度敏感度",比如卷积 / 矩阵乘法属于 "精度不敏感型",BN / 损失计算属于 "精度敏感型。然后,通过上下文管理器,在模型前向传播时自动识别算子类型,按规则切换精度;
- 目的: 在反向传播时,对梯度做特殊处理(高精度),避免半精度梯度下溢(梯度为 0),保证参数更新的准确性。
FP16 的数值范围小,梯度值太小时会变成 0(下溢),导致模型无法收敛;
1.3 反向传播
GradScaler 处理反向传播时,会自动把损失 / 梯度的计算切换回 FP32:
python
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) # 梯度缩放器
# -------------------------- 4. 反向传播(梯度缩放) --------------------------
# AMP梯度缩放:防止FP16梯度下溢,先缩放损失再反向传播
scaler.scale(loss).backward()
- scaler.scale(loss): 缩放loss避免低精度梯度下溢,这也就是scale up。防止下溢的具体实现为:在反向传播之前,它将 loss 乘以一个很大的缩放因子(例如 216=655362^{16} = 65536216=65536)。当 backward() 结束,所有的梯度已经安全地存储在每个参数的 .grad 属性里了。
- scaler.scale(loss).backward(): 返回 None,核心是生成「放大的 FP32 梯度」存到参数 .grad 中,避免 FP16 下溢;
这些「放大的梯度」会直接写入模型参数的 .grad 属性中(比如 model.conv1.weight.grad),优化器此时还没接触到梯度,只是后续会读取这个 .grad。
AMP 完整逻辑: autocast(前向提速) + GradScaler(反向保精度),二者配合才是完整的自动混合精度,而非 autocast 单独搞定。
1.4 梯度更新
python
# 梯度累积步数达标后,执行参数更新
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer) # 取消梯度缩放,把它们的 .grad管理的参数全部除以 65536
# 梯度裁剪:限制梯度L2范数,防止梯度爆炸(大模型必做)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer) # 更新模型参数
scaler.update() # 更新混合精度缩放器状态
optimizer.zero_grad(set_to_none=True) # 清空梯度(set_to_none更省显存)
- scaler.unscale_(optimizer) : 自动扫描优化器管理的所有参数(本质还是模型的参数),把它们的 .grad 全部除以 65536。
关键:操作的依然是「模型参数的 .grad」,优化器只是 "桥梁"------ 告诉 scaler 要处理哪些参数的 .grad。梯度是 "挂" 在模型参数上的,优化器只是 "工具人"------ 帮你找到这些参数,让 scaler 能对参数的梯度做缩放 / 还原操作。
例子:autocast 实现前向传播混合精度计算 + gradscaler 实现 scale up
python
import torch
import torch.nn as nn
# 适配新版本:从 torch.amp 导入(替代 torch.cuda.amp)
from torch.amp import GradScaler, autocast
# 初始化模型、优化器、缩放器
model = nn.Linear(10, 1).cuda() # 模型是 nn.Linear(10, 1),那么model.weight 的形状是 (1, 10)(输出维度 1,输入维度 10),model.weight.grad 的形状和参数完全一致,也是 (1, 10)(1 行、10 列);
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 新版本 GradScaler:指定 device_type(必填)
scaler = GradScaler()
# 模拟前向传播(约束输入/损失范围,避免梯度inf/nan)
x = torch.randn(2, 10).cuda() * 0.1 # 缩小输入范围
with autocast(device_type="cuda", dtype=torch.float16): # 显式指定设备和精度
output = model(x)
loss = nn.MSELoss()(output, torch.zeros_like(output)) # 用MSELoss约束loss
# ========== 关键修正:先执行反向传播,再打印梯度 ==========
# 1. 缩放loss并反向传播(此时才会计算梯度,写入model.weight.grad)
scaler.scale(loss).backward()
# 查看模型参数的梯度(此时grad已存在,可切片)
print("缩放后模型参数的grad:", model.weight.grad)
print("缩放后模型参数的grad shape:", model.weight.grad.shape)
# 查看优化器是否有grad(优化器无grad属性)
print("优化器的grad:", hasattr(optimizer, 'grad')) # 输出 False
# 2. 手动unscale_(还原梯度)
scaler.unscale_(optimizer)
# 查看还原后的真实梯度
print("还原后模型参数的grad:", model.weight.grad[:2, :2])
print("还原后模型参数的grad shape:", model.weight.grad.shape)
# 补充:完整流程需执行step+update+清空梯度
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
输出内容:
python
缩放后模型参数的grad: tensor([[1415.0000, 961.5000, 352.0000, 2096.0000, -848.0000, 1648.0000,
-569.5000, 1281.0000, 378.7500, 1107.0000]], device='cuda:0')
缩放后模型参数的grad shape: torch.Size([1, 10])
优化器的grad: False
还原后模型参数的grad: tensor([[0.0216, 0.0147]], device='cuda:0')
还原后模型参数的grad shape: torch.Size([1, 10])
- model = nn.Linear(10, 1).cuda(): 模型是 nn.Linear(10, 1),那么model.weight 的形状是 (1, 10)(输出维度 1,输入维度 10),model.weight.grad 的形状和参数完全一致,也是 (1, 10)(1 行、10 列);
print("还原后模型参数的grad shape:", model.weight.grad.shape),所以为 torch.Size([1, 10])
- 切片 [:2, :2] 的执行逻辑: PyTorch 张量切片遵循「取到维度的最后一个元素为止,不会越界报错」的规则。
2. SFT 和 PreTrain 在 DataSet 的区别
python
import torch
from torch.utils.data import Dataset
from datasets import load_dataset # 注意:需要安装datasets库
class SFTDataset(Dataset):
"""
用于大模型SFT(监督微调)的自定义Dataset类
功能:加载JSONL格式的对话数据,生成符合ChatML格式的input_ids,并为labels做掩码(仅回答部分保留真实值,其余为-100)
"""
def __init__(self, jsonl_path, tokenizer, max_length=1024):
"""
初始化数据集
Args:
jsonl_path: str - JSONL格式的数据文件路径(每行是一个对话样本)
tokenizer: PreTrainedTokenizer - 模型对应的tokenizer(如LlamaTokenizer)
max_length: int - 输入序列的最大长度,超出截断,不足补pad
"""
super().__init__()
self.tokenizer = tokenizer # 保存tokenizer
self.max_length = max_length # 保存最大序列长度
# 加载JSONL数据集(使用huggingface的datasets库,加载train拆分)
self.samples = load_dataset('json', data_files=jsonl_path, split='train')
# 预计算assistant回复的起始标记id(bos_id)和结束标记id(eos_id)
# 例如:tokenizer.bos_token是<s>,则bos_id对应<s>assistant\n的token id列表
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
# eos_id对应</s>\n的token id列表
self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
def __len__(self):
"""返回数据集总样本数(必须实现的Dataset方法)"""
return len(self.samples)
def create_chat_prompt(self, cs):
"""
将对话列表转换成符合ChatML格式的文本prompt
Args:
cs: list - 单条样本的对话列表,格式如[{"role":"user","content":"你好"}, {"role":"assistant","content":"您好!"}]
Returns:
str - 格式化后的对话文本(如"<s>user\n你好</s>assistant\n您好!</s>")
"""
# 复制对话列表,避免修改原数据
messages = cs.copy()
# 处理工具调用场景:如果第一条是system消息且包含functions(工具定义),则传入tools参数
tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
# 使用tokenizer的chat_template生成标准化的对话prompt(不分词,不添加生成提示)
return self.tokenizer.apply_chat_template(
messages,
tokenize=False, # 只返回文本,不做分词
add_generation_prompt=False, # 不添加"assistant\n"的生成提示(因为数据中已有)
tools=tools # 传入工具定义(如有)
)
def generate_labels(self, input_ids):
"""
生成标签:仅保留assistant回复部分的token id,其余部分设为-100(PyTorch中-100会被CrossEntropyLoss忽略)
Args:
input_ids: list - 整个对话的token id列表
Returns:
list - 与input_ids等长的labels列表,非回答部分为-100,回答部分为对应input_ids的值
"""
# 初始化labels为全-100(默认所有位置都不计算损失)
labels = [-100] * len(input_ids)
i = 0 # 遍历input_ids的指针
# 遍历整个input_ids,找到所有assistant回复的区间
while i < len(input_ids):
# 找到assistant回复的起始位置(匹配bos_id:<s>assistant\n)
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
# 回复起始位置 = bos_id结束的位置
start = i + len(self.bos_id)
end = start # 回复结束位置的指针
# 从start开始,找到eos_id(</s>\n)的位置,确定回复的结束边界
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
# 将assistant回复区间(start到end+eos_id长度)的labels设为对应input_ids的值
# 注意:不超过max_length,避免越界
for j in range(start, min(end + len(self.eos_id), self.max_length)):
labels[j] = input_ids[j]
# 指针跳到当前回复的结束位置,继续找下一个回复(如有)
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
# 未找到bos_id,指针后移一位
i += 1
return labels
def __getitem__(self, index):
"""
加载单条样本(必须实现的Dataset方法)
Args:
index: int - 样本索引
Returns:
tuple - (input_ids_tensor, labels_tensor),均为long型tensor,长度=max_length
"""
# 获取单条样本(对话数据)
sample = self.samples[index]
# 1. 将对话列表转换成ChatML格式的文本prompt
prompt = self.create_chat_prompt(sample['conversations'])
# 2. 对prompt分词,生成input_ids,并截断到max_length
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
# 3. 不足max_length的部分用pad_token_id补齐
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
# 4. 生成labels:仅保留assistant回复部分,其余为-100
labels = self.generate_labels(input_ids)
# # === 调试打印(可选)===
# print(f"\n--- Sample {index} ---")
# for i, (x, y) in enumerate(zip(input_ids[:-1], labels[1:])):
# print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([input_ids[i+1]])!r:16s} label={y}")
# # ================
# 5. 转换为torch tensor(long型,符合模型输入要求)
return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
区别一:SFT用了chat_template
核心目的: 封装 tokenizer.apply_chat_template() ,将 messages 列表封装到 chat_template 中,返回标准对话的 prompt ;
因为 SFT 与 Pretrain 的对话构成不一样,即 labels 不一样。前者的真值来源于 assistant 的回复,而后者为续写下一个词,只需要对齐即可。
python
def create_chat_prompt(self, cs):
"""
将对话列表转换成符合ChatML格式的文本prompt
Args:
cs: list - 单条样本的对话列表,格式如[{"role":"user","content":"你好"}, {"role":"assistant","content":"您好!"}]
Returns:
str - 格式化后的对话文本(如"<s>user\n你好</s>assistant\n您好!</s>")
"""
# 复制对话列表,避免修改原数据
messages = cs.copy()
# 处理工具调用场景:如果第一条是system消息且包含functions(工具定义),则传入tools参数
tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
# 使用tokenizer的chat_template生成标准化的对话prompt(不分词,不添加生成提示)
return self.tokenizer.apply_chat_template(
messages,
tokenize=False, # 只返回文本,不做分词
add_generation_prompt=False, # 不添加"assistant\n"的生成提示(因为数据中已有)
tools=tools # 传入工具定义(如有)
)
返回的 Prompt:
python
<|im_start|>system
你是一个优秀的聊天机器人,总是给我正确的回应!<|im_end|>
<|im_start|>user
你来自哪里?<|im_end|>
<|im_start|>assistant
我来自地球<|im_end|>
然后对 Prompt 进行分词:
python
# 2. 对prompt分词,生成input_ids,并截断到max_length
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
# 3. 不足max_length的部分用pad_token_id补齐
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
input_ids:
python
model_inpus: {'input_ids': [1, 85, 91, 85, 86, 71, 79, 201, 358, 988, 4904, 5195, 265, 1958, 911, 1862, 570, 606, 604, 2, 201, 1, 87, 85, 71, 84, 201, 358, 3583, 1648, 336, 2, 201, 1, 67, 85, 85, 75, 85, 86, 67, 80, 86, 201, 289, 3583, 1562, 2, 201], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
区别二:labels 的构建方式不一样
python
def generate_labels(self, input_ids):
"""
生成标签:仅保留assistant回复部分的token id,其余部分设为-100(PyTorch中-100会被CrossEntropyLoss忽略)
Args:
input_ids: list - 整个对话的token id列表
Returns:
list - 与input_ids等长的labels列表,非回答部分为-100,回答部分为对应input_ids的值
"""
# 初始化labels为全-100(默认所有位置都不计算损失)
labels = [-100] * len(input_ids)
i = 0 # 遍历input_ids的指针
# 遍历整个input_ids,找到所有assistant回复的区间
while i < len(input_ids):
# 找到assistant回复的起始位置(匹配bos_id:<s>assistant\n)
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
# 回复起始位置 = bos_id结束的位置
start = i + len(self.bos_id)
end = start # 回复结束位置的指针
# 从start开始,找到eos_id(</s>\n)的位置,确定回复的结束边界
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
# 将assistant回复区间(start到end+eos_id长度)的labels设为对应input_ids的值
# 注意:不超过max_length,避免越界
for j in range(start, min(end + len(self.eos_id), self.max_length)):
labels[j] = input_ids[j]
# 指针跳到当前回复的结束位置,继续找下一个回复(如有)
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
# 未找到bos_id,指针后移一位
i += 1
return labels
- labels = [-100] * len(input_ids): 构建一个长度和 input_ids 一样的列表,元素值都为 -100,[-100] 为一个单元素列表。
- labels区间: <|im_start|>assistant\n 是助手回复开始标记,<|im_end|>\n 是结尾标记;