Mamba架构讲解

简介

Mamba 是一种高效的深度学习序列建模架构,于2023年提出,基于选择性状态空间模型(Selective State Space Model),通过引入输入依赖 的动态机制,使模型能够有选择地处理和保留信息。相比传统的 Transformer,Mamba 具有线性时间复杂度O(L),能高效处理超长序列(如数万个 token),在语言建模、基因组学和音频处理等任务中表现出色,同时显著降低计算和内存开销,被视为下一代序列建模的重要方向之一。

原论文的引言部分:

基础模型(FMs)是指在海量数据上预训练后,再适配下游任务的大型模型,已成为现代机器学习中的有效范式。这些基础模型的骨干网络通常是序列模型,能够处理来自语言、图像、语音、音频、时间序列和基因组学等多个领域的任意输入序列。尽管这一概念并不局限于特定的模型架构,但现代基础模型主要基于一种序列模型:Transformer及其核心的注意力层。自注意力机制的有效性源于其能在上下文窗口内密集地传递信息,从而建模复杂数据。然而,这一特性也带来了根本性缺陷:无法建模有限窗口之外的信息,且计算复杂度随窗口长度呈二次方增长 。已有大量研究致力于开发更高效的注意力变体以克服这些缺陷,但往往以牺牲其核心有效性为代价。迄今为止,尚未有任何变体被证明能在跨领域大规模场景下取得理想的实证效果。

近年来,结构化状态空间序列模型(SSMs)已成为序列建模领域极具潜力的架构类别。这类模型可被解读为循环神经网络(RNNs)与卷积神经网络(CNNs)的结合,并从经典状态空间模型中汲取灵感。它们能以循环或卷积的形式高效计算,序列长度的计算复杂度呈线性或近线性 缩放。此外,在特定数据模态中,它们具备建模长程依赖的严谨机制,并在长程序列基准测试中占据主导地位。多种类型的状态空间模型已在音频和视觉等连续信号数据领域取得成功。然而,它们在建模文本等离散且信息密集型数据时效果欠佳。

我们提出了一类新的选择性状态空间模型,在多个方面对现有工作进行改进,旨在兼具 Transformer 的建模能力与序列长度线性缩放特性。

选择机制

首先,我们发现现有模型的一个关键局限:无法以输入依赖的方式高效筛选 数据(即聚焦或忽略特定输入)。基于选择性复制和归纳头(induction heads)等重要合成任务的直觉,我们通过将状态空间模型的参数设计为输入的函数,构建了一种简单的选择机制。这使得模型能够过滤无关信息,并无限期记忆相关信息。

硬件感知算法

这一简单改动给模型计算带来了技术挑战:事实上,所有现有状态空间模型都必须满足时间不变性和输入不变性,才能保证计算效率。我们通过一种硬件感知算法克服了这一问题 ------ 该算法以扫描(scan) 而非卷积的方式循环计算模型,但不会显式存储扩展状态,从而避免了 GPU 存储层次结构不同级别之间的 IO 访问。最终实现无论在理论上(序列长度线性缩放,而所有基于卷积的状态空间模型均为伪线性缩放)还是在现代硬件上(A100 GPU 上速度提升高达 3 倍),都优于以往方法。

架构设计

我们简化了现有的深度序列模型架构,将现有状态空间模型架构设计与 Transformer 的 MLP 块整合为单个块,构建出一种简洁且同质化的架构(Mamba),该架构融合了选择性状态空间。

选择性状态空间模型及其延伸的 Mamba 架构是完全循环的模型,具备作为通用基础模型骨干网络处理序列数据的关键特性:

(1)高质量:选择性使其在语言和基因组学等密集模态上表现出色;

(2)训练与推理高效:训练过程中计算和内存随序列长度线性缩放,推理时自回归展开模型仅需恒定的每步时间,无需缓存先前元素;

(3)长上下文支持:高质量与高效率的结合,使其在长达 100 万长度的序列数据上性能持续提升。

我们从预训练质量和特定领域任务性能两方面,在多种模态和场景下实证验证了 Mamba 作为通用序列基础模型骨干网络的潜力:

