deepspeed zero3 结合 llamafactory 微调 ,save_only_model: true 导致保存时候出错
bash
# ============================================
# 保存 checkpoint 时不保存 global_step
# 防止内存溢出 oom
# ============================================
save_only_model: true 这样才是能少保存啊,还有前面说的让cpu来保存 是什么意思
bash
rank1]: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) 08:17:57 [76/1838]
[rank1]: File "/root/LLaMA-Factory/src/llamafactory/train/sft/trainer.py", line 103, in compute_loss
[rank1]: return super().compute_loss(model, inputs, *args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3801, in compute_loss
[rank1]: outputs = model(**inputs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank1]: ret_val = func(*args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1855, in forward
[rank1]: loss = self.module(*inputs, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
[rank1]: result = forward_call(*args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/peft/peft_model.py", line 1757, in forward
[rank1]: return self.base_model(
[rank1]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
[rank1]: result = forward_call(*args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 193, in forward
[rank1]: return self.model.forward(*args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py", line 965, in wrapper
[rank1]: output = func(self, *args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank1]: return func(*args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 823, in forward
[rank1]: outputs: BaseModelOutputWithPast = self.model(
[rank1]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
[rank1]: result = forward_call(*args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py", line 965, in wrapper
[rank1]: output = func(self, *args, **kwargs) [rank1]: File "/opt/conda/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 537, in forward [rank1]: layer_outputs = self._gradient_checkpointing_func(
[rank1]: File "/root/LLaMA-Factory/src/llamafactory/model/model_utils/checkpointing.py", line 97, in custom_gradient_checkpointing_func
[sft-0] 0:[tmux]* "worker-0" 10:37 16-Jun-26W0616 08:18:02.558000 140079632246592 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 584 closing signal SIGTERM
W0616 08:18:02.559000 140079632246592 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 586 closing signal SIGTERM
W0616 08:18:02.560000 140079632246592 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 587 closing signal SIGTERM
E0616 08:18:02.724000 140079632246592 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 1 (pid: 585) of binary: /opt/con
da/bin/python3.10
Traceback (most recent call last):
File "/opt/conda/bin/torchrun", line 8, in <module>
sys.exit(main())
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper
return f(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py", line 901, in main
run(args)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py", line 892, in run
elastic_launch(
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
/root/LLaMA-Factory/src/llamafactory/launcher.py FAILED
------------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2026-06-16_08:18:02
host : worker-0
rank : 1 (local_rank: 1)
exitcode : 1 (pid: 585)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[sft-0] 0:[tmux]*
这是一个非常好的问题!很多刚接触 DeepSpeed ZeRO-3 的同学都会在这里踩坑。
你之前的直觉是对的:在普通单卡训练或 ZeRO-2 中,save_only_model: true 确实能少保存东西,防止内存溢出;但是,在 ZeRO-3 中,这个选项是致命的!
我来详细解释一下为什么,以及"用 CPU 保存"到底是什么意思。
一、为什么 ZeRO-3 绝不能设 save_only_model: true?
核心原因:在 ZeRO-3 中,优化器状态里藏着模型参数的"原件",你把优化器丢了,模型就残缺了。
- ZeRO-3 的切片机制 :
在 ZeRO-3 下,模型的参数被切成了 4 份,分布在 4 张卡上。更关键的是,为了省显存,模型参数的 fp32 主副本(即使你用 bf16 训练,优化器里也会存一份 fp32 用于更新)是存在优化器状态里的! - 保存 Checkpoint 时发生了什么 :
- 如果
save_only_model: false(正常保存):DeepSpeed 会把模型分片和优化器分片一起保存。下次恢复训练时,DeepSpeed 能从优化器里把 fp32 主副本找回来,重新拼好。 - 如果
save_only_model: true(强制只保存模型):HuggingFace Trainer 会跳过优化器的保存。这就导致 DeepSpeed 只存了残缺的 bf16 模型分片,丢掉了优化器里的 fp32 主副本和分片映射关系。
- 如果
- 后果 :
保存时,DeepSpeed 发现需要的数据对不上,就会发生死锁 或者通信超时------这就是为什么你的训练在 100 步保存时卡死然后报错退出!
总结 :在 ZeRO-3 下,必须老老实实存完整 checkpoint(包含优化器)。虽然占硬盘,但这是 ZeRO-3 续训的唯一方式。如果你觉得存得太占硬盘,可以把
save_steps调大一点(比如 500 步存一次),或者save_total_limit: 1(只保留最新 1 个 checkpoint)。
二、"让 CPU 来保存/导出"是什么意思?
这涉及到 ZeRO-3 如何把切碎的模型"拼"回来的问题。
1. 为什么不能在 GPU 上拼?
假设你的 Qwen2 模型总大小是 14GB。在 ZeRO-3 下,每张 16G 的显卡只存了 3.5GB 的碎片。
如果你想保存一个完整的模型文件(比如 pytorch_model.bin),系统需要把这 4 份 3.5GB 的碎片收集到一张卡上拼接。
问题来了 :拼接需要 14GB 显存,而你的显卡在训练时已经占了 13GB,根本放不下这 14GB!一旦强行在 GPU 上拼,瞬间 OOM(显存溢出)崩溃 。
2. CPU 导出的原理 :
既然 GPU 显存(16GB)太小,拼不下完整的模型,我们就借用系统内存(通常有 128GB 甚至 256GB)来拼。
- Step 1:把 4 张卡上的模型碎片,依次复制到电脑的 CPU 内存中。
- Step 2:在 CPU 内存中,把这 4 个碎片拼成一个完整的 14GB 模型。
- Step 3 :把拼好的完整模型保存到硬盘上。
这就是--export_device cpu的意思:在 CPU 内存中合并模型,避免 GPU 显存爆炸。
三、你现在应该怎么改?
结合你的情况(4张 16G 卡,100步保存崩溃),你需要做以下两步:
1. 修改 YAML 配置(解决保存死锁和超时)
yaml
# 必须改成 false!ZeRO-3 需要保存优化器才能维持模型完整
save_only_model: false
# 极大增加超时时间,防止硬盘写入慢导致多卡失联
ddp_timeout: 18000000
# 只保留最新的 1-2 个 checkpoint,防止硬盘被撑爆
save_total_limit: 2
2. 训练正常结束后,如何提取最终模型?
因为 save_only_model: false 加上 ZeRO-3,你保存在文件夹里的文件是碎的(有很多 zero_pp_rank_0... 这样名字的文件),不能直接用 transformers 加载。
你需要用 LLaMA-Factory 的 导出功能,利用 CPU 把碎片拼起来,并和你的 LoRA 合并:
bash
llamafactory-cli export \
--model_name_or_path /public/home/a15657581978/merged_correct_2epoch \
--adapter_name_or_path /public/home/a15657581978/output/binding_part_000_512 \
--template alpaca \
--finetuning_type lora \
--export_dir /public/home/a15657581978/exported_final_model \
--export_device cpu \ # 关键!用 CPU 内存拼模型,防止 GPU OOM
--export_legacy_format false
执行完这个命令,exported_final_model 文件夹里就会出现完整的、可以直接用的 HuggingFace 格式模型了。
总结一下 :训练时,让 ZeRO-3 存完整的碎片(save_only_model: false);训练后,用 CPU 把碎片拼成完整的模型(export_device cpu)。这是 4x16G 显存跑大模型的标准流程。