挑战Transformer的Mamba是什么来头?作者博士论文理清SSM进化路径

对 SSM 感兴趣的研究者不妨读一下这篇博士论文。

在大模型领域,Transformer 凭一己之力撑起了整个江山。但随着模型规模的扩展和需要处理的序列不断变长,Transformer 的局限性也逐渐凸显,比如其自注意力机制的计算量会随着上下文长度的增加呈平方级增长。为了克服这些缺陷,研究者们开发出了很多注意力机制的高效变体,但收效甚微。

最近,一项名为「Mamba」的研究似乎打破了这一局面,它在语言建模方面可以媲美甚至击败 Transformer。这都要归功于作者提出的一种新架构 ------ 选择性状态空间模型( selective state space model),该架构是 Mamba 论文作者 Albert Gu 此前主导研发的 S4 架构(Structured State Spaces for Sequence Modeling )的一个简单泛化。

在 Mamba 论文发布后,很多研究者都对 SSM(state space model)、S4 等相关研究产生了好奇。其中,有位研究者表示自己要在飞机上把这些论文都读一下。对此,Albert Gu 给出了更好的建议:他的博士论文其实把这些进展都梳理了一下,读起来可能更有条理。

在论文摘要中,作者写到,序列模型是深度学习模型的支柱,已在科学应用领域取得了广泛成功。然而,现有的方法需要针对不同的任务、模态和能力进行广泛的专业化;存在计算效率瓶颈;难以对更复杂的序列数据(如涉及长依赖关系时)进行建模。因此,继续开发对一般序列进行建模的原则性和实用性方法仍然具有根本性的重要意义。

论文链接:stacks.stanford.edu/file/druid:...

作者在论文中阐述了一种使用状态空间模型进行深度序列建模的新方法,这是一种灵活的方法,具有理论基础,计算效率高,并能在各种数据模态和应用中取得强大的结果。

首先,作者介绍了一类具有众多表征和属性的模型,它们概括了标准深度序列模型(如循环神经网络和卷积神经网络)的优势。然而,作者表明计算这些模型可能具有挑战性,并开发了在当前硬件上运行非常快速的新型结构化状态空间,无论是在扩展到长序列时还是在自回归推理等其他设置中都是如此。最后,他们提出了一个用于对连续信号进行增量建模的新颖数学框架,该框架可与状态空间模型相结合,为其赋予原则性的状态表示,并提高其对长程依赖关系的建模能力。总之,这一类新方法为机器学习模型提供了有效而多用途的构建模块,特别是在大规模处理通用序列数据方面。

以下是论文各部分简介。

深度序列模型

针对序列数据的深度学习模型可被视为围绕循环、卷积或注意力等简单机制建立的序列到序列转换。

这些基元(primitive)可以被纳入标准的深度神经网络架构,形成主要的深度序列模型系列:循环神经网络(RNN)、卷积神经网络(CNN)和 Transformers,它们表达了强大的参数化变换,可以使用标准的深度学习技术(如梯度下降反向传播)进行学习。图 1.1 和定义 1.1 展示了本论文中使用的序列模型抽象,第 2.1 节将结合实例对其进行更正式的定义。

定义 1.1(非正式)。作者使用序列模型来指代在序列 y = f_θ(x) 上的参数化映射,其中输入和输出 x、y 是 R^D 中长度为 L 的特征向量序列,θ 是通过梯度下降学习的参数。

上述每个模型系列都为机器学习带来了巨大的成功:例如,RNN 为机器翻译带来了深度学习,CNN 是第一个神经音频生成模型,而 Transformers 则彻底改变了 NLP 的广阔领域。

不过,这些模型也有其序列机制所遗留的折衷问题。例如,RNN 对于序列数据来说是一个天然的有状态模型,每个时间步只需要恒定的计算 / 存储,但训练速度慢,而且存在优化困难(如梯度消失问题),这限制了它们处理长序列的能力。CNN 专注于局部上下文,编码 shift equivariance 等特性,并具有快速、可并行训练的特点,但其序列推理成本较高,且上下文长度受到固有限制。Transformers 因其处理长程依赖关系的能力和可并行性而获得巨大成功,但在序列长度上存在二次扩展问题。另一个最新的模型系列是神经微分方程(NDE),这是一种有理论基础的数学模型,理论上可以解决连续时间问题和长期依赖关系,但效率非常低。

