AdaFactor Optimizer 大模型训练优化器简介

文章目录

AdaFactor Optimizer 简介

AdaFactor 是一种用于训练深度学习模型的优化器,由谷歌在 2018 年提出,来自论文:

Noam M. Shazeer and Mitchell Stern. Adafactor: Adaptive learning rates with sublinear memory cost. ArXiv, abs/1804.04235, 2018.

它旨在解决传统自适应优化器(如 Adam)在训练大型模型时面临的一些问题,特别是内存消耗大和泛化能力可能受限的问题。

核心特点

  1. 降低内存消耗

    • 传统自适应优化器的问题:像 Adam 这样的优化器,会为每个模型参数维护两个一阶矩估计(均值)和二阶矩估计(未中心化的方差)的统计量。对于大型模型,参数数量庞大,这些统计量会占用大量内存。例如,一个拥有 10 亿参数的模型,使用 Adam 优化器时,仅存储这些统计量就需要消耗大量 GPU 内存。
    • AdaFactor 的改进 :AdaFactor 采用了一种低秩近似的方法来存储二阶矩估计。它将二阶矩估计矩阵分解为两个低秩矩阵的乘积,从而大大减少了内存占用。具体来说,它将原本需要存储的 n × m n \times m n×m 的矩阵( n n n 和 m m m 分别是参数矩阵的行数和列数)近似为两个较小的矩阵的乘积,使得内存消耗从 O ( n m ) O(nm) O(nm) 降低到 O ( n + m ) O(n + m) O(n+m)。
  2. 自适应学习率调整

    • 原理:AdaFactor 继承了自适应优化器的优点,能够根据参数的历史梯度信息自动调整学习率。它通过计算一阶矩估计和二阶矩估计来动态地缩放梯度,使得不同参数能够以合适的步长进行更新。
    • 优势:与固定学习率的优化器相比,AdaFactor 可以更快地收敛,并且在处理不同尺度的特征时更加稳定。例如,在训练神经网络时,不同层的参数可能具有不同的梯度尺度,AdaFactor 能够自动适应这些差异,提高训练效率。
  3. 避免学习率过早衰减

    • 传统优化器的不足:一些自适应优化器在训练过程中可能会出现学习率过早衰减的问题,导致模型在后期训练中收敛速度变慢,甚至无法达到更好的性能。
    • AdaFactor 的解决方案:AdaFactor 采用了一种基于相对变化的学习率调整策略。它通过比较当前梯度和历史梯度的相对变化来决定是否调整学习率,而不是简单地依赖固定的衰减策略。这样可以避免学习率过早衰减,使模型在训练后期仍然能够保持较好的学习效率。

数学原理

  • 一阶矩估计 :AdaFactor 计算梯度的一阶矩估计 m t m_t mt,类似于 Adam 中的做法,但使用了一种指数移动平均的方式进行更新:
    m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t - 1} + (1 - \beta_1)g_t mt=β1mt−1+(1−β1)gt

    其中, g t g_t gt 是当前时刻的梯度, β 1 \beta_1 β1 是一个超参数,控制着一阶矩估计的衰减速度。

  • 二阶矩估计的低秩近似 :AdaFactor 将二阶矩估计 V t V_t Vt 分解为两个低秩矩阵 R t R_t Rt 和 C t C_t Ct 的乘积,即 V t ≈ R t C t T V_t \approx R_t C_t^T Vt≈RtCtT。在更新过程中,分别对 R t R_t Rt 和 C t C_t Ct 进行更新:
    R t = β 2 R t − 1 + ( 1 − β 2 ) g t ⊙ g t ⋅ 1 max ⁡ ( 1 , row_norm ( R t − 1 ) ) R_t = \beta_2 R_{t - 1} + (1 - \beta_2)g_t \odot g_t \cdot \frac{1}{\max(1, \text{row\norm}(R{t - 1}))} Rt=β2Rt−1+(1−β2)gt⊙gt⋅max(1,row_norm(Rt−1))1
    C t = β 2 C t − 1 + ( 1 − β 2 ) g t ⊙ g t ⋅ 1 max ⁡ ( 1 , col_norm ( C t − 1 ) ) C_t = \beta_2 C_{t - 1} + (1 - \beta_2)g_t \odot g_t \cdot \frac{1}{\max(1, \text{col\norm}(C{t - 1}))} Ct=β2Ct−1+(1−β2)gt⊙gt⋅max(1,col_norm(Ct−1))1

    其中, ⊙ \odot ⊙ 表示逐元素相乘, β 2 \beta_2 β2 是另一个超参数,控制着二阶矩估计的衰减速度, row_norm \text{row\_norm} row_norm 和 col_norm \text{col\_norm} col_norm 分别表示对矩阵的行和列进行归一化操作。

  • 参数更新 :根据一阶矩估计和二阶矩估计的低秩近似,计算参数的更新量 Δ θ t \Delta \theta_t Δθt:
    Δ θ t = − m t V t + ϵ \Delta \theta_t = -\frac{m_t}{\sqrt{V_t} + \epsilon} Δθt=−Vt +ϵmt

    其中, ϵ \epsilon ϵ 是一个很小的常数,用于避免除以零。然后,使用更新量对参数进行更新:
    θ t + 1 = θ t + Δ θ t \theta_{t + 1} = \theta_t + \Delta \theta_t θt+1=θt+Δθt

