浅谈梯度累积(Gradient Accumulation)和梯度检查点(Gradient Checkpointing)

目录

  • [1. 回顾](#1. 回顾)
  • [2. 梯度累积(Gradient Accumulation)](#2. 梯度累积(Gradient Accumulation))
    • [2.1 工作原理](#2.1 工作原理)
    • [2.2 代码示例](#2.2 代码示例)
    • [2.3 Q&A](#2.3 Q&A)
  • [3. 梯度检查点(Gradient Checkpointing)](#3. 梯度检查点(Gradient Checkpointing))
  • Ref

1. 回顾

在讨论梯度累积技术之前,让我们先回顾一些 PyTorch 的基础知识,特别是关于模型的梯度计算和参数更新。以下是一个简单的线性模型示例:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Linear(2, 1)

inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
targets = torch.tensor([[1.0], [2.0]])

criterion = nn.MSELoss()

# 前向传播:计算模型输出和损失
outputs = model(inputs)
loss = criterion(outputs, targets)

# 反向传播:计算梯度
loss.backward()

for name, param in model.named_parameters():
    print(f"参数: {name}")
    print(f"参数值: {param}")
    print(f"梯度: {param.grad}\n")

在这个示例中,我们首先定义了一个简单的线性模型 nn.Linear(2, 1),它接受一个二维输入并输出一个标量。接着,我们创建了输入数据 inputs 和对应的目标标签 targets。损失函数选择了均方误差损失 nn.MSELoss()

通过前向传播计算模型的输出 outputs,并基于输出和目标计算损失 loss。然后,通过调用 loss.backward() 进行反向传播,计算模型参数的梯度。

值得注意的是,gradtensor 的形状是相同的。在 PyTorch 中,一个 tensor 通常包含两个主要属性:data(数据本身)和 grad(对应的梯度)。这两个属性都会占用显存,尤其是在处理大型模型或大批量数据时,显存的占用会显著增加。

2. 梯度累积(Gradient Accumulation)

在深度学习的训练过程中,批量大小(batch size) 对模型的训练效果和稳定性有着重要的影响。较大的批量大小通常可以提供更稳定的梯度估计,但也需要更多的显存资源。然而,受限于硬件条件,我们往往无法直接使用大批量进行训练。这时,梯度累积技术就派上了用场。

2.1 工作原理

梯度累积 的核心思想是:在不增加实际批量大小的情况下,通过累积多个小批量的梯度,模拟大批量的效果。具体步骤如下:

  1. 前向传播:对于每个小批量,计算模型的输出和损失。
  2. 反向传播:对每个小批量执行反向传播,计算梯度,但不立即更新模型参数,而是将梯度累加起来。
  3. 参数更新:当累积的梯度达到预设的累积步数(accumulation steps)后,执行一次参数更新,并将累积的梯度清零。

假设我们希望使用批量大小为 64 的数据进行训练,但由于显存限制,我们只能处理批量大小为 16 的数据。通过梯度累积,我们可以将 4 个小批量(每个大小为 16)的梯度累积起来,在第 4 次小批量之后执行一次参数更新。这样,我们在不增加显存占用的情况下,获得了与使用批量大小为 64 进行训练相同的效果。

2.2 代码示例

下面是一个使用 PyTorch 实现梯度累积的示例代码:

python 复制代码
import torch
from torch import nn, optim

model = nn.Linear(10, 1)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

batch_size = 16
accumulation_steps = 4  # 梯度累积步数,有效的 batch_size 为 16 * 4 = 64

data = torch.randn(64, 10)
target = torch.randn(64, 1)

mini_batches = torch.split(data, batch_size)
mini_targets = torch.split(target, batch_size)

optimizer.zero_grad()
for i, (inputs, labels) in enumerate(zip(mini_batches, mini_targets)):
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    
    loss = loss / accumulation_steps  # 下文会解释为什么要除以这个数
    
    # 反向传播,累积梯度
    loss.backward()
    
    # 每 accumulation_steps 步进行一次参数更新
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

# 如果最后的 batch 未达到 accumulation_steps,则仍需更新参数
if (i + 1) % accumulation_steps != 0:
    optimizer.step()
    optimizer.zero_grad()

解释:

  1. accumulation_steps:在这里我们设置为 4,表示每 4 个 mini-batch 后,累积一次梯度并进行一次参数更新。
  2. loss.backward():每次都进行反向传播,累积梯度。
  3. optimizer.step():每 4 次 mini-batch 后,执行一次权重更新。
  4. optimizer.zero_grad():每次更新后需要将累积的梯度清零。

2.3 Q&A

Q1: 即使使用梯度累积,累积完后的有效批量大小与原始的批量大小相同,为什么还能节省显存?

A1 : 在训练过程中,显存的占用主要来自于模型参数、梯度、中间激活值和优化器状态。其中,中间激活值的显存占用与批量大小直接相关。使用梯度累积时,我们在每次前向和反向传播中仅处理较小的批量,因此中间激活值的显存占用较小。

📝 在神经网络中,中间激活值是指每一层的输出(激活函数后的值),这些值通常要被保存在显存中,以便在反向传播时计算梯度。

Q2: 使用梯度累积的效果与直接使用较大批量进行训练的效果是否等价?

A2: 完全等价。我们可以从数学的角度进行推导。

假设模型的输出形状为 (batch_size, seq_len, vocab_size),其中 batch_size 是批量大小,seq_len 是序列长度,vocab_size 是词汇表大小。在语言模型的训练过程中,我们通常对每个 token 计算损失,然后需要将这些损失合成为一个标量用于反向传播。

设 L i j L_{ij} Lij 表示第 i i i 个样本的第 j j j 个 token 的损失,对于一个批量大小为 N N N 的 mini-batch,序列长度为 S S S,损失矩阵 L L L 的形状为 ( N , S ) (N,S) (N,S)。我们需要将损失矩阵合成为一个标量 ℓ \ell ℓ。

通常的做法是对损失矩阵进行平均:

ℓ = 1 N × S ∑ i = 1 N ∑ j = 1 S L i j \ell = \frac{1}{N \times S} \sum_{i=1}^{N} \sum_{j=1}^{S} L_{ij} ℓ=N×S1i=1∑Nj=1∑SLij

令 L i = 1 S ∑ j = 1 S L i j \mathcal{L}i = \frac{1}{S} \sum{j=1}^{S} L_{ij} Li=S1∑j=1SLij,即每个句子的损失,那么总的损失就是所有句子损失的平均:

ℓ = 1 N ∑ i = 1 N L i \ell = \frac{1}{N} \sum_{i=1}^{N} \mathcal{L}_i ℓ=N1i=1∑NLi

在不进行梯度累积的情况下,梯度为:

∇ ℓ = 1 N ∑ i = 1 N ∇ L i \nabla \ell = \frac{1}{N} \sum_{i=1}^{N} \nabla \mathcal{L}_i ∇ℓ=N1i=1∑N∇Li

假设 accumulation steps 为 T T T,这说明每次梯度累积的 mini batch size 为 N / T ≜ k N / T \triangleq k N/T≜k。为方便起见,我们假设这里能够正好整除。假设当前累积到了第 t t t 步( t t t 从 1 开始),这一步的损失为:

ℓ t = 1 k ∑ i = ( t − 1 ) k t k − 1 L i \ell_t = \frac{1}{k} \sum_{i=(t-1)k}^{tk-1} \mathcal{L}_i ℓt=k1i=(t−1)k∑tk−1Li

那么 T T T 步过后累积的梯度为:

∑ t = 1 T ∇ ℓ t = 1 k ∑ t = 1 T ∑ i = ( t − 1 ) k t k − 1 ∇ L i = 1 k ∑ i = 1 N ∇ L i = T ∇ ℓ \sum_{t=1}^{T} \nabla \ell_t = \frac{1}{k} \sum_{t=1}^{T} \sum_{i=(t-1)k}^{tk-1} \nabla \mathcal{L}i = \frac{1}{k} \sum{i=1}^{N} \nabla \mathcal{L}_i = T \nabla \ell t=1∑T∇ℓt=k1t=1∑Ti=(t−1)k∑tk−1∇Li=k1i=1∑N∇Li=T∇ℓ

移项得

∇ ( ℓ ) = ∇ ( ∑ t = 1 T 1 T ℓ t ) \nabla (\ell )=\nabla \left( \sum_{t=1}^{T}\frac1T \ell_t\right) ∇(ℓ)=∇(t=1∑TT1ℓt)

从表达式可以看出,在每次累积梯度的时候,只需要将计算出的损失除以一个 T T T,就能保证最后的梯度一致。

3. 梯度检查点(Gradient Checkpointing)

梯度检查点(Gradient Checkpointing)是一种内存优化策略,特别适用于训练深度神经网络。其基本思想是通过在反向传播时重新计算部分中间结果(激活值),从而减少前向传播时这些中间结果在内存中的存储,降低显存的使用量。这样,我们就可以在有限的显存条件下训练更大的模型或使用更大的批次大小。

  1. 模型训练中的内存消耗:在神经网络训练中,前向传播过程中会计算大量的中间结果(激活值),这些结果在反向传播时需要用到,因此需要暂时存储在内存中。这些中间结果会占用大量内存,尤其是在模型规模很大时,内存消耗将更加显著。

  2. 内存和计算的权衡:梯度检查点通过在前向传播时不存储所有中间结果,而是在需要时重新计算。这意味着我们通过增加一些计算量(重新计算中间结果),来节省存储这些结果所需的内存。

  3. 如何做到的? :以一个非常深的神经网络模型为例,我们可以将其分段处理。在每一段中,只存储一些关键节点的激活值,然后在反向传播时重新计算未存储的那些激活值。这就是用时间换取空间的策略。

PyTorch 提供了简单的方法来实现梯度检查点。我们可以使用 torch.utils.checkpoint.checkpoint 来应用该技术。

py 复制代码
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.Linear(100, 200)
        self.layer2 = nn.Linear(200, 200)
        self.layer3 = nn.Linear(200, 100)

    def forward(self, x):
        x = checkpoint.checkpoint(self.layer1, x)  # 在第一层使用梯度检查点
        x = checkpoint.checkpoint(self.layer2, x)  # 在第二层使用梯度检查点
        x = self.layer3(x)  # 第三层不使用梯度检查点
        return x


model = SimpleModel()
input_data = torch.randn(10, 100)
output = model(input_data)

print("Output shape:", output.shape)

解释:

  1. 我们定义了一个包含三层的简单全连接神经网络。
  2. 对前两层使用了 checkpoint.checkpoint(),这意味着在前向传播时,这些层的激活值不会直接存储,而是在反向传播时重新计算。
  3. 第三层未使用梯度检查点,因此其激活值会被直接存储。

Ref

[1] https://zhuanlan.zhihu.com/p/448395808

[2] https://www.bilibili.com/video/BV1nJ4m1M7Qw

相关推荐
linly121928 分钟前
在python中安装HDDM
开发语言·python
洋葱蚯蚓29 分钟前
构建自己的文生图工具:Python + Stable Diffusion + CUDA
开发语言·python·stable diffusion
ling1s32 分钟前
C#基础(12)递归函数
开发语言·算法·c#
Antonio91536 分钟前
【高级数据结构】树状数组
数据结构·c++·算法
大晴的上分之旅44 分钟前
树和二叉树基本术语、性质
数据结构·算法·二叉树
全智能时代1 小时前
宝塔部署python项目
python
kuiini1 小时前
python学习-09【文件和目录操作】
python·学习
炸膛坦客1 小时前
深度学习:(四)python中的广播
人工智能·python·深度学习
Chase-Hart1 小时前
【每日一题】LeetCode 815.公交路线(广度优先搜索、数组、哈希表)
数据结构·算法·leetcode·散列表·宽度优先
Am心若依旧4091 小时前
[C++进阶[六]]list的相关接口模拟实现
开发语言·数据结构·c++·算法·list