构建中文版的 nanoGPT - 断点续训(resume from checkpoint)
flyfish
参考网址
bash
https://github.com/shaoshengsong/nanoGPT-cn
一、断点续训(resume from checkpoint)
断点续训(resume from checkpoint)是指在训练过程中中断后,能够从上次中断的位置继续训练,而不是从头开始。这对于大规模数据集训练尤为重要,可以:
避免重复训练,节省时间
应对意外中断(如断电、程序崩溃)
支持训练过程中的调整和优化
二、需要保存的数据和参数
2.1 Checkpoint 结构
在 train.py 中,每次保存 checkpoint 时会保存以下数据:
python
checkpoint = {
'model': raw_model.state_dict(), # 模型权重参数
'optimizer': optimizer.state_dict(), # 优化器状态
'model_args': model_args, # 模型配置参数
'iter_num': iter_num, # 当前迭代次数
'epoch_num': epoch_num, # 当前epoch数
'best_val_loss': best_val_loss, # 最佳验证损失
'config': config, # 完整训练配置
}
2.2 参数
| 参数 | 类型 | 作用 | 为什么需要保存 |
|---|---|---|---|
model |
state_dict() |
模型的所有权重参数 | 恢复模型训练到当前状态 |
optimizer |
state_dict() |
优化器的状态(动量、学习率等) | 保持优化器的历史状态,避免从头开始优化 |
model_args |
dict |
模型结构参数(层数、头数、维度等) | 确保恢复时模型结构与训练时一致 |
iter_num |
int |
当前迭代次数 | 确定从哪个迭代继续,用于学习率调度 |
epoch_num |
int |
当前epoch数 | 显示训练进度,可通过 iter_num 计算得出 |
best_val_loss |
float |
历史最佳验证损失 | 判断是否保存新的最佳模型 |
config |
dict |
完整的训练配置 | 记录训练参数,便于复现和追踪 |
2.3 数据解析
1. 模型权重 (model.state_dict())
python
# 保存
checkpoint['model'] = raw_model.state_dict()
# 恢复
state_dict = checkpoint['model']
model.load_state_dict(state_dict)
作用:包含模型所有层的权重和偏置参数,是模型学到的知识的核心。
2. 优化器状态 (optimizer.state_dict())
python
# 保存
checkpoint['optimizer'] = optimizer.state_dict()
# 恢复
optimizer.load_state_dict(checkpoint['optimizer'])
作用 :包含优化器的内部状态,如:
AdamW 的动量参数 (exp_avg, exp_avg_sq)
当前学习率
权重衰减状态
如果不保存优化器状态,恢复后优化器会从头开始,之前的动量信息会丢失,可能导致训练不稳定。
3. 模型配置 (model_args)
python
model_args = dict(
n_layer=n_layer, # Transformer层数
n_head=n_head, # 注意力头数
n_embd=n_embd, # 嵌入维度
block_size=block_size,
bias=bias,
vocab_size=vocab_size,
dropout=dropout
)
作用:确保恢复时创建的模型结构与训练时完全一致。如果模型结构不一致,无法加载权重。
4. 迭代次数 (iter_num)
python
# 保存
checkpoint['iter_num'] = iter_num
# 恢复
iter_num = checkpoint['iter_num']
作用 :
确定学习率(学习率调度依赖迭代次数)
确定评估时机(iter_num % eval_interval == 0)
确定训练终止条件(iter_num > max_iters)
5. Epoch 数 (epoch_num)
python
# 保存
checkpoint['epoch_num'] = epoch_num
# 恢复
epoch_num = checkpoint.get('epoch_num', 0)
# 实时计算(双重保障)
epoch_num = iter_num // iters_per_epoch
作用 :
显示训练进度(epoch X | iter Y)
用于保存 epoch checkpoint(如 ckpt_epoch0_end.pt)
双重保障 :epoch_num 既被显式保存,也可以通过 iter_num // iters_per_epoch 重新计算得出。
三、断点续训的实现流程
3.1 保存 Checkpoint
在 train.py 中,checkpoint 在以下时机保存:
- 每次评估时:
python
if iter_num % eval_interval == 0 and master_process:
# 保存最新checkpoint
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
# 保存带编号的checkpoint
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_epoch{epoch_num}_iter{iter_num}.pt'))
- 每个 epoch 结束时:
python
if save_every_epoch and master_process:
next_epoch_num = (iter_num + 1) // iters_per_epoch
if next_epoch_num > epoch_num:
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_epoch{epoch_num}_end.pt'))
3.2 恢复训练
在 train.py 中,恢复训练通过设置 init_from='resume' 触发:
python
if init_from == 'resume':
print(f"从 {out_dir} 恢复训练")
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
# 恢复模型配置
checkpoint_model_args = checkpoint['model_args']
for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
model_args[k] = checkpoint_model_args[k]
# 创建模型并加载权重
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
state_dict = checkpoint['model']
model.load_state_dict(state_dict)
# 恢复训练状态
iter_num = checkpoint['iter_num']
best_val_loss = checkpoint['best_val_loss']
epoch_num = checkpoint.get('epoch_num', 0)
# 恢复优化器状态
optimizer.load_state_dict(checkpoint['optimizer'])
四、实际操作步骤
4.1 开始训练
bash
# 使用配置文件训练
torchrun --standalone --nproc_per_node=2 train.py config/train_wudaocorpus.py
4.2 中断训练
训练过程中,按 Ctrl+C 中断或等待程序自动保存 checkpoint。
4.3 修改配置文件
修改 config/train_wudaocorpus.py:
python
# 将 init_from 改为 'resume'
init_from = 'resume'
# 如果需要继续训练更多epoch,可以增加epochs
epochs = 5 # 从原来的3个epoch增加到5个
4.4 恢复训练
bash
# 重新运行训练命令
torchrun --standalone --nproc_per_node=2 train.py config/train_wudaocorpus.py
五、Checkpoint 文件说明
5.1 文件类型
| 文件名 | 说明 | 使用场景 |
|---|---|---|
ckpt.pt |
最新的checkpoint | 恢复训练(自动选择) |
ckpt_epoch{num}_iter{num}.pt |
每次评估时的checkpoint | 推理时选择特定迭代的模型 |
ckpt_epoch{num}_end.pt |
每个epoch结束时的checkpoint | 推理时选择特定epoch的模型 |
5.2 查看可用的 Checkpoint
bash
ls out-wudaocorpus/*.pt
输出示例:
ckpt.pt ckpt_epoch0_iter250.pt ckpt_epoch0_iter500.pt
ckpt_epoch0_end.pt ckpt_epoch1_iter14250.pt ckpt_epoch1_end.pt
5.3 使用特定 Checkpoint 进行推理
bash
python sample.py --out_dir=out-wudaocorpus --dataset=wudaocorpus --checkpoint_path=ckpt_epoch1_end.pt --start "深度学习"
参数分解:
| 参数 | 值 | 作用 |
|---|---|---|
--out_dir |
out-wudaocorpus |
指定模型目录 |
--dataset |
wudaocorpus |
指定数据集(用于加载正确的 tokenizer) |
--checkpoint_path |
ckpt_epoch1_end.pt |
指定要使用的 checkpoint 文件 |
--start |
"深度学习" |
起始文本 |
dataset 参数用于加载正确的数据集元数据,主要包括:
- Tokenizer(分词器):决定如何将文本转换为 token
- 词汇表映射:stoi(字符到索引)、itos(索引到字符)
- Tokenizer 类型:区分 ChatGLM SPTokenizer 和字符级编码
查找 meta.pkl 文件
python
# sample.py 第 85-90 行
candidate_path = os.path.join('data', dataset_name, 'meta.pkl')
if os.path.exists(candidate_path):
meta_path = candidate_path
load_meta = True
查找逻辑:
dataset='wudaocorpus'
↓
candidate_path = 'data/wudaocorpus/meta.pkl'
↓
如果存在 → 加载 meta.pkl
如果不存在 → 使用默认 GPT-2 编码
六、注意的事情
6.1 模型结构一致性
恢复训练时,以下模型参数必须与训练时一致:
n_layer - Transformer层数
n_head - 注意力头数
n_embd - 嵌入维度
block_size - 上下文窗口大小
vocab_size - 词汇表大小
bias - 是否使用偏置
这些参数会自动从 checkpoint 的 model_args 中恢复。
6.2 学习率调度
学习率调度依赖于 iter_num,恢复训练后学习率会从正确的位置继续衰减:
python
def get_lr(it):
# 预热阶段
if it < warmup_iters:
return learning_rate * (it + 1) / (warmup_iters + 1)
# 衰减阶段
if it > lr_decay_iters:
return min_lr
# 余弦衰减
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (learning_rate - min_lr)
6.3 数据一致性
恢复训练时,数据加载是随机的(每次从数据中随机采样),这是正常的行为:
python
def get_batch(split):
data = np.memmap(os.path.join(data_dir, f'{split}.bin'), dtype=np.uint16, mode='r')
ix = torch.randint(len(data) - block_size, (batch_size,)) # 随机采样
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
return x, y
6.4 多 GPU 训练注意事项
使用 DDP 训练时,checkpoint 只在主进程(ddp_rank == 0)保存:
python
if iter_num % eval_interval == 0 and master_process: # master_process = (ddp_rank == 0)
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))