面试-SFT

1. 训练逻辑代码

python 复制代码
# 假设以下是外部定义的依赖(方便理解)
# args: 训练参数配置(如设备、学习率、梯度累积步数等)
# optimizer: 优化器(如AdamW)
# model: 待训练的模型
# autocast_ctx: AMP自动混合精度上下文管理器
# scaler: GradScaler梯度缩放器(配合AMP使用)
# Logger: 自定义日志打印函数
# is_main_process: 判断是否为主进程(分布式训练用)
# lm_config: 语言模型配置(如hidden_size、是否用MoE等)
# lm_checkpoint: 自定义的检查点保存函数
# get_lr: 学习率调度函数(根据步数计算当前学习率)
# DistributedDataParallel: 分布式数据并行训练类

def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
    """
    执行单个Epoch的训练过程,包含完整的前向/反向传播、梯度处理、日志记录和模型保存逻辑
    
    Args:
        epoch (int): 当前训练的Epoch编号(从0开始)
        loader (DataLoader): 训练数据加载器,输出(input_ids, labels)
        iters (int): 当前Epoch的总迭代步数
        start_step (int, optional): 起始迭代步数(用于断点续训),默认0
        wandb (optional): Weights & Biases日志对象,用于可视化训练指标,默认None
    """
    # 记录当前Epoch的训练开始时间
    start_time = time.time()
    
    # 遍历数据加载器,step从start_step+1开始计数(断点续训时保持步数连续),一般start_step几百几百的算
    for step, (input_ids, labels) in enumerate(loader, start=start_step + 1)
        # -------------------------- 1. 数据预处理 --------------------------
        # 将输入和标签移到指定设备(GPU/CPU)
        input_ids = input_ids.to(args.device)
        labels = labels.to(args.device)
        
        # -------------------------- 2. 学习率调度 --------------------------
        # 计算当前全局步数(所有Epoch累计),并获取对应的学习率
        global_step = epoch * iters + step
        total_steps = args.epochs * iters  # 训练总步数
        lr = get_lr(global_step, total_steps, args.learning_rate) # 余弦退火算法,更新学习率
        # 更新优化器的学习率
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # -------------------------- 3. 前向传播(混合精度) --------------------------
        # 启用AMP自动混合精度上下文,自动切换FP16/FP32精度
        with autocast_ctx:
            # 模型前向传播,输入input_ids并传入labels计算损失
            res = model(input_ids, labels=labels)
            # 总损失 = 主损失 + 辅助损失(如MoE模型的负载均衡损失、对齐损失等)
            loss = res.loss + res.aux_loss
            # 梯度累积:将损失除以累积步数,保证最终梯度等价于大批次训练
            loss = loss / args.accumulation_steps

        # -------------------------- 4. 反向传播(梯度缩放) --------------------------
        # AMP梯度缩放:防止FP16梯度下溢,先缩放损失再反向传播
        scaler.scale(loss).backward()

        # -------------------------- 5. 梯度裁剪 & 优化器更新(梯度累积到位后) --------------------------
        # 当累积步数达到设定值时,执行参数更新
        if (step + 1) % args.accumulation_steps == 0:
            # 取消梯度缩放:将优化器的梯度值恢复到原始尺度(用于梯度裁剪)
            scaler.unscale_(optimizer)
            # 梯度裁剪:防止梯度爆炸,限制参数梯度的L2范数不超过grad_clip
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

            # 执行优化器步进(参数更新),scaler会自动处理梯度缩放的恢复
            scaler.step(optimizer)
            # 更新scaler的缩放因子(AMP自适应调整)
            scaler.update()

            # 清空梯度(set_to_none=True比zero()更节省显存)
            optimizer.zero_grad(set_to_none=True)

        # -------------------------- 6. 日志记录 --------------------------
        # 达到日志打印间隔,或当前Epoch最后一步时,打印并记录训练指标
        if step % args.log_interval == 0 or step == iters - 1:
            # 计算已消耗的训练时间
            spend_time = time.time() - start_time
            # 恢复真实损失值(反向传播时除以了累积步数,这里乘回来)
            current_loss = loss.item() * args.accumulation_steps
            # 提取辅助损失值(无辅助损失则为0)
            current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
            # 主损失 = 总损失 - 辅助损失
            current_logits_loss = current_loss - current_aux_loss
            # 获取当前优化器的学习率(取最后一个param_group的lr,适配多参数组场景)
            current_lr = optimizer.param_groups[-1]['lr']
            # 计算当前Epoch剩余时间(分钟)
            eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
            # 打印训练日志
            Logger(
                f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), '
                f'loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, '
                f'aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, '
                f'epoch_time: {eta_min:.1f}min'
            )
            # 如果启用wandb,将指标写入wandb可视化
            if wandb: 
                wandb.log({
                    "loss": current_loss, 
                    "logits_loss": current_logits_loss, 
                    "aux_loss": current_aux_loss, 
                    "learning_rate": current_lr, 
                    "epoch_time": eta_min
                })

        # -------------------------- 7. 模型保存 --------------------------
        # 达到保存间隔,或当前Epoch最后一步,且为主进程时(分布式训练避免重复保存)
        if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
            # 切换模型到评估模式(防止BN/Dropout等层影响权重)
            model.eval()
            # 生成MoE模型的后缀(区分普通模型和MoE模型)
            moe_suffix = '_moe' if lm_config.use_moe else ''
            # 构建模型权重保存路径
            ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
            # 处理分布式模型:如果是DDP包装的模型,取原始模型(module)
            raw_model = model.module if isinstance(model, DistributedDataParallel) else model
            # 处理可能的模型包装(如torch.compile后的_orig_mod)
            raw_model = getattr(raw_model, '_orig_mod', raw_model)
            # 获取模型状态字典
            state_dict = raw_model.state_dict()
            # 保存权重:转为FP16并移到CPU,节省磁盘空间
            torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
            # 调用自定义检查点函数,保存更完整的训练状态(模型、优化器、epoch/step、scaler等)
            lm_checkpoint(
                lm_config, weight=args.save_weight, model=model, optimizer=optimizer, 
                epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler
            )
            # 切换回训练模式
            model.train()
            # 删除状态字典,释放显存
            del state_dict

        # -------------------------- 8. 清理显存 --------------------------
        # 删除当前迭代的张量,避免显存泄漏
        del input_ids, labels, res, loss