(1)合成任务:在选择性复制和归纳头这类被认为对大型语言模型至关重要的合成任务上,Mamba 不仅能轻松解决,还能无限外推解决方案(超过 100 万个 token);

(2)音频和基因组学:在音频波形和 DNA 序列建模任务中,Mamba 在预训练质量和下游指标上均优于 SaShiMi、Hyena 和 Transformer 等现有最先进模型(例如,在一个具有挑战性的语音生成数据集上,将 FID 值降低一半以上)。在这两种场景下,随着上下文长度增加至百万级,其性能持续提升;

(3)语言建模:Mamba 是首个真正实现 Transformer 级性能的线性时间序列模型,无论是预训练困惑度还是下游评估均表现优异。在参数规模扩展至 10 亿的实验中,我们发现 Mamba 的性能超过了众多基准模型,包括基于 LLaMa 的强现代 Transformer 训练方案。Mamba 语言模型的生成吞吐量是同等规模 Transformer 的 5 倍,且 Mamba-3B 的质量与两倍于其规模的 Transformer 相当(例如,在常识推理任务上,平均得分比 Pythia-3B 高出 4 个百分点,甚至超过 Pythia-7B)。

框架介绍

带硬件感知状态扩展的选择性状态空间模型

传统状态空间模型(SSM)为了避免存储庞大的隐状态,会固定参数(不随输入变化)并走高效计算路径,但这样缺乏内容推理能力;这个新模型给 SSM 加了 "选择机制"------ 让参数随输入动态变化,能根据当前内容选信息;同时通过硬件优化(只在 GPU 高速存储里处理隐状态),既解决了传统 SSM 的局限,又没牺牲效率,是 Mamba 的核心模块结构。

引入选择机制的动机

我们认为序列建模的核心问题是将上下文压缩 到更小的状态中。事实上,我们可以从这一角度理解主流序列模型的权衡取舍。例如,注意力机制之所以高效且低效,是因为它完全不压缩上下文。这一点可从自回归推理需显式存储整个上下文(即 KV 缓存)看出,这直接导致了 Transformer 推理的线性时间复杂度和训练的二次时间复杂度。另一方面,循环模型之所以高效,是因为它们具有有限状态,意味着推理时间恒定、训练时间线性。然而,其有效性受限于状态对上下文的压缩效果。

为理解这一原理,我们聚焦于两个合成任务示例:

(1)**选择性复制任务(Selective Copying):**对经典复制任务进行修改,改变需记忆 token 的位置。该任务需要基于内容的推理,以记忆相关 token(彩色)并过滤无关 token(白色);

(2)**归纳头任务(Induction Heads):**一种知名机制,被认为是解释大型语言模型上下文学习能力的关键。该任务需要基于上下文的推理,以在适当的上下文中生成正确输出(黑色)。

这些任务揭示了线性时间不变模型的失效模式。从循环视角 来看,其恒定的动态特性无法让它们从上下文中筛选正确信息,也无法以输入依赖的方式影响沿序列传递的隐状态。从卷积视角来看,已知全局卷积可解决标准复制任务,因为该任务仅需时间感知,但由于缺乏内容感知,它们难以解决选择性复制任务。更具体地说,输入与输出之间的间隔是变化的,无法通过静态卷积核建模。

总之,序列模型的效率与有效性权衡取决于状态压缩效果:高效模型必须具备小状态,而有效模型的状态必须包含上下文中的所有必要信息。由此,我们提出构建序列模型的核心原则:选择性,即基于上下文聚焦或过滤输入到序列状态的能力。具体而言,选择机制控制信息沿序列维度的传播或交互方式。

整合选择机制到模型

将选择机制整合到模型中的一种方法,是让影响序列交互的参数 (例如循环神经网络的循环动态或卷积神经网络的卷积核)依赖于输入

算法 1 和算法 2 展示了使用的主要选择机制。核心差异在于将 A、B、C 等多个参数设计为输入的函数,并相应调整张量形状。特别需要强调的是,这些参数现在包含长度维度 L,意味着模型从时间不变变为时变。这一变化导致模型不再等同于卷积。

克服现有方法局限的思路

