面试-SFT

1. 训练逻辑代码

python 复制代码
# 假设以下是外部定义的依赖(方便理解)
# args: 训练参数配置(如设备、学习率、梯度累积步数等)
# optimizer: 优化器(如AdamW)
# model: 待训练的模型
# autocast_ctx: AMP自动混合精度上下文管理器
# scaler: GradScaler梯度缩放器(配合AMP使用)
# Logger: 自定义日志打印函数
# is_main_process: 判断是否为主进程(分布式训练用)
# lm_config: 语言模型配置(如hidden_size、是否用MoE等)
# lm_checkpoint: 自定义的检查点保存函数
# get_lr: 学习率调度函数(根据步数计算当前学习率)
# DistributedDataParallel: 分布式数据并行训练类

def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
    """
    执行单个Epoch的训练过程,包含完整的前向/反向传播、梯度处理、日志记录和模型保存逻辑
    
    Args:
        epoch (int): 当前训练的Epoch编号(从0开始)
        loader (DataLoader): 训练数据加载器,输出(input_ids, labels)
        iters (int): 当前Epoch的总迭代步数
        start_step (int, optional): 起始迭代步数(用于断点续训),默认0
        wandb (optional): Weights & Biases日志对象,用于可视化训练指标,默认None
    """
    # 记录当前Epoch的训练开始时间
    start_time = time.time()
    
    # 遍历数据加载器,step从start_step+1开始计数(断点续训时保持步数连续),一般start_step几百几百的算
    for step, (input_ids, labels) in enumerate(loader, start=start_step + 1)
        # -------------------------- 1. 数据预处理 --------------------------
        # 将输入和标签移到指定设备(GPU/CPU)
        input_ids = input_ids.to(args.device)
        labels = labels.to(args.device)
        
        # -------------------------- 2. 学习率调度 --------------------------
        # 计算当前全局步数(所有Epoch累计),并获取对应的学习率
        global_step = epoch * iters + step
        total_steps = args.epochs * iters  # 训练总步数
        lr = get_lr(global_step, total_steps, args.learning_rate) # 余弦退火算法,更新学习率
        # 更新优化器的学习率
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # -------------------------- 3. 前向传播(混合精度) --------------------------
        # 启用AMP自动混合精度上下文,自动切换FP16/FP32精度
        with autocast_ctx:
            # 模型前向传播,输入input_ids并传入labels计算损失
            res = model(input_ids, labels=labels)
            # 总损失 = 主损失 + 辅助损失(如MoE模型的负载均衡损失、对齐损失等)
            loss = res.loss + res.aux_loss
            # 梯度累积:将损失除以累积步数,保证最终梯度等价于大批次训练
            loss = loss / args.accumulation_steps

        # -------------------------- 4. 反向传播(梯度缩放) --------------------------
        # AMP梯度缩放:防止FP16梯度下溢,先缩放损失再反向传播
        scaler.scale(loss).backward()

        # -------------------------- 5. 梯度裁剪 & 优化器更新(梯度累积到位后) --------------------------
        # 当累积步数达到设定值时,执行参数更新
        if (step + 1) % args.accumulation_steps == 0:
            # 取消梯度缩放:将优化器的梯度值恢复到原始尺度(用于梯度裁剪)
            scaler.unscale_(optimizer)
            # 梯度裁剪:防止梯度爆炸,限制参数梯度的L2范数不超过grad_clip
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

            # 执行优化器步进(参数更新),scaler会自动处理梯度缩放的恢复
            scaler.step(optimizer)
            # 更新scaler的缩放因子(AMP自适应调整)
            scaler.update()

            # 清空梯度(set_to_none=True比zero()更节省显存)
            optimizer.zero_grad(set_to_none=True)

        # -------------------------- 6. 日志记录 --------------------------
        # 达到日志打印间隔,或当前Epoch最后一步时,打印并记录训练指标
        if step % args.log_interval == 0 or step == iters - 1:
            # 计算已消耗的训练时间
            spend_time = time.time() - start_time
            # 恢复真实损失值(反向传播时除以了累积步数,这里乘回来)
            current_loss = loss.item() * args.accumulation_steps
            # 提取辅助损失值(无辅助损失则为0)
            current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
            # 主损失 = 总损失 - 辅助损失
            current_logits_loss = current_loss - current_aux_loss
            # 获取当前优化器的学习率(取最后一个param_group的lr,适配多参数组场景)
            current_lr = optimizer.param_groups[-1]['lr']
            # 计算当前Epoch剩余时间(分钟)
            eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
            # 打印训练日志
            Logger(
                f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), '
                f'loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, '
                f'aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, '
                f'epoch_time: {eta_min:.1f}min'
            )
            # 如果启用wandb,将指标写入wandb可视化
            if wandb: 
                wandb.log({
                    "loss": current_loss, 
                    "logits_loss": current_logits_loss, 
                    "aux_loss": current_aux_loss, 
                    "learning_rate": current_lr, 
                    "epoch_time": eta_min
                })

        # -------------------------- 7. 模型保存 --------------------------
        # 达到保存间隔,或当前Epoch最后一步,且为主进程时(分布式训练避免重复保存)
        if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
            # 切换模型到评估模式(防止BN/Dropout等层影响权重)
            model.eval()
            # 生成MoE模型的后缀(区分普通模型和MoE模型)
            moe_suffix = '_moe' if lm_config.use_moe else ''
            # 构建模型权重保存路径
            ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
            # 处理分布式模型:如果是DDP包装的模型,取原始模型(module)
            raw_model = model.module if isinstance(model, DistributedDataParallel) else model
            # 处理可能的模型包装(如torch.compile后的_orig_mod)
            raw_model = getattr(raw_model, '_orig_mod', raw_model)
            # 获取模型状态字典
            state_dict = raw_model.state_dict()
            # 保存权重:转为FP16并移到CPU,节省磁盘空间
            torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
            # 调用自定义检查点函数,保存更完整的训练状态(模型、优化器、epoch/step、scaler等)
            lm_checkpoint(
                lm_config, weight=args.save_weight, model=model, optimizer=optimizer, 
                epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler
            )
            # 切换回训练模式
            model.train()
            # 删除状态字典,释放显存
            del state_dict

        # -------------------------- 8. 清理显存 --------------------------
        # 删除当前迭代的张量,避免显存泄漏
        del input_ids, labels, res, loss