实际应用

  • 大型语言模型训练:在训练像 BERT、GPT 这样的大型语言模型时,AdaFactor 可以显著减少内存消耗,使得在有限的硬件资源下能够训练更大的模型。例如,在训练一个拥有数十亿参数的语言模型时,使用 AdaFactor 优化器可以将内存占用降低数倍,从而允许在单个 GPU 或较少的 GPU 集群上完成训练。
  • 计算机视觉任务:在图像分类、目标检测等计算机视觉任务中,AdaFactor 也能够提高训练效率和模型性能。特别是在处理高分辨率图像和复杂模型结构时,其降低内存消耗的优势更加明显。

代码示例(PyTorch)

python 复制代码
import torch
import torch.nn as nn
from torch.optim import Adam, AdamW
from transformers import AdaFactor, AdaFactorOptimizer  # 假设使用 transformers 库中的 AdaFactor

# 定义一个简单的神经网络模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(100, 10)

    def forward(self, x):
        return self.fc(x)

# 创建模型实例
model = SimpleModel()

# 定义输入数据和标签
inputs = torch.randn(32, 100)
labels = torch.randint(0, 10, (32,))

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 使用 AdaFactor 优化器
optimizer = AdaFactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False)

# 训练循环
for epoch in range(10):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    print(f'Epoch [{epoch + 1}/10], Loss: {loss.item():.4f}')

总结

AdaFactor 优化器通过低秩近似的方法降低了内存消耗,同时保留了自适应优化器的优点,能够自适应地调整学习率,避免学习率过早衰减。在训练大型深度学习模型时,AdaFactor 是一种非常有效的优化器选择,特别是在内存资源有限的情况下。它在大型语言模型和计算机视觉等领域都有广泛的应用前景。

相关推荐
晓枫-迷麟1 分钟前
【使用conda】安装pytorch
人工智能·pytorch·conda
爱补鱼的猫猫26 分钟前
Pytorch知识点2
人工智能·pytorch·python
deephub26 分钟前
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
人工智能·pytorch·python·深度学习·机器学习·正则化
小于不是小鱼呀40 分钟前
手撕 K-Means
人工智能·算法·机器学习
lilye661 小时前
精益数据分析(95/126):Socialight的定价转型启示——B2B商业模式的价格策略与利润优化
人工智能·数据挖掘·数据分析
Hero_HL1 小时前
Towards Open World Object Detection概述(论文)
人工智能·目标检测·计算机视觉
远方16091 小时前
10-Oracle 23 ai Vector Search 概述和参数
人工智能·oracle
fydw_7151 小时前
Celery 核心概念详解及示例
人工智能·机器学习
audyxiao0011 小时前
计算机视觉顶刊《International Journal of Computer Vision》2025年5月前沿热点可视化分析
图像处理·人工智能·opencv·目标检测·计算机视觉·大模型·视觉检测