总体而言,循环神经网络(如状态空间模型)始终需要在表达能力和速度之间权衡:隐状态维度越大的模型理论上效果越好,但速度越慢。因此,希望在不牺牲速度和内存的前提下,最大化隐状态维度;

循环模式比卷积模式更灵活,因为后者是通过展开前者推导而来。然而,这需要计算并存储形状为(B, L, D, N)的隐状态 h,其规模(乘以状态空间模型的状态维度 N)远大于形状为(B, L, D)的输入 x 和输出 y。因此,人们提出了更高效的卷积模式,无需计算状态,仅需存储大小为(B, L, D)的卷积核;

现有线性时间不变状态空间模型利用循环 - 卷积双重形式,将有效状态维度提升 N 倍(约 10-100),远大于传统循环神经网络,且无效率损失。

选择性扫描概述

选择机制的设计旨在克服线性时间不变模型的局限;同时,我们需要重新审视状态空间模型的计算问题。通过三种经典技术解决这一问题:核融合、并行扫描和重计算

朴素循环计算的计算量为 O (BLDN),而卷积计算的计算量为 O (BLD log (L)),且前者的常数因子更小。因此,对于长序列和不太大的状态维度 N,循环模式的计算量实际上更小; 面临的两大挑战是循环的顺序性和巨大的内存消耗。为解决后者,与卷积模式类似,可以尝试不直接存储完整状态 h。

核心思路是利用现代加速器(GPU)的特性,仅在更高效的存储层次结构中存储状态 h。具体而言,大多数操作(除矩阵乘法外)都受内存带宽限制。这包括扫描操作,通过核融合减少内存 IO,从而比标准实现显著提速。

具体来说,不将大小为(B, L, D, N)的扫描输入存储在 GPU 的高带宽内存(HBM)中,而是直接将状态空间模型参数(Δ、A、B、C)从慢速 HBM 加载到快速静态随机存取存储器(SRAM),在 SRAM 中执行离散化和循环计算,然后将大小为(B, L, D)的最终输出写回 HBM。

为避免顺序循环,发现尽管模型不再是线性的,但仍可通过高效并行扫描算法实现并行化。

最后,还需避免存储反向传播所需的中间状态。巧妙地应用经典的重计算技术以降低内存需求:不存储中间状态,而是在反向传播时,当输入从 HBM 加载到 SRAM 时重新计算中间状态。最终,融合的选择性扫描层与优化的 FlashAttention Transformer 实现具有相同的内存需求。

Mamba

H3(Hungry Hungry Hippos),它是此前基于状态空间模型(SSM)的代表性架构之一,核心设计是 "将线性注意力的循环逻辑与 SSM 结合":整体采用 "门控连接 + SSM 层 + 局部卷积" 的组合结构,具体来说,每个 H3 块会先通过一个 "移位 SSM"(本质是简单的局部卷积)处理局部序列信息,再接入核心的 S4 层(一种经典结构化 SSM)建模长程依赖,同时用门控机制(类似 RNN 的门控)控制信息的传递与过滤,最后还会与线性注意力的思路结合以增强灵活性。不过 H3 的局限在于架构相对复杂,且内部的 SSM 层是传统的 "线性时间不变(LTI)" 类型 ------ 参数不随输入变化,无法像 Mamba 那样根据内容动态选择信息,这也导致它在语言等离散模态上的表现不如 Transformer。

Gated MLP(门控多层感知机),它是 Transformer 架构中 MLP 块的常见增强形式,核心是 "用门控机制提升 MLP 的非线性表达能力":传统 MLP 通常是 "线性投影 + 激活函数" 的简单堆叠,而 Gated MLP 会在其中加入 "门控分支"------ 比如先将输入通过两个并行的线性层生成 "信号路" 和 "门控路",再通过逐元素相乘的方式让门控路动态控制信号路的信息权重(类似 "开关"),常见的 SwiGLU 就是 Gated MLP 的典型变体(用 SiLU 激活函数作为门控)。在 Transformer 中,Gated MLP 的作用是对注意力层输出的 "关联信息" 进行非线性变换,补充模型的局部特征建模能力,但它本身不具备长程依赖建模能力,必须依赖注意力层,这也导致 Transformer 整体仍受限于二次时间复杂度。