这些问题显示了深度序列模型面临的三大挑战。

挑战一:通用能力

深度学习的一个广泛目标是开发可用于各种问题的通用构建模块。序列模型为解决许多此类问题提供了一个通用框架。它们可以应用于任何可被投射为序列的环境。然而,当前的模型通常仍需要大量的专业化能力,以解决特定任务和领域的问题,或针对特定的能力。各类模型的优势分析如下:

  • RNN:需要快速更新隐藏状态的有状态设置,例如在线处理任务和强化学习;

  • CNN:对音频、图像和视频等均匀采样的感知信号进行建模;

  • Transformers:对语言等领域中密集、复杂的交互进行建模;

  • NDE:处理非典型时间序列设置,如缺失或不规则采样数据。

反之,每个模型系列都可能在其不擅长的功能方面举步维艰。

挑战二:计算效率

在实践中应用深度序列模型需要计算其定义的函数(即参数化序列到序列映射),这可以有多种形式。在训练时,任务一般可以用整个输入序列的损失函数来表述,算法的核心问题是如何高效地计算前向传递。在推理时(训练完成后部署模型),设置可能会发生变化;例如,在在线处理或自回归生成设置中,输入每次只显示一个时间步,模型必须能够高效地按顺序处理这些输入。

这两种情况对不同的模型系列都提出了挑战。例如,RNN 本身是序列性的,很难在 GPU 和 TPU 等现代硬件加速器上进行训练,而并行性则能使其受益。另一方面,CNN 和 Transformers 则难以进行高效的自回归推理,因为它们不是有状态的;处理单个新输入的成本可能会与模型的整个上下文大小成比例关系。更奇特的模型可能会带来额外的功能,但通常会使其计算更加困难和缓慢(如需要调用昂贵的微分方程求解器)。

挑战三:长程依赖

现实世界中的序列数据可能需要推理数以万计的时间步骤。除了处理长输入所需的计算问题外,解决这一问题还需要能够对此类长程依赖(LRD)中存在的复杂交互进行建模。具体来说,困难可能来自于无法捕捉数据中的交互,比如模型的上下文窗口有限;也可能来自于优化问题,比如在循环模型中通过长计算图进行反向传播时的梯度消失问题。

由于效率、表达能力或训练能力方面的限制,长程依赖是序列模型长期以来面临的挑战。所有标准模型系列,如 NDE、RNN、CNN 和 Transformers,都包括许多旨在解决这些问题的专门变体。例如对抗梯度消失的正交和 Lipschitz RNN、增加上下文大小的空洞卷积,以及日益庞大的高效注意力变体系列,这些变体降低了对序列长度的二次依赖。然而,尽管这些解决方案都是针对长程依赖设计的,但在 Long Range Arena 等具有挑战性的基准测试中,它们的表现仍然不佳。

状态空间序列模型

本论文介绍了基于线性状态空间模型(SSM)的新系列深度序列模型。作者将 SSM 定义为一个简单的序列模型,它通过一个隐式的潜在状态 x (t)∈R^N 映射一个 1 维函数或序列

SSM 是一种基础科学模型,广泛应用于控制论、计算神经科学、信号处理等领域。广义上,SSM 一词指的是对潜变量如何在状态空间中演化进行建模的任何模型。这些广义的 SSM 有许多种,可以改变 x 的状态空间(如连续、离散或混合空间)、y 的观测空间、过渡动态、附加噪声过程或系统的线性度。SSM 在历史上通常指隐马尔可夫模型(HMM)和线性动力系统(LDS)的变体,如分层狄利克雷过程(HDP-HMM)和 Switching Linear Dynamical 系统(SLDS)。

