新建模范式Mamba——“Selectivity is All You Need?”

目录

一、快速走进和理解Mamba建模架构

(一)从Transformer的统治地位谈起

(二)另一条道路:结构化状态空间模型(SSM)

[(三)Mamba 的核心创新:Selective SSM + 简洁架构](#(三)Mamba 的核心创新:Selective SSM + 简洁架构)

[1. 引入选择性机制(Selectivity)](#1. 引入选择性机制(Selectivity))

[2. 设计硬件友好的并行递归算法](#2. 设计硬件友好的并行递归算法)

[3. 极简神经网络架构](#3. 极简神经网络架构)

[二、State Space Models:结构化状态空间模型的前世今生](#二、State Space Models:结构化状态空间模型的前世今生)

[(一)从连续系统到离散建模:S4 的核心结构](#(一)从连续系统到离散建模:S4 的核心结构)

(二)离散化(Discretization):从连续动态到可微计算图

[(三)高维状态空间与 GPU 亲和性](#(三)高维状态空间与 GPU 亲和性)

[(四)SSM 家族谱:S4 是谁的"儿子"?谁又是它的"继承人"?](#(四)SSM 家族谱:S4 是谁的“儿子”?谁又是它的“继承人”?)

[三、选择性状态空间模型(Selective State Space Models)](#三、选择性状态空间模型(Selective State Space Models))

(一)动机:选择性是一种压缩机制

[(二)引入选择机制到 SSM 中](#(二)引入选择机制到 SSM 中)

[(三)高效实现选择性 SSM](#(三)高效实现选择性 SSM)

[(四)一个极简的 SSM 架构](#(四)一个极简的 SSM 架构)

[1. 与门控机制的关系:RNN 门控机制其实是一种选择机制的特例](#1. 与门控机制的关系:RNN 门控机制其实是一种选择机制的特例)

[2. 对选择机制的三种直觉解释](#2. 对选择机制的三种直觉解释)

[2.1 可变时间间隔(Variable Spacing)](#2.1 可变时间间隔(Variable Spacing))

[2.2 上下文过滤(Filtering Context)](#2.2 上下文过滤(Filtering Context))

[2.3 边界重置(Boundary Resetting)](#2.3 边界重置(Boundary Resetting))

[3. 选择机制中各参数的解释与扩展](#3. 选择机制中各参数的解释与扩展)

Δ(Delta)控制输入选择程度

[A 控制动力系统的衰减/更新速率](#A 控制动力系统的衰减/更新速率)

[B 和 C 提供更精细的状态输入输出控制](#B 和 C 提供更精细的状态输入输出控制)

[四、实证评估(Empirical Evaluation)](#四、实证评估(Empirical Evaluation))

[(一)合成任务验证选择能力(Synthetic Tasks)](#(一)合成任务验证选择能力(Synthetic Tasks))

[1. 选择性拷贝任务(Selective Copying)](#1. 选择性拷贝任务(Selective Copying))

[2. 归纳头任务(Induction Heads)](#2. 归纳头任务(Induction Heads))

[(二)语言建模(Language Modeling)](#(二)语言建模(Language Modeling))

[1. Scaling Law 实验](#1. Scaling Law 实验)

[2. Zero-shot 下游评估](#2. Zero-shot 下游评估)

[(三) DNA Sequence Modeling](#(三) DNA Sequence Modeling)

[(四)Audio Waveform Generation](#(四)Audio Waveform Generation)

[(五) 计算效率评估](#(五) 计算效率评估)

(六)架构与机制消融实验

[(七)✅ 总结:Mamba 的实证贡献](#(七)✅ 总结:Mamba 的实证贡献)

[五、总结与展望:Mamba 的意义与未来方向](#五、总结与展望:Mamba 的意义与未来方向)


干货分享,感谢您的阅读!

基础模型现在推动着深度学习中大部分令人兴奋的应用,几乎普遍基于Transformer架构及其核心注意力模块。为了应对Transformer在长序列上的计算低效,已经提出了许多子平方时间架构,如线性注意力、门控卷积和递归模型,以及结构化状态空间模型(SSM)。然而,这些模型在语言等重要模态上,表现往往不如注意力机制。我们发现,诸如上述模型的一个关键弱点是它们无法进行基于内容的推理,并提出了若干改进。首先,简单地让SSM参数成为输入的函数,解决了其在离散模态中的弱点,允许模型根据当前token有选择性地在序列长度维度上传播或遗忘信息。其次,尽管这一改变阻止了高效卷积的使用,但我们设计了一个硬件感知的并行算法,支持递归模式。我们将这些选择性SSM集成到一个简化的端到端神经网络架构中,该架构无需注意力或甚至MLP模块(Mamba)。Mamba在推理速度上具有显著优势(比Transformer高出5倍吞吐量),且在序列长度上呈线性扩展,其在实际数据上的表现对于百万长度的序列有所提升。作为一种通用的序列模型骨架,Mamba在语言、音频和基因组学等多个模态上都达到了最先进的性能。在语言建模任务上,我们的Mamba-3B模型超越了相同规模的Transformer,并且在预训练和下游评估中与规模为其两倍的Transformer相当。

1 引言

基础模型(FMs),即在大量数据上预训练然后调整用于下游任务的大型模型,已经成为现代机器学习中的一种有效范式。这些基础模型的骨干通常是序列模型,能够处理来自各种领域的任意输入序列,如语言、图像、语音、音频、时间序列和基因组学(Brown et al. 2020; Dosovitskiy et al. 2020; Ismail Fawaz et al. 2019; Oord et al. 2016; Poli et al. 2023; Sutskever, Vinyals, 和 Quoc V Le 2014)。尽管这一概念对模型架构的选择没有特定要求,但现代基础模型主要基于单一类型的序列模型:Transformer(Vaswani et al. 2017)及其核心注意力层(Bahdanau, Cho, 和 Bengio 2015)。自注意力机制的有效性归功于其能够在上下文窗口内密集地路由信息,从而使其能够建模复杂的数据。然而,这一特性带来了根本性的缺陷:无法建模窗口外的信息,并且在窗口长度上具有平方级的计算复杂度。大量的研究已经提出了更高效的注意力变体来克服这些缺点(Tay, Dehghani, Bahri, 等,2022),但往往以牺牲其有效性的某些特性为代价。到目前为止,尚未有这些变体在各个领域的规模上被证明具有实证效果。

最近,结构化状态空间序列模型(SSMs)(Gu, Goel, 和 Ré 2022; Gu, Johnson, Goel, 等,2021)作为一种有前景的序列建模架构出现。这些模型可以被解释为递归神经网络(RNNs)和卷积神经网络(CNNs)的结合,受到经典状态空间模型(Kalman 1960)的启发。这类模型能够非常高效地进行计算,无论是以递归方式还是卷积方式,都在序列长度上实现线性或接近线性的扩展。此外,它们具有建模长程依赖的原则性机制(Gu, Dao, 等,2020),并且在长程领域基准如长程竞技场(Long Range Arena)中占据主导地位(Tay, Dehghani, Abnar, 等,2021)。许多种类的SSM(Gu, Goel, 和 Ré 2022; Gu, Gupta, 等,2022; Gupta, Gu, 和 Berant 2022; Y. Li 等,2023; Ma 等,2023; Orvieto 等,2023; Smith, Warrington, 和 Linderman 2023)在处理连续信号数据如音频和视觉等领域取得了成功(Goel et al. 2022; Nguyen, Goel, 等,2022; Saon, Gupta, 和 Cui 2023)。然而,在处理离散和信息密集型数据如文本时,它们的效果较差。

我们提出了一类新的选择性状态空间模型,通过在多个方面改进,实现在序列长度上线性扩展的同时,具备Transformer的建模能力。

选择机制

首先,我们识别了先前模型的一个关键限制:能够根据输入有效地选择数据(即关注或忽略特定的输入)。基于在一些重要的合成任务(如选择性复制和归纳头)中的直觉,我们通过基于输入对SSM参数进行参数化来设计一个简单的选择机制。这使得模型能够过滤掉无关的信息,并永远记住相关信息。

硬件感知算法

这一简单的改变为模型的计算带来了技术挑战;事实上,所有先前的SSM模型必须是时间和输入不变的,以保持计算的高效性。我们通过一个硬件感知的算法来克服这一挑战,该算法通过扫描而非卷积递归计算模型,但不会实现扩展状态,以避免在GPU内存层次结构之间进行I/O访问。结果,新的实现比以前的方法更快,理论上(在序列长度上线性扩展,相较于所有基于卷积的SSM的伪线性扩展)和在现代硬件上(在A100 GPU上速度提高最多3倍)。

架构

我们通过将先前SSM架构的设计(Dao, Fu, Saab, 等,2023)与Transformer的MLP模块结合,简化了先前的深度序列模型架构,从而得到了一个简单且同质的架构设计(Mamba),并融入了选择性状态空间。

选择性SSM,以及由此扩展的Mamba架构,是完全递归的模型,具有一些使其适合作为操作序列的通用基础模型骨干的关键属性。

(i) 高质量:选择性使得其在像语言和基因组学等密集模态上表现出色。

(ii) 快速训练和推理:在训练过程中,计算和内存在序列长度上线性扩展,且在推理时每一步仅需常数时间,因为它不需要缓存之前的元素。

(iii) 长上下文:质量和效率的结合带来了在实际数据上的性能提升,序列长度可扩展到1M。

我们通过实验证明了Mamba作为通用序列基础模型骨干的潜力,展示了其在预训练质量和领域特定任务表现方面,在多个模态和设置下的优势:

合成任务。 在一些重要的合成任务,如复制和归纳头(这些任务被认为是大型语言模型的关键任务)上,Mamba不仅能够轻松解决它们,而且能够无限地推断解决方案(>1M tokens)。

音频和基因组学。 Mamba在音频波形和DNA序列建模任务上超越了先前的最先进模型,如SaShiMi、Hyena和Transformers,在预训练质量和下游评估指标(例如,在一个具有挑战性的语音生成数据集上,FID降低了一半以上)上均表现优异。在这两种设置中,其性能随着上下文长度的增加,能够处理百万长度的序列。

语言建模。 Mamba是首个真正实现Transformer级别性能的线性时间序列模型,无论是在预训练困惑度还是下游评估中。通过扩展到1B参数的规模,我们展示了Mamba超越了大量基准,包括基于LLaMa(Touvron et al. 2023)的现代Transformer训练方法。我们的Mamba语言模型相比相同规模的Transformer具有5倍的生成吞吐量,且Mamba-3B的质量与其两倍规模的Transformer相当(例如,在常识推理上比Pythia-3B高出4个点,甚至超越了Pythia-7B)。

模型代码和预训练检查点已开源,地址为:https://github.com/state-spaces/mamba。

2 状态空间模型

结构化状态空间序列模型(S4)是近年来为深度学习提出的一类序列模型,广泛与递归神经网络(RNNs)、卷积神经网络(CNNs)和经典的状态空间模型相关。它们的灵感来自于特定的连续系统(1),该系统通过一个隐式的潜在状态 ℎ(𝑡) ∈ ℝ^𝑁,将一维函数或序列 𝑥(𝑡) ∈ ℝ 映射到 𝑦(𝑡) ∈ ℝ。

具体而言,S4 模型通过四个参数(Δ, 𝑨, 𝑩, 𝑪)定义,这些参数在两个阶段中定义了一个序列到序列的转换:

离散化

第一阶段将"连续参数"(Δ, 𝑨, 𝑩)转化为"离散参数"(𝑨, 𝑩),通过固定的公式 𝑨 = 𝑓𝐴(Δ, 𝑨) 和 𝑩 = 𝑓𝐵(Δ, 𝑨, 𝑩),其中对(𝑓𝐴, 𝑓𝐵)的组合称为离散化规则。可以使用不同的规则,比如在方程(4)中定义的零阶保持(ZOH):

离散化与连续时间系统有深刻的联系,这使它们具备额外的属性,如分辨率不变性(Nguyen, Goel, 等,2022),并自动确保模型是适当归一化的(Gu, Johnson, Timalsina, 等,2023;Orvieto 等,2023)。它还与RNN的门控机制相关(Gu, Gulcehre, 等,2020;Tallec 和 Ollivier,2018),我们将在第3.5节中回顾这一点。然而,从机械的角度来看,离散化可以简单地视为在SSM的前向计算图中的第一步。

SSM的其他变体可以跳过离散化步骤,直接对(𝑨, 𝑩)进行参数化(Zhang 等,2023),这可能更容易理解。

计算

在参数从(Δ, 𝑨, 𝑩, 𝑪)转化为(𝑨, 𝑩, 𝑪)后,模型可以通过两种方式计算:要么作为线性递归(2),要么作为全局卷积(3)。

通常,模型在训练时使用卷积模式(3),以便高效并行训练(在该模式下,整个输入序列会一次性看到);而在推理时则切换到递归模式(2),以便高效的自回归推理(此时输入一次看到一个时间步)。

线性时间不变性(LTI)

方程(1)到(3)的一个重要性质是,模型的动态是时间不变的。换句话说,(Δ, 𝑨, 𝑩, 𝑪),因此(𝑨, 𝑩)在所有时间步长中都是固定的。这个属性称为线性时间不变性(LTI),它与递归和卷积深刻相关。非正式地,我们认为LTI SSM相当于任何线性递归(2a)或卷积(3b),并将LTI作为这些模型类别的总称。

到目前为止,所有结构化SSM都是LTI的(例如,作为卷积计算),因为基本的效率约束(在第3.3节中讨论)。然而,这项工作的核心见解是,LTI模型在建模某些类型的数据时存在根本性的局限性,我们的技术贡献在于移除LTI约束,同时克服效率瓶颈。

结构与维度

最后,我们注意到,结构化SSM被称为"结构化",是因为高效地计算它们还需要在𝑨矩阵上施加结构。最常见的结构形式是对角形式(Gu, Gupta, 等,2022;Gupta, Gu, 和 Berant,2022;Smith, Warrington, 和 Linderman,2023),我们也使用这种形式。

在这种情况下,𝑨 ∈ ℝ^(𝑁 × 𝑁),𝑩 ∈ ℝ^(𝑁 × 1),𝑪 ∈ ℝ^(1 × 𝑁)矩阵可以由𝑁个数字表示。为了对批量大小为𝐵、长度为𝐿、具有𝐷通道的输入序列𝑥进行操作,SSM会独立地应用于每个通道。请注意,在这种情况下,总的隐藏状态的维度是每个输入的𝐷𝑁,并且在序列长度上计算它需要𝑂(𝐵𝐿𝐷𝑁)的时间和内存;这是第3.3节中解决的基本效率瓶颈的根源。

一般状态空间模型

我们注意到,状态空间模型这一术语具有非常广泛的意义,简单地表示任何具有潜在状态的递归过程。它已经被用来指代不同学科中的许多不同概念,包括马尔科夫决策过程(MDP)(强化学习(Hafner et al. 2020))、动态因果建模(DCM)(计算神经科学(Friston, Harrison, 和 Penny 2003))、卡尔曼滤波器(控制(Kalman 1960))、隐马尔可夫模型(HMM)和线性动力学系统(LDS)(机器学习),以及大型递归(有时也包括卷积)模型(深度学习)。

在本文中,我们将"SSM"这一术语专门用于指代结构化SSM类或S4模型(Gu, Goel, 和 Ré,2022;Gu, Gupta, 等,2022;Gupta, Gu, 和 Berant,2022;Hasani 等,2023;Ma 等,2023;Smith, Warrington, 和 Linderman,2023),并将这些术语互换使用。为了方便起见,我们还可能包括这类模型的衍生物,例如那些侧重于线性递归或全局卷积视角的模型(Y. Li 等,2023;Orvieto 等,2023;Poli 等,2023),并在必要时澄清细微差别。

SSM架构

SSM是独立的序列变换,可以被集成到端到端的神经网络架构中(我们有时也称SSM架构为SSNNs,它们与CNN层的关系类似,作为SSM层的实现)。我们讨论一些最著名的SSM架构,其中许多也将作为我们的主要基准。

线性注意力 (Katharopoulos 等,2020)是自注意力的近似,它涉及到一个递归过程,可以视为一种退化的线性SSM。

H3 (Dao, Fu, Saab, 等,2023)将这种递归推广到使用S4;它可以被视为一种架构,其中一个SSM被两个门控连接夹在中间(图3)。H3还在主SSM层之前插入了一个标准的局部卷积,将其框定为一个移位SSM。

Hyena (Poli 等,2023)使用与H3相同的架构,但将S4层替换为一个MLP参数化的全局卷积(Romero 等,2021)。

RetNet (Y. Sun 等,2023)在架构中增加了一个附加门控,并使用了一个更简单的SSM,允许一种替代的并行计算路径,使用多头注意力(MHA)的变体替代卷积。

RWKV(B. Peng 等,2023)是最近为语言建模设计的RNN,基于另一种线性注意力近似------无注意力Transformer(S. Zhai 等,2021)。它的主要"WKV"机制涉及LTI递归,并且可以视为两个SSM的比值。

其他与SSM及其架构密切相关的内容将在扩展的相关工作部分(附录B)中进一步讨论。我们特别突出S5(Smith, Warrington, 和 Linderman,2023)、QRNN(Bradbury 等,2016)和SRU(Lei 等,2017),它们是与我们的核心选择性SSM最为相关的方法。

3 选择性状态空间模型

我们通过合成任务的直觉来激发选择机制(第3.1节),然后解释如何将该机制融入状态空间模型(第3.2节)。由此产生的时变SSM无法使用卷积,这提出了如何高效计算它们的技术挑战。我们通过利用现代硬件上的内存层次结构来克服这一挑战(第3.3节)。接着,我们描述了一种简单的SSM架构,不需要注意力机制或MLP模块(第3.4节)。最后,我们讨论了选择机制的其他一些特性(第3.5节)。

3.1 动机:选择作为压缩的一种方式

我们认为序列建模的一个根本问题是将上下文压缩到一个更小的状态中。事实上,我们可以从这个角度来看待流行序列模型的权衡。例如,注意力机制既有效又低效,因为它明确地不压缩上下文。这可以从自回归推理需要显式地存储整个上下文(即KV缓存)这一事实中看出,这直接导致了Transformer的线性时间推理和二次时间训练的慢速性。另一方面,递归模型之所以高效,是因为它们具有有限的状态,这意味着常数时间推理和线性时间训练。然而,它们的有效性受到该状态压缩上下文的能力的限制。

为了理解这个原理,我们集中讨论两个合成任务的例子(图2)。

  • 选择性复制任务:通过改变要记住的标记位置,修改了流行的复制任务(Arjovsky, Shah, 和 Bengio 2016)。它需要内容感知推理,能够记住相关标记(彩色的)并过滤掉不相关的标记(白色的)。

  • 归纳头任务:这是一个广为人知的机制,假设能够解释大多数大语言模型(LLM)的上下文学习能力(Olsson 等,2022)。它需要上下文感知推理,以便在适当的上下文中知道何时生成正确的输出(黑色的)。

这些任务揭示了LTI模型的失败模式。从递归的角度来看,它们的常数动态(例如方程(2)中的(𝑨, 𝑩)转移)无法让它们从上下文中选择正确的信息,或以输入依赖的方式影响沿序列传递的隐藏状态。从卷积的角度来看,已知全局卷积可以解决经典的复制任务(Romero 等,2021),因为它只需要时间感知,但它们在处理选择性复制任务时遇到困难,因为缺乏内容感知(图2)。更具体地说,输入到输出之间的间隔是变化的,而静态卷积核无法建模这种变化。

总之,序列模型的效率与有效性权衡由它们压缩状态的能力来表征:高效模型必须有一个较小的状态,而有效模型必须有一个包含上下文中所有必要信息的状态。因此,我们提出构建序列模型的一个根本原则是选择性:即有上下文感知的能力,能够聚焦或过滤输入到序列状态中的信息。特别地,选择机制控制信息沿序列维度的传播或交互(有关更多讨论,请参见第3.5节)。

3.2 通过选择机制改进SSM

将选择机制引入模型的一种方法是使其影响序列中交互的参数(例如RNN的递归动态或CNN的卷积核)依赖于输入。

算法1和算法2展示了我们使用的主要选择机制。其主要区别仅仅在于使几个参数Δ、𝑩、𝑪成为输入的函数,并且在整个过程中相应地改变了张量的形状。特别地,我们强调这些参数现在具有一个长度维度𝐿,这意味着模型从时间不变变为时间变化。(请注意,形状注解已在第2节中描述。)这使得其与卷积(3)之间的等价性丧失,从而对效率产生影响,接下来我们将讨论这一点。

我们特别选择𝑠𝐵(𝑥) = Linear𝑁(𝑥),𝑠𝐶(𝑥) = Linear𝑁(𝑥),𝑠Δ(𝑥) = Broadcast𝐷(Linear1(𝑥)),以及𝜏Δ = softplus,其中Linear𝑑表示参数化的投影到维度𝑑。选择𝑠Δ和𝜏Δ的原因是它们与第3.5节中解释的RNN门控机制的连接。

3.3 选择性SSM的高效实现

硬件友好的原语,如卷积(Krizhevsky, Sutskever, 和 Hinton 2012)和注意力机制(Bahdanau, Cho, 和 Bengio 2015;Vaswani 等,2017),已广泛应用。我们在此的目标是使选择性SSM在现代硬件(如GPU)上同样高效。选择机制是相当自然的,早期的工作尝试将选择的特殊情况融入其中,例如让Δ在递归SSM中随时间变化(Gu, Dao 等,2020)。然而,正如之前所提到的,SSM使用中的一个核心限制是其计算效率,这也是为什么S4及其所有衍生物使用LTI(非选择性)模型,通常以全局卷积的形式出现。

3.3.1 先前模型的动机

我们首先重新审视这一动机,并概述我们克服先前方法局限性的方式。

  • 从高层次来看,递归模型如SSM始终在表现力和速度之间进行权衡:正如第3.1节所讨论的,具有更大隐藏状态维度的模型应当更有效,但更慢。因此,我们希望在不增加速度和内存成本的情况下,最大化隐藏状态维度。

  • 请注意,递归模式比卷积模式更灵活,因为后者(3)是通过扩展前者(2)得出的(Gu, Goel, 和 Ré 2022;Gu, Johnson, Goel 等,2021)。然而,这将要求计算并物化潜在状态ℎ,其形状为(B, L, D, N),远大于输入𝑥和输出𝑦的形状(B, L, D),差距为𝑁(SSM状态维度)倍。因此,引入了更高效的卷积模式,该模式可以绕过状态计算并物化一个大小仅为(B, L, D)的卷积核(3a)。

  • 先前的LTI状态空间模型利用递归-卷积双重形式,通过一个因子𝑁(≈ 10 − 100)增加有效状态维度,比传统的RNN大得多,而不会产生效率损失。

3.3.2 选择性扫描概述:硬件感知的状态扩展

选择机制的设计旨在克服LTI模型的局限性;与此同时,我们需要重新审视SSM的计算问题。我们通过三种经典技术来解决这个问题:核融合、并行扫描和重计算。我们做出了两个主要观察:

  • 朴素的递归计算需要𝑂(𝐵𝐿𝐷𝑁) FLOP,而卷积计算需要𝑂(𝐵𝐿𝐷 log(𝐿)) FLOP,前者具有较低的常数因子。因此,对于长序列和不是特别大的状态维度𝑁,递归模式实际上可以使用更少的FLOP。

  • 两个挑战是递归的顺序性质和大的内存使用。为了应对后者,就像卷积模式一样,我们可以尝试不实际物化完整的状态ℎ。

主要思想是利用现代加速器(GPU)的特性,仅在内存层次结构中更高效的级别中物化状态ℎ。特别地,大多数操作(除了矩阵乘法)都受内存带宽的限制(Dao, Fu, Ermon 等,2022;Ivanov 等,2021;Williams, Waterman, 和 Patterson 2009)。这包括我们的扫描操作,我们使用核融合来减少内存I/O的数量,从而显著加速相较于标准实现。

具体来说,我们不在GPU HBM(高带宽内存)中准备扫描输入(𝑨, 𝑩),而是将SSM参数(Δ, 𝑨, 𝑩, 𝑪)直接从较慢的HBM加载到快速SRAM,在SRAM中执行离散化和递归操作,然后将最终的输出(大小为(B, L, D))写回到HBM中。

为了避免顺序递归,我们观察到,尽管递归不是线性的,它仍然可以使用高效的并行扫描算法(Blelloch 1990;Martin 和 Cundy 2018;Smith, Warrington, 和 Linderman 2023)进行并行化。

最后,我们还必须避免保存反向传播所需的中间状态。我们巧妙地应用了经典的重计算技术,以减少内存需求:中间状态不会被存储,而是在反向传播时,在将输入从HBM加载到SRAM时重新计算。因此,融合的选择性扫描层具有与优化后的Transformer实现(使用FlashAttention)相同的内存需求。

融合核和重计算的详细信息见附录D。完整的选择性SSM层和算法如图1所示。

3.4 简化的SSM架构

与结构化SSM一样,选择性SSM是独立的序列转换,可以灵活地集成到神经网络中。H3架构是最著名的SSM架构的基础(第2节),这些架构通常由一个受到线性注意力启发的模块和一个MLP(多层感知机)模块交替组成。我们通过将这两个组件合并成一个,简化了这一架构,并将其均匀地堆叠(图3)。这一灵感来源于门控注意力单元(GAU)(Hua 等,2022),其为注意力机制做了类似的处理。

这一架构通过可控的扩展因子𝐸扩展模型维度𝐷。对于每个模块,大多数参数(3𝐸𝐷²)位于线性投影中(输入投影为2𝐸𝐷²,输出投影为𝐸𝐷²),而内部SSM的贡献较小。与此相比,SSM参数的数量(Δ、𝑩、𝑪的投影以及矩阵𝑨)要小得多。我们重复该模块,并与标准的归一化和残差连接交替使用,形成Mamba架构。在实验中,我们始终固定𝐸 = 2,并使用两层该模块来匹配Transformer的交错MHA(多头注意力)和MLP模块的12𝐷²参数。我们使用SiLU / Swish激活函数(Hendrycks 和 Gimpel,2016;Ramachandran, Zoph 和 Quoc V Le,2017),这样可以使得门控MLP变为流行的"SwiGLU"变体(Chowdhery 等,2023;Dauphin 等,2017;Shazeer,2020;Touvron 等,2023)。

最后,我们还使用了一个可选的归一化层(我们选择LayerNorm(J. L. Ba, Kiros, 和 Hinton,2016)),其灵感来源于RetNet在类似位置使用归一化层的做法(Y. Sun 等,2023)。

3.5 选择机制的特性

选择机制是一个更广泛的概念,可以以不同的方式应用,例如应用于更传统的RNN或CNN,或者应用于不同的参数(例如算法2中的𝑨),或者使用不同的变换𝑠(𝑥)。

3.5.1 与门控机制的关系

我们强调最重要的关系:RNN的经典门控机制是我们为SSM设计的选择机制的一个实例。我们注意到,RNN门控与连续时间系统离散化之间的联系是非常明确的(Funahashi 和 Nakamura,1993;Tallec 和 Ollivier,2018)。实际上,定理1是对Gu, Johnson, Goel等(2021年,第3.1引理)的改进,扩展到了ZOH离散化和输入依赖门控(证明见附录C)。更广泛地说,SSM中的Δ可以看作是RNN门控机制的广义版本。根据以往的研究,我们采纳了将SSM离散化视为启发式门控机制的理论基础的观点。

定理 1. 当 𝑁 = 1,𝑨 = −1,𝑩 = 1,𝑠Δ = Linear(𝑥),且 𝜏Δ = softplus 时,选择性SSM的递归(算法2)可以表示为:

如第3.2节所述,我们对𝑠Δ、𝜏Δ的具体选择正是基于这一联系。特别地,请注意,如果给定的输入𝑥𝑡应该被完全忽略(如在合成任务中所需),所有𝐷个通道都应忽略该输入,因此我们会将输入投影到1维,然后再用Δ进行重复/广播。

3.5.2 选择机制的解释

我们阐述了选择机制的三个特定机制效应。

可变间隔

选择性允许过滤掉可能出现在感兴趣的输入之间的无关噪声标记。这在选择性复制任务中有所体现,但在常见的数据模态中普遍存在,特别是对于离散数据------例如,语言填充词如"呃"。这一特性出现的原因是,模型可以机械地过滤掉任何特定的输入𝑥𝑡,举例来说,在门控RNN的情况下(定理1)当𝑔𝑡 → 0时。

过滤上下文

有实验证明,尽管更多的上下文应该严格提升性能,但许多序列模型在增加上下文长度时并未改善性能(F. Shi等,2023)。一个解释是,许多序列模型在必要时无法有效地忽略无关的上下文;一个直观的例子是全局卷积(以及一般的LTI模型)。另一方面,选择性模型可以随时重置其状态,以去除多余的历史信息,因此从原则上讲,随着上下文长度的增加,它们的性能会单调改善(例如,第4.3.2节)。

边界重置

在多个独立序列拼接在一起的设置中,Transformer可以通过实例化特定的注意力掩码来保持它们的独立性,而LTI模型则会在序列之间泄漏信息。选择性SSM也可以在边界处重置其状态(例如,Δ𝑡 → ∞,或者定理1中当𝑔𝑡 → 1时)。这些设置可能是人为的(例如,为了提高硬件利用率而将文档打包在一起),也可能是自然的(例如,强化学习中的episode边界(Lu等,2023))。

此外,我们还详细阐述了每个选择性参数的效应。

Δ的解释

通常,Δ控制着在多大程度上聚焦或忽略当前的输入𝑥𝑡。它是RNN门控(例如定理1中的𝑔𝑡)的推广:从机械角度来看,较大的Δ重置状态ℎ并专注于当前输入𝑥,而较小的Δ则保持状态并忽略当前输入。SSM(1)和(2)可以被解释为一个通过时间步长Δ离散化的连续系统,在这个背景下,直觉是大Δ → ∞代表系统长时间关注当前输入(从而"选择"它并忘记其当前状态),而小Δ → 0则代表一个被忽略的瞬时输入。

𝑨的解释

我们指出,虽然𝑨参数也可以是选择性的,但它最终只通过与Δ的交互(通过𝑨 = exp(Δ𝑨)离散化(4))影响模型。因此,Δ的选择性足以确保(𝑨, 𝑩)的选择性,并且是主要的改进来源。我们假设使𝑨具有选择性,除了(或代替)Δ,也会有类似的性能,因此为了简便,我们没有包含这一部分。

𝑩和𝑪的解释

如第3.1节所讨论的,选择性最重要的特性是过滤掉无关信息,以便将序列模型的上下文压缩成一个高效的状态。在SSM中,将𝑩和𝑪修改为选择性可以更精细地控制是否将输入𝑥𝑡传递到状态ℎ𝑡中,或将状态传递到输出𝑦𝑡中。这可以解释为,模型根据内容(输入)和上下文(隐藏状态)分别调节递归动态。

3.6 其他模型细节

实数与复数

大多数先前的SSM使用复数作为状态ℎ的一部分,这是在许多感知模态(例如音频、视频)中实现良好性能所必需的(Gu, Goel, 和 Ré 2022)。然而,有实验证明,在某些设置中,完全实值的SSM似乎也能很好地工作,甚至表现得更好(Ma et al. 2023)。我们默认使用实数值,这在除一个任务外的所有任务中都表现良好;我们假设复数和实数的权衡与数据模态中的连续-离散谱有关,其中复数对于连续模态(例如音频、视频)有帮助,但对于离散模态(例如文本、DNA)则没有。

初始化

大多数先前的SSM还建议使用特殊的初始化,特别是在复数值的情况下,这在低数据的情况下有助于提高模型表现。我们在复数情况下的默认初始化是S4D-Lin,在实数情况下是S4D-Real(Gu, Gupta等,2022),这些初始化基于HIPPO理论(Gu, Dao等,2020)。这些初始化将𝑨的第𝑛个元素定义为 −1/2 + 𝑛𝑖 和 −(𝑛 + 1),分别适用于复数和实数情况。然而,我们预计在大数据和实值SSM的情况下,许多初始化方法都能正常工作;一些消融实验将在第4.6节讨论。

Δ的参数化

我们将Δ的选择性调整定义为𝑠Δ(𝑥) = Broadcast𝐷 (Linear1 (𝑥)),这基于Δ的机械原理(第3.5节)。我们观察到,这可以从维度1推广到更大的维度R。我们将其设置为D的一个小分数,这与块中的主要线性投影相比使用了极少量的参数。我们还注意到,广播操作可以被视为另一种线性投影,初始化为特定的1和0的模式;如果这个投影是可训练的,那么它会导致替代的𝑠Δ(𝑥) = Linear𝐷 (Linear𝑅 (𝑥)),这可以视为一个低秩投影。

在我们的实验中,Δ参数(可以视为偏置项)初始化为

遵循先前在SSM中的工作(Gu, Johnson, Timalsina等,2023)。

备注 3.1

为了简便起见,在我们的实验结果中,我们有时将选择性SSM简写为S6模型,因为它们是带有选择机制并通过扫描计算的S4模型。

4. 实证评估

在第4.1节中,我们测试了Mamba解决第3.1节中提出的两个人工任务的能力。随后,我们在三个领域进行了评估,每个领域都进行了自回归预训练以及下游任务的评估。

  • 第4.2节:语言模型预训练(规模法则),以及零-shot下游评估。

  • 第4.3节:DNA序列预训练,并在长序列分类任务上进行微调。

  • 第4.4节:音频波形预训练,以及自回归生成的语音片段的质量。

最后,第4.5节展示了Mamba在训练和推理时的计算效率,第4.6节则对架构和选择性SSM的各个组件进行了消融实验。

4.1 人工任务

这些任务的完整实验细节,包括任务细节和训练协议,见附录E.1。

4.1.1 选择性复制任务

复制任务是最为研究的序列建模任务之一,最初是为测试递归模型的记忆能力而设计的。如第3.1节所讨论,LTI(线性递归和全局卷积)SSM可以轻松解决这个任务,通过仅仅跟踪时间,而不是考虑数据内容;例如,通过构建一个恰到好处的卷积核(见图2)。这一点在早期的全局卷积工作中得到了明确验证(Romero等,2021)。选择性复制任务通过随机化标记之间的间距来防止这种简化方案。注意,这个任务之前已作为去噪任务(Jing等,2019)引入。

许多先前的研究认为,添加架构门控(乘性交互)可以赋予模型"数据依赖性"并解决相关任务(Dao, Fu, Saab等,2023;Poli等,2023)。然而,我们认为这个解释在直觉上是不充分的,因为这种门控机制并没有沿着序列轴进行交互,无法影响标记之间的间距。特别地,架构门控并不是选择机制的一个实例(见附录A)。

表1确认了,像H3和Mamba这样的门控架构只在一定程度上提高了性能,而选择性机制(将S4修改为S6)轻松地解决了这一任务,尤其是在与这些更强大的架构结合时。

4.1.2 诱导头任务

诱导头(Olsson等,2022)是一个简单的任务,从机制解释的角度(Elhage等,2021)来看,令人惊讶的是它能预测LLM的上下文学习能力。它要求模型执行联想回忆和复制:例如,如果模型在序列中看到一个二元组"Harry Potter",那么下次"Harry"出现在同一序列中时,模型应该能够通过历史记录预测"Potter"。

数据集

我们在诱导头任务上训练了一个2层模型,序列长度为256,词汇表大小为16,这与先前在该任务上的工作(Dao, Fu, Saab等,2023)相当,但序列长度更长。我们还通过在从2^6 = 64到2^20 = 1048576的多个序列长度上进行评估,研究了模型的泛化和外推能力。

模型

根据诱导头任务的既定工作,我们使用了2层模型,这使得注意力机制能够从机制上解决诱导头任务(Olsson等,2022)。我们测试了多头注意力(8个头,采用不同的位置信息编码)和SSM变体。我们为Mamba设置了模型维度𝐷为64,其他模型为128。

结果

表2显示,Mamba---或者更准确地说,它的选择性SSM层---能够完美地解决这个任务,因为它能够选择性地记住相关的标记,同时忽略其间的所有其他内容。它在序列长度达到百万级时(比训练时看到的序列长4000倍)完美地泛化,而没有其他方法能突破2倍的限制。

在所有为注意力模型设计的位置信息编码变体中,xPos(为长度外推设计)略优于其他变体;同时注意,由于内存限制,所有注意力模型的测试仅限于序列长度2^14 = 16384。与其他SSM相比,H3和Hyena表现相似,这与Poli等(2023)的研究发现相反。

4.2 语言建模

我们在标准的自回归语言建模任务上评估了Mamba架构,并与其他架构进行了比较,包括预训练指标(困惑度)和零-shot评估。我们将模型的大小(深度和宽度)设置为与GPT3的规格相匹配。我们使用了Pile数据集(L. Gao, Biderman等,2020),并遵循了Brown等(2020)中描述的训练方案。所有训练细节见附录E.2。

4.2.1 扩展规律

作为基准,我们将Mamba与标准的Transformer架构(GPT3架构)进行比较,并与我们所知的最强Transformer配方(以下简称Transformer++)进行比较,该配方基于PaLM和LLaMa架构(例如,旋转嵌入、SwiGLU MLP、RMSNorm代替LayerNorm、无线性偏置和更高的学习率)。我们还将Mamba与其他最近的次二次架构进行比较(见图4)。所有模型的详细信息见附录E.2。

图4展示了在标准Chinchilla(Hoffmann等,2022)协议下的扩展规律,模型参数从约125M到约1.3B不等。Mamba是首个在性能上与非常强大的Transformer配方(Transformer++)相匹配的无注意力模型,特别是在序列长度增长时。(我们注意到,由于缺乏高效的实现,RWKV和RetNet基准(之前的强递归模型,也可以解释为SSM)在上下文长度为8k时的完整结果缺失,导致内存不足或不现实的计算需求。)

4.2.2 下游评估

表3展示了Mamba在一系列流行的零-shot下游评估任务中的表现。我们与这些尺寸下最知名的开源模型进行了比较,最重要的是Pythia(Biderman等,2023)和RWKV(B. Peng等,2023),这些模型在与我们的模型相同的分词器、数据集和训练长度(300B tokens)上进行训练。(请注意,Mamba和Pythia的训练上下文长度为2048,而RWKV的训练上下文长度为1024。)

4.3 DNA建模

受大规模语言模型成功的启发,近期研究探索了将基础模型范式应用于基因组学。DNA被比作语言,因为它由具有有限词汇的离散符号序列组成。DNA建模还以其需要长程依赖性而著称(Avsec等,2021)。我们探讨了Mamba作为基础模型(FM)骨干进行预训练和微调,采用了与最近的DNA长序列模型相关的设置(Nguyen, Poli等,2023)。特别地,我们关注了跨模型大小和序列长度的扩展规律(见图5),以及一个需要长上下文的困难下游合成分类任务(见图6)。

在预训练方面,我们大体遵循标准的因果语言建模(下一个标记预测)设置,训练和模型细节请参见附录E.2。对于数据集,我们大致遵循HyenaDNA(Nguyen, Poli等,2023)的设置,该方法使用HG38数据集进行预训练,该数据集包含一个单一的人类基因组,约有45亿个标记(DNA碱基对)用于训练集。

4.3.1 扩展性:模型大小

在本实验中,我们研究了基因组基础模型在不同模型骨干下的扩展性(见图5左)。

训练

为了使基线模型受益,我们使用短序列长度1024进行训练;如4.3.2节所示,我们预计在更长的序列长度下,Mamba的表现将更加突出。我们固定了一个全局批量大小为1024,每个批次大约包含2^20 ≈ 1M个标记。模型训练了10K梯度步骤,总共使用了10B个标记。

结果

图5(左)显示,Mamba的预训练困惑度随着模型大小的增加而平稳下降,并且Mamba在扩展性上优于HyenaDNA和Transformer++。例如,在最大的模型大小约40M参数时,曲线显示Mamba能够以大约3到4倍更少的参数匹配Transformer++和HyenaDNA模型的表现。

4.3.2 扩展性:上下文长度

在接下来的DNA实验中,我们研究了模型在序列长度方面的扩展性。我们只比较HyenaDNA和Mamba模型,因为在较长的序列长度下,二次方注意力变得非常昂贵。我们在不同的序列长度上进行预训练,分别是2^10 = 1024、2^12 = 4096、2^14 = 16384、2^16 = 65536、2^18 = 262144、2^20 = 1048576。我们固定了一个模型大小为6层,每层宽度128(大约1.3M-1.4M参数)。模型训练了20K梯度步骤,总共使用了大约330B个标记。较长的序列长度采用了与(Nguyen, Poli等,2023)类似的序列长度预热策略。

结果

图5(右)显示,Mamba能够有效利用更长的上下文,甚至在极长的1M序列长度下,其预训练困惑度随着上下文的增加而逐渐下降。另一方面,HyenaDNA模型在序列长度增加时表现变差。这与我们在3.5节中讨论的选择机制属性是吻合的。特别地,LTI模型不能选择性地忽略信息;从卷积的角度来看,一个非常长的卷积核会聚合来自长序列的所有信息,而这些信息可能是非常嘈杂的。需要注意的是,尽管HyenaDNA声称随着更长上下文而表现更好,但他们的结果并没有控制计算时间。

4.3.3 合成物种分类

我们在一个下游任务上评估模型,该任务是通过随机抽取它们DNA的连续片段对5种不同物种进行分类。这个任务改编自HyenaDNA,后者使用了{人类、狐猴、老鼠、猪、河马}作为物种集合。我们通过将任务改为在五种大猩猩物种之间进行分类,增加了任务的难度,物种包括{人类、黑猩猩、猩猩、猩红猩猩、倭黑猩猩},这些物种已知DNA相似度高达99%。

4.4 音频建模与生成

对于音频波形模态,我们主要与SaShiMi架构和训练协议进行比较(Goel等,2022)。该模型包括:

  1. 一个U-Net骨架,具有两个池化阶段,每个阶段通过因子𝑝来加倍模型维度𝐷,

  2. 在每个阶段交替使用S4和MLP块。

我们考虑将S4+MLP块替换为Mamba块。实验的详细信息见附录E.4。

4.4.1 长上下文自回归预训练

我们在YouTubeMix数据集(DeepSound,2017)上评估了预训练质量(自回归下一个样本预测),该数据集是一个标准的钢琴音乐数据集,先前的工作也使用了该数据集,包含4小时的钢琴独奏音乐,采样率为16000 Hz。预训练的细节大体遵循标准语言建模设置(见4.2节)。图7评估了将训练序列长度从2^13 = 8192增加到2^20 ≈ 10^6的影响,同时保持计算量不变。(由于数据整理方式的边缘情况,可能会导致扩展曲线中出现弯曲。例如,只有一分钟长的片段可用,因此最大序列长度实际上由60秒·16000Hz = 960000个样本所限制。)

Mamba和SaShiMi(S4+MLP)基线在更长的上下文长度下均表现出持续改进;Mamba在整个过程中表现更好,且在较长序列长度下差距扩大。主要度量是每字节比特数(BPB),它是标准负对数似然(NLL)损失的常数因子log(2),用于预训练其他模态。

我们注意到一个重要细节:这是本文中唯一一个从实数参数化切换到复数参数化的实验(见3.6节)。我们在附录E.4中展示了其他消融实验。

4.4.2 自回归语音生成

SC09是一个基准语音生成数据集(Donahue,McAuley和Puckette,2019;Warden,2018),包含以16000 Hz采样的1秒片段,内容是数字"零"到"九",具有高度可变的特征。我们大体上遵循Goel等(2022)提出的自回归训练设置和生成协议。

表4展示了Mamba-UNet模型与Goel等(2022)提出的各种基线模型的自动化评估指标:WaveNet(Oord等,2016)、SampleRNN(Mehri等,2017)、WaveGAN(Donahue,McAuley和Puckette,2019)、DiffWave(Z. Kong等,2021)和SaShiMi。一小型Mamba模型超越了最先进的(且更大规模的)GAN和扩散模型。在与基线模型参数相匹配的更大模型中,生成的音频在保真度指标上得到了显著提升。

表5以小型Mamba模型为基础,研究了外部阶段和中心阶段不同架构的组合。结果显示,在外部块中,Mamba始终优于S4+MLP,而在中心块中,Mamba > S4+MLP > MHA+MLP。

4.5 速度和内存基准测试

我们基准测试了SSM扫描操作(状态扩展𝑁 = 16)的速度,以及Mamba的端到端推理吞吐量,结果见图8。

我们的高效SSM扫描在序列长度超过2K时,比我们所知的最佳注意力实现(FlashAttention-2(Dao,2024))更快,且比标准的PyTorch扫描实现快20-40倍。Mamba的推理吞吐量比同等规模的Transformer高出4-5倍,因为它无需KV缓存,能够使用更大的批次大小。例如,一个6.9B参数的Mamba(未经训练)在推理吞吐量上将超过一个5倍更小的1.3B参数Transformer。详细信息见附录E.5,其中还包括内存消耗的基准测试。

4.6 模型消融实验

我们对模型的各个组件进行了详细的消融实验,重点研究语言建模任务中的设置(模型规模约为350M,使用Chinchilla令牌计数,与图4相同的设置)。

4.6.1 架构

表6研究了架构(块)及其内部SSM层(图3)的效果。我们发现:

  • 在之前的非选择性(LTI)SSM中,这些模型等同于全局卷积,其性能非常相似。

  • 将之前工作的复数值S4变体替换为实数值变体对性能影响不大,这表明(至少对于语言建模任务)考虑硬件效率时,实数值SSM可能是更好的选择。

  • 用选择性SSM(S6)替换这些模型会显著提高性能,验证了第3节中的动机。

  • Mamba架构的表现与H3架构相似(当使用选择性层时,似乎略有更好表现)。

我们还研究了将Mamba块与其他块(如MLP(传统架构)、MHA(混合注意力架构))交替使用的情况,详细信息请参见附录E.2.2。

4.6.2 选择性SSM

表7通过考虑选择性Δ、𝑩和𝑪参数(算法2)的不同组合,消融了选择性SSM层,结果显示Δ是最重要的参数,因为它与RNN门控机制(定理1)有关。

表8考虑了SSM的不同初始化,这在某些数据模态和设置中已被证明能产生较大差异(Gu, Goel, 和 Ré 2022;Gu, Gupta 等人 2022)。在语言建模任务中,我们发现较简单的实数值对角初始化(S4D-Real,第3行)比更常见的复数值参数化(S4D-Lin,第1行)表现更好。随机初始化也表现良好,这与先前工作的发现一致(Mehta 等人 2023)。

表9和表10分别考虑了Δ和(𝑩,𝑪)投影的维度变化。从静态到选择性投影的变化提供了最大的益处,而进一步增加维度通常会略微提高性能,并伴随着小幅的参数量增加。

特别值得注意的是,当状态大小𝑁增加时,选择性SSM表现出显著改进,在仅增加1%的参数量的情况下,困惑度改善超过1.0。这验证了我们在第3.1节和第3.3节中的核心动机。

5 讨论

我们讨论相关工作、局限性以及一些未来的研究方向。

相关工作:附录A讨论了选择机制与类似概念的关系。附录B提供了关于SSM和其他相关模型的扩展相关工作。

没有免费午餐:连续-离散谱:结构化SSM最初被定义为连续系统(1)的离散化,并且在感知信号等连续时间数据模态(例如音频、视频)上具有很强的归纳偏置。如第3.1节和第3.5节所讨论,选择机制克服了它们在离散模态(如文本和DNA)中的弱点;但反过来,这可能会妨碍它们在LTI SSM擅长的数据上的表现。我们在音频波形上的消融实验更详细地研究了这一权衡。

下游适配能力:基于Transformer的基础模型(特别是LLM)具有丰富的属性和与预训练模型的交互模式,如微调、适应、提示、上下文学习、指令调优、RLHF、量化等。我们特别感兴趣的是,像SSM这样的Transformer替代品是否具备类似的属性和适配能力。

扩展:我们的实证评估局限于较小的模型规模,低于大多数强大的开源LLM(例如Llama(Touvron 等人 2023))以及其他递归模型如RWKV(B. Peng 等人 2023)和RetNet(Y. Sun 等人 2023)的规模,这些模型已在7B参数及以上的规模上进行评估。仍需评估Mamba在这些更大规模上的表现是否仍然具有优势。我们还注意到,扩展SSM可能涉及进一步的工程挑战和模型调整,这些内容在本文中没有讨论。

6 结论

我们向结构化状态空间模型引入了一种选择机制,使其能够在序列长度上线性扩展的同时执行依赖上下文的推理。当该机制被融入一个简单的无注意力架构时,Mamba在多种领域中达到了最先进的成果,甚至超越了强大的Transformer模型的性能。我们对选择性状态空间模型在不同领域中构建基础模型的广泛应用感到兴奋,尤其是在需要长上下文的创新模态(如基因组学、音频和视频)中。我们的结果表明,Mamba是成为通用序列模型骨干的强有力候选者。

致谢

我们感谢Karan Goel、Arjun Desai和Kush Bhatia对草稿的有益反馈。

备注:本文是对《Mamba: Linear-Time Sequence Modeling with Selective State Spaces》为代表的多篇论文的直接翻译和解读。