Branch-Train-MiX: 可以大幅度降低训练成本的新型混合专家模型架构

在人工智能的广阔领域中,大型语言模型(LLMs)的研究始终是技术前沿的热门话题。随着技术的不断进步,我们对LLMs的期望也在不断提高,不仅希望它们能够处理日常的语言任务,还期待它们在专业领域如编程、数学推理和世界知识等方面展现出更高的能力。在这样的背景下,Meta的研究团队提出了一种名为Branch-Train-MiX(BTX)的创新方法,旨在高效地提升LLMs在多个专业领域的能力。

论文标题:Branch-Train-MiX: Mixing Expert LLMs into a Mixture-of-Experts LLM

机构:FAIR at Meta

论文链接:arxiv.org/pdf/2403.07...

1. Motivation

传统的分布式LLMs训练方法成本高昂、通信成本大、同步训练脆弱等问题。这些问题限制了模型的训练效率,阻碍了模型在特定领域的应用。本文所介绍的Branch-Train-MiX就是一个针对于垂直领域的LLMs的低代价训练方法。

2. 创新的方法

BTX方法的核心在于它的两阶段训练策略:

  • 首先是分支训练(Branch-Train),从一个预训练的基础模型出发,训练多个专家模型,这些专家模型是在不同的数据集上相互隔离的进行并行训练的;
  • 其次是混合专家(MiX),将这些专家模型的前馈参数混合到Mixture-of-Experts(MoE)层中,并通过MoE微调阶段学习token级别的路由。这种方法不仅提高了训练效率,还保持了模型的统一性,使得最终的模型可以像任何其他标准LLM一样进行微调或使用。

下面我将详细解释这些阶段:

2.1 Branch & Train: 并行异步专家训练

Branch(分支):

BTX方法从一个预训练的种子模型(seed model) 开始,更简单的说,我们首先要有一个预训练完的LLM。而后我们将这个种子模型复制多份,每个复制的模型都会训练为一个独立的专家模型(expert LLM)。这个过程称为"分支",因为它就像代码的分支一样,从一个相同的基础代码中,改动不同的部分,从而形成不同的版本。

Train(训练):

接下来,每个专家模型在各自的特定领域数据集上进行训练。这些数据集分别对应于不同的知识领域,比如我们将复制三个种子模型,然后分别在数学数据集,代码数据集,维基百科数据集上进行训练,每个数据集对应不同的领域,每个专家模型通过这种方式学习并专精于其对应领域的特定知识。

关键点是,这些专家模型是并行且异步 地训练的,即它们独立地在不同的计算资源上运行,相互之间无需任何同步。这种训练方式大幅度压低了通信成本,并提高了训练的吞吐量。

2.2 MiX: 结合独立专家成为混合专家模型

在所有专家模型异步训练完成后,BTX方法进入混合阶段。在这个阶段,专家模型的前馈子层(feedforward sublayers)被合并到一个统一的Mixture-of-Experts(MoE)模块中。这个MoE模块在模型的每一层都包含所有专家的前馈子层,并通过一个路由器网络(router network)动态选择在给定输入令牌时应使用哪个专家的前馈子层。

首先说明一下概念:

前馈子层(Feedforward Sublayers)

在深度学习的Transformer架构中,前馈子层(通常表示为FF层)是网络中的一个基本组件,它对输入数据进行非线性变换。具体来说,每个Transformer层包含一个自注意力子层(Attention Sublayer)和一个前馈子层。前馈子层通常由两个线性变换组成,中间夹着一个非线性激活函数(如ReLU或GELU)。在BTX方法中,每个专家模型的前馈子层被设计为处理特定领域的数据,这意味着它们学习到了与各自领域相关的特征表示。

Mixture-of-Experts(MoE)模块

MoE模块是一种特殊的网络结构,它允许模型在处理不同输入时动态地选择和激活最合适的专家子网络。MoE的核心思想是将大型网络分解为多个较小的专家网络,每个专家网络负责处理特定类型的输入数据。这种方法可以显著提高模型的计算效率,因为在任何给定时间,只有一小部分专家网络被激活,从而减少了无效计算。

路由器网络(Router Network)

路由器网络是MoE结构中的关键组件,负责根据输入数据的特征动态选择要激活的专家网络。在BTX方法中,路由器网络接收输入令牌(token),并输出一个路由概率分布,指示每个专家网络处理该令牌的相对重要性。路由器的输出通常通过softmax函数转换为概率分布,以确保选择的专家网络总数与MoE模块中的专家数量相匹配。