1.1 数据预处理:
python 复制代码
    for step, (input_ids, labels) in enumerate(loader, start=start_step + 1)
        # -------------------------- 1. 数据预处理 --------------------------
        # 将输入和标签移到指定设备(GPU/CPU)
        input_ids = input_ids.to(args.device)
        labels = labels.to(args.device)
1.2 学习率调度:
python 复制代码
        # -------------------------- 2. 学习率调度 --------------------------
        # 计算当前全局步数(所有Epoch累计),并获取对应的学习率
        global_step = epoch * iters + step
        total_steps = args.epochs * iters  # 训练总步数
        lr = get_lr(global_step, total_steps, args.learning_rate) # 余弦退火算法,更新学习率
        # 更新优化器的学习率
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

其中核心就在于 get_lr() 函数,它根据 当前训练步数 以及 总步数 动态调整学习率大小,其实现过程如下:

python 复制代码
import math

def get_lr(current_step, total_steps, lr):
    """
    基于余弦退火的学习率调度函数
    :param current_step: 当前训练步数(从0开始,到total_steps结束)
    :param total_steps: 训练的总步数
    :param lr: 初始学习率(基准学习率)
    :return: 当前步数对应的动态学习率
    """
    # 核心计算逻辑:余弦退火公式 + 系数调整
    return lr * (0.1 + 0.45 * (1 + math.cos(math.pi * current_step / total_steps)))
  • 余弦项核心 math.cos(math.pi * current_step / total_steps): 将当前步数归一化到 0 ~ 1 之间,然后映射到 0 ~ π 弧度(余弦函数在 0~π 区间是从 1 平滑下降到 -1)
    目的:通过余弦函数让学习率随步数平滑下降,防止学习率下降过快或过低。
1.3 AMP 自动混合精度前向传播

核心内容在于通过 with autocast_ctx: 实现 AMP 自动混合精度上下文,管理前向传播的精度切换:

python 复制代码
# -------------------------- 3. 前向传播(混合精度) --------------------------
        # 启用AMP自动混合精度上下文,自动切换FP16/FP32精度
        with autocast_ctx:
            # 模型前向传播,输入input_ids并传入labels计算损失
            res = model(input_ids, labels=labels)
            # 总损失 = 主损失 + 辅助损失(如MoE模型的负载均衡损失、对齐损失等)
            loss = res.loss + res.aux_loss
            # 梯度累积:将损失除以累积步数,保证最终梯度等价于大批次训练
            loss = loss / args.accumulation_steps

AMP 会通过智能判断的方式,对 精度敏感的运算(如梯度累加、BN 层)保留 FP32(单精度),对 精度不敏感的运算(如卷积、矩阵乘法)自动转成 FP16/BF16(半精度)。

  • 框架如何 "智能判断" 精度? PyTorch 框架给所有内置算子标注 "精度敏感度",比如卷积 / 矩阵乘法属于 "精度不敏感型",BN / 损失计算属于 "精度敏感型。然后,通过上下文管理器,在模型前向传播时自动识别算子类型,按规则切换精度;
  • 目的: 在反向传播时,对梯度做特殊处理(高精度),避免半精度梯度下溢(梯度为 0),保证参数更新的准确性。

FP16 的数值范围小,梯度值太小时会变成 0(下溢),导致模型无法收敛;

1.3 反向传播

GradScaler 处理反向传播时,会自动把损失 / 梯度的计算切换回 FP32:

python 复制代码
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) # 梯度缩放器
# -------------------------- 4. 反向传播(梯度缩放) --------------------------
        # AMP梯度缩放:防止FP16梯度下溢,先缩放损失再反向传播
        scaler.scale(loss).backward()
  • scaler.scale(loss): 缩放loss避免低精度梯度下溢,这也就是scale up。防止下溢的具体实现为:在反向传播之前,它将 loss 乘以一个很大的缩放因子(例如 216=655362^{16} = 65536216=65536)。当 backward() 结束,所有的梯度已经安全地存储在每个参数的 .grad 属性里了。
  • scaler.scale(loss).backward(): 返回 None,核心是生成「放大的 FP32 梯度」存到参数 .grad 中,避免 FP16 下溢;

这些「放大的梯度」会直接写入模型参数的 .grad 属性中(比如 model.conv1.weight.grad),优化器此时还没接触到梯度,只是后续会读取这个 .grad。

AMP 完整逻辑: autocast(前向提速) + GradScaler(反向保精度),二者配合才是完整的自动混合精度,而非 autocast 单独搞定。

1.4 梯度更新

python 复制代码
# 梯度累积步数达标后,执行参数更新
if (step + 1) % args.accumulation_steps == 0:
    scaler.unscale_(optimizer)  # 取消梯度缩放,把它们的 .grad管理的参数全部除以 65536
    # 梯度裁剪:限制梯度L2范数,防止梯度爆炸(大模型必做)
    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

    scaler.step(optimizer)  # 更新模型参数
    scaler.update()        # 更新混合精度缩放器状态

    optimizer.zero_grad(set_to_none=True)  # 清空梯度(set_to_none更省显存)
  • scaler.unscale_(optimizer) : 自动扫描优化器管理的所有参数(本质还是模型的参数),把它们的 .grad 全部除以 65536。

关键:操作的依然是「模型参数的 .grad」,优化器只是 "桥梁"------ 告诉 scaler 要处理哪些参数的 .grad。梯度是 "挂" 在模型参数上的,优化器只是 "工具人"------ 帮你找到这些参数,让 scaler 能对参数的梯度做缩放 / 还原操作。

例子:autocast 实现前向传播混合精度计算 + gradscaler 实现 scale up

python 复制代码
import torch
import torch.nn as nn
# 适配新版本:从 torch.amp 导入(替代 torch.cuda.amp)
from torch.amp import GradScaler, autocast

# 初始化模型、优化器、缩放器
model = nn.Linear(10, 1).cuda() # 模型是 nn.Linear(10, 1),那么model.weight 的形状是 (1, 10)(输出维度 1,输入维度 10),model.weight.grad 的形状和参数完全一致,也是 (1, 10)(1 行、10 列);
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 新版本 GradScaler:指定 device_type(必填)
scaler = GradScaler()

