Training - PyTorch Lightning 分布式训练的 global_step 参数 (accumulate_grad_batches)

欢迎关注我的CSDN:https://spike.blog.csdn.net/

本文地址:https://blog.csdn.net/caroline_wendy/article/details/137640653

在 PyTorch Lightning 中,pl.Traineraccumulate_grad_batches 参数允许在执行反向传播和优化器步骤之前,累积多个批次的梯度。这样,可以增加有效的批次大小,而不会增加内存开销。例如,如果设置 accumulate_grad_batches=8,则会在执行优化器的 .step() 方法之前,累积 8 个批次的梯度。

accumulate_grad_batchesglobal_step 的关系:

  1. global_step 会在每次调用优化器的 .step() 方法后递增。
  2. 使用梯度累积,global_step 增长小于 批次(batch) 的数量
  3. 多个批次贡献到 1 个 global_step 的更新中。

例如,如果 accumulate_grad_batches=8,那么每 8 个批次,只会增加 1 次 global_step,如果多卡,则 global_step 表示单卡的次数。日志,如下:

bash 复制代码
[INFO] [CL] global_step: 0, iter_step: 8
[INFO] [CL] global_step: 1, iter_step: 16

其中 pl.Trainer 的源码:

bash 复制代码
    trainer = pl.Trainer(
        accelerator="gpu",
        # ...
        accumulate_grad_batches=args.accumulate_grad,
        strategy=strategy,  # 多机多卡配置
        num_nodes=args.num_nodes,  # 节点数
        devices=1,  # 每个节点 GPU 卡数
    )

输出日志:

bash 复制代码
log = {'epoch': self.trainer.current_epoch, 'step': self.trainer.global_step}
wandb.log(log)
相关推荐
迪娜学姐4 分钟前
GenSpark vs Manus实测对比:文献综述与学术PPT,哪家强?
论文阅读·人工智能·prompt·powerpoint·论文笔记
TDengine (老段)6 分钟前
TDengine 在电力行业如何使用 AI ?
大数据·数据库·人工智能·时序数据库·tdengine·涛思数据
猎板PCB厚铜专家大族8 分钟前
高频 PCB 技术发展趋势与应用解析
人工智能·算法·设计规范
l0sgAi12 分钟前
SpringBoot整合LangChain4j实现RAG (检索增强生成)
人工智能
祐言QAQ13 分钟前
浅谈边缘计算
人工智能·边缘计算
lboyj13 分钟前
高频通信与航天电子的材料革命:猎板PCB高端压合基材技术解析
人工智能
奔跑吧邓邓子31 分钟前
DeepSeek 赋能智能教育知识图谱:从构建到应用的革命性突破
人工智能·知识图谱·应用·deepseek·智能教育
Mantanmu34 分钟前
Python训练day40
人工智能·python·机器学习
ss.li40 分钟前
TripGenie:畅游济南旅行规划助手:个人工作纪实(二十二)
javascript·人工智能·python
小天才才1 小时前
前沿论文汇总(机器学习/深度学习/大模型/搜广推/自然语言处理)
人工智能·深度学习·机器学习·自然语言处理