MiX(混合)过程:

在BTX方法的MiX阶段,研究者首先将所有专家模型的前馈子层逐层的合并到一个MoE模块中。这个MoE模块在每个Transformer层中包含所有专家的前馈子层,并通过路由器网络来决定对于每个输入令牌应该激活哪些专家的前馈子层。这种路由机制允许模型在处理特定任务时,利用在训练阶段学习到的领域特定知识。

具体来说,假设我们输入了一个token <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x,每个专家的前馈子层都会计算FFl i(x),但是只有那些被路由器网络选中的专家子层才会被实际用于计算最终输出。这通过引入一个路由函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> g g </math>g和一个线性变换 <math xmlns="http://www.w3.org/1998/Math/MathML"> W l W_l </math>Wl来实现,公式如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> F F M o E l ( x ) = ∑ i = 1 N g i ( W l x ) F F i l ( x ) \mathrm{FF}{\mathrm{MoE}}^l(x)=\sum{i=1}^N g_i\left(W_l x\right) \mathrm{FF}_i^l(x) </math>FFMoEl(x)=∑i=1Ngi(Wlx)FFil(x)

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> F F i l ( x ) FF_i^{l}(x) </math>FFil(x) 表示第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 个领域专家模型在第 <math xmlns="http://www.w3.org/1998/Math/MathML"> l l </math>l 层的前馈子层输出。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> W l W_l </math>Wl是第 <math xmlns="http://www.w3.org/1998/Math/MathML"> l l </math>l 层的线性变换矩阵。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> g i g_i </math>gi是第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 个专家的路由函数。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 是输入向量。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N 是专家的总数。

这里的 <math xmlns="http://www.w3.org/1998/Math/MathML"> g i ( W l x ) g_i\left(W_l x\right) </math>gi(Wlx)是一个稀疏的路由函数,它决定了每个token应该由哪些专家模型处理。通常使用Top-k路由,其中k=2,意味着对于每个token,只会选择两个最相关的专家子层进行计算。

除了前馈子层外,模型的其他部分(如自注意力层和嵌入层)通过简单地平均各个专家模型的权重来合并。这个过程不引入新的参数,确保了模型的统一性和进一步微调的能力。

简而言之,所谓的混合过程就是:将自注意力层和嵌入层通过平均值进行合并,而前馈层则在每一层都添加一个路由函数,模型在中间计算时在每一层都可能会选择不同的前馈子层进行计算。

MoE-finetuning(MoE微调):

在混合阶段之后,模型会进行MoE微调,这是一个继续训练的过程,旨在优化路由器的选择策略,并调整模型权重以提高整体性能。这一阶段,模型在所有用于训练专家模型的数据上进行微调,使得路由器能够学习如何在不同领域间有效地混合专家的前馈子层。

4. 实验分析

4.1 实验设置

实验基于Llama-2 7B模型。研究者们创建了三个专家模型,分别在数学、编程和维基百科数据集上进行训练。此外,还包括了原始的Llama-2 7B模型作为一个"通才"专家。这些专家模型被合并到一个MoE模型中,并在所有数据源上进行了进一步的微调。

4.2 主要结果

4.2.1 整体性能

  • 专家模型的专业化:实验结果显示,每个专家模型在其专长的领域内表现出色。例如,数学专家在GSM8K和MATH等数学任务上取得了显著的进步,而编程专家在HumanEval和MBPP等编程任务上表现优异。
  • BTX模型的全面提升:BTX模型在所有领域任务上都有所提升,尤其是在数学和编程任务上,接近或超过了专门化模型的性能。这表明BTX方法能够有效地结合专家模型的专业化知识,同时保持或提升在其他任务上的性能。
  • 与基线的比较:BTX模型与其他基线模型相比,包括Llama-2 7B、Dense模型、Sparse upcycling和Branch-Train-Merge (BTM) 方法,展现了更好的性能。特别是在数学和编程领域,BTX模型的性能显著优于BTM模型,这表明MoE微调阶段学习令牌级路由的好处。

4.2.2 计算效率

  • 与Dense和BTM的比较:尽管BTX在MoE训练阶段使用了较少的训练预算,但其在一般能力上的提升与Dense和BTM相比更为显著。
  • 与Sparse upcycling的比较:作为一种特殊的BTX(没有专家训练阶段),Sparse upcycling在相同的或更大的计算预算下表现不如BTX,这表明专家并行训练的计算效率优势。

4.3 消融研究和分析

