模型训练中梯度累积步数(gradient_accumulation_steps)的作用

模型训练中梯度累积步数(gradient_accumulation_steps)的作用

flyfish

在使用训练大模型时,TrainingArguments有一个参数梯度累积步数(gradient_accumulation_steps)

py 复制代码
from transformers import TrainingArguments

梯度累积是一种在训练深度学习模型时用于处理内存限制问题的技术。在每次迭代中,模型的梯度是通过反向传播计算得到的,而梯度累积步数(gradient_accumulation_steps)指定了在执行实际的参数更新之前,要累积多少个小批次(mini - batch)的梯度。

以代码来说gradient_accumulation_steps的作用

py 复制代码
import torch
from torch import nn, optim

# 生成更合理的数据集,假设目标关系是y = 3 * x + 2 加上一些噪声
def generate_dataset(num_samples):
    inputs = torch.randn(num_samples, 10)
    # 根据线性关系生成标签,添加一些随机噪声模拟真实情况
    labels = 3 * inputs.sum(dim=1, keepdim=True) + 2 + torch.randn(num_samples, 1) * 0.5
    return list(zip(inputs, labels))

# 生成数据集,这里生成2000个样本(可根据实际情况调整数据量)
your_dataset = generate_dataset(2000)

# 模型、损失和优化器
model = nn.Linear(10, 1)
# 使用Xavier初始化方法来初始化模型参数,有助于缓解梯度消失和爆炸问题,提升训练效果
nn.init.xavier_uniform_(model.weight)
nn.init.zeros_(model.bias)
criterion = nn.MSELoss()
# 适当调整学习率,这里改为0.1,可根据实际情况进一步微调
optimizer = optim.Adam(model.parameters(), lr=0.1)

# 配置梯度累积步数
gradient_accumulation_steps = 4
global_step = 0

# 模拟训练循环
for epoch in range(20):  # 训练20个周期
    for step, (inputs, labels) in enumerate(torch.utils.data.DataLoader(your_dataset, batch_size=8)):
        
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # 反向传播(累积梯度)
        loss.backward()
        
        # 执行梯度更新
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            print(f"更新了模型参数,当前全局步数: {global_step}, 当前损失: {loss.item()}")

解释:

  • batch_size=8:每个梯度计算时,模型会处理 8 张图像。
  • gradient_accumulation_steps=4:表示每次参数更新前累积 4 次梯度。

因此:

  • 每个 step: 处理 8 张图像。
  • 累积 4 个 step: 共处理 8 × 4 = 32 8 \times 4 = 32 8×4=32 张图像。

关键点:

  • 一个 step: 是指一次前向和后向传播(不包含参数更新)。
  • 一次参数更新: 在累积 4 个 step 后,进行一次模型参数更新。

等效有效批次:

有效批次大小 = batch_size × gradient_accumulation_steps

即: 8 × 4 = 32 8 \times 4 = 32 8×4=32。

这意味着,即使显存有限,模型仍然能以有效批次大小 32 的方式进行训练

相关推荐
大知闲闲哟6 小时前
深度学习TR3周:Pytorch复现Transformer
pytorch·深度学习·transformer
学Linux的语莫16 小时前
transformer与神经网络
深度学习·神经网络·transformer
这张生成的图像能检测吗1 天前
(论文速读)RMT:Retentive+ViT的视觉新骨干
人工智能·深度学习·计算机视觉·transformer·注意力机制
老鱼说AI1 天前
Vision Transformer(ViT)模型实例化PyTorch逐行实现
pytorch·深度学习·transformer
老鱼说AI1 天前
Vision Transformer (ViT) 详解:当Transformer“看见”世界,计算机视觉的范式革命
人工智能·深度学习·transformer
战争热诚2 天前
基于transformer的目标检测——匈牙利匹配算法
算法·目标检测·transformer
蹦蹦跳跳真可爱5892 天前
Python----大模型(大模型微调--BitFit、Prompt Tuning、P-tuning、Prefix-tuning、LORA)
人工智能·python·深度学习·自然语言处理·transformer
盼小辉丶2 天前
PyTorch生成式人工智能(24)——使用PyTorch构建Transformer模型
pytorch·深度学习·transformer
陈敬雷-充电了么-CEO兼CTO2 天前
从游戏NPC到手术助手:Agent AI重构多模态交互,具身智能打开AGI新大门
人工智能·深度学习·算法·chatgpt·重构·transformer·agi
deephub2 天前
NSA稀疏注意力深度解析:DeepSeek如何将Transformer复杂度从O(N²)降至线性,实现9倍训练加速
人工智能·深度学习·transformer·deepseek·稀疏注意力