Mamba学习笔记(4)——Mamba核心

文章目录

  • [A Visual Guide to Mamba and State Space Models](#A Visual Guide to Mamba and State Space Models)
    • [第一部分:The Problem with Transformers](#第一部分:The Problem with Transformers)
    • [第二部分:The State Space Model(SSM)](#第二部分:The State Space Model(SSM))
      • [What is a State Space?](#What is a State Space?)
      • [What is a State Space Model?](#What is a State Space Model?)
    • [第三部分:Mamba - Selective SSM](#第三部分:Mamba - Selective SSM)
      • [What Problem does it attempt to Solve?](#What Problem does it attempt to Solve?)
      • [Selectively Retain Information](#Selectively Retain Information)
      • [The Scan Operation](#The Scan Operation)
      • [Hardware-aware Algorithm](#Hardware-aware Algorithm)
      • [The Mamba Block](#The Mamba Block)

👇👇👇👇👇💕💕💕💕💕👇👇👇👇👇
团队公众号 Deep Mapping 致力于为大家分享人工智能算法的基本原理与最新进展,探索深度学习与测绘行业的创新结合,展示团队的最新研究成果,并及时报道行业会议动态。每周更新,持续为你带来前沿科技资讯、实用技巧和深入解读。感谢每一位关注我们的读者,期待与大家一起见证更多精彩时刻!
👇👇👇👇👇💕💕💕💕💕👇👇👇👇👇

A Visual Guide to Mamba and State Space Models

Author:Maarten Grootendorst [Link]

第一部分:The Problem with Transformers

Transformer将任何文本输入视为由token组成的序列。

Transformer 的一个主要优点是,无论它收到什么输入,它都可以回顾序列中任何较早的token来得出其表示。

Transformer由两个结构组成,一组用于表示文本的编码器块和一组用于生成文本的解码器块

我们可以采用这种结构,仅使用解码器来创建生成模型。这个基于Transformer的模型,即Generative Pre-trained Transformers(GPT),使用解码器模块来完成一些输入文本。

一个解码器模块由两个主要组成部分构成:masked self-attentionfeed-forward neural network

它创建了一个矩阵,用于比较每个词元与之前所有词元的关系。矩阵中的权重由词元对彼此的相关性决定。

在训练的时候,这个矩阵是一次性创建的。在计算" name "和" is "之间的注意力之前,不需要先计算" My "和" name "之间的注意力。它实现了并行化,极大地加快了训练速度!

然而,这样做有一个缺陷,当生成下一个 token 时,我们需要重新计算整个序列的注意力,即使我们已经生成了一些 token。

生成长度为L的序列的标记大约需要 的计算,如果序列长度增加,则成本会很高。

需要重新计算整个序列是 Transformer 架构的一个主要瓶颈。让我们看看"经典"技术------RNN如何解决这个推理速度慢的问题。

RNN是一种基于序列的网络,它以序列中每个时间步的两个输入(即时间步t 的输入上一个时间步长 t-1的隐藏状态)生成下一个隐藏状态预测输出。RNN 具有循环机制,允许它们将信息从上一步传递到下一步。我们可以"展开"此可视化,使其更加明确。

生成输出时,RNN 仅需要考虑先前的隐藏状态当前输入。它可以避免重新计算所有先前的隐藏状态,而Transformer则无法避免。换句话说,RNN可以快速进行推理,因为它与序列长度成线性比例!理论上,它甚至可以具有无限的上下文长度。为了说明,我们将 RNN 应用于我们之前使用过的输入文本。

每个隐藏状态都是所有先前隐藏状态的集合,通常是一个compressed view

请注意,在生成名称" Maarten "时,最后一个隐藏状态不再包含有关单词" Hello "的信息。 RNN 往往会随着时间的推移而忘记信息,因为它们只考虑一个先前的状态。尽管 RNN 的训练和推理速度都很快,但它们缺乏 Transformer 模型所能提供的准确性。相反,我们研究State Space Models来有效地使用 RNN(有时使用卷积)。

第二部分:The State Space Model(SSM)

SSM 与 Transformer 和 RNN 类似,用于处理信息序列,例如文本和信号。在本节中,我们将介绍 SSM 的基础知识以及它们与文本数据的关系。

What is a State Space?

State Space是包含完整描述系统的最小变量数。它是一种通过定义系统可能的状态来以数学方式表示问题的方法。让我们稍微简化一下。想象一下,我们正在穿越迷宫。"State Space"是所有可能位置(状态)的地图。每个点代表迷宫中一个独特的位置,并带有具体的细节,比如你离出口有多远。

"state space representation"是此地图的简化描述。它显示您在哪里(当前状态)、您下一步可以去哪里(可能的未来状态)以及哪些变化会带您进入下一个状态(向右或向左)。

尽管State Space Models使用方程和矩阵来跟踪这种行为,但它只是一种跟踪您在哪里、您可以去哪里以及如何到达那里的方法。描述state的变量,在我们的例子中是 X 和 Y 坐标,以及到出口的距离,可以表示为"state vectors"。

听起来很熟悉?这是因为语言模型中的embeddingsvectors也经常用于描述输入序列的"state"。例如,当前位置的向量(state vector)可能看起来有点像这样:

在神经网络中,系统的"state"通常指的是其隐藏状态;而在大语言模型的上下文中,生成新词元(token)时,隐藏状态是最重要的方面之一。

What is a State Space Model?

SSM用于描述这些state表示,并根据输入预测系统的下一个state可能是什么。

通常,在时间步t,将输入序列𝑥(𝑡)映射到潜在状态表示ℎ(𝑡)并推导出预测的输出序列𝑦(𝑡)

  • 𝑥(𝑡)(例如,迷宫中向左和向下移动)
  • ℎ(𝑡)(例如,表示距出口的距离以及x/y坐标)
  • 𝑦(𝑡)(例如,再次向左移动以更快到达出口)

然而,与使用离散序列(如向左移动一次)不同,它将连续序列作为输入,并预测输出序列

SSM假设动态系统(例如在三维空间中移动的物体)可以通过其在时间tstate通过两个方程进行预测。

通过求解这些方程,我们假设可以揭示出统计原理,从而根据观察到的数据(输入序列和之前的状态)来预测系统的状态。[这些方程描述了如何通过当前状态(h(t))和输入(x(t))来预测系统的未来状态(h'(t))及输出(y(t))]

其目标是找到这样的状态表示h(t),使得我们能够从输入序列推导出输出序列。

这两个公式是SSM的核心
state equation描述了状态h(t)如何变化(through matrix A),以及输入x(t)如何影响状态h'(t) (through matrix B)

output equation描述了状态h(t)如何通过矩阵C转化为输出y(t),以及输入如何通过矩阵D影响输出y(t)

矩阵ABCD均为可学习参数矩阵。

通过可视化这两个方程,我们可以得到如下架构:

让我们一步步解析这一通用技术,了解这些矩阵如何影响学习过程。

假设我们有一个输入信号 x(t),该信号首先与矩阵B相乘,矩阵B描述了输入如何影响系统。

更新后的状态(类似于神经网络的隐藏状态)是一个潜在空间,包含了环境的核心"知识"。我们将状态与矩阵A相乘,矩阵A描述了所有内部状态之间的连接方式,反映了系统的底层动态变化。

你可能已经注意到,矩阵A在创建State Representation之前被应用,并在状态表示更新后也会继续更新。

接着,我们使用矩阵C来描述状态如何转换为输出。

最后,我们可以利用矩阵D来提供从输入直接到输出的信号,这通常也被称为跳跃连接(skip-connection)。这种连接允许输入直接影响输出,而不需要经过状态的中间处理步骤。

由于矩阵D类似于跳跃连接(skip-connection),因此SSM通常被认为是没有跳跃连接的如下形式。

回到我们简化的视角,现在我们可以将注意力集中在矩阵ABC上,它们构成了SSM的核心。

我们可以更新原始方程(并添加一些颜色标注),以表明每个矩阵的作用,和之前所做的一样。

这两个方程的目的是从观察到的数据中预测系统的状态。由于输入预计是连续的,因此SSM的主要表示形式是continuous-time representation

如果输入的是连续信号,解析上找到h(t)是很具有挑战性的。此外,由于我们通常输入的是离散信号(如文本序列),因此我们希望将模型离散化

为此,我们使用零阶保持(Zero-order hold, ZOH)技术。其工作原理如下:首先,每当接收到一个离散信号时,我们保持该信号的值,直到接收到新的离散信号为止。这个过程会生成一个SSM可以使用的连续信号

保持信号值的时间长度由一个新的可学习参数表示,称为步长(step size),通常记作Δ。这个参数决定了输入的分辨率。

现在我们已经为输入生成了连续信号,可以生成连续输出,并仅根据输入的时间步长对输出进行采样。

这些采样值就是我们的离散化输出!

从数学上讲,我们可以如下应用Zero-order hold

总之,它们使我们能够从连续 SSM转变为离散SSM,该SSM由一个公式表示,该公式不再是函数到函数,x(t) → y(t),而是序列到序列,x ₖ → y ₖ

这里,矩阵AB现在表示模型的离散参数。我们使用k而不是t来表示离散时间步长,以便在我们提到连续与离散SSM 时使其更加清晰。既然我们已经有了离散表示的公式,让我们探索一下如何实际计算该模型。

我们的离散化SSM使我们能够在特定的时间步长下而不是连续信号中来表述问题。递归方法在这里非常有用,正如我们之前在RNNs中看到的那样。如果我们考虑离散时间步而不是连续信号,我们可以用时间步来重新表述这个问题:

在每个时间步,我们计算当前输入 (B-xₖ) 如何影响之前的状态 (A-hₖ₋₁),然后计算预测的输出 (Chₖ)。

这个表示可能已经显得有些熟悉了!我们可以像之前在 RNN 中那样来处理它。

展开后可以表示为如下

注意我们如何可以使用这个离散化版本,运用 RNN 的基本方法。

我们可以使用的另一个 SSM Representation是卷积。请记住,在经典的图像识别任务中,我们应用过filters (kernels)来提取聚合特征:

由于我们处理的是文本而不是图像,因此我们需要采用一维的视角:

我们用来表示这个"filter"的(kernels)是从 SSM 公式中导出的:

让我们来探讨这个(kernels)在实际中的工作原理。就像卷积一样,我们可以使用SSM's kernels遍历每组词元,并计算输出:

这也说明了padding对输出可能产生的影响。我更改了padding顺序以改善可视化效果,我们通常在句子的末尾应用padding。

在下一步中,(kernels)会移动一次,执行计算的下一步:

在最后一步中,我们可以看到(kernels)的全部效果:

将 SSM 表示为卷积的一个主要好处是,它可以像CNN一样并行训练。然而,由于固定的核大小,它们的推理速度不如 RNN 快且无界。

这三种表示方式------连续、递归和卷积------各自都有不同的优缺点:

有趣的是,现在我们可以利用Recurrent SSM进行高效推理,并使用Convolutional SSM进行并行化训练

使用这些表示方式时,我们可以采用一个巧妙的技巧,即根据任务选择不同的表示。在训练时,我们使用可以并行化的卷积表示,而在推理时,我们使用高效的递归表示:

该模型被称为线性状态空间层(Linear State-Space Layer, LSSL)

这些表示具有一个重要的共同特性,即线性时不变性(Linear Time Invariance, LTI)。LTI 表明 SSM 的参数 ABC 在所有时间步内都是固定的。这意味着矩阵 A、B 和 C 对于 SSM 生成的每个词元都是相同的。换句话说,无论你给 SSM 提供什么序列,ABC 的值始终保持不变。我们得到的是一个与内容无关的静态表示。在探讨 Mamba 如何解决这一问题之前,让我们先研究这个谜题的最后一部分------矩阵A

可以说,SSM 公式中最重要的方面之一就是矩阵 A。正如我们之前在递归表示中看到的那样,它捕捉了前一个状态的信息,用来构建新的状态

本质上,矩阵 A 生成了隐藏状态

因此,创建矩阵 A 可能会影响记住少量先前词元与捕捉到迄今为止看到的每个词元之间的差异。特别是在递归表示的上下文中,因为它仅回顾前一个状态。那么,我们如何以保持较大记忆(上下文大小)的方式创建矩阵A呢?我们使用"饥饿的河马"(Hungry Hungry Hippo)或称HiPPO(高阶多项式投影算子)。HiPPO 尝试将迄今为止看到的所有输入信号压缩成一个系数向量。

它使用矩阵 A来构建一个状态表示,有效捕捉最近的词元并衰减较旧的词元。其公式可以表示如下:

假设我们有一个方阵 A,可以表示为:

使用HiPPO构建矩阵 A被证明比将其初始化为随机矩阵要好得多。因此,它能够更准确地重建较新的信号(最近的词元),而不是较旧的信号(最初的词元)。HiPPO 矩阵的核心思想是生成一个能够记忆其历史的隐藏状态。从数学上讲,它通过跟踪勒让德(Legendre polynomial)多项式系数来实现这一点,从而能够近似所有先前的历史。然后,HiPPO 被应用于我们之前看到的递归和卷积表示,以处理长程依赖性。结果是序列的结构状态空间(Structured State Space for Sequences, S4),这是一类能够有效处理长序列的SSM。[Link]

它由三个部分组成:

  • 状态空间模型(State Space Models, SSM)
  • HiPPO 用于处理长程依赖性(HiPPO for handling long-range dependencies)
  • 用于创建递归和卷积表示的离散化(Discretization for creating recurrent and convolution representations)

这类SSM根据选择的表示(递归或卷积)具有多个优点。它还可以处理长文本序列,并通过构建 HiPPO 矩阵有效地存储记忆。

第三部分:Mamba - Selective SSM

我们最终涵盖了理解 Mamba 特殊之处所需的所有基础知识。SSM可以用于建模文本序列,但仍然存在一系列我们希望避免的缺点。在本节中,我们将介绍 Mamba 的两个主要贡献:

  • 选择性扫描算法(selective scan algorithm),允许模型过滤(不)相关信息。
  • 硬件感知算法(hardware-aware algorithm),通过并行扫描、核融合和重新计算,允许高效存储(中间)结果。

这两者共同创造了selective SSM 或 S6 模型,可以像自注意力机制一样用于创建 Mamba块。在探讨这两个主要贡献之前,让我们先了解它们为何必要。

What Problem does it attempt to Solve?

SSM,甚至 S4(Structured State Space Model),在某些在语言建模和生成中至关重要的任务上表现不佳,即聚焦或忽略特定输入的能力。我们可以通过两个synthetic任务来说明这一点,即selective copyinginduction heads。在selective copying任务中,SSM 的目标是复制部分输入并按顺序输出:

然而,由于 SSM 是Linear Time Invariant的(无论是递归的还是卷积的),在这个任务中表现不佳。正如我们之前看到的,矩阵 ABC 对于 SSM 生成的每个词元都是相同的。因此,SSM 无法进行内容感知推理,因为由于固定的 A、B 和 C 矩阵,它将每个词元视为相同。这是一个问题,因为我们希望 SSM 能够推理输入(prompt)。

SSM 表现不佳的第二个任务是induction heads,其目标是reproduce patterns found in the input

在上述示例中,我们实际上是在执行one-shot prompting,试图"教会"模型在每个"Q:"之后提供"A:"的响应。然而,由于SSM 是时间不变的,它无法选择从历史中回忆哪些先前的词元。让我们通过关注矩阵B来说明这一点。无论输入x是什么,矩阵 B始终保持不变,因此独立于 x

同样,矩阵 AC 也在输入不变的情况下保持固定。这展示了SSM的静态特性

相比之下,这些任务对 Transformers 来说相对简单,因为它们根据输入序列动态地改变注意力。它们可以选择性地"查看"或"关注"序列的不同部分。SSM 在这些任务上的表现不佳,说明了时间不变 SSM 的根本问题:矩阵 ABC静态特性导致了内容感知的问题。

Selectively Retain Information

SSM 的递归表示创建了一个较小的状态,这种状态非常高效,因为它压缩了整个历史。然而,与不对历史进行压缩的 Transformer 模型(通过注意力矩阵)相比,它的能力要弱得多。Mamba 的目标是兼具两者的优势:一个状态,同时具备 Transformer 状态的强大能力。

如上所述,Mamba 通过将数据选择性地压缩到状态中来实现这一目标。当您有一个输入句子时,通常会有一些信息,例如停用词,并没有太大意义。为了选择性地压缩信息,我们需要使参数依赖于输入。为此,让我们首先探讨在训练过程中SSM 中输入和输出的维度

在S4中,矩阵 ABC输入无关,因为它们的维度 ND静态的,不会发生变化。

相反,Mamba 通过将输入的序列长度批大小纳入考虑,使矩阵 BC 甚至步长 依赖于输入:

这意味着对于每个输入词元,我们现在拥有不同的 BC 矩阵,这解决了内容感知的问题![注意:矩阵 A 保持不变,因为我们希望状态本身保持静态,但其受影响的方式(通过 B 和 C)是动态的。]

它们共同选择在隐藏状态保留什么,忽略什么,因为它们现在依赖于输入。较小的步长导致忽略特定的词,而更多地使用之前的上下文;而较大的步长则更关注输入词而非上下文:

The Scan Operation

由于这些矩阵现在是动态的,因此无法使用卷积表示进行计算,因为卷积表示假设使用固定的核。我们只能使用递归表示,这样就失去了卷积所提供的并行化。为了启用并行化,让我们探讨如何使用递归计算输出:

每个状态是前一个状态(乘以 A)当前输入(乘以 B)的和。这被称为scan operation ,可以通过 for 循环轻松计算。相反,由于每个状态只能在我们拥有前一个状态的情况下计算,因此并行化似乎是不可能的。然而,Mamba 通过并行扫描算法使这一切成为可能。它假设我们执行操作的顺序并不重要,这利用了结合性属性。因此,我们可以将序列分成几部分进行计算,并迭代地将它们组合在一起:

动态矩阵 B 和 C 以及并行扫描算法共同构成了选择性扫描算法,以表示使用递归表示的动态性快速性

Hardware-aware Algorithm

近期 GPU 的一个缺点是其小型但高效的 SRAM 与大型但稍微不那么高效的 DRAM 之间的传输(IO)速度有限。频繁在 SRAM 和 DRAM 之间复制信息会成为瓶颈。

在 GPU 中,SRAM(静态随机存取存储器)DRAM(动态随机存取存储器) 是两种不同类型的内存,各自具有不同的特性和用途。SRAM 通常用于缓存,提供快速的数据访问,而 DRAM 则用于存储大量的数据和程序,支持更复杂的计算任务。由于两者的特性,GPU 在处理任务时需要频繁地在 SRAM 和 DRAM 之间进行数据传输,这种传输速度的差异可能会导致性能瓶颈。

  • SRAM(静态随机存取存储器)

速度 :SRAM 的访问速度非常快,通常用于缓存和寄存器。
存储容量 :SRAM 通常容量较小,适合用于高速缓存(如 L1、L2、L3 缓存)。
功耗 :相对较低,但仍高于 DRAM。
结构:SRAM 使用多个晶体管构成一个存储单元,数据保持不需要周期性刷新。

  • DRAM(动态随机存取存储器)

速度 :虽然 DRAM 的速度比 SRAM 慢,但仍然非常快速,适合大规模数据存储。
存储容量 :DRAM 通常具有较大的存储容量,广泛用于主内存(如 GPU 的显存)。
功耗 :相对较高,特别是在频繁读写的情况下。
结构:DRAM 由一个晶体管和一个电容器构成,每个存储单元的数据需要定期刷新,以防止数据丢失。

Mamba 像 Flash Attention 一样,减少我们从 DRAM 到 SRAM 以及反向传输的次数。它通过核融合(kernel fusion)来实现这一点,允许模型防止写入中间结果,并持续进行计算,直到完成。

我们可以通过可视化 Mamba 的基础架构来查看 DRAM 和 SRAM 分配的具体实例:

在这里,以下内容被融合为一个内核:

  • 离散化步骤和步长∆
  • 选择性扫描算法
  • 与 C 的乘法

硬件感知算法的最后一部分是重新计算。中间状态不被保存,但在向后传递中计算梯度时是必需的。相反,作者在向后传递过程中重新计算这些中间状态。尽管这看起来可能效率不高,但与从相对较慢的 DRAM 中读取所有这些中间状态相比,成本要低得多。我们现在已经覆盖了其架构的所有组件,以下图像来自其文章:

该架构通常被称为selective SSM或 S6 模型,因为它本质上是使用选择性扫描算法计算的 S4 模型。

The Mamba Block

到目前为止,我们探讨的selective SSM可以作为一个模块实现,类似于我们在解码器模块中表示自注意力的方式。

与解码器类似,我们可以堆叠多个 Mamba 模块,并将它们的输出作为下一个 Mamba 模块的输入:

它首先进行线性投影(linear projection),以扩展input embeddings。然后,在Selective SSM 之前应用卷积,以防止独立的词元计算。

选择性 SSM 具有以下特性:

  • 通过离散化创建的递归 SSM
  • 在矩阵 A 上进行 HiPPO 初始化,以捕获长距离依赖关系
  • 选择性扫描算法,用于选择性压缩信息
  • 硬件感知算法,以加速计算

当我们查看代码实现时,可以对这个架构进行进一步扩展,并探索一个端到端示例的样子:

注意到一些变化,比如加入了归一化层和用于选择输出词元的softmax。当我们将所有部分组合在一起时,我们可以实现快速推理和训练,甚至获得无限的上下文。使用这种架构,作者发现它的表现与同规模的 Transformer 模型相当,有时甚至超越它们!

相关推荐
blackA_3 小时前
数据库MySQL学习——day4(更多查询操作与更新数据)
数据库·学习·mysql
梁下轻语的秋缘5 小时前
每日c/c++题 备战蓝桥杯(P1049 [NOIP 2001 普及组] 装箱问题)
c语言·c++·学习·蓝桥杯
刘婉晴5 小时前
【信息安全工程师备考笔记】第三章 密码学基本理论
笔记·安全·密码学
球求了5 小时前
C++:继承机制详解
开发语言·c++·学习
时光追逐者6 小时前
MongoDB从入门到实战之MongoDB快速入门(附带学习路线图)
数据库·学习·mongodb
一弓虽6 小时前
SpringBoot 学习
java·spring boot·后端·学习
晓数7 小时前
【硬核干货】JetBrains AI Assistant 干货笔记
人工智能·笔记·jetbrains·ai assistant
我的golang之路果然有问题8 小时前
速成GO访问sql,个人笔记
经验分享·笔记·后端·sql·golang·go·database
genggeng不会代码8 小时前
用于协同显著目标检测的小组协作学习 2021 GCoNet(总结)
学习
lwewan8 小时前
26考研——存储系统(3)
c语言·笔记·考研