1.1 数据预处理:
python 复制代码
    for step, (input_ids, labels) in enumerate(loader, start=start_step + 1)
        # -------------------------- 1. 数据预处理 --------------------------
        # 将输入和标签移到指定设备(GPU/CPU)
        input_ids = input_ids.to(args.device)
        labels = labels.to(args.device)
1.2 学习率调度:
python 复制代码
        # -------------------------- 2. 学习率调度 --------------------------
        # 计算当前全局步数(所有Epoch累计),并获取对应的学习率
        global_step = epoch * iters + step
        total_steps = args.epochs * iters  # 训练总步数
        lr = get_lr(global_step, total_steps, args.learning_rate) # 余弦退火算法,更新学习率
        # 更新优化器的学习率
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

其中核心就在于 get_lr() 函数,它根据 当前训练步数 以及 总步数 动态调整学习率大小,其实现过程如下:

python 复制代码
import math

def get_lr(current_step, total_steps, lr):
    """
    基于余弦退火的学习率调度函数
    :param current_step: 当前训练步数(从0开始,到total_steps结束)
    :param total_steps: 训练的总步数
    :param lr: 初始学习率(基准学习率)
    :return: 当前步数对应的动态学习率
    """
    # 核心计算逻辑:余弦退火公式 + 系数调整
    return lr * (0.1 + 0.45 * (1 + math.cos(math.pi * current_step / total_steps)))
  • 余弦项核心 math.cos(math.pi * current_step / total_steps): 将当前步数归一化到 0 ~ 1 之间,然后映射到 0 ~ π 弧度(余弦函数在 0~π 区间是从 1 平滑下降到 -1)
    目的:通过余弦函数让学习率随步数平滑下降,防止学习率下降过快或过低。
1.3 AMP 自动混合精度前向传播

核心内容在于通过 with autocast_ctx: 实现 AMP 自动混合精度上下文,管理前向传播的精度切换:

python 复制代码
# -------------------------- 3. 前向传播(混合精度) --------------------------
        # 启用AMP自动混合精度上下文,自动切换FP16/FP32精度
        with autocast_ctx:
            # 模型前向传播,输入input_ids并传入labels计算损失
            res = model(input_ids, labels=labels)
            # 总损失 = 主损失 + 辅助损失(如MoE模型的负载均衡损失、对齐损失等)
            loss = res.loss + res.aux_loss
            # 梯度累积:将损失除以累积步数,保证最终梯度等价于大批次训练
            loss = loss / args.accumulation_steps

AMP 会通过智能判断的方式,对 精度敏感的运算(如梯度累加、BN 层)保留 FP32(单精度),对 精度不敏感的运算(如卷积、矩阵乘法)自动转成 FP16/BF16(半精度)。

  • 框架如何 "智能判断" 精度? PyTorch 框架给所有内置算子标注 "精度敏感度",比如卷积 / 矩阵乘法属于 "精度不敏感型",BN / 损失计算属于 "精度敏感型。然后,通过上下文管理器,在模型前向传播时自动识别算子类型,按规则切换精度;
  • 目的: 在反向传播时,对梯度做特殊处理(高精度),避免半精度梯度下溢(梯度为 0),保证参数更新的准确性。

FP16 的数值范围小,梯度值太小时会变成 0(下溢),导致模型无法收敛;

1.3 反向传播

GradScaler 处理反向传播时,会自动把损失 / 梯度的计算切换回 FP32:

python 复制代码
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) # 梯度缩放器
# -------------------------- 4. 反向传播(梯度缩放) --------------------------
        # AMP梯度缩放:防止FP16梯度下溢,先缩放损失再反向传播
        scaler.scale(loss).backward()
  • scaler.scale(loss): 缩放loss避免低精度梯度下溢,这也就是scale up。防止下溢的具体实现为:在反向传播之前,它将 loss 乘以一个很大的缩放因子(例如 216=655362^{16} = 65536216=65536)。当 backward() 结束,所有的梯度已经安全地存储在每个参数的 .grad 属性里了。
  • scaler.scale(loss).backward(): 返回 None,核心是生成「放大的 FP32 梯度」存到参数 .grad 中,避免 FP16 下溢;

这些「放大的梯度」会直接写入模型参数的 .grad 属性中(比如 model.conv1.weight.grad),优化器此时还没接触到梯度,只是后续会读取这个 .grad。

AMP 完整逻辑: autocast(前向提速) + GradScaler(反向保精度),二者配合才是完整的自动混合精度,而非 autocast 单独搞定。

1.4 梯度更新

python 复制代码
# 梯度累积步数达标后,执行参数更新
if (step + 1) % args.accumulation_steps == 0:
    scaler.unscale_(optimizer)  # 取消梯度缩放,把它们的 .grad管理的参数全部除以 65536
    # 梯度裁剪:限制梯度L2范数,防止梯度爆炸(大模型必做)
    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

    scaler.step(optimizer)  # 更新模型参数
    scaler.update()        # 更新混合精度缩放器状态

    optimizer.zero_grad(set_to_none=True)  # 清空梯度(set_to_none更省显存)
  • scaler.unscale_(optimizer) : 自动扫描优化器管理的所有参数(本质还是模型的参数),把它们的 .grad 全部除以 65536。

关键:操作的依然是「模型参数的 .grad」,优化器只是 "桥梁"------ 告诉 scaler 要处理哪些参数的 .grad。梯度是 "挂" 在模型参数上的,优化器只是 "工具人"------ 帮你找到这些参数,让 scaler 能对参数的梯度做缩放 / 还原操作。

例子:autocast 实现前向传播混合精度计算 + gradscaler 实现 scale up

python 复制代码
import torch
import torch.nn as nn
# 适配新版本:从 torch.amp 导入(替代 torch.cuda.amp)
from torch.amp import GradScaler, autocast

# 初始化模型、优化器、缩放器
model = nn.Linear(10, 1).cuda() # 模型是 nn.Linear(10, 1),那么model.weight 的形状是 (1, 10)(输出维度 1,输入维度 10),model.weight.grad 的形状和参数完全一致,也是 (1, 10)(1 行、10 列);
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 新版本 GradScaler:指定 device_type(必填)
scaler = GradScaler()

