Learning to (Learn at Test Time): RNNs with Expressive Hidden States 论文阅读

前言

这篇文章也是和TTT相关的(虽然是rnn),我看完一遍之后的感觉就是--类似于上次看的CV领域的那一篇TTT结构的方法--本质上就是用一个辅助模型来帮助主模型处理一些训练集里没有的东西,同时,辅助模型的训练方式也都是自监督学习,在cv里用图像翻转,在nlp里就当然使用字符串遮盖预测了,总之 ,文章链接:https://arxiv.org/pdf/2407.04620

正文

作者想要解决的局限:RNN优秀的复杂度但对长文本无法处理

我先分两个方向讲一下上面提到的复杂度和对长文本的局限:

计算复杂度方面

  1. 线性复杂度的含义
    计算复杂度衡量的是随着输入规模的变化,算法运算量增长的趋势。对于循环神经网络(RNN)层而言,它具有线性复杂度,这意味着当输入序列的长度(比如文本中词元的数量)不断增加时,其计算量大致是按照和输入长度成正比的关系增长的。例如,如果输入序列长度变为原来的两倍,那么计算量也大约会变为原来的两倍。相较于自注意力机制那种二次方复杂度(输入长度翻倍,计算量变为原来四倍)来说,从计算量随输入规模增长这个角度,RNN在处理较长序列时,不会像自注意力机制那样面临计算量急剧膨胀的问题,在计算资源消耗的控制上相对更有优势。

长文本语境表现方面

  1. 隐藏状态及其作用
    在RNN中,隐藏状态起着关键作用。它可以看作是对之前输入信息的一种汇总和编码,每一步处理新的输入时,都会基于当前输入和上一时刻的隐藏状态来更新当前时刻的隐藏状态,从而让网络能够"记住"之前的一些信息,以便对后续输入进行更好的处理。比如在处理自然语言文本时,隐藏状态可以捕捉到句子前文的语义、语法等相关特征,来辅助理解后面出现的词和整个句子的意思。
  2. 表达能力的局限
    然而,RNN的隐藏状态存在表达能力的限制。在长文本语境中,随着输入序列越来越长,隐藏状态需要编码和汇总的信息越来越多,但是其自身的结构和特性决定了它难以有效地对这些海量且复杂的信息进行充分表示。比如说,长文本中可能存在远距离的语义关联、复杂的逻辑结构以及多种层次的语法关系等,RNN的隐藏状态很难把这些丰富的信息都准确地涵盖进去,会出现信息丢失或者表示不全面的情况。
  3. 对整体性能的影响
    由于隐藏状态没办法很好地处理长序列中的复杂信息,这就导致了RNN在长文本语境下的整体性能受限。例如在一些自然语言处理任务中,像语言生成、文本分类等,如果文本很长,RNN可能就无法精准地把握全文的关键语义信息,进而生成不符合逻辑的文本内容,或者在分类任务中做出错误的判断。而且随着序列长度进一步增加,这种性能下降的问题会越发明显,和那些在长语境中表现更好的模型(如Transformer等基于自注意力机制的模型,能更好地捕捉长距离依赖关系)相比,RNN在处理长文本相关任务时就显得力不从心了。

上面这两条反映了现有 RNN 的尴尬现实。一方面,RNN 的主要优势(相对于 Transformer)在于其线性(相对于二次)复杂度。这种渐近优势只有在实践中针对长上下文才能实现,根据图 3,这在 8k 之后。另一方面,一旦上下文足够长,现有的 RNN(如 Mamba)难以真正利用所依赖的额外信息。

所以作者想要提出一种新的结构来,能够结合RNN的计算复杂度的优势又能解决他的长文本乏力问题:

TTT layers. Motivated by this observation, we design a new class of sequence modeling layers where the hidden state is a model, and the update rule is a step of self-supervised learning. Because the process of updating the hidden state on a test sequence is equivalent to training a model at test time, this new class of layers is called Test-Time Training (TTT) layers. We introduce two simple instantiations within this class: TTT-Linear and TTT-MLP, where the hidden state is a linear model and a two-layer MLP, respectively. TTT layers can be integrated into any network architecture and optimized end-to-end, similar to RNNs layers and self-attention.

