模型训练中梯度累积步数(gradient_accumulation_steps)的作用

模型训练中梯度累积步数(gradient_accumulation_steps)的作用

flyfish

在使用训练大模型时,TrainingArguments有一个参数梯度累积步数(gradient_accumulation_steps)

py 复制代码
from transformers import TrainingArguments

梯度累积是一种在训练深度学习模型时用于处理内存限制问题的技术。在每次迭代中,模型的梯度是通过反向传播计算得到的,而梯度累积步数(gradient_accumulation_steps)指定了在执行实际的参数更新之前,要累积多少个小批次(mini - batch)的梯度。

以代码来说gradient_accumulation_steps的作用

py 复制代码
import torch
from torch import nn, optim

# 生成更合理的数据集,假设目标关系是y = 3 * x + 2 加上一些噪声
def generate_dataset(num_samples):
    inputs = torch.randn(num_samples, 10)
    # 根据线性关系生成标签,添加一些随机噪声模拟真实情况
    labels = 3 * inputs.sum(dim=1, keepdim=True) + 2 + torch.randn(num_samples, 1) * 0.5
    return list(zip(inputs, labels))

# 生成数据集,这里生成2000个样本(可根据实际情况调整数据量)
your_dataset = generate_dataset(2000)

# 模型、损失和优化器
model = nn.Linear(10, 1)
# 使用Xavier初始化方法来初始化模型参数,有助于缓解梯度消失和爆炸问题,提升训练效果
nn.init.xavier_uniform_(model.weight)
nn.init.zeros_(model.bias)
criterion = nn.MSELoss()
# 适当调整学习率,这里改为0.1,可根据实际情况进一步微调
optimizer = optim.Adam(model.parameters(), lr=0.1)

# 配置梯度累积步数
gradient_accumulation_steps = 4
global_step = 0

# 模拟训练循环
for epoch in range(20):  # 训练20个周期
    for step, (inputs, labels) in enumerate(torch.utils.data.DataLoader(your_dataset, batch_size=8)):
        
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # 反向传播(累积梯度)
        loss.backward()
        
        # 执行梯度更新
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            print(f"更新了模型参数,当前全局步数: {global_step}, 当前损失: {loss.item()}")

解释:

  • batch_size=8:每个梯度计算时,模型会处理 8 张图像。
  • gradient_accumulation_steps=4:表示每次参数更新前累积 4 次梯度。

因此:

  • 每个 step: 处理 8 张图像。
  • 累积 4 个 step: 共处理 8 × 4 = 32 8 \times 4 = 32 8×4=32 张图像。

关键点:

  • 一个 step: 是指一次前向和后向传播(不包含参数更新)。
  • 一次参数更新: 在累积 4 个 step 后,进行一次模型参数更新。

等效有效批次:

有效批次大小 = batch_size × gradient_accumulation_steps

即: 8 × 4 = 32 8 \times 4 = 32 8×4=32。

这意味着,即使显存有限,模型仍然能以有效批次大小 32 的方式进行训练

相关推荐
KingDol_MIni12 小时前
Transformer-LSTM混合模型在时序回归中的完整流程研究
回归·lstm·transformer
机器学习之心HML15 小时前
Transformer编码器+SHAP分析,模型可解释创新表达!
人工智能·深度学习·transformer
jzwei0231 天前
为啥大模型一般将kv进行缓存,而q不需要
深度学习·ai·transformer
COOCC12 天前
PyTorch 实战:从 0 开始搭建 Transformer
人工智能·pytorch·python·深度学习·算法·机器学习·transformer
jerwey2 天前
Diffusion Transformer(DiT)
人工智能·深度学习·transformer·dit
KingDol_MIni2 天前
transformer➕lstm训练回归模型
回归·lstm·transformer
聚客AI2 天前
预训练模型实战手册:用BERT/GPT-2微调实现10倍效率提升,Hugging Face生态下的迁移学习全链路实践
人工智能·语言模型·chatgpt·transformer·ai大模型·模型微调·deepseek
Panesle3 天前
ACE-Step:扩散自编码文生音乐基座模型快速了解
大模型·transformer·音频·扩散模型·文本生成音乐
伊布拉西莫4 天前
NLP 和大模型技术路线
语言模型·自然语言处理·transformer