# 模拟前向传播(约束输入/损失范围,避免梯度inf/nan)
x = torch.randn(2, 10).cuda() * 0.1  # 缩小输入范围
with autocast(device_type="cuda", dtype=torch.float16):  # 显式指定设备和精度
    output = model(x)
    loss = nn.MSELoss()(output, torch.zeros_like(output))  # 用MSELoss约束loss

# ========== 关键修正:先执行反向传播,再打印梯度 ==========
# 1. 缩放loss并反向传播(此时才会计算梯度,写入model.weight.grad)
scaler.scale(loss).backward()

# 查看模型参数的梯度(此时grad已存在,可切片)
print("缩放后模型参数的grad:", model.weight.grad)
print("缩放后模型参数的grad shape:", model.weight.grad.shape)

# 查看优化器是否有grad(优化器无grad属性)
print("优化器的grad:", hasattr(optimizer, 'grad'))  # 输出 False

# 2. 手动unscale_(还原梯度)
scaler.unscale_(optimizer)
# 查看还原后的真实梯度
print("还原后模型参数的grad:", model.weight.grad[:2, :2])
print("还原后模型参数的grad shape:", model.weight.grad.shape)

# 补充:完整流程需执行step+update+清空梯度
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)

输出内容:

python 复制代码
缩放后模型参数的grad: tensor([[1415.0000,  961.5000,  352.0000, 2096.0000, -848.0000, 1648.0000,
         -569.5000, 1281.0000,  378.7500, 1107.0000]], device='cuda:0')
缩放后模型参数的grad shape: torch.Size([1, 10])
优化器的grad: False
还原后模型参数的grad: tensor([[0.0216, 0.0147]], device='cuda:0')
还原后模型参数的grad shape: torch.Size([1, 10])
  • model = nn.Linear(10, 1).cuda(): 模型是 nn.Linear(10, 1),那么model.weight 的形状是 (1, 10)(输出维度 1,输入维度 10),model.weight.grad 的形状和参数完全一致,也是 (1, 10)(1 行、10 列);

print("还原后模型参数的grad shape:", model.weight.grad.shape),所以为 torch.Size([1, 10])

  • 切片 [:2, :2] 的执行逻辑: PyTorch 张量切片遵循「取到维度的最后一个元素为止,不会越界报错」的规则。

2. SFT 和 PreTrain 在 DataSet 的区别

python 复制代码
import torch
from torch.utils.data import Dataset
from datasets import load_dataset  # 注意:需要安装datasets库

