Training - PyTorch Lightning 分布式训练的 global_step 参数 (accumulate_grad_batches)

欢迎关注我的CSDN:https://spike.blog.csdn.net/

本文地址:https://blog.csdn.net/caroline_wendy/article/details/137640653

在 PyTorch Lightning 中,pl.Traineraccumulate_grad_batches 参数允许在执行反向传播和优化器步骤之前,累积多个批次的梯度。这样,可以增加有效的批次大小,而不会增加内存开销。例如,如果设置 accumulate_grad_batches=8,则会在执行优化器的 .step() 方法之前,累积 8 个批次的梯度。

accumulate_grad_batchesglobal_step 的关系:

  1. global_step 会在每次调用优化器的 .step() 方法后递增。
  2. 使用梯度累积,global_step 增长小于 批次(batch) 的数量
  3. 多个批次贡献到 1 个 global_step 的更新中。

例如,如果 accumulate_grad_batches=8,那么每 8 个批次,只会增加 1 次 global_step,如果多卡,则 global_step 表示单卡的次数。日志,如下:

bash 复制代码
[INFO] [CL] global_step: 0, iter_step: 8
[INFO] [CL] global_step: 1, iter_step: 16

其中 pl.Trainer 的源码:

bash 复制代码
    trainer = pl.Trainer(
        accelerator="gpu",
        # ...
        accumulate_grad_batches=args.accumulate_grad,
        strategy=strategy,  # 多机多卡配置
        num_nodes=args.num_nodes,  # 节点数
        devices=1,  # 每个节点 GPU 卡数
    )

输出日志:

bash 复制代码
log = {'epoch': self.trainer.current_epoch, 'step': self.trainer.global_step}
wandb.log(log)
相关推荐
YSGZJJ1 小时前
股指期货的套保策略如何精准选择和规避风险?
人工智能·区块链
无脑敲代码,bug漫天飞1 小时前
COR 损失函数
人工智能·机器学习
HPC_fac130520678162 小时前
以科学计算为切入点:剖析英伟达服务器过热难题
服务器·人工智能·深度学习·机器学习·计算机视觉·数据挖掘·gpu算力
小陈phd4 小时前
OpenCV从入门到精通实战(九)——基于dlib的疲劳监测 ear计算
人工智能·opencv·计算机视觉
Guofu_Liao5 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
ZHOU_WUYI9 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt
如若1239 小时前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
老艾的AI世界10 小时前
AI翻唱神器,一键用你喜欢的歌手翻唱他人的曲目(附下载链接)
人工智能·深度学习·神经网络·机器学习·ai·ai翻唱·ai唱歌·ai歌曲
DK2215110 小时前
机器学习系列----关联分析
人工智能·机器学习
Robot25110 小时前
Figure 02迎重大升级!!人形机器人独角兽[Figure AI]商业化加速
人工智能·机器人·微信公众平台