deepspeed zero3 结合 llamafactory 微调 ,save_only_model: true 导致保存时候出错

deepspeed zero3 结合 llamafactory 微调 ,save_only_model: true 导致保存时候出错

bash 复制代码
# ============================================
# 保存 checkpoint 时不保存 global_step
# 防止内存溢出 oom
# ============================================
save_only_model: true  这样才是能少保存啊,还有前面说的让cpu来保存   是什么意思  
bash 复制代码
rank1]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)                                                08:17:57 [76/1838]
[rank1]:   File "/root/LLaMA-Factory/src/llamafactory/train/sft/trainer.py", line 103, in compute_loss                                                       
[rank1]:     return super().compute_loss(model, inputs, *args, **kwargs)                                                                                     
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3801, in compute_loss                                                
[rank1]:     outputs = model(**inputs)                                                                                                                       
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl                                       
[rank1]:     return self._call_impl(*args, **kwargs)                                                                                                         
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl                                               
[rank1]:     return forward_call(*args, **kwargs)                                                                                                            
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn                                                    
[rank1]:     ret_val = func(*args, **kwargs)                                                                                                                 
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1855, in forward                                                 
[rank1]:     loss = self.module(*inputs, **kwargs)                            
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl                                       
[rank1]:     return self._call_impl(*args, **kwargs)                          
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl                                               
[rank1]:     result = forward_call(*args, **kwargs)                           
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/peft/peft_model.py", line 1757, in forward                                                          
[rank1]:     return self.base_model(                                          
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl                                       
[rank1]:     return self._call_impl(*args, **kwargs)                          
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl                                               
[rank1]:     result = forward_call(*args, **kwargs)                           
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 193, in forward                                                  
[rank1]:     return self.model.forward(*args, **kwargs)                       
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py", line 965, in wrapper                                                
[rank1]:     output = func(self, *args, **kwargs)                                                                                                            
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func                                       
[rank1]:     return func(*args, **kwargs)                                                                                                                    
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 823, in forward                                  
[rank1]:     outputs: BaseModelOutputWithPast = self.model(                                                                                                  
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl                                       
[rank1]:     return self._call_impl(*args, **kwargs)                                                                                                         
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl                                               
[rank1]:     result = forward_call(*args, **kwargs)                                                                                                          
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py", line 965, in wrapper                                                
[rank1]:     output = func(self, *args, **kwargs)                                                                                                            [rank1]:   File "/opt/conda/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 537, in forward                                  [rank1]:     layer_outputs = self._gradient_checkpointing_func(                                                                                              
[rank1]:   File "/root/LLaMA-Factory/src/llamafactory/model/model_utils/checkpointing.py", line 97, in custom_gradient_checkpointing_func                    
[sft-0] 0:[tmux]*                                                                                                                  "worker-0" 10:37 16-Jun-26W0616 08:18:02.558000 140079632246592 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 584 closing signal SIGTERM                       
W0616 08:18:02.559000 140079632246592 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 586 closing signal SIGTERM                       
W0616 08:18:02.560000 140079632246592 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 587 closing signal SIGTERM                       
E0616 08:18:02.724000 140079632246592 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 1 (pid: 585) of binary: /opt/con
da/bin/python3.10                                                                                                                                            
Traceback (most recent call last):                                                                                                                           
  File "/opt/conda/bin/torchrun", line 8, in <module>                                                                                                        
    sys.exit(main())                   
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper                          
    return f(*args, **kwargs)          
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py", line 901, in main                                                                 
    run(args)                          
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py", line 892, in run                                                                  
    elastic_launch(                    
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 133, in __call__                                                    
    return launch_agent(self._config, self._entrypoint, list(args))           
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent                                                
    raise ChildFailedError(            
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:            
============================================================                  
/root/LLaMA-Factory/src/llamafactory/launcher.py FAILED                       
------------------------------------------------------------                  
Failures:                              
  <NO_OTHER_FAILURES>                  
------------------------------------------------------------                  
Root Cause (first observed failure):                                          
[0]:                                   
  time      : 2026-06-16_08:18:02      
  host      : worker-0                 
  rank      : 1 (local_rank: 1)        
  exitcode  : 1 (pid: 585)             
  error_file: <N/A>                    
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html                                                                   
[sft-0] 0:[tmux]*

这是一个非常好的问题!很多刚接触 DeepSpeed ZeRO-3 的同学都会在这里踩坑。

你之前的直觉是对的:在普通单卡训练或 ZeRO-2 中,save_only_model: true 确实能少保存东西,防止内存溢出;但是,在 ZeRO-3 中,这个选项是致命的!
我来详细解释一下为什么,以及"用 CPU 保存"到底是什么意思。

一、为什么 ZeRO-3 绝不能设 save_only_model: true

核心原因:在 ZeRO-3 中,优化器状态里藏着模型参数的"原件",你把优化器丢了,模型就残缺了。

  1. ZeRO-3 的切片机制
    在 ZeRO-3 下,模型的参数被切成了 4 份,分布在 4 张卡上。更关键的是,为了省显存,模型参数的 fp32 主副本(即使你用 bf16 训练,优化器里也会存一份 fp32 用于更新)是存在优化器状态里的!
  2. 保存 Checkpoint 时发生了什么
    • 如果 save_only_model: false(正常保存):DeepSpeed 会把模型分片和优化器分片一起保存。下次恢复训练时,DeepSpeed 能从优化器里把 fp32 主副本找回来,重新拼好。
    • 如果 save_only_model: true(强制只保存模型):HuggingFace Trainer 会跳过优化器的保存。这就导致 DeepSpeed 只存了残缺的 bf16 模型分片,丢掉了优化器里的 fp32 主副本和分片映射关系。
  3. 后果
    保存时,DeepSpeed 发现需要的数据对不上,就会发生死锁 或者通信超时------这就是为什么你的训练在 100 步保存时卡死然后报错退出!

总结 :在 ZeRO-3 下,必须老老实实存完整 checkpoint(包含优化器)。虽然占硬盘,但这是 ZeRO-3 续训的唯一方式。如果你觉得存得太占硬盘,可以把 save_steps 调大一点(比如 500 步存一次),或者 save_total_limit: 1(只保留最新 1 个 checkpoint)。


二、"让 CPU 来保存/导出"是什么意思?

这涉及到 ZeRO-3 如何把切碎的模型"拼"回来的问题。

1. 为什么不能在 GPU 上拼?

假设你的 Qwen2 模型总大小是 14GB。在 ZeRO-3 下,每张 16G 的显卡只存了 3.5GB 的碎片。

如果你想保存一个完整的模型文件(比如 pytorch_model.bin),系统需要把这 4 份 3.5GB 的碎片收集到一张卡上拼接。

问题来了 :拼接需要 14GB 显存,而你的显卡在训练时已经占了 13GB,根本放不下这 14GB!一旦强行在 GPU 上拼,瞬间 OOM(显存溢出)崩溃

2. CPU 导出的原理

既然 GPU 显存(16GB)太小,拼不下完整的模型,我们就借用系统内存(通常有 128GB 甚至 256GB)来拼。

  • Step 1:把 4 张卡上的模型碎片,依次复制到电脑的 CPU 内存中。
  • Step 2:在 CPU 内存中,把这 4 个碎片拼成一个完整的 14GB 模型。
  • Step 3 :把拼好的完整模型保存到硬盘上。
    这就是 --export_device cpu 的意思:在 CPU 内存中合并模型,避免 GPU 显存爆炸

三、你现在应该怎么改?

结合你的情况(4张 16G 卡,100步保存崩溃),你需要做以下两步:

1. 修改 YAML 配置(解决保存死锁和超时)
yaml 复制代码
# 必须改成 false!ZeRO-3 需要保存优化器才能维持模型完整
save_only_model: false
# 极大增加超时时间,防止硬盘写入慢导致多卡失联
ddp_timeout: 18000000
# 只保留最新的 1-2 个 checkpoint,防止硬盘被撑爆
save_total_limit: 2
2. 训练正常结束后,如何提取最终模型?

因为 save_only_model: false 加上 ZeRO-3,你保存在文件夹里的文件是碎的(有很多 zero_pp_rank_0... 这样名字的文件),不能直接用 transformers 加载。

你需要用 LLaMA-Factory 的 导出功能,利用 CPU 把碎片拼起来,并和你的 LoRA 合并:

bash 复制代码
llamafactory-cli export \
    --model_name_or_path /public/home/a15657581978/merged_correct_2epoch \
    --adapter_name_or_path /public/home/a15657581978/output/binding_part_000_512 \
    --template alpaca \
    --finetuning_type lora \
    --export_dir /public/home/a15657581978/exported_final_model \
    --export_device cpu \   # 关键!用 CPU 内存拼模型,防止 GPU OOM
    --export_legacy_format false

执行完这个命令,exported_final_model 文件夹里就会出现完整的、可以直接用的 HuggingFace 格式模型了。

总结一下 :训练时,让 ZeRO-3 存完整的碎片(save_only_model: false);训练后,用 CPU 把碎片拼成完整的模型(export_device cpu)。这是 4x16G 显存跑大模型的标准流程。

相关推荐
swordbob2 小时前
NIO的channel中什么是 fd(File Descriptor,文件描述符)
java·开发语言·nio
源分享2 小时前
Java线程同步的多种实现方法(非常详细)
java·开发语言·jvm
renhongxia12 小时前
世界模型作为AGI落地底层底座的作用
人工智能·深度学习·生成对抗网络·自然语言处理·知识图谱·agi
Luminous.2 小时前
C语言--day30
c语言·开发语言
计算机科研狗@OUC2 小时前
(cvpr26) AIMDepth: Asymmetric Image-Event Mamba for Monocular Depth Estimation
人工智能·深度学习·计算机视觉
码云骑士2 小时前
32-慢查询排查全流程(下)-索引优化实战与最左前缀原则
python
何以解忧,唯有..3 小时前
Go语言循环语句详解:for、range与循环控制
开发语言·算法·golang
謓泽3 小时前
C语言不是语法,是通往机器的地图。
c语言·开发语言
云水一下3 小时前
从零开始学 PHP 系列(一):PHP 的前世今生与开发环境搭建
开发语言·php