class SFTDataset(Dataset):
    """
    用于大模型SFT(监督微调)的自定义Dataset类
    功能:加载JSONL格式的对话数据,生成符合ChatML格式的input_ids,并为labels做掩码(仅回答部分保留真实值,其余为-100)
    """
    def __init__(self, jsonl_path, tokenizer, max_length=1024):
        """
        初始化数据集
        Args:
            jsonl_path: str - JSONL格式的数据文件路径(每行是一个对话样本)
            tokenizer: PreTrainedTokenizer - 模型对应的tokenizer(如LlamaTokenizer)
            max_length: int - 输入序列的最大长度,超出截断,不足补pad
        """
        super().__init__()
        self.tokenizer = tokenizer  # 保存tokenizer
        self.max_length = max_length  # 保存最大序列长度
        
        # 加载JSONL数据集(使用huggingface的datasets库,加载train拆分)
        self.samples = load_dataset('json', data_files=jsonl_path, split='train')
        
        # 预计算assistant回复的起始标记id(bos_id)和结束标记id(eos_id)
        # 例如:tokenizer.bos_token是<s>,则bos_id对应<s>assistant\n的token id列表
        self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
        # eos_id对应</s>\n的token id列表
        self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids

    def __len__(self):
        """返回数据集总样本数(必须实现的Dataset方法)"""
        return len(self.samples)

    def create_chat_prompt(self, cs):
        """
        将对话列表转换成符合ChatML格式的文本prompt
        Args:
            cs: list - 单条样本的对话列表,格式如[{"role":"user","content":"你好"}, {"role":"assistant","content":"您好!"}]
        Returns:
            str - 格式化后的对话文本(如"<s>user\n你好</s>assistant\n您好!</s>")
        """
        # 复制对话列表,避免修改原数据
        messages = cs.copy()
        
        # 处理工具调用场景:如果第一条是system消息且包含functions(工具定义),则传入tools参数
        tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
        
        # 使用tokenizer的chat_template生成标准化的对话prompt(不分词,不添加生成提示)
        return self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,  # 只返回文本,不做分词
            add_generation_prompt=False,  # 不添加"assistant\n"的生成提示(因为数据中已有)
            tools=tools  # 传入工具定义(如有)
        )

    def generate_labels(self, input_ids):
        """
        生成标签:仅保留assistant回复部分的token id,其余部分设为-100(PyTorch中-100会被CrossEntropyLoss忽略)
        Args:
            input_ids: list - 整个对话的token id列表
        Returns:
            list - 与input_ids等长的labels列表,非回答部分为-100,回答部分为对应input_ids的值
        """
        # 初始化labels为全-100(默认所有位置都不计算损失)
        labels = [-100] * len(input_ids)
        i = 0  # 遍历input_ids的指针
        
        # 遍历整个input_ids,找到所有assistant回复的区间
        while i < len(input_ids):
            # 找到assistant回复的起始位置(匹配bos_id:<s>assistant\n)
            if input_ids[i:i + len(self.bos_id)] == self.bos_id:
                # 回复起始位置 = bos_id结束的位置
                start = i + len(self.bos_id)
                end = start  # 回复结束位置的指针
                
                # 从start开始,找到eos_id(</s>\n)的位置,确定回复的结束边界
                while end < len(input_ids):
                    if input_ids[end:end + len(self.eos_id)] == self.eos_id:
                        break
                    end += 1
                
                # 将assistant回复区间(start到end+eos_id长度)的labels设为对应input_ids的值
                # 注意:不超过max_length,避免越界
                for j in range(start, min(end + len(self.eos_id), self.max_length)):
                    labels[j] = input_ids[j]
                
                # 指针跳到当前回复的结束位置,继续找下一个回复(如有)
                i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
            else:
                # 未找到bos_id,指针后移一位
                i += 1
        return labels

    def __getitem__(self, index):
        """
        加载单条样本(必须实现的Dataset方法)
        Args:
            index: int - 样本索引
        Returns:
            tuple - (input_ids_tensor, labels_tensor),均为long型tensor,长度=max_length
        """
        # 获取单条样本(对话数据)
        sample = self.samples[index]
        
        # 1. 将对话列表转换成ChatML格式的文本prompt
        prompt = self.create_chat_prompt(sample['conversations'])
        
        # 2. 对prompt分词,生成input_ids,并截断到max_length
        input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
        
        # 3. 不足max_length的部分用pad_token_id补齐
        input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
        
        # 4. 生成labels:仅保留assistant回复部分,其余为-100
        labels = self.generate_labels(input_ids)
        
        # # === 调试打印(可选)===
        # print(f"\n--- Sample {index} ---")
        # for i, (x, y) in enumerate(zip(input_ids[:-1], labels[1:])):
        #     print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([input_ids[i+1]])!r:16s} label={y}")
        # # ================
        
        # 5. 转换为torch tensor(long型,符合模型输入要求)
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
区别一:SFT用了chat_template

核心目的: 封装 tokenizer.apply_chat_template() ,将 messages 列表封装到 chat_template 中,返回标准对话的 prompt ;

因为 SFT 与 Pretrain 的对话构成不一样,即 labels 不一样。前者的真值来源于 assistant 的回复,而后者为续写下一个词,只需要对齐即可。

python 复制代码
    def create_chat_prompt(self, cs):
        """
        将对话列表转换成符合ChatML格式的文本prompt
        Args:
            cs: list - 单条样本的对话列表,格式如[{"role":"user","content":"你好"}, {"role":"assistant","content":"您好!"}]
        Returns:
            str - 格式化后的对话文本(如"<s>user\n你好</s>assistant\n您好!</s>")
        """
        # 复制对话列表,避免修改原数据
        messages = cs.copy()
        
        # 处理工具调用场景:如果第一条是system消息且包含functions(工具定义),则传入tools参数
        tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
        
        # 使用tokenizer的chat_template生成标准化的对话prompt(不分词,不添加生成提示)
        return self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,  # 只返回文本,不做分词
            add_generation_prompt=False,  # 不添加"assistant\n"的生成提示(因为数据中已有)
            tools=tools  # 传入工具定义(如有)
        )

