deepspeed zero3 + llamafactory 保存checkpoint后第一step 就 OOM

deepspeed zero3 + llamafactory 保存checkpoint后第一step 就 OOM

4张16g显卡 训练14b模型

bash 复制代码
{
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "zero_allow_untested_optimizer": true,
  "fp16": {
    "enabled": "auto",
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 16,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "bf16": {
    "enabled": "auto"
  },
  "zero_optimization": {
    "stage": 3,
    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": 1e8,
    "stage3_prefetch_bucket_size": 1e8,
    "stage3_param_persistence_threshold": 1e6,
    "stage3_max_live_parameters": 3e8,
    "stage3_max_reuse_distance": 3e8,
    "stage3_gather_16bit_weights_on_model_save": false
  }
}
bash 复制代码
# ============================================
# 模型与路径配置
# ============================================
model_name_or_path: /public/home/a15657581978/merged_correct_2epoch

# 对话模板格式:alpaca格式
# 其他可选:llama3, chatml, qwen, vicuna等
template: alpaca

# ============================================
# 数据集配置
# ============================================
dataset_dir: /public/home/a15657581978/data-new/binding_sft_real_paired_512/short_tokens

dataset: binding_part_000

cutoff_len: 1024

# ================= 数据加载优化 =================
streaming: false

preprocessing_num_workers: 4

overwrite_cache: false

# ================= 内存防爆核心 =================
dataloader_num_workers: 2

dataloader_pin_memory: false

dataloader_prefetch_factor: 2

# ============================================
# 保存 checkpoint 时不保存 global_step
# 防止内存溢出 oom
# ============================================
save_only_model: true

# ============================================
# 训练阶段与微调方法
# ============================================
stage: sft

do_train: true

finetuning_type: lora

lora_rank: 64

lora_alpha: 128

lora_dropout: 0.05

lora_target: q_proj,v_proj,k_proj,o_proj,up_proj,gate_proj,down_proj

# ============================================
# 海光 DCU 核心防崩溃配置(关键!)
# ============================================
# additional_target: embed_tokens,lm_head

flash_attn: disabled
# attn_implementation: sdpa
ddp_backend: nccl

pure_bf16: false
bf16: true
fp16: false

low_cpu_mem_usage: true

gradient_checkpointing: true

use_cache: false

ddp_find_unused_parameters: false

# ============================================
# 批次与训练参数
# ============================================
per_device_train_batch_size: 1

per_device_eval_batch_size: 1

gradient_accumulation_steps: 16

learning_rate: 0.00005

num_train_epochs: -1
max_steps: 18750

warmup_ratio: 0.1

lr_scheduler_type: cosine

# ============================================
# 输出与日志
# ============================================
output_dir: /public/home/a15657581978/output/binding_part_000_512

logging_steps: 10

save_steps: 100

save_total_limit: 2

max_samples: 1400000

# ============================================
# 分布式与监控
# ============================================
ddp_timeout: 180000

report_to: ["tensorboard"]

logging_dir: /public/home/a15657581978/tblogs

deepspeed: /public/home/a15657581978/training_configs_512/deepspeed_z3_offload_4x16g.json

关键发现:OOM 不是在"保存中",而是在"保存后继续训练的第 101 步"

看你之前的日志时间线:

makefile 复制代码
08:17:53  ✅ Checkpoint global_step100 is ready now!  (保存成功完成)
08:17:57  💥 Rank 1 OOM in compute_loss → forward    (4秒后,第101步前向时崩)

保存本身是成功的! OOM 发生在保存完成之后、训练恢复的那个 forward 里。

那为什么第 1~100 步都不 OOM,偏偏第 101 步 OOM?

因为保存过程"搞脏了"GPU 显存。具体来说有 3 个机制:

机制 1:ZeRO-3 的参数缓存被清空了(最关键)

ZeRO-3 的参数不是常驻 GPU 的,而是用"缓存池"管理:

arduino 复制代码
正常训练(step 1~100的稳态):
┌─────────────────────────────┐
│ GPU 显存                     │
│ ┌─────────┐ ┌─────────┐    │
│ │当前层参数│ │下一层参数│ ← 缓存预热好了 │
│ └─────────┘ └─────────┘    │
│ 剩余 ~1.5GiB 给临时张量      │  ← softmax 36MiB 够用
└─────────────────────────────┘
保存 checkpoint 时:
  DeepSpeed 需要把参数从 GPU 拷到 CPU 再写磁盘
  → 参数缓存池被腾空/打乱
保存完后恢复训练(step 101):
┌─────────────────────────────┐
│ GPU 显存                     │
│ ┌───────┐ ┌───────┐ ┌────┐ │
│ │碎片参数│ │碎片参数│ │碎片│ │  ← 缓存是冷的 │
│ └───────┘ └───────┘ └────┘ │
│ 14.81GiB 已分配, 571MiB 碎片  │  ← softmax 36MiB 分配不出来!
│ 0 bytes 真正空闲              │
└─────────────────────────────┘

你的报错信息也印证了这一点:

14.81 GiB is allocated by PyTorch, and 571.62 MiB is reserved by PyTorch but unallocated 571MiB 的显存被 PyTorch 占着但没法用,就是碎片。

机制 2:保存时的序列化过程留下临时张量

DeepSpeed 保存时,会做这些事:

  1. 把 GPU 上的参数张量 .cpu() 拷到内存
  2. torch.save() 序列化写磁盘
  3. 释放临时张量 但在步骤 1→2→3 之间,GPU 和 CPU 之间有大量数据搬运,PyTorch 的内存分配器会留下碎片和未释放的临时缓冲区。在 NVIDIA CUDA 上这通常能自动回收,但在海光 DCU(HIP)上,内存回收机制可能没那么及时。

机制 3:保存后 ZeRO-3 AllGather 的"冷启动峰值"

第 101 步的 forward 需要通过 AllGather 从其他卡收集当前层的完整参数。在稳态训练时,这个 AllGather 是和上一层的计算重叠的,峰值显存可控。但保存后第一步是"冷启动":

arduino 复制代码
稳态训练时(step 99→100):
  层N计算的同时,层N+1的参数已经在后台AllGather好了
  → 峰值显存 = 计算层N + 缓存层N+1 ≈ 可控
冷启动时(step 100→101,保存后):
  缓存是空的,层1需要现场AllGather
  → 峰值显存 = AllGather临时缓冲区 + 层1参数 + 优化器碎片
  → 超了!

根本原因总结

复制代码
你的 4×16G 卡跑 Qwen2 + ZeRO-3 + LoRA:
  正常训练时显存占用 ≈ 14.5 GiB / 16 GiB(只剩 1.5GiB 余量)
  → 1.5GiB 刚好够 softmax 等临时操作
  
保存 checkpoint 后:
  内存碎片 + 缓存清空 + 冷启动峰值
  → 实际可用连续空间 < 36 MiB
  → softmax 分配 36 MiB 失败 → OOM!

不是保存本身吃显存,而是保存"搅乱"了显存布局,导致恢复训练时那 1.5GiB 的救命余量没了。

怎么解决?

第 1 步:开 expandable_segments(解决碎片问题,最可能一招见效)

bash 复制代码
export PYTORCH_HIP_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128

这会让 PyTorch 的显存分配器支持动态扩展内存段,571MiB 的碎片就能被重新利用,36MiB 的 softmax 就能分配成功。

第 2 步:减小 DeepSpeed 通信桶(降低峰值显存)

json 复制代码
{
  "zero_optimization": {
    "stage": 3,
    "overlap_comm": false,
    "contiguous_gradients": true,
    "sub_group_size": 1000000000,
    "reduce_bucket_size": 200000000,
    "stage3_prefetch_bucket_size": 20000000,
    "stage3_param_persistence_threshold": 1000000,
    "stage3_max_live_parameters": 500000000,
    "stage3_max_reuse_distance": 500000000,
    "stage3_gather_16bit_weights_on_model_save": false
  }
}

第 3 步:降低序列长度(如果业务允许)

yaml 复制代码
cutoff_len: 768   # 从 1024 降到 768,直接砍掉 ~25% 的激活显存

第 4 步:拉大保存间隔(减少被"搅乱"的频率)

yaml 复制代码
save_steps: 500   # 从 100 改成 500,减少保存次数

第 5 步(终极方案):加 CPU Offload

如果上面都不够,让优化器状态住 CPU,彻底释放 GPU 显存:

json 复制代码
{
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "offload_param": {
      "device": "cpu",
      "pin_memory": true
    },
    "overlap_comm": false,
    "contiguous_gradients": true,
    "sub_group_size": 1000000000,
    "reduce_bucket_size": 200000000,
    "stage3_prefetch_bucket_size": 20000000,
    "stage3_param_persistence_threshold": 1000000,
    "stage3_max_live_parameters": 500000000,
    "stage3_max_reuse_distance": 500000000,
    "stage3_gather_16bit_weights_on_model_save": false
  }
}