作者用这段话说明了他的方法---TTT层,这段话的核心思想就是:"将隐藏状态本身视为一个机器学习模型,并将更新规则视为自监督学习的一步。"这句话怎么理解呢?简单来说,为了在长文本中保持高效和表达力,我们需要一个更好的压缩启发式算法。具体来说,我们需要将数千甚至数百万个词元压缩成一个隐藏状态,该状态能够有效地捕捉其底层结构和关系。然后我们知道使用自监督训练的模型可以捕捉到其训练数据的底层结构和关系,大型语言模型本身就是很好的例子。它们通过自监督的下一个词预测任务进行训练,其权重可以被视为对互联网上现有知识的压缩存储形式。通过查询大型语言模型,我们可以从它们的权重中提取知识。更重要的是,大型语言模型通常表现出对现有知识之间语义联系的深刻理解,从而表达新的推理片段.

TTT层

核心思想阐述

  • 将隐藏状态作为机器学习模型:
    通常在传统模型里,隐藏状态更多是作为一种对过往输入信息进行汇总、编码的中间表示形式,起到传递信息的作用。但在这个新的思路中,把隐藏状态本身当作一个机器学习模型来看待,意味着它不再仅仅是被动承载信息,而是可以像常规的机器学习模型(如神经网络等)那样具备自主学习、对输入输出进行处理的能力。例如,它能够基于接收到的输入数据进行更复杂、灵活的特征提取和转换,就如同一个独立的小型预测模型在发挥作用。这里再讲的详细一点:
  1. 传统模型中隐藏状态的训练
    • 在传统的序列模型(如简单的RNN)中,隐藏状态确实是在训练过程中被更新的。但是,它的更新主要是基于当前输入和上一时刻的隐藏状态,通过一个相对固定的公式来实现。例如,在基本的RNN中, h t = f ( W x t + U h t − 1 + b ) h_t = f(Wx_t + Uh_{t - 1}+b) ht=f(Wxt+Uht−1+b),其中 h t h_t ht是当前时刻的隐藏状态, x t x_t xt是当前输入, h t − 1 h_{t-1} ht−1是上一时刻的隐藏状态, W W W、 U U U是权重矩阵, b b b是偏置项, f f f是激活函数。这个过程更像是一个信息的传递和简单的转换,隐藏状态主要是对过去输入信息的累积表示。
  2. 新模型中隐藏状态作为机器学习模型的区别
    • 主动学习能力
      • 在提出的新模型中,将隐藏状态本身视为一个机器学习模型,这意味着它具有更强的主动学习能力。它不再只是按照固定的公式被动地更新,而是像一个完整的机器学习模型一样,可以对输入进行更复杂的特征提取和处理。例如,隐藏状态可能会根据输入序列的不同模式,自主地调整自己的内部结构(类似于模型参数)来更好地适应这些模式。
    • 功能拓展
      • 传统隐藏状态主要是用于存储和传递信息,为后续的输出层提供一个中间的表示。而新的隐藏状态作为机器学习模型,可以直接对输入序列进行预测、分类等操作。比如,它可以在没有输出层参与的情况下,直接根据当前的输入和自身的状态,对序列中的下一个元素进行预测,这类似于一个独立的小型预测模型。
    • 学习方式的改变
      • 传统的隐藏状态更新是基于训练数据集中的标签信息(如果是有监督学习)或者基于重建输入(如果是无监督学习,如自编码器)。而新的隐藏状态作为机器学习模型,其更新规则是自监督学习的一步。这意味着它可以利用输入序列自身的结构特点,例如序列中的顺序关系、重复模式等,来自我更新和优化,而不依赖于外部提供的标签或者特定的重建目标。例如,它可以通过预测序列中的下一个元素,然后根据预测结果和实际的下一个元素之间的差异来进行自我更新,就像一个自监督的语言模型一样。
  • 更新规则作为自监督学习的一步:
    自监督学习是一种不需要人工标注大量数据标签,通过利用数据自身的结构特点来构造监督信息进行学习的方式。把隐藏状态的更新规则设定为自监督学习的一步,就是说隐藏状态在每一次更新的过程中,会像在自监督学习场景下那样,依据数据本身内在的关联和规律来进行自我调整、优化参数,以更好地适应输入序列的各种特征,使得隐藏状态能不断学习到更有效的信息表示方式,强化其对整个序列信息的编码能力。

自监督任务

Our key idea is to use self-supervised learning to compress the historic context

x1, . . . , xt

into a hidden state

st

, by making the context an unlabeled dataset and the state a model. Concretely, the hidden state

st

is now equivalent to

Wt

, the weights of a model

f

, which can be a linear model, a small neural network, or anything else. The output rule is simply:

以下是对这段话的详细理解:

整体核心思路

这段话阐述的核心想法是运用自监督学习的方式,对历史上下文信息(也就是序列中从 x 1 x_1 x1 到 x t x_t xt 这些元素所包含的信息)进行处理,将其压缩并整合进一个隐藏状态 s t s_t st 里。这里把历史上下文当作一个无标签的数据集来看待,同时将隐藏状态视为一个模型,通过这样的设定来实现信息的有效整合与表示。

