构建中文版的 nanoGPT - 断点续训(resume from checkpoint)

构建中文版的 nanoGPT - 断点续训(resume from checkpoint)

flyfish

参考网址

bash 复制代码
https://github.com/shaoshengsong/nanoGPT-cn

一、断点续训(resume from checkpoint)

断点续训(resume from checkpoint)是指在训练过程中中断后,能够从上次中断的位置继续训练,而不是从头开始。这对于大规模数据集训练尤为重要,可以:

避免重复训练,节省时间

应对意外中断(如断电、程序崩溃)

支持训练过程中的调整和优化

二、需要保存的数据和参数

2.1 Checkpoint 结构

train.py 中,每次保存 checkpoint 时会保存以下数据:

python 复制代码
checkpoint = {
    'model': raw_model.state_dict(),        # 模型权重参数
    'optimizer': optimizer.state_dict(),    # 优化器状态
    'model_args': model_args,               # 模型配置参数
    'iter_num': iter_num,                   # 当前迭代次数
    'epoch_num': epoch_num,                 # 当前epoch数
    'best_val_loss': best_val_loss,         # 最佳验证损失
    'config': config,                       # 完整训练配置
}

2.2 参数

参数 类型 作用 为什么需要保存
model state_dict() 模型的所有权重参数 恢复模型训练到当前状态
optimizer state_dict() 优化器的状态(动量、学习率等) 保持优化器的历史状态,避免从头开始优化
model_args dict 模型结构参数(层数、头数、维度等) 确保恢复时模型结构与训练时一致
iter_num int 当前迭代次数 确定从哪个迭代继续,用于学习率调度
epoch_num int 当前epoch数 显示训练进度,可通过 iter_num 计算得出
best_val_loss float 历史最佳验证损失 判断是否保存新的最佳模型
config dict 完整的训练配置 记录训练参数,便于复现和追踪

2.3 数据解析

1. 模型权重 (model.state_dict())
python 复制代码
# 保存
checkpoint['model'] = raw_model.state_dict()

# 恢复
state_dict = checkpoint['model']
model.load_state_dict(state_dict)

作用:包含模型所有层的权重和偏置参数,是模型学到的知识的核心。

2. 优化器状态 (optimizer.state_dict())
python 复制代码
# 保存
checkpoint['optimizer'] = optimizer.state_dict()

# 恢复
optimizer.load_state_dict(checkpoint['optimizer'])

作用 :包含优化器的内部状态,如:

AdamW 的动量参数 (exp_avg, exp_avg_sq)

当前学习率

权重衰减状态

如果不保存优化器状态,恢复后优化器会从头开始,之前的动量信息会丢失,可能导致训练不稳定。

3. 模型配置 (model_args)
python 复制代码
model_args = dict(
    n_layer=n_layer,    # Transformer层数
    n_head=n_head,      # 注意力头数
    n_embd=n_embd,      # 嵌入维度
    block_size=block_size,
    bias=bias,
    vocab_size=vocab_size,
    dropout=dropout
)

作用:确保恢复时创建的模型结构与训练时完全一致。如果模型结构不一致,无法加载权重。

4. 迭代次数 (iter_num)
python 复制代码
# 保存
checkpoint['iter_num'] = iter_num

# 恢复
iter_num = checkpoint['iter_num']

作用

确定学习率(学习率调度依赖迭代次数)

