面试-预训练

代码

首先是训练逻辑部分的代码,主要包含 数据加载学习率调度混合精度训练梯度累积日志监控模型保存 等关键环节:

1. 加载 input_ids 和 labels
python 复制代码
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
    start_time = time.time()  # 记录本轮训练开始时间
    # 遍历数据加载器,start参数支持断点续训(从指定步数开始计数)
    for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
        # 数据迁移到指定设备(GPU/CPU),保证数据与模型同设备
        input_ids = input_ids.to(args.device)
        labels = labels.to(args.device)
  • 预训练的标签: 可以看出 labels 的内容和 input 的内容是通过 loader 加载出来的,那么我如果要知道 labels 和 input 的形状,肯定需要知道 loader 是什么。
2. DataLoader 类

核心目的: 依照 batch_size 维度,将 Dataset 产生的样本对 (input_ids, labels) 封装为张量批次。这种封装不仅实现了数据的并行计算。确立了训练过程中的"步进(Step)"逻辑,使训练进度与梯度更新变得直观可控。

python 复制代码
from torch.utils.data import DataLoader, DistributedSampler

# train_ds 为打包好的(input_ids, labels),batch_sampler 一种抽取数据的方式(=batchsize),num_workers 为并行抽取数量(几个人抽)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
  • train_ds:PretrainDataset 类,返回为 (input_ids, labels);
  • batch_sampler : 假设你的数据集里有 1000 条数据。BatchSampler 就像一个抽签桶,里面有 0 到 999 的编号。它会根据你的 batch_size=32,一次性抓出 32 个随机编号。
  • num_workers: 拿到这 32 个编号后,会立刻派 8 个工人(num_workers=8)去 PretrainDataset 这个大书架上取书。
    最后解包得到一个元组 (input_ids, labels)。分别为 [bsz, seq, dim]。那么 labels 的维度到底长什么样我们需要进一步看 PretrainDataset 类。
3. PretrainDataset 类

核心目的: 加载 json 数据集,确定分词细节,并确立 labels 的形状。