方程(1.1)的状态空间模型在状态空间和动力学上都是连续的,并且是完全线性和确定性的,但还没有被用作定义 1.1 意义上的深度序列模型。本论文探讨了状态空间序列模型的诸多优点,以及如何利用它们来解决一般序列建模难题,同时克服其自身的局限性。

通用序列模型

SSM 是一种简单而基本的模型,具有许多丰富的特性。它们与 NDE、RNN 和 CNN 等模型族密切相关,实际上可以以多种形式编写,以实现通常需要专门模型才能实现的各种功能(挑战一)。

  • SSM 是连续的。SSM 本身是一个微分方程。因此,它可以执行连续时间模型的独特应用,如模拟连续过程、处理缺失数据,以及适应不同的采样率。

  • SSM 是循环的。可以使用标准技术将 SSM 离散化为线性 recurrence,并在推理过程中模拟为状态循环模型,每个时间步的内存和计算量保持不变。

  • SSM 是卷积系统。SSM 是线性时不变系统,可显式表示为连续卷积。此外,离散时间版本可以在使用离散卷积进行训练时并行化,从而实现高效训练。

因此,SSM 是一种通用序列模型,在并行和序列环境以及各种领域(如音频、视觉、时间序列)中都能高效运行。论文第 2 章介绍了 SSM 的背景,并阐述了状态空间序列模型的这些特性。

不过,SSM 的通用性也有代价。原始 SSM 仍然面临两个额外挑战 ------ 也许比其他模型更严重 ------ 这阻碍了它们作为深度序列模型的使用。挑战包括:(1)一般 SSM 比同等大小的 RNN 和 CNN 慢得多;(2)它们在记忆长依赖关系时会很吃力,例如继承了 RNN 的梯度消失问题。

作者通过 SSM 的新算法和理论来应对这些挑战。

利用结构化 SSM 进行高效计算(S4)

遗憾的是,由于状态表示 x (t) ∈ R^N 对计算和内存的要求过高(挑战二),通用的 SSM 在实践中无法用作深度序列模型。

对于 SSM 的状态维度 N 和序列长度 L,仅计算完整的潜在状态 x 就需要 O (N^2L) 次运算和 O (NL) 的空间 ------ 与计算总体输出的 Ω(L + N) 下界相比。因此,对于合理大小的模型(例如 N ≈ 100),SSM 使用的内存要比同等大小的 RNN 或 CNN 多出几个数量级,因此作为通用序列建模解决方案,SSM 在计算上是不切实际的。

要克服这一计算瓶颈,就必须以一种适合高效算法的方式对状态矩阵 A 施加结构。作者介绍了具有各种形式结构矩阵 A 的结构化状态空间序列模型(S4)(或简称结构化状态空间)家族,以及能以任何表示形式(如循环或卷积)高效计算 S4 模型的新算法。

论文第 3 章介绍了这些高效 S4 模型的不同类型。第一种结构使用状态矩阵的对角参数化(diagonal parameterization),它非常简单、通用,足以表示几乎所有的 SSM。然后,作者通过允许低秩校正项对其进行推广,这对于捕捉后面介绍的一类特殊的 SSM 是必要的。通过结合众多技术思想,如生成函数、线性代数变换和结构矩阵乘法的结果,作者为这两种结构开发了时间复杂度为 和空间复杂度为 O (N + L) 的算法,这对于序列模型来说基本上是严密的。

使用 HIPPO 解决长程依赖关系

即使不考虑计算问题,基本的 SSM 在实验中仍然表现不佳,而且无法建模长程依赖关系(挑战三)。直观地说,其中一种解释是线性一阶 ODE 求解为指数函数,因此可能会出现梯度随序列长度呈指数级缩放的问题。这也可以从它们作为线性 recurrence 的解释中看出,这涉及到反复对一个 recurrent 矩阵进行幂运算,这就是众所周知的 RNN 梯度消失/爆炸问题的起因。

