突破限制:ReLoRA技术让大型神经网络训练效率飞跃

概述

ReLoRA,即"High-Rank Training Through Low-Rank Updates",该篇文章基于LoRA,提出了一种通过低秩更新训练高秩网络的方法,从而能够降低训练模型所需要的算力。自2011年以来,随着深度学习技术的飞速发展,模型规模不断扩大,从百万参数到现在的数百亿参数。然而,这种规模的增长也导致计算成本不断升高,使得训练过程变得日益复杂。

ReLoRA提出了一种新的训练技术,旨在通过低秩更新来高效地训练高秩网络,不仅能够节省资源,还能使得模型性能和常规训练相同。

核心理念

ReLoRA的核心在于利用矩阵的秩性质,通过一系列低秩的更新来聚合成一个高秩的网络。这种方法的灵感来源于LoRA(Low-Rank Adaptation of Large Language Models),但ReLoRA进一步优化了训练过程,使其更加高效。

方法细节

ReLoRA的训练过程包括以下几个关键步骤:

  1. 全秩预热:在训练开始阶段,使用全秩的方式对网络进行热启动,以确保模型有一个良好的起点。(其实就是进行一定程度的预训练。)
  2. 低秩更新:在预热之后,ReLoRA通过应用低秩更新来逐步调整模型参数,这些更新是通过特定的矩阵分解技术实现的。
  3. 参数合并与重置:在训练过程中,定期将低秩参数合并回主网络,并重置优化器状态,这一步骤有助于模型探索更广阔的参数空间。
  4. 学习率调整:ReLoRA采用了一种特殊的学习率调度策略,即锯齿形余弦调度,以适应模型在不同阶段的训练需求。

1. 全秩预热(Full-Rank Warm Start)

在训练的初始阶段,ReLoRA首先使用全秩的方式对神经网络进行预热------------网络的所有参数都以常规的方式 进行初始化和训练。简而言之就是对于整个网络进行预训练:即所有的参数都是可训练的,并且使用标准的优化器(如Adam)进行更新。

  1. 预热周期:预热所需要的epoch数目并不是固定的,而是一个超参,实际的预热周期可能会根据模型的表现和收敛速度进行调整。
  2. 预热后的状态:在全秩预热阶段结束后,模型应该已经学习到了一些有效的表示,这为后续的低秩更新奠定了基础。在这个阶段,模型可能还没有达到最优性能,但已经足够成熟。
  3. 过渡到低秩更新:一旦完成了预热步骤,模型就会过渡到ReLoRA的低秩更新阶段。在这个阶段,模型的线性层会被转换为ReLoRA的形式,即通过低秩矩阵乘积来更新权重。

2. 低秩更新(Low-Rank Updates)

通过LoRA(Low-Rank Adaptation)技术,我们可以知道,一个模型中的权重更新值 <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ W \Delta W </math>ΔW,通常是一个巨大的低秩矩阵 。而这就意味着:权重更新 <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ W \Delta W </math>ΔW可以被分解为两个较小秩的矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W A W_A </math>WA和 <math xmlns="http://www.w3.org/1998/Math/MathML"> W B W_B </math>WB的乘积,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ W = A B T , A ∈ R n × d , B ∈ N d × m , d ≪ n \Delta W=A B^T, A \in \mathbb{R}^{n \times d}, B \in \mathbb{N}^{d \times m}, d \ll n </math>ΔW=ABT,A∈Rn×d,B∈Nd×m,d≪n, <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ W \Delta W </math>ΔW 这个参数矩阵的秩(Rank,lora_dim)。由于下游细分任务的域非常小,所以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d可以取得很小,很多时候我们可以取1。因此我们可以将原本的线性层转化为参数量非常小的矩阵乘积。

为大家举个直观的例子,我们假设原来的 <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ W \Delta W </math>ΔW是1001024的参数矩阵,那么参数量为102400,LoRA模型将 <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ W \Delta W </math>ΔW矩阵拆成了两个矩阵相乘,如果我们设置Rank=8,那么就是100 8的B矩阵与8*1024的A矩阵做矩阵乘法,参数量为800+8192=8992,整体参数量下降了约11.39倍