确定评估时机(iter_num % eval_interval == 0

确定训练终止条件(iter_num > max_iters

5. Epoch 数 (epoch_num)
python 复制代码
# 保存
checkpoint['epoch_num'] = epoch_num

# 恢复
epoch_num = checkpoint.get('epoch_num', 0)

# 实时计算(双重保障)
epoch_num = iter_num // iters_per_epoch

作用

显示训练进度(epoch X | iter Y

用于保存 epoch checkpoint(如 ckpt_epoch0_end.pt

双重保障 :epoch_num 既被显式保存,也可以通过 iter_num // iters_per_epoch 重新计算得出。

三、断点续训的实现流程

3.1 保存 Checkpoint

train.py 中,checkpoint 在以下时机保存:

  1. 每次评估时
python 复制代码
if iter_num % eval_interval == 0 and master_process:
    # 保存最新checkpoint
    torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
    # 保存带编号的checkpoint
    torch.save(checkpoint, os.path.join(out_dir, f'ckpt_epoch{epoch_num}_iter{iter_num}.pt'))
  1. 每个 epoch 结束时
python 复制代码
if save_every_epoch and master_process:
    next_epoch_num = (iter_num + 1) // iters_per_epoch
    if next_epoch_num > epoch_num:
        torch.save(checkpoint, os.path.join(out_dir, f'ckpt_epoch{epoch_num}_end.pt'))

3.2 恢复训练

train.py 中,恢复训练通过设置 init_from='resume' 触发:

python 复制代码
if init_from == 'resume':
    print(f"从 {out_dir} 恢复训练")
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    
    # 恢复模型配置
    checkpoint_model_args = checkpoint['model_args']
    for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = checkpoint_model_args[k]
    
    # 创建模型并加载权重
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    model.load_state_dict(state_dict)
    
    # 恢复训练状态
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']
    epoch_num = checkpoint.get('epoch_num', 0)
    
    # 恢复优化器状态
    optimizer.load_state_dict(checkpoint['optimizer'])

四、实际操作步骤

4.1 开始训练

bash 复制代码
# 使用配置文件训练
torchrun --standalone --nproc_per_node=2 train.py config/train_wudaocorpus.py

4.2 中断训练

训练过程中,按 Ctrl+C 中断或等待程序自动保存 checkpoint。

4.3 修改配置文件

修改 config/train_wudaocorpus.py

python 复制代码
# 将 init_from 改为 'resume'
init_from = 'resume'

# 如果需要继续训练更多epoch,可以增加epochs
epochs = 5  # 从原来的3个epoch增加到5个

4.4 恢复训练

bash 复制代码
# 重新运行训练命令
torchrun --standalone --nproc_per_node=2 train.py config/train_wudaocorpus.py

五、Checkpoint 文件说明

5.1 文件类型

文件名 说明 使用场景
ckpt.pt 最新的checkpoint 恢复训练(自动选择)
ckpt_epoch{num}_iter{num}.pt 每次评估时的checkpoint 推理时选择特定迭代的模型
ckpt_epoch{num}_end.pt 每个epoch结束时的checkpoint 推理时选择特定epoch的模型

5.2 查看可用的 Checkpoint

bash 复制代码
ls out-wudaocorpus/*.pt

输出示例:

复制代码
ckpt.pt                  ckpt_epoch0_iter250.pt   ckpt_epoch0_iter500.pt
ckpt_epoch0_end.pt       ckpt_epoch1_iter14250.pt ckpt_epoch1_end.pt

5.3 使用特定 Checkpoint 进行推理

bash 复制代码
python sample.py --out_dir=out-wudaocorpus --dataset=wudaocorpus --checkpoint_path=ckpt_epoch1_end.pt --start "深度学习"

参数分解

参数 作用
--out_dir out-wudaocorpus 指定模型目录
--dataset wudaocorpus 指定数据集(用于加载正确的 tokenizer)
--checkpoint_path ckpt_epoch1_end.pt 指定要使用的 checkpoint 文件
--start "深度学习" 起始文本

dataset 参数用于加载正确的数据集元数据,主要包括:

  1. Tokenizer(分词器):决定如何将文本转换为 token
  2. 词汇表映射:stoi(字符到索引)、itos(索引到字符)
  3. Tokenizer 类型:区分 ChatGLM SPTokenizer 和字符级编码

查找 meta.pkl 文件

python 复制代码
# sample.py 第 85-90 行
candidate_path = os.path.join('data', dataset_name, 'meta.pkl')
if os.path.exists(candidate_path):
    meta_path = candidate_path
    load_meta = True

查找逻辑

复制代码
dataset='wudaocorpus'
    ↓
candidate_path = 'data/wudaocorpus/meta.pkl'
    ↓
如果存在 → 加载 meta.pkl
如果不存在 → 使用默认 GPT-2 编码

六、注意的事情

6.1 模型结构一致性

恢复训练时,以下模型参数必须与训练时一致:

n_layer - Transformer层数

n_head - 注意力头数

n_embd - 嵌入维度

block_size - 上下文窗口大小

vocab_size - 词汇表大小

bias - 是否使用偏置

这些参数会自动从 checkpoint 的 model_args 中恢复。

6.2 学习率调度

学习率调度依赖于 iter_num,恢复训练后学习率会从正确的位置继续衰减:

python 复制代码
def get_lr(it):
    # 预热阶段
    if it < warmup_iters:
        return learning_rate * (it + 1) / (warmup_iters + 1)
    # 衰减阶段
    if it > lr_decay_iters:
        return min_lr
    # 余弦衰减
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (learning_rate - min_lr)

6.3 数据一致性

恢复训练时,数据加载是随机的(每次从数据中随机采样),这是正常的行为:

python 复制代码
def get_batch(split):
    data = np.memmap(os.path.join(data_dir, f'{split}.bin'), dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - block_size, (batch_size,))  # 随机采样
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    return x, y

6.4 多 GPU 训练注意事项

使用 DDP 训练时,checkpoint 只在主进程(ddp_rank == 0)保存:

python 复制代码
if iter_num % eval_interval == 0 and master_process:  # master_process = (ddp_rank == 0)
    torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
相关推荐
高洁0113 小时前
智能体如何改变工作流一、工作流的“痛点
人工智能·python·数据挖掘·transformer·知识图谱
weixin_4684668513 小时前
Mamba 架构新手入门与实战指南
人工智能·架构·transformer·ssm·注意力机制·mamba·状态空间方程
西西弗Sisyphus13 小时前
构建中文版的 nanoGPT - 中文版 nanoGPT 的分词(tokenization)
transformer·attention·注意力·self-attention·nanogpt
z小猫不吃鱼14 小时前
05 Transformer Decoder 详解:GPT 为什么使用 Decoder?
gpt·深度学习·transformer
z小猫不吃鱼15 小时前
06 Tokenizer 详解:BPE、WordPiece、SentencePiece 有什么区别?
人工智能·语言模型·自然语言处理·transformer
weixin_4684668515 小时前
Transformer 模型新手入门与实战指南
人工智能·python·深度学习·机器学习·transformer·热力图·注意力机制
君为先-bey1 天前
CogVideoX——Transformer从文本到视频的扩散模型
深度学习·音视频·transformer·扩散模型
这是谁的博客?1 天前
Mamba 状态空间模型深度解析:挑战 Transformer 的新一代架构
深度学习·ai·架构·transformer·ssm·mamba·状态空间模型
AndrewHZ1 天前
【大模型技术博客】什么是大语言模型(LLM)?从零认识AI新范式
人工智能·深度学习·ai·语言模型·大模型·llm·transformer