PyTorch 中的累积梯度

https://stackoverflow.com/questions/62067400/understanding-accumulated-gradients-in-pytorch

有一个小的计算图,两次前向梯度累积的结果,可以看到梯度是严格相等的。


代码:

复制代码
import numpy as np
import torch


class ExampleLinear(torch.nn.Module):

    def __init__(self):
        super().__init__()
        # Initialize the weight at 1
        self.weight = torch.nn.Parameter(torch.Tensor([1]).float(),
                                         requires_grad=True)

    def forward(self, x):
        return self.weight * x


model = ExampleLinear()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)


def calculate_loss(x: torch.Tensor) -> torch.Tensor:
    y = 2 * x
    y_hat = model(x)
    temp1 = (y - y_hat)
    temp2 = temp1**2
    return temp2


# With mulitple batches of size 1
batches = [torch.tensor([4.0]), torch.tensor([2.0])]

optimizer.zero_grad()
for i, batch in enumerate(batches):
    # The loss needs to be scaled, because the mean should be taken across the whole
    # dataset, which requires the loss to be divided by the number of batches.
    temp2 = calculate_loss(batch)
    loss = temp2 / len(batches)
    loss.backward()
    print(f"Batch size 1 (batch {i}) - grad: {model.weight.grad}")
    print(f"Batch size 1 (batch {i}) - weight: {model.weight}")
    print("="*50)

# Updating the model only after all batches
optimizer.step()
print(f"Batch size 1 (final) - grad: {model.weight.grad}")
print(f"Batch size 1 (final) - weight: {model.weight}")

运行结果

复制代码
Batch size 1 (batch 0) - grad: tensor([-16.])
Batch size 1 (batch 0) - weight: Parameter containing:
tensor([1.], requires_grad=True)
==================================================
Batch size 1 (batch 1) - grad: tensor([-20.])
Batch size 1 (batch 1) - weight: Parameter containing:
tensor([1.], requires_grad=True)
==================================================
Batch size 1 (final) - grad: tensor([-20.])
Batch size 1 (final) - weight: Parameter containing:
tensor([1.2000], requires_grad=True)

然而,如果训练一个真实的模型,结果没有这么理想,比如训练一个bert,𝐵=8,𝑁=1:没有梯度累积(累积每一步),

𝐵=2,𝑁=4:梯度累积(每 4 步累积一次)

使用带有梯度累积的 Batch Normalization 通常效果不佳,原因很简单,因为 BatchNorm 统计数据无法累积。更好的解决方案是使用 Group Normalization 而不是 BatchNorm。

https://ai.stackexchange.com/questions/21972/what-is-the-relationship-between-gradient-accumulation-and-batch-size

相关推荐
AndrewHZ22 分钟前
【图像处理基石】什么是tone mapping?
图像处理·人工智能·算法·计算机视觉·hdr
Ai尚研修-贾莲23 分钟前
基于DeepSeek、ChatGPT支持下的地质灾害风险评估、易发性分析、信息化建库及灾后重建
人工智能·chatgpt
SelectDB技术团队1 小时前
Apache Doris 2025 Roadmap:构建 GenAI 时代实时高效统一的数据底座
大数据·数据库·数据仓库·人工智能·ai·数据分析·湖仓一体
weixin_435208161 小时前
通过 Markdown 改进 RAG 文档处理
人工智能·python·算法·自然语言处理·面试·nlp·aigc
大数据在线1 小时前
AI重塑云基础设施,亚马逊云科技打造AI定制版IaaS“样板房”
人工智能·云基础设施·ai大模型·亚马逊云科技
hello_ejb31 小时前
聊聊Spring AI的RetrievalAugmentationAdvisor
人工智能·spring·restful
东方佑1 小时前
利用Python自动化处理PPT样式与结构:从提取到生成
python·自动化·powerpoint
你觉得2051 小时前
浙江大学朱霖潮研究员:《人工智能重塑科学与工程研究》以蛋白质结构预测为例|附PPT下载方法
大数据·人工智能·机器学习·ai·云计算·aigc·powerpoint
橘猫云计算机设计1 小时前
基于springboot的考研成绩查询系统(源码+lw+部署文档+讲解),源码可白嫖!
java·spring boot·后端·python·考研·django·毕业设计
牙牙要健康1 小时前
【目标检测】【深度学习】【Pytorch版本】YOLOV3模型算法详解
pytorch·深度学习·目标检测