在低秩更新阶段, <math xmlns="http://www.w3.org/1998/Math/MathML"> W A W_A </math>WA和 <math xmlns="http://www.w3.org/1998/Math/MathML"> W B W_B </math>WB是新增的可训练参数,它们在训练后可以合并回原始权重矩阵中。

3. 参数合并与重置(Parameter Merging and Reinitialization)

在训练过程中,ReLoRA会定期 将低秩更新参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> W A W_A </math>WA和 <math xmlns="http://www.w3.org/1998/Math/MathML"> W B W_B </math>WB合并回网络的主参数中。这一步骤通过简单的矩阵加法来完成,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> W i = ( W i + W A i W B i ) W_i = (W_i + W_A^iW_B^i) </math>Wi=(Wi+WAiWBi),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> W i W_i </math>Wi是原始权重, <math xmlns="http://www.w3.org/1998/Math/MathML"> W A i W_A^i </math>WAi和 <math xmlns="http://www.w3.org/1998/Math/MathML"> W B i W_B^i </math>WBi是低秩更新参数。合并后, <math xmlns="http://www.w3.org/1998/Math/MathML"> W A W_A </math>WA和 <math xmlns="http://www.w3.org/1998/Math/MathML"> W B W_B </math>WB会被重新初始化, <math xmlns="http://www.w3.org/1998/Math/MathML"> W A W_A </math>WA使用Kaiming初始化,而 <math xmlns="http://www.w3.org/1998/Math/MathML"> W B W_B </math>WB则初始化为零。这一过程有助于模型在保持之前学习到的知识的同时,继续探索新的参数空间。

简而言之,相对于LoRA,ReLoRA其实就是在训练过程中,周期性的采用LoRA的方式进行训练,将模型中的线性层分解为两个低秩矩阵,而后直接更新低秩矩阵,从而减少需要计算的参数量。而后周期性的合并低秩矩阵到线性层中,实现权重的彻底更新。

但是,我们在训练的时候,是需要记录优化器状态和调整学习率的,那么当我们重新初始化低秩矩阵后,我们该如何处理已计算的优化器状态和自动计算的学习率呢?

4. 优化器状态重置(Optimizer State Reset)

很多优化器,如Adam,其在训练时会保留关于模型梯度的累计信息。在合并参数后,为了确保模型不会受到之前训练步骤中累积的梯度信息的影响,ReLoRA 执行了优化器状态的部分重置。

优化器(如 Adam)维护了两组矩阵来跟踪梯度的一阶矩(mean)和二阶矩(unbiased variance)。在 ReLoRA 中,这两组矩阵通常被称为 M(一阶矩)和 V(二阶矩)。为了重置优化器,ReLoRA 执行了以下操作:

  • **基于阈值的剪枝:基于M 和 V 矩阵中元素的绝对值大小,保留较大值的梯度信息,丢弃较小的梯度信息,从而将 M 和 V 矩阵中的大部分元素设置为零。**这个过程可以通过以下伪代码表示:
scss 复制代码
FOR each matrix M and V in the optimizer state
    PRUNE(M, threshold) // 保留幅度大于阈值的元素,其余设置为零
    PRUNE(V, threshold)
  • 由于 M 和 V 矩阵的大部分元素被剪枝,优化器的状态实际上是被部分重置了。这样做有助于模型在新的参数空间中探索,而不是被旧的梯度信息所束缚。

5. 学习率调整(Learning Rate Scheduling)

在优化器状态重置之后,ReLoRA 会调整学习率,以帮助模型在新的参数配置下稳定训练。

ReLoRA采用了一种锯齿形余弦调度(jagged cosine scheduler)来调整学习率。这种调度策略在每次优化器重置后将学习率设为零,然后通过一个预热阶段(例如,50-100步)逐渐恢复到之前的学习率。而后通过余弦退火方法计算新的学习率。

如下图所示:

每一次将学习率设置为0之后,会快速把学习率拉回设置之前的值,而后再改用余弦退火方法。