Mamba块的设计正是在 H3 和 Gated MLP 的基础上做了 "整合与简化":它将 H3 中 "门控 + SSM" 的长程建模逻辑,与 Gated MLP 的非线性门控思路,合并成一个统一的 "Mamba 块"------ 去掉了 H3 中冗余的局部卷积和线性注意力组件,同时把传统 SSM 升级为 "选择性 SSM"(带硬件感知优化),让单个块既能像 H3 那样建模长程依赖,又能像 Gated MLP 那样提供强非线性表达,最终实现了架构简化与性能、效率的平衡。

具体流程是:

(1)输入序列通过一个线性层将维度扩展(比如从原维度 D 扩展到 2ED,E 是扩展因子),生成两路并行的信号;其中一路信号会经过 SiLU 激活函数,构成 "门控分支"(对应 MLP 的非线性 能力),另一路则进入 "选择性状态空间层"(Mamba 块里的 SSM,正是改进后的 "带硬件感知状态扩展的选择性状态空间模型",负责处理序列的长程依赖 ,同时通过输入依赖的选择机制动态调整参数).

(2)两路信号会通过 "逐元素相乘" 的方式完成门控融合 ------ 让模型根据输入内容,自主控制长程信息的传递权重。

(3)融合后的结果会通过线性层将维度还原,并加上残差连接和 LayerNorm 归一化,完成一个 Mamba 块的计算。

步骤 描述 组件
输入分割 输入被分为两路:一路用于主干路径,另一路用于门控路径。 底部两个绿色梯形(Linear projection)
局部特征提取 一维卷积操作,用于从输入序列中提取局部特征。 Conv(蓝色框)
非线性激活 在卷积后应用非线性激活函数(如SiLU),使后续SSM参数具有非线性特性。 σ(非线性激活)
SSM参数生成 利用卷积层的输出动态生成选择性状态空间模型(SSM)的参数(B, C, Δ)。 Conv(蓝色框)
全局建模 使用选择性SSM进行全局依赖关系建模,其参数根据当前输入动态调整。 SSM(蓝色框)
信息过滤 SSM的输出与门控信号相乘,控制哪些信息应该被保留或抑制。 ⊗(乘法)
线性变换 对经过门控的信息进行最终的线性变换,准备输出。 顶部绿色梯形(Linear projection)
残差连接 将处理后的信息与原始输入相加,形成残差连接,增强模型的学习能力。 残差连接

这种设计的优势在于 "兼顾简洁性、高效性与表达力":既用一个块替代了 Transformer 的双块结构,让架构更易堆叠;又以选择性状态空间层的 "线性时间复杂度" 替代了注意力的二次复杂度,同时保留了长程建模能力;还通过门控融合结合了 SSM 的长程建模与 MLP 的非线性表达,最终实现了 "线性效率" 与 "Transformer 级性能" 的平衡。

相关推荐
koo3642 小时前
pytorch深度学习笔记
pytorch·笔记·深度学习
java1234_小锋4 小时前
基于Python深度学习的车辆车牌识别系统(PyTorch2卷积神经网络CNN+OpenCV4实现)视频教程 - 裁剪和矫正车牌
python·深度学习·cnn·车牌识别
koo3644 小时前
pytorch深度学习笔记1
pytorch·笔记·深度学习
慕ゞ笙5 小时前
2025年Ubuntu24.04系统安装以及深度学习环境配置
人工智能·深度学习
java1234_小锋5 小时前
基于Python深度学习的车辆车牌识别系统(PyTorch2卷积神经网络CNN+OpenCV4实现)视频教程 - 车牌矩阵定位
python·深度学习·cnn·车牌识别
_codemonster8 小时前
深度学习实战(基于pytroch)系列(三十六)循环神经网络的pytorch简洁实现
pytorch·rnn·深度学习
自然语8 小时前
人工智能之数字生命-学习的过程
数据结构·人工智能·深度学习·学习·算法
Yuezero_8 小时前
Research Intern面试(一)——手敲LLM快速复习
pytorch·深度学习·算法
Coding茶水间9 小时前
基于深度学习的火焰检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·计算机视觉