20260605
最近读的一篇《Long Context Pre-Training with Lighthouse Attention》
1、要解决的核心问题是SDPA(缩放点积注意力)在训练长上下文的因果转换器(transformers : 当前大模型的基础架构,核心是自注意力机制)时的瓶颈问题:时间和空间复杂度是平方复杂度的(quadratic)。论文提出了一种Lighthouse Attention来解决这个问题。
2、Lighthouse Attention的核心是,是一种基于选择的分层注意力算法,这个算法包裹了一个普通的SDPA。并且算法只在训练时使用,在训练结束前可以很容易移除,从而在推理使用时,效果与算法无关,达到普通SDPA训练出来的效果。
这个过程使用的分层选择器是梯度无关的,使我们免于去处理复杂和低效的反向传播核(backward pass kernel)
这个算法的贡献主要包括三个方面:
(1)一个次平方的分层前处理和后处理步骤,对序列进行适应性地压缩和解压缩。
(2)一个对称压缩策略,可以同时池化(pool)Q(queries)、K(keys)、V(values)------在池化的同时保留序列从左到右的因果性,极大地提高了并发性。
(3)一个两阶段训练方法,大部分时间使用Lighthouse Attention进行预训练,并且在训练的尾声使用一个简短的训练恢复全注意力模型。
最后通过一个小规模的LLM预训练实验,对本文提出的方法和全注意力训练(full attention)进行验证对比,可以看到本方法在恢复阶段后实现了更快的总训练时间和更低的最终损失。
3、前面《大模型架构》里面简单了解过注意力机制的概念,QKV的计算,每个token映射到一个Q、K、V,QKV的维度跟token维度相等,token经过三个独立的可训练权重分别映射到Q、K、V。经过注意力训练之后,token转换为经过注意力机制计算后的向量,这个向量融合了上下文语义(计算了上下文关系),维度和token相等(同形状)。
缩放点积注意力SDPA的计算公式: Attention(Q,K,V) = softmax(QK^T/sqrt(d))*V 其中d是维度,除以根号维度是为了稳住梯度,防止softmax坍缩。其中Q是当前词,K是所有词的K,V也是所有词的V。从注意力公式可以看出将当前词和其他所有词的特征进行了运算,从而融合了与所有词的关系,最终输出的向量维度跟原token相同。这个计算方式就有一个问题,上下文太长时,总的token数量是很多的。于是提出了稀疏注意力和滑动窗口,只计算部分词之间的关系。但是滑动窗口会丢失一部分上下文关系,稀疏注意力如何选择锚点(锚点是指只有这部分token会计算与全部词的关系,其他非锚点只计算前后一部分关系,最简单的锚点选择方式是等间隔选取)?本文的lighthouse Attention就是稀疏注意力的一种,研究的是如何在保留上下文关系的同时提高性能。
注意力机制是大模型一个很重要的基础,通过这个算法融合了上下文关系,使得上下文关系可以训练,而不只是训练独立的各个词源。谷歌2017年的论文奠定了AI大模型的基础,那篇论文叫做《Attention Is All You Need》。
4、论文中还提到一种注意力类型,相比SDPA也更高效,叫做闪存注意力(FlashAttention),但是没有根本解决问题。方式是把Q K V切成小块,分块算,从而不会在运算过程产生NN矩阵(完整的QK就是个N*N矩阵,N是上下文序列的词元数量),从而每小块可以适配显存,可以利用高速SRAM进行运算。涉及到反向重算,计算量上升,但是显存要求下降,速度也更快,支持的序列长度更长。
有许多不同的算法和处理方式对稠密注意力进行了优化,包括deepseek的方式DSA,HISA等。可以再分别去看看相关的研究。文章说这些方法有两个缺点:非对称性,只对K和V进行池化(pool),而Q没有,所以代表性不够全面?架构纠缠,选择器耦合到了注意力核中,导致经过当前流行的张量核GPU加速的、精心优化过的稠密注意力核无法被再使用。每个稀疏算法都要有自己的核(kernel)。
5、进一步介绍论文的工作内容:Lighthouse Attention是一种基于选择的分层注意力,通过一个多层金字塔对Q、K、V对称池化,使用一个无参数打分器对金字塔的每个节点进行双向打分,然后使用一个分块双调排序方式选择top-k节点。被选择的节点形成一个稠密、连续的子序列。top-k这一步不可微,梯度直接回流不受影响,整个过程没有加入辅助参数或其他损失。有两个结论:对称的金字塔是对背景的完整的、多角度的表示,而没有被压缩。存储注意力的尺寸可以压缩到O(NlogN),其中N是上下文序列长度。
再次总结三个主要贡献:
(1)一个为了长上下文训练而设计的基于选择的分层注意力,采用了对称Q/K/V池化,双向Top-k选择,在一个聚集的子序列存储FlashAttention,将稀疏逻辑隔离在注意力核(attention kernel)之外
(2)融合的GPU核,超长上下文场景下,这个设计显著加快
(3)我们的认知里训练时分层方法里最具有实践的标准:Lighthouse预训练后的密集SDPA与从头开始密集训练的基线,性能相当。
后面是各级的具体实现,总共四个点
6、构建金字塔:金字塔总共有L层,最底层是原始序列,池化窗口大小为1。设定一个池化因子p,往上每一层的窗口大小就是下面一层窗口大小*P,窗口不会重叠,窗口框住下面一层的元素,做平均池化,算平均值,作为这一层的窗口的Q、K、V(重新进行了计算,所以每个窗口框住的不管是原始序列的几个元素,算出来都是一个节点,一个QKV)。
7、评分和选择:每个金字塔节点计算出两个标量分数,一个作为query,一个作为key,在第0层(最底层),使用L2标准化计算标量分数,L2标准化是向量的长度,向量所有值平方和的平方根。其中query是用Q向量的L2标准化计算,key是用K向量的L2标准化计算。
往上粒度更大的层级,则直接取下一层窗口内的最大值,范围是当前节点窗口框住的下一层的节点,取其中的最大值。
整个金字塔所有的标量分数放在一起,通过分块双调核选择最大的K个节点,选中的节点是把金字塔节点原始的QKV往后传(不是传标量分数值)。粒度最大的层级我们总是全部保留,因为代价小,而且可以保证原始原素最少有一次参与注意力过程。
8、聚集序列注意力:通过L函数选择topk形成一个连续子序列之后,取对应的QKV,总共取S个节点(topk取了S个),其中S的长度是最上层全部节点共N/p^(L-1)个,其他每一层取pk个节点,共(L-1)*pk个。
最后形成的QKV进行一次SDPA注意力计算,并且加一个下三角Mask矩阵掩码,确保每个词不会看到未来的内容,只与前面的内容有关。
9、反向散射重建(Scatter-back Reconstruction):注意力运算的结果O~长度S,要重建完整的N-token输出O。
反向构建的第j个元素如何构建?先设定一个滑动窗口,与构建金字塔时的窗口规则相同,只是范围是向右移动p^l-1,也就是移动一个窗口的大小。如果元素j落到这个窗口里,则每一层这个窗口的左边窗口的O~求和得到元素j的最终输出O。这样保证了每个元素的总结不会包括未来的内容,因为都是左边的窗口求和。(虽然我想到,这样的方式可能会导致最前面的元素没有人参与构建,但是最前面的元素本身就没有太多关联关系,后面才要依赖前面建立的关系去预测。)
10、以上就是整个过程的4个关键设计。后面还有一些核设计和实验内容没有看完,先总结这么多。
整个过程看到其实就是在尝试建模,理解真实世界的关联,然后通过一些线性非线性的关系对元素之间建立连接,然后通过训练来找到一个最佳参数。关系的剖析和建模的想象力都很重要。想起来大学时候的数学建模比赛,那时候觉得是天马行空,可是世界就是这么天马行空,尽可能把自己理解到的世界关系表达出来,不管通过什么方式,线性的、指数的、对数的、积分、微分,这个世界可能真的就是这样,被函数关系构建起来的。
后面看《Attention is All You Need》。知道了原来多头注意力不是随便对矩阵做分块的,是直接训练多个W映射,将输入映射到多个低纬度的QKV,所以虽然是低纬度的QKV,但还是包括了输入的全部内容,然后通过不同的W参数,构建出来多块QKV可能包含了原始输入的很多不同层面的特征。很妙。