余弦退火是一种常见的学习率调度方法,它模拟了余弦函数的周期性变化。在这种方法中,学习率从一个较高的初始值开始,随着训练的进行逐渐减小。余弦退火的基本公式如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> η t = 1 2 η 0 ( 1 + cos ⁡ ( T 2 π t ) ) \eta_t = \frac{1}{2} \eta_0 (1 + \cos(\frac{T}{2} \pi t)) </math>ηt=21η0(1+cos(2Tπt))

其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> η t \eta_t </math>ηt 是在时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 的学习率, <math xmlns="http://www.w3.org/1998/Math/MathML"> η 0 \eta_0 </math>η0 是初始学习率, <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 是总的训练步数。

6. 局部低秩训练(Locally Low-Rank Training)

ReLoRA的实验分析表明,尽管预训练的神经网络在长期轨迹上表现出高秩更新,但对于足够小的轨迹,训练可以通过低秩更新有效近似。这意味着网络训练在局部上是低秩的,这一观察结果直接激励了ReLoRA的设计。

实验结果

实验设置

实验中,ReLoRA 被应用于不同规模的 Transformer 语言模型,包括 60M、130M、250M、350M 和 1.3B 参数的模型。所有实验都在没有数据重复的情况下进行单次训练(single epoch),并且在计算最优的数据量上进行训练,这是根据 Scaling Laws 估计的。

训练性能


实验结果显示,ReLoRA 在所有测试的模型规模上都显著优于 LoRA 训练。具体来说,ReLoRA 在 250M 模型上的性能甚至接近了全秩训练的性能。这表明 ReLoRA 能够有效地通过低秩更新来近似高秩训练的效果。

内存和速度提升


ReLoRA 在内存和训练速度方面都取得了显著的提升。例如,在 1.3B 参数的模型上,ReLoRA 节省了高达 5.5Gb 的 GPU RAM,并且根据模型大小和硬件设置,训练速度提高了 9-40%。这说明 ReLoRA 是一种高效的训练方法,可以在保持性能的同时减少资源消耗。

高秩训练的证据


通过对学习到的权重更新的奇异值谱进行分析,实验结果支持了 ReLoRA 能够执行高秩更新的观点。ReLoRA 的奇异值谱与全秩训练的更为相似,而与仅使用 LoRA 的有显著差异。这表明 ReLoRA 通过多次低秩更新,能够有效地实现高秩训练的效果。

结论

综合实验结果,ReLoRA 证明了其在训练大型神经网络时的有效性。它不仅能够节省大量的内存资源,还能在不同的硬件配置上提高训练速度。更重要的是,ReLoRA 能够在不牺牲模型性能的前提下,实现与全秩训练相似的结果。这些发现表明 ReLoRA 是一种有前景的高效训练方法,特别适合于资源受限或需要快速迭代的场景。

相关推荐
DuoRuaiMiFa31 分钟前
ChatGPT全新功能Canvas上线:开启智能编程与写作新篇章
人工智能·chatgpt
DisonTangor34 分钟前
Windows 11将新增基于AI的搜索、生成式填充和其它AI功能
人工智能
soso196836 分钟前
【AI自然语言处理应用】通过API调用通义晓蜜CCAI-对话分析AIO应用
人工智能·自然语言·ccai
网安-搬运工39 分钟前
RAG再总结之如何使大模型更好使用外部数据:四个不同层级及查询-文档对齐策略
人工智能·自然语言处理·大模型·llm·大语言模型·ai大模型·rag
大模型八哥40 分钟前
大模型扫盲系列——大模型实用技术介绍(上)
人工智能·程序人生·ai·大模型·llm·llama·ai大模型
被制作时长两年半的个人练习生1 小时前
【pytorch】权重为0的情况
人工智能·pytorch·深度学习
Elastic 中国社区官方博客1 小时前
使用 Vertex AI Gemini 模型和 Elasticsearch Playground 快速创建 RAG 应用程序
大数据·人工智能·elasticsearch·搜索引擎·全文检索
说私域2 小时前
地理定位营销与开源AI智能名片O2O商城小程序的融合与发展
人工智能·小程序
Q_w77422 小时前
计算机视觉小目标检测模型
人工智能·目标检测·计算机视觉