CPU Offload 会让训练慢 20%40%,但能省出 35GiB GPU 显存,OOM 问题基本就不会再出现了。

完整启动脚本

bash 复制代码
#!/bin/bash
pkill -9 python; pkill -9 torchrun; sleep 3
# ====== 显存救命环境变量 ======
export PYTORCH_HIP_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128
# ====== 指定4张卡 ======
export ROCR_VISIBLE_DEVICES=0,1,2,3
# ====== 超时时间拉长 ======
export NCCL_TIMEOUT=1800
# ====== 启动训练 ======
FORCE_TORCHRUN=1 llamafactory-cli train \
  "/public/home/a15657581978/training_configs_512/binding_paired_000.yaml"

先只加 PYTORCH_HIP_ALLOC_CONF 这一个变量试试,大概率就能从"每次保存后 OOM"变成"能稳定跑完"。 如果还不行,再加上第 2~5 步的修改。

相关推荐
长大19881 小时前
ggplot2 高阶美化:SCI 期刊级论文图表从零绘制全流程
后端
墩墩大魔王丶2 小时前
macOS Rust 安装教程:自定义 CARGO_HOME 和 RUSTUP_HOME
后端
进阶的小名3 小时前
Spring Boot SSE + Nginx 配置:解决 EventSource 不实时返回、连接超时、流式响应被缓冲问题
spring boot·后端·nginx
PinkSun4 小时前
Spring AI RAG踩坑:我骂了半年的FilterExpression,其实是背锅侠
后端·ai编程
我登哥MVP4 小时前
SpringCloud Alibaba 核心组件解析:服务链路追踪
java·spring boot·后端·spring·spring cloud·java-ee·maven
by————组态4 小时前
Ricon组态系统 - 新一代Web可视化组态平台
前端·后端·物联网·架构·组态·组态软件
云技纵横4 小时前
ThreadLocal 内存泄漏:你的应用正在悄悄 OOM
后端
小撒的私房菜4 小时前
Multi-Agent 里谁来指挥?我用一个调度员,让多个 Agent 开始协作
人工智能·后端·agent
范什么特西4 小时前
Spring boot细节
java·spring boot·后端