模型训练中梯度累积步数(gradient_accumulation_steps)的作用

模型训练中梯度累积步数(gradient_accumulation_steps)的作用

flyfish

在使用训练大模型时,TrainingArguments有一个参数梯度累积步数(gradient_accumulation_steps)

py 复制代码
from transformers import TrainingArguments

梯度累积是一种在训练深度学习模型时用于处理内存限制问题的技术。在每次迭代中,模型的梯度是通过反向传播计算得到的,而梯度累积步数(gradient_accumulation_steps)指定了在执行实际的参数更新之前,要累积多少个小批次(mini - batch)的梯度。

以代码来说gradient_accumulation_steps的作用

py 复制代码
import torch
from torch import nn, optim

# 生成更合理的数据集,假设目标关系是y = 3 * x + 2 加上一些噪声
def generate_dataset(num_samples):
    inputs = torch.randn(num_samples, 10)
    # 根据线性关系生成标签,添加一些随机噪声模拟真实情况
    labels = 3 * inputs.sum(dim=1, keepdim=True) + 2 + torch.randn(num_samples, 1) * 0.5
    return list(zip(inputs, labels))

# 生成数据集,这里生成2000个样本(可根据实际情况调整数据量)
your_dataset = generate_dataset(2000)

# 模型、损失和优化器
model = nn.Linear(10, 1)
# 使用Xavier初始化方法来初始化模型参数,有助于缓解梯度消失和爆炸问题,提升训练效果
nn.init.xavier_uniform_(model.weight)
nn.init.zeros_(model.bias)
criterion = nn.MSELoss()
# 适当调整学习率,这里改为0.1,可根据实际情况进一步微调
optimizer = optim.Adam(model.parameters(), lr=0.1)

# 配置梯度累积步数
gradient_accumulation_steps = 4
global_step = 0

# 模拟训练循环
for epoch in range(20):  # 训练20个周期
    for step, (inputs, labels) in enumerate(torch.utils.data.DataLoader(your_dataset, batch_size=8)):
        
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # 反向传播(累积梯度)
        loss.backward()
        
        # 执行梯度更新
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            print(f"更新了模型参数,当前全局步数: {global_step}, 当前损失: {loss.item()}")

解释:

  • batch_size=8:每个梯度计算时,模型会处理 8 张图像。
  • gradient_accumulation_steps=4:表示每次参数更新前累积 4 次梯度。

因此:

  • 每个 step: 处理 8 张图像。
  • 累积 4 个 step: 共处理 8 × 4 = 32 8 \times 4 = 32 8×4=32 张图像。

关键点:

  • 一个 step: 是指一次前向和后向传播(不包含参数更新)。
  • 一次参数更新: 在累积 4 个 step 后,进行一次模型参数更新。

等效有效批次:

有效批次大小 = batch_size × gradient_accumulation_steps

即: 8 × 4 = 32 8 \times 4 = 32 8×4=32。

这意味着,即使显存有限,模型仍然能以有效批次大小 32 的方式进行训练

相关推荐
周杰伦_Jay2 小时前
Ollama能本地部署Llama 3等大模型的原因解析(ollama核心架构、技术特性、实际应用)
数据结构·人工智能·深度学习·架构·transformer·llama
好评笔记18 小时前
AIGC视频生成模型:ByteDance的PixelDance模型
论文阅读·人工智能·深度学习·机器学习·计算机视觉·aigc·transformer
珊珊而川20 小时前
BERT和Transformer模型有什么区别
人工智能·bert·transformer
feifeikon1 天前
深度学习 DAY2:Transformer(一部分)
人工智能·深度学习·transformer
RockWang.2 天前
【llama_factory】qwen2_vl训练与批量推理
llama·qwen2-vl
无意21212 天前
【自动驾驶BEV感知之Transformer】
人工智能·自动驾驶·transformer
bug404_2 天前
Restormer: Efficient Transformer for High-Resolution Image Restoration解读
人工智能·深度学习·transformer
GISer Liu2 天前
Transformer详解:Attention机制原理
人工智能·python·gpt·深度学习·机器学习·语言模型·transformer
王了了哇3 天前
精度论文:【Focaler-IoU: More Focused Intersection over Union Loss】
人工智能·pytorch·深度学习·计算机视觉·transformer
机器学习之心3 天前
强推未发表!3D图!Transformer-LSTM+NSGAII工艺参数优化、工程设计优化!
lstm·transformer·nsgaii工艺参数优化·工程设计优化