# 模拟前向传播(约束输入/损失范围,避免梯度inf/nan)
x = torch.randn(2, 10).cuda() * 0.1  # 缩小输入范围
with autocast(device_type="cuda", dtype=torch.float16):  # 显式指定设备和精度
    output = model(x)
    loss = nn.MSELoss()(output, torch.zeros_like(output))  # 用MSELoss约束loss

# ========== 关键修正:先执行反向传播,再打印梯度 ==========
# 1. 缩放loss并反向传播(此时才会计算梯度,写入model.weight.grad)
scaler.scale(loss).backward()

# 查看模型参数的梯度(此时grad已存在,可切片)
print("缩放后模型参数的grad:", model.weight.grad)
print("缩放后模型参数的grad shape:", model.weight.grad.shape)

# 查看优化器是否有grad(优化器无grad属性)
print("优化器的grad:", hasattr(optimizer, 'grad'))  # 输出 False

# 2. 手动unscale_(还原梯度)
scaler.unscale_(optimizer)
# 查看还原后的真实梯度
print("还原后模型参数的grad:", model.weight.grad[:2, :2])
print("还原后模型参数的grad shape:", model.weight.grad.shape)

# 补充:完整流程需执行step+update+清空梯度
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)

输出内容:

python 复制代码
缩放后模型参数的grad: tensor([[1415.0000,  961.5000,  352.0000, 2096.0000, -848.0000, 1648.0000,
         -569.5000, 1281.0000,  378.7500, 1107.0000]], device='cuda:0')
缩放后模型参数的grad shape: torch.Size([1, 10])
优化器的grad: False
还原后模型参数的grad: tensor([[0.0216, 0.0147]], device='cuda:0')
还原后模型参数的grad shape: torch.Size([1, 10])
  • model = nn.Linear(10, 1).cuda(): 模型是 nn.Linear(10, 1),那么model.weight 的形状是 (1, 10)(输出维度 1,输入维度 10),model.weight.grad 的形状和参数完全一致,也是 (1, 10)(1 行、10 列);

print("还原后模型参数的grad shape:", model.weight.grad.shape),所以为 torch.Size([1, 10])

  • 切片 [:2, :2] 的执行逻辑: PyTorch 张量切片遵循「取到维度的最后一个元素为止,不会越界报错」的规则。

2. SFT 和 PreTrain 在 DataSet 的区别

python 复制代码
import torch
from torch.utils.data import Dataset
from datasets import load_dataset  # 注意:需要安装datasets库

