深度学习新架构Mamba
论文介绍 Mamba: Linear-Time Sequence Modeling with Selective State Spaces
关注微信公众号: DeepGoAI
项目地址:github.com/state-space... (已经6.3k+)
本文介绍了一种新的序列模型架构,名为Mamba,它通过选择性状态空间模型(Selective State Space Models, SSMs)来改进传统的状态空间模型。Mamba通过输入依赖的方式调整SSM参数,允许模型根据当前的数据选择性地传递或遗忘信息,从而解决了以前模型在处理离散和信息密集型数据(如文本)时的不足。此外,尽管这种改变使得模型不能使用高效的卷积计算,研究者设计了一种硬件感知的并行算法,以递归模式运行,使得Mamba在推理速度上比传统的Transformer快5倍,并且在序列长度上实现线性缩放。在语言模型、音频和基因组数据模型等多个领域,Mamba都取得了最先进的性能。
上图提供了结构化状态空间模型(SSM) 如何处理输入数据的概述。具体来说,它说明了输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 的每个通道是如何通过更高维度的潜在状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 独立映射到输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y 的。例如,如果输入有 5 个通道( <math xmlns="http://www.w3.org/1998/Math/MathML"> D = 5 D = 5 </math>D=5),则通过维度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> N = 4 N = 4 </math>N=4 的潜在状态处理它们。先前实现的 SSM 由于其在处理大有效状态时的效率而被注意到,这个大有效状态是维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N,以及批量大小 <math xmlns="http://www.w3.org/1998/Math/MathML"> B B </math>B 和序列长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L 的乘积。这种效率是通过避免这个大状态的实体化(materialization of the large state)并改用依赖时间不变性的替代计算路径来实现的------意味着参数( <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ \Delta </math>Δ, <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A, <math xmlns="http://www.w3.org/1998/Math/MathML"> B B </math>B, <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C)随时间保持不变。文本提到了一个"选择机制",在此上下文中引入了输入依赖的动态。这意味着参数现在可能会根据输入变化,引入了更动态且可能更准确的时间数据建模。然而,这也需要一个"谨慎的硬件感知算法",因为增加的复杂性和计算要求。为确保模型可以在特定的硬件配置上有效执行而做的优化或考虑,例如提到的 GPU 及其 SRAM 和 HBM(高带宽内存)。
方法概述
结构化状态空间序列模型(S4)
结构化状态空间序列模型(Structured State Space Sequence Models,简称S4)是一种新兴的深度学习序列模型,它与循环神经网络(RNNs)、卷积神经网络(CNNs)及经典的状态空间模型相关联。这些模型受到了一种将一维函数或序列映射到另一个一维函数或序列通过隐式潜在状态的特定连续系统的启发。
具体来说,S4模型使用四个参数( <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ A \Delta A </math>ΔA, <math xmlns="http://www.w3.org/1998/Math/MathML"> B B </math>B, <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C)定义序列到序列的转换过程,分为两个阶段:
第一阶段(离散化):将"连续参数"( <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ A \Delta A </math>ΔA, <math xmlns="http://www.w3.org/1998/Math/MathML"> B B </math>B)转换为"离散参数"( <math xmlns="http://www.w3.org/1998/Math/MathML"> A ˉ \bar{A} </math>Aˉ, <math xmlns="http://www.w3.org/1998/Math/MathML"> B ˉ \bar{B} </math>Bˉ)。
第二阶段:通过离散化后的参数计算序列转换,可以通过线性递归或全局卷积两种方式实现
S4模型 核心公式
首先描述S4模型的核心动态,即如何从一个时间步骤到下一个时间步骤更新潜在状态,并如何从这个潜在状态生成输出序列: 公式(1a): <math xmlns="http://www.w3.org/1998/Math/MathML"> h ′ ( t ) = A h ( t − 1 ) + B x ( t ) h'(t) = A h(t-1) + B x(t) </math>h′(t)=Ah(t−1)+Bx(t)
这个公式描述了潜在状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h ( t ) h(t) </math>h(t)的更新规则,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A是状态转换矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> B B </math>B是输入到状态的转换矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> x ( t ) x(t) </math>x(t)是当前时间步的输入。这表明当前的潜在状态是由前一个时间步的潜在状态和当前输入的组合决定的。
公式(1b): <math xmlns="http://www.w3.org/1998/Math/MathML"> y ( t ) = C h ( t ) y(t) = C h(t) </math>y(t)=Ch(t)
这个公式描述了如何从潜在状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h ( t ) h(t) </math>h(t)生成输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ( t ) y(t) </math>y(t),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C是状态到输出的转换矩阵。这表明输出直接依赖于当前的潜在状态。
进一步,作者提供这两个公式(1a)和(1b)的另一种表示形式,强调了模型的递归特性:
公式(2a): <math xmlns="http://www.w3.org/1998/Math/MathML"> h t = A ˉ h t − 1 + B ˉ x t h_t = \bar{A} h_{t-1} + \bar{B} x_t </math>ht=Aˉht−1+Bˉxt
这是对(1a)的简化表示,强调了潜在状态是通过递归方式更新的,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t − 1 h_{t-1} </math>ht−1代表前一时间步的潜在状态。
公式(2b): <math xmlns="http://www.w3.org/1998/Math/MathML"> y t = C h t y_t = C h_t </math>yt=Cht
这是对(1b)的简化表示,说明了输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> y t y_t </math>yt是当前潜在状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht的直接函数。
以下两个公式提供了通过全局卷积实现序列到序列转换的视角:
公式(3a): <math xmlns="http://www.w3.org/1998/Math/MathML"> K ˉ = ( C B ˉ , C A B ˉ , ... , C A ˉ k B ˉ , ... ) \bar{K} = (C\bar{B}, C\bar{AB}, \ldots, C\bar{A}^{k}\bar{B}, \ldots) </math>Kˉ=(CBˉ,CABˉ,...,CAˉkBˉ,...)
公式(3b): <math xmlns="http://www.w3.org/1998/Math/MathML"> y = x ∗ K ˉ y = x \ast \bar{K} </math>y=x∗Kˉ
公式(3a)和(3b) 展示了如何通过状态空间模型将输入转换成输出。公式(3a)描述了通过状态转换矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> A ˉ \bar{A} </math>Aˉ对输入矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> B ˉ \bar{B} </math>Bˉ的连续应用,结合输出矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C,构造出核 <math xmlns="http://www.w3.org/1998/Math/MathML"> K ˉ \bar{K} </math>Kˉ。这个过程有效地建立了一系列应用于输入的转换。公式(3b)展示了输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x与核 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K的卷积,说明了如何综合变换后的输入产生输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y。这种公式化的目的是为了模拟动态系统在将输入转换为输出时的行为,捕捉时间依赖性和系统参数对输出的影响。
离散化
首先将"连续参数"( <math xmlns="http://www.w3.org/1998/Math/MathML"> ∆ ∆ </math>∆、 <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A、 <math xmlns="http://www.w3.org/1998/Math/MathML"> B B </math>B)转换为"离散参数"( <math xmlns="http://www.w3.org/1998/Math/MathML"> A ˉ \bar{A} </math>Aˉ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> B ˉ \bar{B} </math>Bˉ)的过程,这一转换通过固定的公式完成,例如通过零阶保持(ZOH)规则,可以得到: <math xmlns="http://www.w3.org/1998/Math/MathML"> [ A = e Δ A , B = ( Δ A ) − 1 ( e Δ A − I ) ⋅ Δ B ] [A = e^{\Delta A}, \quad B = (\Delta A)^{-1}(e^{\Delta A} - I) \cdot \Delta B] </math>[A=eΔA,B=(ΔA)−1(eΔA−I)⋅ΔB]
这种离散化与连续时间系统有深刻的联系,它为模型赋予了额外的属性,如分辨率不变性和自动的模型正则化等。此外,它还与RNNs的门控机制有关,但从机械角度来看,离散化可以简单视为SSM在前向传播中的计算图的第一步。
计算
参数从( <math xmlns="http://www.w3.org/1998/Math/MathML"> ∆ ∆ </math>∆、 <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A、 <math xmlns="http://www.w3.org/1998/Math/MathML"> B B </math>B、 <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C)转换为( <math xmlns="http://www.w3.org/1998/Math/MathML"> A ˉ \bar{A} </math>Aˉ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> B ˉ \bar{B} </math>Bˉ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C)后,模型可以以两种方式计算:线性递归或全局卷积。模型通常在训练时使用卷积模式(以便于并行化处理),在自回归推理时切换到递归模式(逐时间步处理输入)。
线性时间不变性(LTI)
这些方程的一个重要属性是模型的动态性质在时间上是恒定的,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∆ ∆ </math>∆、 <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A、 <math xmlns="http://www.w3.org/1998/Math/MathML"> B B </math>B、 <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C(包括 <math xmlns="http://www.w3.org/1998/Math/MathML"> A ˉ \bar{A} </math>Aˉ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> B ˉ \bar{B} </math>Bˉ)对所有时间步都是固定的。这种性质称为线性时间不变性,与递归和卷积深度相关联。从直观上讲,LTI SSMs等价于任何线性递归或卷积,并使用LTI作为这些模型类别的统称。
然而,所有的结构化SSMs都是LTI的,即通过卷积计算,这是由于基本的效率约束。但本介绍的工作的一个核心洞见是,LTI模型在模拟某些类型的数据时存在根本性限制,作者的技术贡献包括在克服效率瓶颈的同时移除LTI约束。
结构和维度
高效计算结构化SSMs也需要对 <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A矩阵施加结构,最常见的结构形式是对角线。在这种情况下, <math xmlns="http://www.w3.org/1998/Math/MathML"> A ∈ R N × N , B ∈ R N × 1 , C ∈ R l × N A \in \mathbb{R}^{N \times N}, \quad B \in \mathbb{R}^{N \times 1}, \quad C \in \mathbb{R}^{l \times N} </math>A∈RN×N,B∈RN×1,C∈Rl×N矩阵可以用数字表示。为了处理具有通道数的输入序列,SSM独立应用于每个通道。在这种情况下,总的隐藏状态具有每个输入的维度,其计算跨序列长度需要的时间和内存是效率瓶颈的根源。
Selective State Space Models
文中进一步提出一种改进的状态空间模型(SSM),这种模型通过引入选择机制来增强其对输入序列的处理能力。这个选择机制可能允许模型根据输入的相关性来动态调整其关注点,从而提高处理长序列数据的效率和准确性。
上图描述了几个不同的任务,用于展示时间不变模型如线性递归和全局卷积的局限性,以及选择性模型在处理这些限制方面的优势。这些任务包括:标准复制任务(左图):这个任务在输入和输出元素之间有恒定的间距,可以通过时间不变模型,如线性递归和全局卷积轻松解决。 选择性复制任务(右上图):这个任务在输入之间有随机间距,需要时间变化模型能够根据内容选择性地记住或忽略输入。 归纳头任务(右下图):这是一种基于上下文检索答案的关联回忆示例,对于大型语言模型来说是一个关键能力。
从中可以发现在设计序列模型时,引入能够根据输入序列的内容和上下文进行动态调整的机制的重要性。这种机制不仅提高了模型在处理复杂序列时的灵活性和效率,而且对于提升模型在执行基于上下文的高级认知任务方面的能力也是至关重要的。
算法1和算法2展示了作者使用的主要选择机制。它们的关键区别在于,使得一些参数( <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ \Delta </math>Δ, <math xmlns="http://www.w3.org/1998/Math/MathML"> B B </math>B, <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C)成为输入的函数,并伴随整个过程中张量形状的相关变化。具体而言,这些参数现在具有长度维度,意味着模型从时间不变(time-invariant)变为时间变化(time-varying)。这一变化失去了与卷积(公式3)的等效性,对其效率有所影响。
算法1(SSM(S4)):展示了一个结构化的状态空间模型(SSM)的构建过程,它不包括选择机制。这是一个时间不变模型,可以通过递归或卷积的方式实现。 算法2(SSM + 选择(S6)):在算法1的基础上添加了选择机制,通过使 <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ \Delta </math>Δ, <math xmlns="http://www.w3.org/1998/Math/MathML"> B B </math>B, <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C成为输入的函数,引入了时间变化特性。这种机制允许模型根据输入动态调整其行为,使其能够更灵活地处理复杂序列数据。
通过引入选择机制,算法2能够针对每个时间步的特定输入调整模型参数,从而使模型能够根据上下文内容进行适应性变化。这种时间变化的特性使得模型能够在序列建模中更有效地压缩上下文信息,提高了处理长序列和复杂依赖关系的能力。
改进状态空间模型
通过使模型的关键参数(如 RNN 的循环动态或 CNN 的卷积核)成为输入的函数,可以实现选择机制的集成。这种方法使得模型能够根据输入动态调整其行为,进而在序列中更有效地传播或忽略信息。
具体来说,算法1和算法2展示了采用的主要选择机制。关键的变化在于,通过使某些参数( <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ \Delta </math>Δ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> B B </math>B、 <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C)成为输入的函数,以及相关的张量尺寸变化,模型从时间不变转变为时间变化。这意味着,与使用固定卷积核的模型相比,该方法允许模型根据输入内容的不同动态调整其内部参数,从而在处理序列数据时提供更大的灵活性。然而,这也意味着与卷积操作的等效性丧失,对模型的效率产生影响。 作者特别选择将参数化投影到维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 的函数作为 <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ B \Delta B </math>ΔB 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ C \Delta C </math>ΔC,因为它与 RNN 的门控机制相关。通过使模型能够基于输入内容调整其状态转换的动态性,选择机制提高了模型的灵活性和处理序列数据的能力。
这些描述展示了 SSMs 在架构设计方面的创新和多样性,每种方法都试图通过不同的技术改进来提高模型的性能和效率。这里介绍了它是如何通过在 SSM 之间引入门控连接和局部卷积来增强模型处理能力的。这些架构的共同点在于它们都试图通过结构化的方法来捕获和处理序列数据的依赖关系,同时提高计算效率。
总结
这篇文章主要围绕结构化状态空间模型(Structured State Space Models, SSMs)展开,贡献可以从以下几个方面进行总结:
-
提出新的架构和方法:文章通过介绍不同的 SSM 架构(如 H3、Hyena、RetNet)和方法(如线性注意力),展示了在处理序列数据和时间序列预测方面的创新。这些方法试图通过特定的设计来提高模型的计算效率和性能,特别是在处理大规模数据集时。
-
硬件感知算法:文章强调了为实现高效计算而采用的硬件感知算法的重要性。通过考虑特定硬件(如 GPU)的特性,文章探讨了如何优化算法来利用这些硬件的高性能计算能力。这种方法有助于加速模型的训练和推理过程,使得复杂的模型在实际应用中更加可行。
-
输入依赖的动态:与传统的 SSMs 不同,这篇文章通过引入选择机制来添加输入依赖的动态,从而允许模型捕获和利用输入数据中的时间变化信息。这种动态性的引入旨在提高模型对时间序列数据的理解和预测能力,使模型能够更好地适应数据的变化和发展趋势。
-
提高模型性能和效率:通过上述创新,文章展示了如何在保持或提高模型性能的同时,减少计算和存储需求。这对于处理大规模数据集和长时间序列特别重要,因为这些场景通常需要大量的计算资源和时间。
总的来说,这篇文章通过引入新的架构设计、硬件感知算法优化和输入依赖的动态,为提高结构化状态空间模型在序列数据处理和时间序列预测方面的性能和效率做出了贡献。这些创新使得 SSMs 在面对大规模和复杂数据时更加高效和有效,拓宽了它们在不同领域应用的可能性。
鉴于篇幅,实验和其他更多细节请参阅论文原文