模型训练中梯度累积步数(gradient_accumulation_steps)的作用

模型训练中梯度累积步数(gradient_accumulation_steps)的作用

flyfish

在使用训练大模型时,TrainingArguments有一个参数梯度累积步数(gradient_accumulation_steps)

py 复制代码
from transformers import TrainingArguments

梯度累积是一种在训练深度学习模型时用于处理内存限制问题的技术。在每次迭代中,模型的梯度是通过反向传播计算得到的,而梯度累积步数(gradient_accumulation_steps)指定了在执行实际的参数更新之前,要累积多少个小批次(mini - batch)的梯度。

以代码来说gradient_accumulation_steps的作用

py 复制代码
import torch
from torch import nn, optim

# 生成更合理的数据集,假设目标关系是y = 3 * x + 2 加上一些噪声
def generate_dataset(num_samples):
    inputs = torch.randn(num_samples, 10)
    # 根据线性关系生成标签,添加一些随机噪声模拟真实情况
    labels = 3 * inputs.sum(dim=1, keepdim=True) + 2 + torch.randn(num_samples, 1) * 0.5
    return list(zip(inputs, labels))

# 生成数据集,这里生成2000个样本(可根据实际情况调整数据量)
your_dataset = generate_dataset(2000)

# 模型、损失和优化器
model = nn.Linear(10, 1)
# 使用Xavier初始化方法来初始化模型参数,有助于缓解梯度消失和爆炸问题,提升训练效果
nn.init.xavier_uniform_(model.weight)
nn.init.zeros_(model.bias)
criterion = nn.MSELoss()
# 适当调整学习率,这里改为0.1,可根据实际情况进一步微调
optimizer = optim.Adam(model.parameters(), lr=0.1)

# 配置梯度累积步数
gradient_accumulation_steps = 4
global_step = 0

# 模拟训练循环
for epoch in range(20):  # 训练20个周期
    for step, (inputs, labels) in enumerate(torch.utils.data.DataLoader(your_dataset, batch_size=8)):
        
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # 反向传播(累积梯度)
        loss.backward()
        
        # 执行梯度更新
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            print(f"更新了模型参数,当前全局步数: {global_step}, 当前损失: {loss.item()}")

解释:

  • batch_size=8:每个梯度计算时,模型会处理 8 张图像。
  • gradient_accumulation_steps=4:表示每次参数更新前累积 4 次梯度。

因此:

  • 每个 step: 处理 8 张图像。
  • 累积 4 个 step: 共处理 8 × 4 = 32 8 \times 4 = 32 8×4=32 张图像。

关键点:

  • 一个 step: 是指一次前向和后向传播(不包含参数更新)。
  • 一次参数更新: 在累积 4 个 step 后,进行一次模型参数更新。

等效有效批次:

有效批次大小 = batch_size × gradient_accumulation_steps

即: 8 × 4 = 32 8 \times 4 = 32 8×4=32。

这意味着,即使显存有限,模型仍然能以有效批次大小 32 的方式进行训练

相关推荐
禾风wyh1 天前
【深度学习】深刻理解Swin Transformer
人工智能·深度学习·transformer
宝贝儿好1 天前
【NLP】第五章:注意力机制Attention
人工智能·python·深度学习·神经网络·自然语言处理·transformer
通信仿真实验室2 天前
Google BERT入门(5)Transformer通过位置编码学习位置
人工智能·深度学习·神经网络·自然语言处理·nlp·bert·transformer
知来者逆2 天前
Layer-Condensed KV——利用跨层注意(CLA)减少 KV 缓存中的内存保持 Transformer 1B 和 3B 参数模型的准确性
人工智能·深度学习·机器学习·transformer
Eshin_Ye2 天前
transformer学习笔记-自注意力机制(1)
笔记·学习·transformer·attention·注意力机制
耐心的等待52833 天前
【Transformer序列预测】Pytorch中构建Transformer对序列进行预测源代码
pytorch·深度学习·transformer
四代机您发多少3 天前
入门pytorch-Transformer
人工智能·pytorch·transformer
池央3 天前
GPT (Generative Pre-trained Transformer):开启自然语言处理新时代
gpt·自然语言处理·transformer
无水先生3 天前
机器学习中的 Transformer 简介(第 1 部分)
人工智能·机器学习·transformer