【论文阅读】xLSTM: Extended Long Short-Term Memory

xLSTM: Extended Long Short-Term Memory

引用: Beck M, Pöppel K, Spanring M, et al. xLSTM: Extended Long Short-Term Memory[J]. arXiv preprint arXiv:2405.04517, 2024.

论文链接: [2405.04517] xLSTM: Extended Long Short-Term Memory (arxiv.org)

作者: Maximilian Beck, Korbinian Pöppel, Markus Spanring, Andreas Auer, Oleksandra Prudnikova, Michael Kopp, Günter Klambauer, Johannes Brandstetter, Sepp Hochreiter

机构: ELLIS Unit, LIT AI Lab, Institute for Machine Learning, JKU Linz, Austria; NXAI Lab, Linz, Austria; NXAI GmbH, Linz, Austria

文章目录

  • [xLSTM: Extended Long Short-Term Memory](#xLSTM: Extended Long Short-Term Memory)
    • 摘要
    • 引言
    • xLSTM架构
      • [1. **sLSTM(Scalar LSTM)**](#1. sLSTM(Scalar LSTM))
      • [2. **mLSTM(Matrix LSTM)**](#2. mLSTM(Matrix LSTM))
      • [3. **xLSTM块(xLSTM Blocks)**](#3. xLSTM块(xLSTM Blocks))
      • [4. **xLSTM架构(xLSTM Architecture)**](#4. xLSTM架构(xLSTM Architecture))
    • 实验
    • 结论

摘要

  • 论文提出了xLSTM,一种扩展的长短期记忆网络,旨在解决传统LSTM的局限性,并在大规模参数下进行语言建模。
  • xLSTM引入了指数门控和适当的归一化与稳定技术,修改了LSTM记忆结构,包括标量记忆的sLSTM和完全可并行化的具有矩阵记忆和协方差更新规则的mLSTM。
  • 通过将这些LSTM变体集成到残差块中,构建了xLSTM架构,这些架构在性能和扩展性方面与最先进的Transformers和状态空间模型相媲美。

引言

  • LSTM自1990年代引入以来,在多个领域取得了成功,特别是在大型语言模型(LLMs)中。
  • 引入Transformer技术后,其并行化的自注意力机制使得LSTM在大规模应用中的性能受到挑战。
  • 论文提出了一个问题:当LSTM扩展到数十亿参数,并结合现代LLMs的最新技术,同时克服LSTM的已知限制时,我们能在语言建模中走多远?

xLSTM架构

1. sLSTM(Scalar LSTM)

指数门控是sLSTM中的一个创新点,它允许模型更有效地更新其记忆状态。在传统的LSTM中,门控机制通常涉及sigmoid函数,但在xLSTM中,输入门( i t i_t it)和遗忘门( f t f_t ft)可以具有指数激活函数:

c t = f t c t − 1 + i t z t c _ { t } = f _ { t } c _ { t - 1 } + i _ { t } z _ { t } ct=ftct−1+itzt

n t = f t n t − 1 + i t n _ { t } = f _ { t } n _ { t - 1 } + i _ { t } nt=ftnt−1+it

h t = o t h t ~ , h t ~ = o t / n t h _ { t } = o _ { t } \tilde{h _ { t }}, \quad \tilde{h _ { t }} = o _ { t } / n _ { t } ht=otht~,ht~=ot/nt

z t = φ ( z ~ t ) , z ~ t = w z T x t + r z h t − 1 + b z z _ { t } = \varphi ( \tilde { z } _ { t } ), \quad \tilde { z } _ { t } = w _ { z } ^ { T } x _ { t } + r _ { z } h _ { t - 1 } + b _ { z } zt=φ(z~t),z~t=wzTxt+rzht−1+bz

i t = e x p ( i ~ t ) , i ~ t = w i T x t + r i h t − 1 + b i i _ { t } = exp ( \tilde { i } _ { t } ), \quad \tilde { i } _ { t } = w _ { i } ^ { T } x _ { t } + r _ { i } h _ { t - 1 } + b _ { i } it=exp(i~t),i~t=wiTxt+riht−1+bi

f t = σ ( f ~ t ) O R e x p ( f ~ t ) , f ~ t = w f T x t + r f h t − 1 + b f f _ { t } = \sigma ( \tilde { f } _ { t } ) \quad OR \quad e x p ( \tilde { f } _ { t } ), \quad \tilde { f } _ { t } = w _ { f } ^ { T } x _ { t } + r _ { f } h _ { t - 1 } + b _ { f } ft=σ(f~t)ORexp(f~t),f~t=wfTxt+rfht−1+bf

o t = e x p ( o ~ t ) , o ~ t = w o T x t + r o h t − 1 + b o o _ { t } = exp ( \tilde { o } _ { t } ), \quad \tilde { o } _ { t } = w _ { o } ^ { T } x _ { t } + r _ { o } h _ { t - 1 } + b _ { o } ot=exp(o~t),o~t=woTxt+roht−1+bo

指数激活函数可能导致较大的值,从而导致溢出。因此,用一个额外的状态 m t m_t mt来稳定门:

m t = max ⁡ ( log ⁡ ( f t ) + m t − 1 , log ⁡ ( i t ) ) m _ { t } = \max ( \log ( f _ { t } ) + m _ { t - 1 } , \log ( i _ { t } ) ) mt=max(log(ft)+mt−1,log(it))

i t ′ = e x p ( log ⁡ ( i t ) − m t ) = e x p ( i ~ t − m t ) i _ { t } ^ { \prime } = e x p ( \log ( i _ { t } ) - m _ { t } ) = e x p ( \tilde { i } _ { t } - m _ { t } ) it′=exp(log(it)−mt)=exp(i~t−mt)

f t ′ = e x p ( log ⁡ ( f t ) + m t − 1 − m t ) f _ { t } ^ { \prime } = e x p ( \log ( f _ { t } ) + m _ { t - 1 } - m _ { t } ) ft′=exp(log(ft)+mt−1−mt)

其中,$m_t​ $是稳定状态,用于防止梯度爆炸。

同时,sLSTM引入了新的记忆混合技术,允许在多个内存单元之间进行更复杂的交互。多个存储器单元使得能够分别经由从隐藏状态向量 h h h到存储器单元输入 z z z和门 i i i、 f f f、 o o o的循环连接 R z R_z Rz、 R i R_i Ri、 R f R_f Rf、 R o R_o Ro进行存储器混合。sLSTM可以有多个头,每个头内混合内存,但不能跨头混合。

2. mLSTM(Matrix LSTM)

mLSTM使用矩阵记忆来增强存储容量,并通过协方差更新规则来存储关键值对。

C t = f t C t − 1 + i t v t k t T C _ { t } = f _ { t } C _ { t - 1 } + i _ { t } v _ { t } k _ { t } ^ { T } Ct=ftCt−1+itvtktT

n t = f t n t − 1 + i t k t n _ { t } = f _ { t } n _ { t - 1 } + i _ { t } k _ { t } nt=ftnt−1+itkt

h t = o t ⊙ h ~ t , h ~ t = C t q t / max ⁡ { ∣ n t T q t ∣ , 1 } h _ { t } = o _ { t } \odot \tilde { h } _ { t } , \quad \tilde { h } _ { t } = C _ { t } q _ { t } / \max \left\{ | n _ { t } ^ { T } q _ { t } | , 1 \right\} ht=ot⊙h~t,h~t=Ctqt/max{∣ntTqt∣,1}

q t = W q x t + b q q _ { t } = W _ { q } x _ { t } + b _ { q } qt=Wqxt+bq

k t = 1 d W k x t + b k k _ { t } = \frac { 1 } { \sqrt { d } } W _ { k } x _ { t } + b _ { k } kt=d 1Wkxt+bk

v t = W v x t + b v v _ { t } = W _ { v } x _ { t } + b _ { v } vt=Wvxt+bv

i t = e x p ( i ~ t ) , i ~ t = w i T x t + b i i _ { t } = e x p ( \tilde { i } _ { t } ) , \quad \tilde { i } _ { t } = w _ { i } ^ { T } x _ { t } + b _ { i } it=exp(i~t),i~t=wiTxt+bi

f t = σ ( f ~ t ) O R e x p ( f ~ t ) , f ~ t = w f T x t + b f f _ { t } = \sigma ( \tilde { f } _ { t } ) \quad OR \quad exp(\tilde { f } _ { t }), \quad \tilde { f } _ { t } = w _ { f } ^ { T } x _ { t } + b _ { f } ft=σ(f~t)ORexp(f~t),f~t=wfTxt+bf

o t = σ ( o ~ t ) , o ~ t = w o T x t + b o o _ { t } = \sigma ( \tilde { o } _ { t } ) , \quad \tilde { o } _ { t } = w _ { o } ^ { T } x _ { t } + b _ { o } ot=σ(o~t),o~t=woTxt+bo

3. xLSTM块(xLSTM Blocks)


xLSTM块结合了sLSTM和mLSTM的特性,并通过残差连接来进一步提高性能。对于残差sLSTM块,输入首先进入sLSTM,然后是一个门控的多层感知机(MLP)。对于残差mLSTM块,输入首先通过两个MLP,然后是mLSTM,通过卷积、可学习的跳跃连接和输出门。

4. xLSTM架构(xLSTM Architecture)

xLSTM架构通过残差堆叠xLSTM块来构建,利用了预层归一化(preLayerNorm)残差骨干网络。

实验



  • 论文在合成任务和长距离竞技场(Long Range Arena)上测试了xLSTM,并与其他方法进行了比较。
  • 在SlimPajama数据集上进行了语言建模实验,比较了不同方法的性能。
  • 进行了扩展实验,训练了更大的模型,并在更多的训练数据上评估了它们的扩展行为。

结论

  • xLSTM在语言建模方面的表现至少与当前的Transformer或状态空间模型相当。
  • xLSTM有潜力在强化学习、时间序列预测或物理系统建模等深度学习领域产生重大影响。
相关推荐
DevinLGT23 分钟前
6Pin Type-C Pin脚定义:【图文讲解】
人工智能·单片机·嵌入式硬件
宋一诺3327 分钟前
机器学习—高级优化方法
人工智能·机器学习
龙的爹233339 分钟前
论文 | The Capacity for Moral Self-Correction in LargeLanguage Models
人工智能·深度学习·机器学习·语言模型·自然语言处理·prompt
Mr.简锋42 分钟前
opencv视频读写
人工智能·opencv·音视频
Baihai_IDP42 分钟前
「混合专家模型」可视化指南:A Visual Guide to MoE
人工智能·llm·aigc
笨小古1 小时前
路径规划——RRT-Connect算法
算法·路径规划·导航
<但凡.1 小时前
编程之路,从0开始:知识补充篇
c语言·数据结构·算法
寰宇视讯1 小时前
“津彩嘉年,洽通天下” 2024中国天津投资贸易洽谈会火热启动 首届津彩生活嘉年华重磅来袭!
大数据·人工智能·生活
f狐0狸x1 小时前
【数据结构副本篇】顺序表 链表OJ
c语言·数据结构·算法·链表
Light601 小时前
低代码牵手 AI 接口:开启智能化开发新征程
人工智能·python·深度学习·低代码·链表·线性回归