class SFTDataset(Dataset):
    """
    用于大模型SFT(监督微调)的自定义Dataset类
    功能:加载JSONL格式的对话数据,生成符合ChatML格式的input_ids,并为labels做掩码(仅回答部分保留真实值,其余为-100)
    """
    def __init__(self, jsonl_path, tokenizer, max_length=1024):
        """
        初始化数据集
        Args:
            jsonl_path: str - JSONL格式的数据文件路径(每行是一个对话样本)
            tokenizer: PreTrainedTokenizer - 模型对应的tokenizer(如LlamaTokenizer)
            max_length: int - 输入序列的最大长度,超出截断,不足补pad
        """
        super().__init__()
        self.tokenizer = tokenizer  # 保存tokenizer
        self.max_length = max_length  # 保存最大序列长度
        
        # 加载JSONL数据集(使用huggingface的datasets库,加载train拆分)
        self.samples = load_dataset('json', data_files=jsonl_path, split='train')
        
        # 预计算assistant回复的起始标记id(bos_id)和结束标记id(eos_id)
        # 例如:tokenizer.bos_token是<s>,则bos_id对应<s>assistant\n的token id列表
        self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
        # eos_id对应</s>\n的token id列表
        self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids

    def __len__(self):
        """返回数据集总样本数(必须实现的Dataset方法)"""
        return len(self.samples)

    def create_chat_prompt(self, cs):
        """
        将对话列表转换成符合ChatML格式的文本prompt
        Args:
            cs: list - 单条样本的对话列表,格式如[{"role":"user","content":"你好"}, {"role":"assistant","content":"您好!"}]
        Returns:
            str - 格式化后的对话文本(如"<s>user\n你好</s>assistant\n您好!</s>")
        """
        # 复制对话列表,避免修改原数据
        messages = cs.copy()
        
        # 处理工具调用场景:如果第一条是system消息且包含functions(工具定义),则传入tools参数
        tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
        
        # 使用tokenizer的chat_template生成标准化的对话prompt(不分词,不添加生成提示)
        return self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,  # 只返回文本,不做分词
            add_generation_prompt=False,  # 不添加"assistant\n"的生成提示(因为数据中已有)
            tools=tools  # 传入工具定义(如有)
        )

    def generate_labels(self, input_ids):
        """
        生成标签:仅保留assistant回复部分的token id,其余部分设为-100(PyTorch中-100会被CrossEntropyLoss忽略)
        Args:
            input_ids: list - 整个对话的token id列表
        Returns:
            list - 与input_ids等长的labels列表,非回答部分为-100,回答部分为对应input_ids的值
        """
        # 初始化labels为全-100(默认所有位置都不计算损失)
        labels = [-100] * len(input_ids)
        i = 0  # 遍历input_ids的指针
        
        # 遍历整个input_ids,找到所有assistant回复的区间
        while i < len(input_ids):
            # 找到assistant回复的起始位置(匹配bos_id:<s>assistant\n)
            if input_ids[i:i + len(self.bos_id)] == self.bos_id:
                # 回复起始位置 = bos_id结束的位置
                start = i + len(self.bos_id)
                end = start  # 回复结束位置的指针
                
                # 从start开始,找到eos_id(</s>\n)的位置,确定回复的结束边界
                while end < len(input_ids):
                    if input_ids[end:end + len(self.eos_id)] == self.eos_id:
                        break
                    end += 1
                
                # 将assistant回复区间(start到end+eos_id长度)的labels设为对应input_ids的值
                # 注意:不超过max_length,避免越界
                for j in range(start, min(end + len(self.eos_id), self.max_length)):
                    labels[j] = input_ids[j]
                
                # 指针跳到当前回复的结束位置,继续找下一个回复(如有)
                i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
            else:
                # 未找到bos_id,指针后移一位
                i += 1
        return labels

    def __getitem__(self, index):
        """
        加载单条样本(必须实现的Dataset方法)
        Args:
            index: int - 样本索引
        Returns:
            tuple - (input_ids_tensor, labels_tensor),均为long型tensor,长度=max_length
        """
        # 获取单条样本(对话数据)
        sample = self.samples[index]
        
        # 1. 将对话列表转换成ChatML格式的文本prompt
        prompt = self.create_chat_prompt(sample['conversations'])
        
        # 2. 对prompt分词,生成input_ids,并截断到max_length
        input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
        
        # 3. 不足max_length的部分用pad_token_id补齐
        input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
        
        # 4. 生成labels:仅保留assistant回复部分,其余为-100
        labels = self.generate_labels(input_ids)
        
        # # === 调试打印(可选)===
        # print(f"\n--- Sample {index} ---")
        # for i, (x, y) in enumerate(zip(input_ids[:-1], labels[1:])):
        #     print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([input_ids[i+1]])!r:16s} label={y}")
        # # ================
        
        # 5. 转换为torch tensor(long型,符合模型输入要求)
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
区别一:SFT用了chat_template

核心目的: 封装 tokenizer.apply_chat_template() ,将 messages 列表封装到 chat_template 中,返回标准对话的 prompt ;

因为 SFT 与 Pretrain 的对话构成不一样,即 labels 不一样。前者的真值来源于 assistant 的回复,而后者为续写下一个词,只需要对齐即可。

python 复制代码
    def create_chat_prompt(self, cs):
        """
        将对话列表转换成符合ChatML格式的文本prompt
        Args:
            cs: list - 单条样本的对话列表,格式如[{"role":"user","content":"你好"}, {"role":"assistant","content":"您好!"}]
        Returns:
            str - 格式化后的对话文本(如"<s>user\n你好</s>assistant\n您好!</s>")
        """
        # 复制对话列表,避免修改原数据
        messages = cs.copy()
        
        # 处理工具调用场景:如果第一条是system消息且包含functions(工具定义),则传入tools参数
        tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
        
        # 使用tokenizer的chat_template生成标准化的对话prompt(不分词,不添加生成提示)
        return self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,  # 只返回文本,不做分词
            add_generation_prompt=False,  # 不添加"assistant\n"的生成提示(因为数据中已有)
            tools=tools  # 传入工具定义(如有)
        )

返回的 Prompt:

