模型训练中梯度累积步数(gradient_accumulation_steps)的作用

模型训练中梯度累积步数(gradient_accumulation_steps)的作用

flyfish

在使用训练大模型时,TrainingArguments有一个参数梯度累积步数(gradient_accumulation_steps)

py 复制代码
from transformers import TrainingArguments

梯度累积是一种在训练深度学习模型时用于处理内存限制问题的技术。在每次迭代中,模型的梯度是通过反向传播计算得到的,而梯度累积步数(gradient_accumulation_steps)指定了在执行实际的参数更新之前,要累积多少个小批次(mini - batch)的梯度。

以代码来说gradient_accumulation_steps的作用

py 复制代码
import torch
from torch import nn, optim

# 生成更合理的数据集,假设目标关系是y = 3 * x + 2 加上一些噪声
def generate_dataset(num_samples):
    inputs = torch.randn(num_samples, 10)
    # 根据线性关系生成标签,添加一些随机噪声模拟真实情况
    labels = 3 * inputs.sum(dim=1, keepdim=True) + 2 + torch.randn(num_samples, 1) * 0.5
    return list(zip(inputs, labels))

# 生成数据集,这里生成2000个样本(可根据实际情况调整数据量)
your_dataset = generate_dataset(2000)

# 模型、损失和优化器
model = nn.Linear(10, 1)
# 使用Xavier初始化方法来初始化模型参数,有助于缓解梯度消失和爆炸问题,提升训练效果
nn.init.xavier_uniform_(model.weight)
nn.init.zeros_(model.bias)
criterion = nn.MSELoss()
# 适当调整学习率,这里改为0.1,可根据实际情况进一步微调
optimizer = optim.Adam(model.parameters(), lr=0.1)

# 配置梯度累积步数
gradient_accumulation_steps = 4
global_step = 0

# 模拟训练循环
for epoch in range(20):  # 训练20个周期
    for step, (inputs, labels) in enumerate(torch.utils.data.DataLoader(your_dataset, batch_size=8)):
        
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # 反向传播(累积梯度)
        loss.backward()
        
        # 执行梯度更新
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            print(f"更新了模型参数,当前全局步数: {global_step}, 当前损失: {loss.item()}")

解释:

  • batch_size=8:每个梯度计算时,模型会处理 8 张图像。
  • gradient_accumulation_steps=4:表示每次参数更新前累积 4 次梯度。

因此:

  • 每个 step: 处理 8 张图像。
  • 累积 4 个 step: 共处理 8 × 4 = 32 8 \times 4 = 32 8×4=32 张图像。

关键点:

  • 一个 step: 是指一次前向和后向传播(不包含参数更新)。
  • 一次参数更新: 在累积 4 个 step 后,进行一次模型参数更新。

等效有效批次:

有效批次大小 = batch_size × gradient_accumulation_steps

即: 8 × 4 = 32 8 \times 4 = 32 8×4=32。

这意味着,即使显存有限,模型仍然能以有效批次大小 32 的方式进行训练

相关推荐
机器学习之心2 天前
SSA-Transformer-LSTM麻雀搜索算法优化组合模型分类预测结合SHAP分析!优化深度组合模型可解释分析,Matlab代码
分类·lstm·transformer·麻雀搜索算法优化·ssa-transformer
Rock_yzh3 天前
AI学习日记——Transformer的架构:编码器与解码器
人工智能·深度学习·神经网络·学习·transformer
yuluo_YX3 天前
语义模型 - 从 Transformer 到 Qwen
人工智能·深度学习·transformer
大千AI助手3 天前
Megatron-LM张量并行详解:原理、实现与应用
人工智能·大模型·llm·transformer·模型训练·megatron-lm张量并行·大千ai助手
Cathy Bryant3 天前
智能模型对齐(一致性)alignment
笔记·神经网络·机器学习·数学建模·transformer
知识搬运工人4 天前
传统卷积神经网络中的核心运算是卷积或者矩阵乘,请问transformer模型架构主要的计算
矩阵·cnn·transformer
跳跳糖炒酸奶4 天前
第九章、GPT1:Improving Language Understanding by Generative Pre-Training(理论部分)
transformer·解码器·gpt1
Yeats_Liao4 天前
华为开源自研AI框架昇思MindSpore应用案例:跑通Vision Transformer图像分类
人工智能·华为·transformer
AndrewHZ5 天前
【图像处理基石】图像Inpainting入门详解
图像处理·人工智能·深度学习·opencv·transformer·图像修复·inpainting
迪三达5 天前
GPT-0: Attention+Transformer+可视化
gpt·深度学习·transformer