代码
首先是训练逻辑部分的代码,主要包含 数据加载 、学习率调度 、混合精度训练 、梯度累积 、日志监控 、模型保存 等关键环节:
1. 加载 input_ids 和 labels
python
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
start_time = time.time() # 记录本轮训练开始时间
# 遍历数据加载器,start参数支持断点续训(从指定步数开始计数)
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
# 数据迁移到指定设备(GPU/CPU),保证数据与模型同设备
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
- 预训练的标签: 可以看出 labels 的内容和 input 的内容是通过 loader 加载出来的,那么我如果要知道 labels 和 input 的形状,肯定需要知道 loader 是什么。
2. DataLoader 类
核心目的: 依照 batch_size 维度,将 Dataset 产生的样本对 (input_ids, labels) 封装为张量批次。这种封装不仅实现了数据的并行计算。确立了训练过程中的"步进(Step)"逻辑,使训练进度与梯度更新变得直观可控。
python
from torch.utils.data import DataLoader, DistributedSampler
# train_ds 为打包好的(input_ids, labels),batch_sampler 一种抽取数据的方式(=batchsize),num_workers 为并行抽取数量(几个人抽)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
- train_ds: 为 PretrainDataset 类,返回为 (input_ids, labels);
- batch_sampler : 假设你的数据集里有 1000 条数据。BatchSampler 就像一个抽签桶,里面有 0 到 999 的编号。它会根据你的 batch_size=32,一次性抓出 32 个随机编号。
- num_workers: 拿到这 32 个编号后,会立刻派 8 个工人(num_workers=8)去 PretrainDataset 这个大书架上取书。
最后解包得到一个元组 (input_ids, labels)。分别为 [bsz, seq, dim]。那么 labels 的维度到底长什么样我们需要进一步看 PretrainDataset 类。
3. PretrainDataset 类
核心目的: 加载 json 数据集,确定分词细节,并确立 labels 的形状。
python
class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
# 使用 Huggingface 的 load_dataset 加载 jsonl 文件
# 这会自动处理文件的读取,并将数据转化为一个类似列表的结构
self.samples = load_dataset('json', data_files=data_path, split='train')
def __len__(self):
# 告诉 DataLoader 整个数据集有多少条数据
return len(self.samples)
def __getitem__(self, index):
# 1. 根据索引获取一条原始数据,例如: {'text': '你好,世界'}
sample = self.samples[index]
# 2. 将文本转为数字 ID (Token IDs)
# truncation=True: 超过 max_length-2 的部分会被切掉
# max_length-2 是为了给后面的 BOS 和 EOS 留出位置
tokens = self.tokenizer(
str(sample['text']),
add_special_tokens=False,
max_length=self.max_length - 2,
truncation=True
).input_ids
# 3. 构建完整的序列:[BOS] + 文本内容 + [EOS]
# BOS: 开始符 (Begin of Sentence)
# EOS: 结束符 (End of Sentence)
tokens = [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
# 4. 填充 (Padding)
# 如果长度不足 max_length,就在后面补 PAD 符(通常是 0)
# 确保同一个 Batch 里的所有序列长度都是一样的(都是 512)
input_ids = tokens + [self.tokenizer.pad_token_id] * (self.max_length - len(tokens))
# 5. 转为 PyTorch 张量
input_ids = torch.tensor(input_ids, dtype=torch.long)
# 6. 构建标签 (Labels)
# 在预训练中,标签最初就是输入序列本身(因为我们要预测下一个词)
labels = input_ids.clone()
# 7. 忽略填充位的损失
# 将 labels 中等于 PAD 的位置设为 -100
# PyTorch 的 CrossEntropyLoss 会自动跳过标签为 -100 的位置
# 这样模型就不会去学习如何预测"无意义"的填充符
labels[input_ids == self.tokenizer.pad_token_id] = -100
# 返回一对张量。DataLoader 会把多个这样的元组打包成 (batch_size, 512) 的张量
return input_ids, labels
- 分词细节: Dataset 类的核心目的就是确定分词细节,先加载数据集,分词,加特殊Token,填充Token至统一长度。
- labels: 可以看出 labels 的形状大小和 input_ids 是一致的。
到这里,我们可以明白 train_epoch() 函数中的 input_ids 和 labels 大小是一致的,那么一个新的问题来了,预训练的 loss 是怎么算的呢?这跟模型内部处理有关!
4. 混合精度前向传播和梯度归一化
核心作用: 通过 autocast 实现混合精度,加速矩阵运算,同时归一化损失适配梯度累积。
python
# 混合精度上下文:GPU用autocast加速,CPU用空上下文兼容
with autocast_ctx:
# 模型前向传播:输入token id,输出包含主loss和辅助loss的结果
res = model(input_ids, labels=labels)
loss = res.loss + res.aux_loss # 总损失=主任务损失+辅助损失(如MoE负载均衡损失)
loss = loss / args.accumulation_steps # 梯度累积归一化(避免累积后梯度幅度过大)
- 混合精度(autocast_ctx): 将模型计算转为bfloat16/float16,显存占用减半,训练速度提升 30%+;
- 梯度累积归一化: 可以理解为"逻辑损失",损失除以累积步数,保证最终累积梯度与 "真实大批次" 一致(如 8 步累积时,8次 batch 的 loss和 / 8,避免参数更新幅度过大)。起到了平滑的作用。
那么有个新的问题,这个 res.loss 损失是怎么来的呢?如何计算的?我们需要看一下 res = model(input_ids, labels=labels) 的实现逻辑。
5. model(input_ids, labels) 前向传播
我们定位到模型文件的最外壳------CausalLM(因果语言模型)。它的核心逻辑就是:"把模型提取的抽象特征(hidden_states)转换成人类能看懂的词表概率,并计算和标准答案之间的差距。"
核心目的: 将模型提取的 hidden_states 转换成人类能看懂的词表概率,并计算和标准答案之间的差距。
python
def forward(self,
input_ids: Optional[torch.Tensor] = None, # 形状: [bsz, seq_len],输入的 Token ID
attention_mask: Optional[torch.Tensor] = None, # 用于处理 Padding,屏蔽无效位
labels: Optional[torch.Tensor] = None, # 形状: [bsz, seq_len],训练时的参考答案
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, # 推理加速用的 KV Cache
use_cache: bool = False, # 是否开启 KV Cache
logits_to_keep: Union[int, torch.Tensor] = 0, # 为了节省显存,只计算最后几个 Token 的 Logits
**args):
# 1. 核心模型前向传播:
# hidden_states: 形状 [bsz, seq_len, hidden_size],模型对每个位置提取的特征向量
# aux_loss: MoE 架构特有的辅助损失,用于平衡各个专家的负载
hidden_states, past_key_values, aux_loss = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
**args
)
# 2. 确定计算 Logits 的范围:
# slice_indices 决定了我们对序列中的哪些位置进行"解码"(转成词概率)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
# 3. 输出层映射 (LM Head):
# 将 [hidden_size] 的向量映射到词表大小 [vocab_size]
# 得到 logits 形状: [bsz, seq_len, vocab_size]
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
# 4. 训练模式:计算 Loss
if labels is not None:
# 【核心对齐逻辑】:预测下一个词
# shift_logits: 去掉最后一个预测(因为它后面没答案了)
# shift_labels: 去掉第一个标签(因为它前面没预测)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# 5. 计算交叉熵损失 (Cross Entropy):
# view(-1, ...): 将多维张量展平为一维,符合计算要求
# ignore_index=-100: 告诉模型碰到标签为 -100 的 Padding 位直接忽略,不计入 Loss
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100
)
# 6. 封装输出对象:
# CausalLMOutputWithPast 是 Huggingface 风格的输出结构,方便后续调用
output = CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_key_values,
hidden_states=hidden_states
)
# 将 MoE 的辅助损失也塞进去,训练时要加到总 Loss 里
output.aux_loss = aux_loss
return output
- shift 错位对齐: 这是实现 pretrain loss 的关键。它是将 "当前位置的预测"(shift_logits) 去对比 "下一位置的真实单词"(shift_labels)。
python
# shift_logits: 去掉最后一个预测(因为它后面没答案了)
# shift_labels: 去掉第一个标签(因为它前面没预测)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
目前未知,我们终于知道预训练的 loss 是怎么来的了。
6. 梯度反向传播与参数更新
核心作用: 算梯度的,进行参数更新。
python
# 混合精度反向传播:缩放loss避免低精度梯度下溢
scaler.scale(loss).backward()
# 梯度累积步数达标后,执行参数更新
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer) # 取消梯度缩放,把它们的 .grad管理的参数全部除以 65536
# 梯度裁剪:限制梯度L2范数,防止梯度爆炸(大模型必做)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer) # 更新模型参数
scaler.update() # 更新混合精度缩放器状态
optimizer.zero_grad(set_to_none=True) # 清空梯度(set_to_none更省显存)
这里有个关键细节,就是 scale up,目的是解决低精度训练的梯度下溢问题(即梯度在 float16/bf16 下溢为 0 的情况)。实现代码如下:
python
# 原理:在反向传播之前,它将 loss 乘以一个很大的缩放因子(例如 $2^{16} = 65536$)。当 backward() 结束,所有的梯度已经安全地存储在每个参数的 .grad 属性里了。
scaler.scale(loss).backward()
链式法则推导: 根据微积分的链式法则,如果 Loss 扩大了 65536 倍,那么求导算出的所有梯度(Gradients)也会同步扩大 65536 倍。结果:那些原本在 float16 范围之外、即将变成 000 的微小梯度,被强行"拉"回了 float16 的可表示范围内,从而活了下来。
然后,scaler.unscale_(optimizer) 之后,它会自动扫描优化器管理的所有参数,把它们的 .grad 全部除以 65536。
| 代码步骤 | 动作 | 核心目的 |
|---|---|---|
scale(loss).backward() |
放大 | 保护微小梯度不被 FP16 精度抹杀(保命)。 |
unscale_(optimizer) |
还原 | 恢复真实数值,为裁剪和更新做准备(归位)。 |
clip_grad_norm_(...) |
限流 | 确保步子迈得稳,不让大梯度冲毁模型。 |
scaler.step(...) |
更新 | 检查无误后,真正修改模型权重 WWW。 |
7. 训练日志监控
核心作用: 实时监控训练状态;
python
# 按间隔/最后一步打印日志(监控训练状态)
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time # 已用时间
current_loss = loss.item() * args.accumulation_steps # 还原真实损失值(反归一化)
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
current_logits_loss = current_loss - current_aux_loss # 分离主损失/辅助损失
current_lr = optimizer.param_groups[-1]['lr'] # 获取当前学习率
# 计算本轮剩余时间(ETA)
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
# 打印核心指标:轮数、步数、损失、学习率、剩余时间
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
关键细节: loss.item() * args.accumulation_steps 还原真实损失(抵消梯度累积的归一化),保证日志显示的损失值与实际一致。
8. 模型保存(断点续训)
核心作用: 保存模型权重 + 训练状态,支持断点续训;
python
# 按间隔/最后一步保存模型(仅主进程执行,避免多卡重复保存)
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval() # 切换评估模式(避免BatchNorm/Dropout影响权重)
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
# 解包模型(兼容DDP/compile包装)
# # 如果模型是 DistributedDataParallel 类型的,就取它的 .module 属性(那才是里面的模型)
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
# # 看看模型里有没有 _orig_mod 这个属性,有的话就取出来,没有就保持原样
raw_model = getattr(raw_model, '_orig_mod', raw_model) # 兼容torch.compile
# 提取模型权重字典
state_dict = raw_model.state_dict()
# 保存权重(半精度+CPU,大幅减小文件体积)
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
# 保存完整检查点(含优化器/缩放器状态,支持断点续训)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train() # 切回训练模式
del state_dict # 释放显存
- ckp: 动态生成的文件名。
- 最外层 (DDP 壳): 开启多卡训练后,PyTorch 又包了一层 module。层名变成了 module._orig_mod.layers.0.attention...。若没有多卡则为 model。
- 第二层 (Compile 壳): 开启加速后,PyTorch 会在外面包一层 _orig_mod。它负责优化算子,但会把里面的层改名为 _orig_mod.layers.0.attention...。
- 最内核 (The Core): 比如 MiniMindForCausalLM 模型。这是真正存权重的地方(比如层名叫 layers.0.attention...)。