python 复制代码
<|im_start|>system
你是一个优秀的聊天机器人,总是给我正确的回应!<|im_end|>
<|im_start|>user
你来自哪里?<|im_end|>
<|im_start|>assistant
我来自地球<|im_end|>

然后对 Prompt 进行分词:

python 复制代码
# 2. 对prompt分词,生成input_ids,并截断到max_length
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]

# 3. 不足max_length的部分用pad_token_id补齐
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))

input_ids:

python 复制代码
model_inpus: {'input_ids': [1, 85, 91, 85, 86, 71, 79, 201, 358, 988, 4904, 5195, 265, 1958, 911, 1862, 570, 606, 604, 2, 201, 1, 87, 85, 71, 84, 201, 358, 3583, 1648, 336, 2, 201, 1, 67, 85, 85, 75, 85, 86, 67, 80, 86, 201, 289, 3583, 1562, 2, 201], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
区别二:labels 的构建方式不一样
python 复制代码
    def generate_labels(self, input_ids):
        """
        生成标签:仅保留assistant回复部分的token id,其余部分设为-100(PyTorch中-100会被CrossEntropyLoss忽略)
        Args:
            input_ids: list - 整个对话的token id列表
        Returns:
            list - 与input_ids等长的labels列表,非回答部分为-100,回答部分为对应input_ids的值
        """
        # 初始化labels为全-100(默认所有位置都不计算损失)
        labels = [-100] * len(input_ids)
        i = 0  # 遍历input_ids的指针
        
        # 遍历整个input_ids,找到所有assistant回复的区间
        while i < len(input_ids):
            # 找到assistant回复的起始位置(匹配bos_id:<s>assistant\n)
            if input_ids[i:i + len(self.bos_id)] == self.bos_id:
                # 回复起始位置 = bos_id结束的位置
                start = i + len(self.bos_id)
                end = start  # 回复结束位置的指针
                
                # 从start开始,找到eos_id(</s>\n)的位置,确定回复的结束边界
                while end < len(input_ids):
                    if input_ids[end:end + len(self.eos_id)] == self.eos_id:
                        break
                    end += 1
                
                # 将assistant回复区间(start到end+eos_id长度)的labels设为对应input_ids的值
                # 注意:不超过max_length,避免越界
                for j in range(start, min(end + len(self.eos_id), self.max_length)):
                    labels[j] = input_ids[j]
                
                # 指针跳到当前回复的结束位置,继续找下一个回复(如有)
                i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
            else:
                # 未找到bos_id,指针后移一位
                i += 1
        return labels
  • labels = [-100] * len(input_ids): 构建一个长度和 input_ids 一样的列表,元素值都为 -100,[-100] 为一个单元素列表。
  • labels区间: <|im_start|>assistant\n 是助手回复开始标记,<|im_end|>\n 是结尾标记;
相关推荐
摆烂工程师18 小时前
GPT-5.4 发布!再看 OpenClaw:AI 真正危险的,不是更会聊天,而是开始自己“干活”
人工智能·openai·ai编程
飞哥数智坊1 天前
分享被迫变直播:AI·Spring养虾记就这样上线了
人工智能
Mr_Lucifer1 天前
「一句话」生成”小红书“式金句海报(CodeFlicker + quote-poster-generator)
人工智能·aigc·visual studio code
冬奇Lab1 天前
OpenClaw 深度解析(五):模型与提供商系统
人工智能·开源·源码阅读
冬奇Lab1 天前
一天一个开源项目(第42篇):OpenFang - 用 Rust 构建的 Agent 操作系统,16 层安全与 7 个自主 Hands
人工智能·rust·开源
IT_陈寒1 天前
SpringBoot性能飙升200%?这5个隐藏配置你必须知道!
前端·人工智能·后端
yiyu07161 天前
3分钟搞懂深度学习AI:反向传播:链式法则的归责游戏
人工智能·深度学习
机器之心1 天前
OpenClaw绝配!GPT-5.4问世,AI能力开始大一统,就是太贵
人工智能·openai
机器之心1 天前
海外华人15人团队打造,统一理解与生成的图像模型,超越Nano banana登顶图像编辑
人工智能·openai
用户552796026051 天前
在老版本 HPC 系统上运行 Antigravity(反重力)
人工智能