Training - PyTorch Lightning 分布式训练的 global_step 参数 (accumulate_grad_batches)

欢迎关注我的CSDN:https://spike.blog.csdn.net/

本文地址:https://blog.csdn.net/caroline_wendy/article/details/137640653

在 PyTorch Lightning 中,pl.Traineraccumulate_grad_batches 参数允许在执行反向传播和优化器步骤之前,累积多个批次的梯度。这样,可以增加有效的批次大小,而不会增加内存开销。例如,如果设置 accumulate_grad_batches=8,则会在执行优化器的 .step() 方法之前,累积 8 个批次的梯度。

accumulate_grad_batchesglobal_step 的关系:

  1. global_step 会在每次调用优化器的 .step() 方法后递增。
  2. 使用梯度累积,global_step 增长小于 批次(batch) 的数量
  3. 多个批次贡献到 1 个 global_step 的更新中。

例如,如果 accumulate_grad_batches=8,那么每 8 个批次,只会增加 1 次 global_step,如果多卡,则 global_step 表示单卡的次数。日志,如下:

bash 复制代码
[INFO] [CL] global_step: 0, iter_step: 8
[INFO] [CL] global_step: 1, iter_step: 16

其中 pl.Trainer 的源码:

bash 复制代码
    trainer = pl.Trainer(
        accelerator="gpu",
        # ...
        accumulate_grad_batches=args.accumulate_grad,
        strategy=strategy,  # 多机多卡配置
        num_nodes=args.num_nodes,  # 节点数
        devices=1,  # 每个节点 GPU 卡数
    )

输出日志:

bash 复制代码
log = {'epoch': self.trainer.current_epoch, 'step': self.trainer.global_step}
wandb.log(log)
相关推荐
武子康1 分钟前
AI研究-117 特斯拉 FSD 视觉解析:多摄像头 - 3D占用网络 - 车机渲染,盲区与低速复杂路况安全指南
人工智能·科技·计算机视觉·3d·视觉检测·特斯拉·model y
Geoking.8 分钟前
PyTorch torch.unique() 基础与实战
人工智能·pytorch·python
Fr2ed0m13 分钟前
卡尔曼滤波算法原理详解:核心公式、C 语言代码实现及电机控制 / 目标追踪应用
c语言·人工智能·算法
熊猫_豆豆27 分钟前
神经网络的科普,功能用途,包含的数学知识
人工智能·深度学习·神经网络
笨蛋不要掉眼泪38 分钟前
deepseek封装结合websocket实现与ai对话
人工智能·websocket·网络协议
hesorchen1 小时前
算力与数据驱动的 AI 技术演进全景(1999-2024):模型范式、Infra 数据、语言模型与多模态的关键突破
人工智能·语言模型·自然语言处理
你也渴望鸡哥的力量么1 小时前
基于边缘信息提取的遥感图像开放集飞机检测方法
人工智能·计算机视觉
xian_wwq1 小时前
【学习笔记】深度学习中梯度消失和爆炸问题及其解决方案研究
人工智能·深度学习·梯度
StarRocks_labs1 小时前
StarRocks 4.0:Real-Time Intelligence on Lakehouse
starrocks·人工智能·json·数据湖·存算分离
Tracy9732 小时前
DNR6521x_VC1:革新音频体验的AI降噪处理器
人工智能·音视频·xmos模组固件