python 复制代码
class PretrainDataset(Dataset):
    def __init__(self, data_path, tokenizer, max_length=512):
        super().__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        # 使用 Huggingface 的 load_dataset 加载 jsonl 文件
        # 这会自动处理文件的读取,并将数据转化为一个类似列表的结构
        self.samples = load_dataset('json', data_files=data_path, split='train')

    def __len__(self):
        # 告诉 DataLoader 整个数据集有多少条数据
        return len(self.samples)

    def __getitem__(self, index):
        # 1. 根据索引获取一条原始数据,例如: {'text': '你好,世界'}
        sample = self.samples[index]
        
        # 2. 将文本转为数字 ID (Token IDs)
        # truncation=True: 超过 max_length-2 的部分会被切掉
        # max_length-2 是为了给后面的 BOS 和 EOS 留出位置
        tokens = self.tokenizer(
            str(sample['text']), 
            add_special_tokens=False, 
            max_length=self.max_length - 2, 
            truncation=True
        ).input_ids
        
        # 3. 构建完整的序列:[BOS] + 文本内容 + [EOS]
        # BOS: 开始符 (Begin of Sentence)
        # EOS: 结束符 (End of Sentence)
        tokens = [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
        
        # 4. 填充 (Padding)
        # 如果长度不足 max_length,就在后面补 PAD 符(通常是 0)
        # 确保同一个 Batch 里的所有序列长度都是一样的(都是 512)
        input_ids = tokens + [self.tokenizer.pad_token_id] * (self.max_length - len(tokens))
        
        # 5. 转为 PyTorch 张量
        input_ids = torch.tensor(input_ids, dtype=torch.long)
        
        # 6. 构建标签 (Labels)
        # 在预训练中,标签最初就是输入序列本身(因为我们要预测下一个词)
        labels = input_ids.clone()
        
        # 7. 忽略填充位的损失
        # 将 labels 中等于 PAD 的位置设为 -100
        # PyTorch 的 CrossEntropyLoss 会自动跳过标签为 -100 的位置
        # 这样模型就不会去学习如何预测"无意义"的填充符
        labels[input_ids == self.tokenizer.pad_token_id] = -100
        
        # 返回一对张量。DataLoader 会把多个这样的元组打包成 (batch_size, 512) 的张量
        return input_ids, labels
  • 分词细节: Dataset 类的核心目的就是确定分词细节,先加载数据集,分词,加特殊Token,填充Token至统一长度。
  • labels: 可以看出 labels 的形状大小和 input_ids 是一致的。

到这里,我们可以明白 train_epoch() 函数中的 input_ids 和 labels 大小是一致的,那么一个新的问题来了,预训练的 loss 是怎么算的呢?这跟模型内部处理有关!

4. 混合精度前向传播和梯度归一化

核心作用: 通过 autocast 实现混合精度,加速矩阵运算,同时归一化损失适配梯度累积。

python 复制代码
# 混合精度上下文:GPU用autocast加速,CPU用空上下文兼容
with autocast_ctx:
    # 模型前向传播:输入token id,输出包含主loss和辅助loss的结果
    res = model(input_ids, labels=labels)
    loss = res.loss + res.aux_loss  # 总损失=主任务损失+辅助损失(如MoE负载均衡损失)
    loss = loss / args.accumulation_steps  # 梯度累积归一化(避免累积后梯度幅度过大)
  • 混合精度(autocast_ctx): 将模型计算转为bfloat16/float16,显存占用减半,训练速度提升 30%+;
  • 梯度累积归一化: 可以理解为"逻辑损失",损失除以累积步数,保证最终累积梯度与 "真实大批次" 一致(如 8 步累积时,8次 batch 的 loss和 / 8,避免参数更新幅度过大)。起到了平滑的作用。

那么有个新的问题,这个 res.loss 损失是怎么来的呢?如何计算的?我们需要看一下 res = model(input_ids, labels=labels) 的实现逻辑。

5. model(input_ids, labels) 前向传播

我们定位到模型文件的最外壳------CausalLM(因果语言模型)。它的核心逻辑就是:"把模型提取的抽象特征(hidden_states)转换成人类能看懂的词表概率,并计算和标准答案之间的差距。"

核心目的: 将模型提取的 hidden_states 转换成人类能看懂的词表概率,并计算和标准答案之间的差距。

python 复制代码
def forward(self,
                input_ids: Optional[torch.Tensor] = None, # 形状: [bsz, seq_len],输入的 Token ID
                attention_mask: Optional[torch.Tensor] = None, # 用于处理 Padding,屏蔽无效位
                labels: Optional[torch.Tensor] = None, # 形状: [bsz, seq_len],训练时的参考答案
                past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, # 推理加速用的 KV Cache
                use_cache: bool = False, # 是否开启 KV Cache
                logits_to_keep: Union[int, torch.Tensor] = 0, # 为了节省显存,只计算最后几个 Token 的 Logits
                **args):
        # 1. 核心模型前向传播:
        # hidden_states: 形状 [bsz, seq_len, hidden_size],模型对每个位置提取的特征向量
        # aux_loss: MoE 架构特有的辅助损失,用于平衡各个专家的负载
        hidden_states, past_key_values, aux_loss = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            **args
        )

        # 2. 确定计算 Logits 的范围:
        # slice_indices 决定了我们对序列中的哪些位置进行"解码"(转成词概率)
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        
        # 3. 输出层映射 (LM Head):
        # 将 [hidden_size] 的向量映射到词表大小 [vocab_size]
        # 得到 logits 形状: [bsz, seq_len, vocab_size]
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        # 4. 训练模式:计算 Loss
        if labels is not None:
            # 【核心对齐逻辑】:预测下一个词
            # shift_logits: 去掉最后一个预测(因为它后面没答案了)
            # shift_labels: 去掉第一个标签(因为它前面没预测)
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            # 5. 计算交叉熵损失 (Cross Entropy):
            # view(-1, ...): 将多维张量展平为一维,符合计算要求
            # ignore_index=-100: 告诉模型碰到标签为 -100 的 Padding 位直接忽略,不计入 Loss
            loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)), 
                shift_labels.view(-1), 
                ignore_index=-100
            )

        # 6. 封装输出对象:
        # CausalLMOutputWithPast 是 Huggingface 风格的输出结构,方便后续调用
        output = CausalLMOutputWithPast(
            loss=loss, 
            logits=logits, 
            past_key_values=past_key_values, 
            hidden_states=hidden_states
        )
        # 将 MoE 的辅助损失也塞进去,训练时要加到总 Loss 里
        output.aux_loss = aux_loss
        return output
  • shift 错位对齐: 这是实现 pretrain loss 的关键。它是将 "当前位置的预测"(shift_logits) 去对比 "下一位置的真实单词"(shift_labels)。
python 复制代码
# shift_logits: 去掉最后一个预测(因为它后面没答案了)
# shift_labels: 去掉第一个标签(因为它前面没预测)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

目前未知,我们终于知道预训练的 loss 是怎么来的了。

6. 梯度反向传播与参数更新

核心作用: 算梯度的,进行参数更新。

python 复制代码
# 混合精度反向传播:缩放loss避免低精度梯度下溢
scaler.scale(loss).backward()

# 梯度累积步数达标后,执行参数更新
if (step + 1) % args.accumulation_steps == 0:
    scaler.unscale_(optimizer)  # 取消梯度缩放,把它们的 .grad管理的参数全部除以 65536
    # 梯度裁剪:限制梯度L2范数,防止梯度爆炸(大模型必做)
    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

    scaler.step(optimizer)  # 更新模型参数
    scaler.update()        # 更新混合精度缩放器状态

    optimizer.zero_grad(set_to_none=True)  # 清空梯度(set_to_none更省显存)

这里有个关键细节,就是 scale up,目的是解决低精度训练的梯度下溢问题(即梯度在 float16/bf16 下溢为 0 的情况)。实现代码如下:

python 复制代码
# 原理:在反向传播之前,它将 loss 乘以一个很大的缩放因子(例如 $2^{16} = 65536$)。当 backward() 结束,所有的梯度已经安全地存储在每个参数的 .grad 属性里了。
scaler.scale(loss).backward()