关于隐藏状态作为模型及权重的理解

  1. 隐藏状态与模型权重的等价关系
    具体来讲,此时的隐藏状态 s t s_t st 等同于 W t W_t Wt,这里的 W t W_t Wt 是一个模型 f f f 的权重。这是一种很独特的视角转换,以往隐藏状态更多是作为一种中间的向量表示,存储之前输入信息的汇总情况,但现在它被赋予了等同于模型权重的角色。也就是说,这个隐藏状态不再仅仅是简单承载信息,而是像模型的权重参数一样,能够决定模型(这里的模型就是 f f f )对输入进行怎样的变换、处理以及特征提取等操作。这里顺便比较一下传统情况下的区别:
比较维度 隐藏状态(传统情况) 权重(传统情况)
定义与本质 是对过往输入信息进行汇总和编码的中间向量表示,随着输入序列的推进不断更新,动态反映不同时刻输入信息的累积情况。 是模型中需要学习的参数,定义了不同层之间(如输入到隐藏层、隐藏层到输出层等)的转换关系,是相对固定的数值(训练时调整,训练后基本固定)。
更新机制 在每个时间步,根据当前输入以及上一时刻的隐藏状态,按照特定的计算规则(如在RNN中依据相应公式)进行更新,主要是为了整合新输入与已有信息。 通过训练过程,基于训练数据、损失函数,利用优化算法(如梯度下降等)来不断调整其数值,目的是最小化损失,使模型输出更符合预期,训练结束后在测试阶段基本不再改变。
功能作用 起到传递历史信息的作用,便于后续时间步利用这些信息进行进一步计算,辅助模型理解整个序列的特征,最终服务于输出结果的生成。 决定了输入数据经过模型各层时如何进行变换、进行何种程度的特征提取以及如何映射到输出空间等,从根本上控制着模型的计算流程和对输入的处理方式。
在序列处理中的变化情况 会随着输入序列的逐个时间步不断变化,时刻反映序列中不同位置输入信息的综合情况,对于长序列来说其变化贯穿整个序列处理过程。 在训练阶段按优化算法逐步调整,训练结束后面对不同的输入序列(在测试阶段)通常保持稳定,不会因输入序列中的不同元素而动态变化(除非特殊的可动态调整权重的架构,但不是常规情况)。
与模型输出的关系 间接影响输出,作为中间环节,先汇总信息后参与到后续输出的计算中,其状态好坏影响最终输出是否能准确反映序列特征。 直接决定输出,不同的权重取值会使得相同输入经过模型后产生不同的输出结果,权重是控制输出的关键因素之一。
  1. 模型 f f f 的多样性
    这个模型 f f f 可以有多种形式,它可以是一个简单的线性模型,例如线性回归那样的形式,只通过线性变换来处理输入;也可以是一个小型的神经网络,像包含几个隐藏层的多层感知机,能够进行更复杂、非线性的特征映射;甚至可以是其他任意符合需求的模型结构。这种灵活性意味着可以根据具体的任务场景、数据特点等选择合适的模型形式来构建隐藏状态,使其具备相应的信息处理能力。

输出规则

从直觉上讲,输出的结果只是对xt的预测,该预测是由f使用更新后的权重Wt进行的。更新规则是针对某个自监督损失ℓ进行梯度下降的一步:这里学习率是η.

在这里作者说:

From the compression point of view, every heuristic needs to decide which input to remember or forget. Our

W

remembers inputs that produce large gradients -- intuitively, inputs that make

W

learn a lot.

怎么理解呢?

从压缩角度出发的一般性考量

在处理序列数据并尝试对其进行压缩表示时(就像把一段长的历史上下文信息压缩进一个隐藏状态那样),各种方法或者启发式策略(heuristic)都面临一个关键问题,那就是需要去决定哪些输入信息应该被记住,哪些又该被忘掉。这是因为不可能无限制地把所有输入都完整保留下来,一方面计算资源和存储资源有限,另一方面也没必要全部保留,需要提取最关键、最有价值的信息进行存储,以形成一种有效的压缩表示,便于后续利用这个表示进行序列相关的任务处理,比如预测、分类等。

关于"W"记住特定输入的机制

这里提到的"W"(也就是前面所讲的等同于隐藏状态的模型权重)有着独特的信息选择方式。它会记住那些能够产生较大梯度的输入。从直观层面理解,梯度在机器学习中通常与模型的学习过程密切相关,梯度反映了损失函数相对于模型参数(在这里就是"W")的变化率。