在第 4 章中,作者从 SSM 后退一步,转而研究如何从第一性原理出发,用循环模型对 LRD 进行建模。他们开发了一个名为 HIPPO 的数学框架,它形式化并解决了一个名为在线函数逼近(或记忆)的问题。这种方法旨在通过保持对连续函数历史的压缩,以逐步记忆连续函数。尽管这些方法的动机完全独立,但它们都是 SSM 的具体形式。这些最终的方法被证明是 SSM 的特定形式 ------ 尽管它们的动机是完全独立的。

论文第 5 章完善了这一框架,并将其与 SSM 抽象更严格地联系起来。它引入了一个正交 SSM 概念,广泛推广了 HIPPO,并推导出更多实例和理论结果,例如如何以原则性的方式初始化所有 SSM 参数。

HIPPO 概览

考虑一个输入函数 u (t)、一个固定的概率度量 ω(t),以及 N 个正交基函数(如多项式)组成的序列。在每个时间 t,u 在时间 t 之前的历史都可以投影到这个基上,从而得到一个系数向量 x (t)∈ R^N,这个向量代表了 u 的历史相对于所提供的度量 ω 的最佳近似值。函数 u (t)∈R 映射到系数 x (t)∈R^N 的映射被称为关于度量 ω 的高阶多项式投影算子 (HIPPO)。在很多情况下,在许多情况下,其形式为 x ′ (t) = Ax (t) + Bu (t),对于 (A, B) 有封闭形式的公式。

HIPPO 和 S4 的组合

HIPPO 提供了一个数学工具来构建具有重要属性的 SSM,而 S4 是关于计算表示的。第 6 章正式将两者联系起来,并说明它们可以结合起来,以获得两个世界的最佳效果。论文表明,HIPPO 生成的用于处理长程依赖关系的特殊矩阵实际上可以用第 3 章中开发的特定结构形式来编写。这就提供了结合 HIPPO 的 S4 的具体实例,从而产生了一个具有丰富功能、非常高效并擅长长程推理的通用序列模型。

应用、消融和扩展

通用序列建模功能

第 7 章对 S4 方法在各种领域和任务中的应用进行了全面的实证验证。当 S4 方法被纳入一个通用的简单深度神经网络时,它在许多基准测试中推进了 SOTA。

特别的亮点和功能包括:

  • 通用序列建模。在不改变架构的情况下,S4 在语音分类方面超越了音频 CNN,在时间序列预测问题上优于专门的 Informer 模型,在序列 CIFAR 方面与 2-D ResNet 相媲美,准确率超过 90%。

  • 长程依赖。在针对高效序列模型的 LRA 基准测试中,S4 的速度与所有基线一样快,同时比所有 Transformer 变体的平均准确率高出 25% 以上。S4 是第一个解决了 LRA Path-X 任务(长度为 16384)这一难题的模型,准确率达到 96%,而之前所有工作的随机猜测准确率仅为 50%。

  • 采样分辨率变化。与专门的 NDE 方法一样,S4 无需再训练即可适应时间序列采样频率的变化。

  • 大规模生成建模与快速自回归生成。在 CIFAR-10 密度估计方面,S4 与最好的自回归模型(每维 2.85 比特)不相上下。在 WikiText-103 语言建模方面,S4 大幅缩小了与 Transformers 的差距(在 0.5 困惑度范围内),在无注意力模型中实现了 SOTA。与 RNN 一样,在 CIFAR-10/WikiText-103 上,S4 利用其潜在状态生成像素 /token 的速度比标准自回归模型快 60 倍。

理论消融

作者对 S4 的处理讨论了训练 SSM 的许多理论细节,例如如何仔细初始化每个参数以及如何纳入 HIPPO 框架。他们对这些细节进行了全面的实证分析和消融研究,验证了他们的 SSM 理论的各个方面。例如,他们验证了 HIPPO 大大提高了 SSM 的建模能力,在标准序列模型基准上的性能比原始 SSM 实例提高了 15%。在算法上,他们的 S4 算法比传统的 SSM 算法提高了几个数量级(例如,速度提高了 30 倍,内存使用量减少到 1/400)。

应用:音频波形生成