返回的 Prompt:

python 复制代码
<|im_start|>system
你是一个优秀的聊天机器人,总是给我正确的回应!<|im_end|>
<|im_start|>user
你来自哪里?<|im_end|>
<|im_start|>assistant
我来自地球<|im_end|>

然后对 Prompt 进行分词:

python 复制代码
# 2. 对prompt分词,生成input_ids,并截断到max_length
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]

# 3. 不足max_length的部分用pad_token_id补齐
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))

input_ids:

python 复制代码
model_inpus: {'input_ids': [1, 85, 91, 85, 86, 71, 79, 201, 358, 988, 4904, 5195, 265, 1958, 911, 1862, 570, 606, 604, 2, 201, 1, 87, 85, 71, 84, 201, 358, 3583, 1648, 336, 2, 201, 1, 67, 85, 85, 75, 85, 86, 67, 80, 86, 201, 289, 3583, 1562, 2, 201], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
区别二:labels 的构建方式不一样
python 复制代码
    def generate_labels(self, input_ids):
        """
        生成标签:仅保留assistant回复部分的token id,其余部分设为-100(PyTorch中-100会被CrossEntropyLoss忽略)
        Args:
            input_ids: list - 整个对话的token id列表
        Returns:
            list - 与input_ids等长的labels列表,非回答部分为-100,回答部分为对应input_ids的值
        """
        # 初始化labels为全-100(默认所有位置都不计算损失)
        labels = [-100] * len(input_ids)
        i = 0  # 遍历input_ids的指针
        
        # 遍历整个input_ids,找到所有assistant回复的区间
        while i < len(input_ids):
            # 找到assistant回复的起始位置(匹配bos_id:<s>assistant\n)
            if input_ids[i:i + len(self.bos_id)] == self.bos_id:
                # 回复起始位置 = bos_id结束的位置
                start = i + len(self.bos_id)
                end = start  # 回复结束位置的指针
                
                # 从start开始,找到eos_id(</s>\n)的位置,确定回复的结束边界
                while end < len(input_ids):
                    if input_ids[end:end + len(self.eos_id)] == self.eos_id:
                        break
                    end += 1
                
                # 将assistant回复区间(start到end+eos_id长度)的labels设为对应input_ids的值
                # 注意:不超过max_length,避免越界
                for j in range(start, min(end + len(self.eos_id), self.max_length)):
                    labels[j] = input_ids[j]
                
                # 指针跳到当前回复的结束位置,继续找下一个回复(如有)
                i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
            else:
                # 未找到bos_id,指针后移一位
                i += 1
        return labels
  • labels = [-100] * len(input_ids): 构建一个长度和 input_ids 一样的列表,元素值都为 -100,[-100] 为一个单元素列表。
  • labels区间: <|im_start|>assistant\n 是助手回复开始标记,<|im_end|>\n 是结尾标记;
相关推荐
Alice_whj2 小时前
AI云原生笔记
人工智能·笔记·云原生
Lyan-X2 小时前
鲁鹏教授《计算机视觉与深度学习》课程笔记与思考 ——13. 生成模型 VAE:从无监督学习到显式密度估计的建模与实现
人工智能·笔记·深度学习·计算机视觉
AI_Auto2 小时前
智能制造-MES与AI结合的核心价值与逻辑
大数据·人工智能·制造
聊聊科技2 小时前
5款AI编曲软件荣登2026年度榜单,逐项对比适合原创音乐人参考
人工智能
董厂长2 小时前
RAG 中的分块策略(Chunking Strategy)
人工智能·llm·rag·分块策略
皮卡丘不断更2 小时前
让数据“开口说话”!SwiftBoot AI 智能看板 v0.1.8 震撼来袭
人工智能·系统架构·ai编程
向哆哆2 小时前
七种常见虫子的图像识别数据集分享(适用于目标检测任务)
人工智能·目标检测·计算机视觉
AI浩2 小时前
面向对象保真度的遥感图像生成扩散模型
人工智能·目标检测
CareyWYR2 小时前
每周AI论文速递(260209-260213)
人工智能