当某个输入使得"W"产生较大的梯度时,意味着这个输入对于"W"(也就是这个隐藏状态所代表的模型权重)的调整、学习有着重要的推动作用。可以想象成这个输入带来了很强烈的"信号",促使模型(通过调整"W")去更好地适应数据、学到更多有用的知识或者模式。所以,"W"会倾向于记住这样的输入,把它们所包含的信息保留下来,整合进自身的表示当中,而对于那些产生梯度较小的输入,相对来说就不太重要,可能就会被"忽略"或者说忘掉,以此实现一种基于学习重要性的信息筛选和压缩机制,让"W"所代表的隐藏状态能够聚焦于那些对模型学习有较大价值的输入信息,进而更好地完成后续的序列处理任务。

具体操作

这个其实很暴力,就是简单的端到端---直接针对下一个词预测的最终目标优化自监督任务。

内外环的问题

We refer to training the larger network as the outer loop, and training

W

within each TTT layer as the inner loop. An important difference between the two nested learning problems is that the inner-loop gradient ∇ℓ is taken w.r.t.

W

, the parameters of

f

, while the outer-loop gradient is taken w.r.t the parameters of the rest of the network, which we will denote by

θrest

. Throughout this paper, outer-loop parameters are always denoted by

θ

with various subscripts.

以下是对这段话的详细理解:

内外环学习的概念界定

涉及到了两个不同层面的训练过程,分别被定义为外环(outer loop)和内环(inner loop)。

  • 外环(outer loop):这里把对整个更大规模网络进行训练的过程称作外环。从整体网络架构层面出发,针对包含了多个部分、多个层次以及众多参数的完整网络进行的一种全局的训练操作。这个完整网络中除了我们重点关注的 TTT(测试时训练)层里面的相关元素外,还有其他很多组成部分,它们共同协作来完成各种任务,比如处理输入序列、生成合适的输出等,而外环训练就是要调整这些众多组成部分对应的参数,让整个网络的性能达到最优。

  • 内环(inner loop):与之相对应,在每个 TTT 层内部对 (W) (前面提到过 (W) 等同于 TTT 层隐藏状态所代表的模型 (f) 的参数)进行训练的这个过程被称为内环。也就是说,在 TTT 层这个相对独立又关键的局部范围内,有它自己专门的训练机制,聚焦于对自身隐藏状态所对应的参数 (W) 进行调整优化,使其能够更好地实现如压缩历史上下文信息、进行有效的序列建模等功能。

区别

二者重要区别体现在计算梯度的对象上,内环梯度是关于 (W) (即 TTT 层里隐藏状态对应的模型 (f) 的参数)计算的,通过分析损失函数 (\ell) 相对于 (W) 的变化来指导调整 (W) 参数,助力模型 (f) 在 TTT 层内更好地完成如更新隐藏状态、编码序列信息等工作;外环梯度则针对网络其余部分用 (\theta_{rest}) 表示的参数来计算,通过分析损失函数相对于这些参数的变化去调整它们,促使整个网络从整体架构层面有效处理输入、输出任务.

尾声

其实论文还有很多部分是在讲他的并行计算和提升效率的问题,但是好像和我没啥关系,就先这样吧()

相关推荐
小嗷犬2 小时前
【论文笔记】VisionZip: Longer is Better but Not Necessary in Vision Language Models
论文阅读·人工智能·语言模型·大模型·多模态
星夜Zn19 小时前
小语言模型综述(A Survey of Small Language Models)-全文中文翻译
论文阅读·人工智能·深度学习·语言模型·小语言模型
25 Hz1 天前
Mind 爱好者周刊 第6期 | 关于假设检验的贝叶斯因子(含R包)、高阶冥想期间的神经现象学、大脑中广泛的 β 网络、视觉和听觉审美具有不同的神经机制……
论文阅读·学习
MorleyOlsen1 天前
【经典论文阅读】Latent Diffusion Models(LDM)
论文阅读
zenpluck1 天前
GS-SLAM论文阅读--RGBDS-SLAM
论文阅读
薛定谔的短耳猫2 天前
如何写出一篇好的论文?
论文阅读·毕业设计·论文笔记·毕设
行然梦实3 天前
毕设记录_论文阅读(动磁式音圈电机的开发与应用)_20241207
论文阅读·课程设计
体系结构论文研讨会3 天前
【论文阅读】对计算机体系结构研究的一点认识
论文阅读
HollowKnightZ3 天前
论文阅读笔记:Adaptive Rotated Convolution for Rotated Object Detection
论文阅读·笔记·目标检测