链式法则推导: 根据微积分的链式法则,如果 Loss 扩大了 65536 倍,那么求导算出的所有梯度(Gradients)也会同步扩大 65536 倍。结果:那些原本在 float16 范围之外、即将变成 000 的微小梯度,被强行"拉"回了 float16 的可表示范围内,从而活了下来。

然后,scaler.unscale_(optimizer) 之后,它会自动扫描优化器管理的所有参数,把它们的 .grad 全部除以 65536。

代码步骤 动作 核心目的
scale(loss).backward() 放大 保护微小梯度不被 FP16 精度抹杀(保命)。
unscale_(optimizer) 还原 恢复真实数值,为裁剪和更新做准备(归位)。
clip_grad_norm_(...) 限流 确保步子迈得稳,不让大梯度冲毁模型。
scaler.step(...) 更新 检查无误后,真正修改模型权重 WWW。
7. 训练日志监控

核心作用: 实时监控训练状态;

python 复制代码
# 按间隔/最后一步打印日志(监控训练状态)
if step % args.log_interval == 0 or step == iters - 1:
    spend_time = time.time() - start_time  # 已用时间
    current_loss = loss.item() * args.accumulation_steps  # 还原真实损失值(反归一化)
    current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
    current_logits_loss = current_loss - current_aux_loss  # 分离主损失/辅助损失
    current_lr = optimizer.param_groups[-1]['lr']  # 获取当前学习率
    # 计算本轮剩余时间(ETA)
    eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
    # 打印核心指标:轮数、步数、损失、学习率、剩余时间
    Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
    if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})

关键细节: loss.item() * args.accumulation_steps 还原真实损失(抵消梯度累积的归一化),保证日志显示的损失值与实际一致。

8. 模型保存(断点续训)

核心作用: 保存模型权重 + 训练状态,支持断点续训;

python 复制代码
# 按间隔/最后一步保存模型(仅主进程执行,避免多卡重复保存)
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
    model.eval()  # 切换评估模式(避免BatchNorm/Dropout影响权重)
    moe_suffix = '_moe' if lm_config.use_moe else ''
    ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
    # 解包模型(兼容DDP/compile包装)
    # # 如果模型是 DistributedDataParallel 类型的,就取它的 .module 属性(那才是里面的模型)
    raw_model = model.module if isinstance(model, DistributedDataParallel) else model
    # # 看看模型里有没有 _orig_mod 这个属性,有的话就取出来,没有就保持原样
    raw_model = getattr(raw_model, '_orig_mod', raw_model)  # 兼容torch.compile
    # 提取模型权重字典
    state_dict = raw_model.state_dict()
    # 保存权重(半精度+CPU,大幅减小文件体积)
    torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
    # 保存完整检查点(含优化器/缩放器状态,支持断点续训)
    lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
    model.train()  # 切回训练模式
    del state_dict  # 释放显存
  • ckp: 动态生成的文件名。
  • 最外层 (DDP 壳): 开启多卡训练后,PyTorch 又包了一层 module。层名变成了 module._orig_mod.layers.0.attention...。若没有多卡则为 model。
  • 第二层 (Compile 壳): 开启加速后,PyTorch 会在外面包一层 _orig_mod。它负责优化算子,但会把里面的层改名为 _orig_mod.layers.0.attention...。
  • 最内核 (The Core): 比如 MiniMindForCausalLM 模型。这是真正存权重的地方(比如层名叫 layers.0.attention...)。
相关推荐
九.九9 小时前
ops-transformer:AI 处理器上的高性能 Transformer 算子库
人工智能·深度学习·transformer
春日见9 小时前
拉取与合并:如何让个人分支既包含你昨天的修改,也包含 develop 最新更新
大数据·人工智能·深度学习·elasticsearch·搜索引擎
恋猫de小郭9 小时前
AI 在提高你工作效率的同时,也一直在增加你的疲惫和焦虑
前端·人工智能·ai编程
deephub9 小时前
Agent Lightning:微软开源的框架无关 Agent 训练方案,LangChain/AutoGen 都能用
人工智能·microsoft·langchain·大语言模型·agent·强化学习
偷吃的耗子9 小时前
【CNN算法理解】:三、AlexNet 训练模块(附代码)
深度学习·算法·cnn
大模型RAG和Agent技术实践9 小时前
从零构建本地AI合同审查系统:架构设计与流式交互实战(完整源代码)
人工智能·交互·智能合同审核
老邋遢9 小时前
第三章-AI知识扫盲看这一篇就够了
人工智能
互联网江湖9 小时前
Seedance2.0炸场:长短视频们“修坝”十年,不如AI放水一天?
人工智能
PythonPioneer10 小时前
在AI技术迅猛发展的今天,传统职业该如何“踏浪前行”?
人工智能
冬奇Lab10 小时前
一天一个开源项目(第20篇):NanoBot - 轻量级AI Agent框架,极简高效的智能体构建工具
人工智能·开源·agent