作为一种具有多种特性的序列建模基元,S4 可以被整合到不同的神经网络架构中,并以多种方式使用。第 8 章介绍了 S4 在原始音频波形生成中的应用,由于音频波形的采样率较高,这是一个具有挑战性的问题。这一章节介绍了围绕 S4 构建的 SaShiMi 多尺度架构,该架构在包括自回归和扩散在内的多种生成建模范式中,推动了无限制音频和语音生成技术的发展。该应用突显了 S4 的灵活功能,包括高效训练、快速自回归生成和用于连续信号建模的强大归纳偏置。

扩展:用于计算机视觉的多维信号

虽然作者主要关注一维序列,但某些形式的数据本身具有更高的维度,如图像(二维)和视频(三维)。序列模型的灵活性也适用于这些环境。第 9 章介绍了 S4ND,这是 S4 从一维到多维(N-D)信号的扩展。S4ND 继承了 S4 的特性,如直接对底层连续信号建模,并具有更好地处理输入分辨率变化等相关优势,是第一个在 ImageNet 等大型视觉任务中性能具有竞争力的连续模型。

更多细节请参考原论文。

最后,借机梳理介绍几篇 SSM 研究,供大家了解、学习。

论文一:Pretraining Without Attention

论文二:Mamba: Linear-Time Sequence Modeling with Selective State Spaces

围绕 Mamba,已经有一些语言模型发布,包括 mamba-130m, mamba-370m, mamba-790m, mamba-1.4b, mamba-2.8b。

HuggingFace 地址:huggingface.co/state-space...

也有人做出 Mamba-Chat:

Github 地址:github.com/havenhq/mam...

论文三:苹果等机构的论文 Diffusion Models Without Attention

论文四:Mamba 作者 Albert Gu 的博士论文 MODELING SEQUENCES WITH STRUCTURED STATE SPACES

论文地址:stacks.stanford.edu/file/druid:...

论文五:Long Range Language Modeling via Gated State Spaces 认为 Transformer 和 SSM 完全可以互补。

论文地址:arxiv.org/abs/2206.13...

论文六:DeepMind 的论文 Block-State Transformer

论文地址:arxiv.org/pdf/2306.09...

相关推荐
埃菲尔铁塔_CV算法11 分钟前
人工智能图像算法:开启视觉新时代的钥匙
人工智能·算法
EasyCVR12 分钟前
EHOME视频平台EasyCVR视频融合平台使用OBS进行RTMP推流,WebRTC播放出现抖动、卡顿如何解决?
人工智能·算法·ffmpeg·音视频·webrtc·监控视频接入
打羽毛球吗️18 分钟前
机器学习中的两种主要思路:数据驱动与模型驱动
人工智能·机器学习
好喜欢吃红柚子35 分钟前
万字长文解读空间、通道注意力机制机制和超详细代码逐行分析(SE,CBAM,SGE,CA,ECA,TA)
人工智能·pytorch·python·计算机视觉·cnn
小馒头学python39 分钟前
机器学习是什么?AIGC又是什么?机器学习与AIGC未来科技的双引擎
人工智能·python·机器学习
神奇夜光杯1 小时前
Python酷库之旅-第三方库Pandas(202)
开发语言·人工智能·python·excel·pandas·标准库及第三方库·学习与成长
正义的彬彬侠1 小时前
《XGBoost算法的原理推导》12-14决策树复杂度的正则化项 公式解析
人工智能·决策树·机器学习·集成学习·boosting·xgboost
Debroon1 小时前
RuleAlign 规则对齐框架:将医生的诊断规则形式化并注入模型,无需额外人工标注的自动对齐方法
人工智能
羊小猪~~1 小时前
神经网络基础--什么是正向传播??什么是方向传播??
人工智能·pytorch·python·深度学习·神经网络·算法·机器学习
AI小杨1 小时前
【车道线检测】一、传统车道线检测:基于霍夫变换的车道线检测史诗级详细教程
人工智能·opencv·计算机视觉·霍夫变换·车道线检测