4.3.1 BTX训练的消融研究

  • 路由方法:研究了不同的路由方法,包括Switch、Sample Top-1和Top-2,以及是否使用负载平衡(load balancing)。结果表明,Top-2路由与负载平衡相结合在大多数任务上表现最佳。
  • 训练策略:测试了不同的BTX训练变体,如不使用负载平衡、冻结专家的前馈模块等。这些消融实验有助于理解不同设计选择对模型性能的影响。

4.3.2 路由分析

  • 路由决策:分析了不同领域任务的令牌路由决策,发现Top-2路由与负载平衡确保了专家间的负载分布更均匀。
  • 专家参与度:在数学和编程任务中,GSM8K任务倾向于选择编程和通才专家,而MATH任务更多依赖于数学专家。这表明不同任务根据其内容和难度选择了最合适的专家。

5 结论

本文提出的Branch-Train-MiX (BTX) 方法的主要目的是为了解决大型语言模型(LLMs)在多个专业领域内进行高效训练的问题。具体来说,BTX旨在实现以下几点:

  1. 提升多领域能力:通过训练专家模型来增强LLMs在特定领域(如编程、数学推理、世界知识等)的性能。
  2. 降低通信成本:通过并行异步训练专家模型,减少在分布式训练中保持多个模型副本同步所需的通信开销。
  3. 提高训练效率:并行训练允许模型利用更多的计算资源,从而加快训练速度。
  4. 增强模型通用性:通过MoE结构整合专家模型,保持一个统一的模型,以便进行进一步的微调和应用。
  5. 保持计算效率:尽管BTX模型拥有更多的参数,但它通过稀疏激活的方式保持了较低的推理计算成本。

相对于其他方法,BTX的优势包括:

  1. 专家训练的并行性:与同步训练相比,BTX的分支训练阶段可以异步进行,这意味着每个专家模型可以独立于其他模型训练,从而提高了训练过程的鲁棒性和效率。
  2. 细粒度的MoE混合:BTX不是简单地混合专家模型的最终输出,而是在Transformer的每个层内进行更细粒度的混合,这允许模型在更深层次上利用专家知识。
  3. 路由机制的优化:BTX使用路由器网络来动态选择最合适的专家前馈子层,这种基于令牌的路由方法比传统的基于任务的路由方法更加灵活和高效。
  4. 统一的微调能力:与Branch-Train-Merge (BTM) 方法相比,BTX通过MoE微调阶段学习令牌级路由,最终得到一个可以像任何标准LLM一样进行微调的统一模型。
  5. 性能与效率的平衡:BTX在保持较低推理计算成本的同时,实现了在多个专业领域内的性能提升,这使得它在计算效率和任务性能之间取得了良好的平衡。
  6. 减少过拟合风险:由于BTX方法在多个领域上训练专家模型,这有助于模型学习到更泛化的特征表示,从而减少在特定领域数据上过拟合的风险。
相关推荐
机器之心11 分钟前
出海应用也能享受高速稳定的 DeepSeek-R1?亚马逊云科技出手了
人工智能·openai
机器之心13 分钟前
一家高校实验室,走出 12 家明星 AI 初创公司!Pieter Abbeel:我的 NB 学生们
人工智能·openai
阿正的梦工坊14 分钟前
卷积神经网络(CNN):深度解析其原理与特性
人工智能·神经网络·cnn
nenchoumi311920 分钟前
AutoGen学习笔记系列(十七)Examples - Literature Review
人工智能·笔记·python·学习·语言模型
机器之心23 分钟前
FP8 模型不再挑卡!DeepSeek 推理成本减半速度翻番,清华团队开源「赤兔」推理引擎
人工智能·openai
智联视频超融合平台1 小时前
网络视频监控平台在医疗领域的应用
网络·人工智能·音视频·健康医疗·视频编解码
嘻嘻哈哈开森1 小时前
AI 代理框架深度对比分享:Agno、OpenManus 和 OWL
人工智能·架构
WenGyyyL1 小时前
使用OpenCV和MediaPipe库——抽烟检测(姿态监控)
人工智能·opencv·计算机视觉
赛卡1 小时前
自动驾驶中间件技术辨析:ROS、Apex.Grace、DDS、AutoSAR和AutoSAR Adaptive
人工智能·中间件·自动驾驶
大模型真好玩1 小时前
大模型私人定制:5分钟教你不写一行代码微调构建属于你的大模型(使用llama-factory微调Qwen大模型)
人工智能·deepseek