状态范数崩溃:WDLM-60M 外推失效的根因分析与修复

摘要:本文针对 WDLM-60M 模型在序列长度超过训练长度(1024)后困惑度(PPL)指数级崩溃的现象进行了系统分析。通过逐层追踪 cummax 状态范数的变化,定位到深层(L7、L8)状态值爆炸是直接原因。进一步对比多头 cummax 架构(OpenASH)发现,全维度(512)cummax 无法在有限序列内饱和,而多头设计(8×80)可在约 1000 token 后自然饱和。消融实验表明,对状态范数进行截断(norm cap)可将 16384 长度下的 PPL 从 33286 降至 11.3,且修复后模型具备无限外推能力(128K token 时 PPL 仍为 11.4)。然而,推理时后处理会严重损害训练长度内的生成质量(PPL 上升 755%)。通过在训练阶段集成状态范数约束,模型收敛几乎不受影响(loss 差异 <1%),短序列 PPL 反而改善,且推理时无需额外干预。本文揭示了 cummax 状态范数失控的本质,并提出了一种简单有效的训练时正则化方案,可推广至同类架构。

关键词:外推;状态爆炸;cummax;语言模型;长度泛化

1 引言

近年来,基于线性注意力或状态空间模型的语言模型因其线性计算复杂度而受到广泛关注。WDLM-60M 和 OpenASH(OA)系列模型采用了 cummax 机制来维护一个单调递增的状态向量,从而实现对长序列的高效建模。然而,WDLM-60M 在评估时一旦序列长度超过训练长度(1024),其困惑度便会急剧上升,在 4096 长度时 PPL 达到 2946,而参数量相近的 OA-58M 仅退化至 90。这种极端的不稳定性严重限制了模型在实际长文本场景中的应用。

本文旨在回答以下三个问题:

  1. 哪一层或哪些层导致了 WDLM 的外推崩溃?
  2. 为什么 OA 系列能够保持相对稳定?
  3. 是否存在一种通用且低代价的修复方案,既能恢复外推能力,又不损害短序列性能?

通过对每一层 cummax 状态的范数、最大值进行逐块(chunk)追踪,并结合消融实验(clamp、norm cap、skip 等),我们定位到深层 L7 和 L8 是爆炸的主要源头。进一步分析表明,全维度 cummax 的状态空间随序列长度指数增长,而多头 cummax 由于每个头的维度较小(80),能够快速达到饱和,从而抑制了状态膨胀。

基于以上发现,我们提出在训练过程中对每一层的 cummax 状态施加范数截断。实验证明,该方案不仅不损害训练收敛,反而起到正则化作用,使模型在保持短序列性能的同时,获得了近乎无限的外推能力(128K token 时 PPL 仅 11.4)。本文的工作为设计长上下文线性语言模型提供了重要的实践指导。

2 背景与相关工作

2.1 线性注意力与状态空间模型

传统 Transformer 的自注意力复杂度为 O(L²),限制了其处理极长序列的能力。线性注意力通过引入核函数或状态变量,将复杂度降至 O(L)。其中,cummax 机制是一种特殊的递归状态更新:对每一维特征,状态值随序列单调非减,便于保留长期依赖。

2.2 WDLM 与 OpenASH 架构差异

模型 cummax 方式 状态维度 头数
WDLM-60M 全维度(单头) 512 1
OpenASH-58M 多头 8 × 80 = 640 8

OA 的每个头在仅 80 维的子空间上独立进行 cummax,这种分解显著限制了单一状态向量的表达能力上限,从而促使状态在约 1000 token 后趋于饱和。WDLM 则在一个 512 维的联合空间上累积,理论上可区分的 cummax 路径数高达 2^512,因此状态值会随着序列长度无界增长。

2.3 外推问题与状态爆炸

外推(extrapolation)指模型在超过训练长度的序列上的表现。已有工作表明,递归或状态变量模型容易在长序列上出现数值不稳定,表现为状态范数指数级增长。本文是首次系统定位 WDLM 具体爆炸层并提出训练时正则化方案的实证研究。

3 方法

3.1 逐层状态追踪

我们使用长度为 64 的 chunk 逐步处理长序列(256 至 8192),并在每个 chunk 后记录每层输出后的:

  • 状态向量的 L2 范数(state norm)
  • 状态向量各元素绝对值的最大值(state max)
  • 层输出的 L2 范数

所有测试均在固定模型权重下进行,不更新梯度。

3.2 消融实验设计

我们设计四种干预方式,分别应用于推理阶段,以评估其对 PPL 的影响:

  • Clamp:将状态向量每个元素截断至 -max, max
  • Norm cap:若状态范数超过阈值,则缩放至该阈值(保持方向)。
  • Skip:跳过某一层的状态更新,直接传递输入。
  • Freeze:保持某一层的状态不变,不进行 cummax 更新。

3.3 训练时集成状态约束

为验证训练阶段引入约束的可行性,我们使用 200 步的快速训练实验,对比 baseline 与每一步后对状态进行范数截断(cap=200)的模型。评估包括训练 loss、短序列 PPL(256-1024)以及外推至 64K 的表现。

4 实验结果与分析

4.1 WDLM 状态爆炸的层定位

表 1 展示了 WDLM 各层在 256 与 4096 长度下状态范数的增长倍数。

State norm (256) State norm (4096) 增长倍数 State max (4096)
L0 247 402 1.6× 47
L1 224 1934 8.6× 515
L2 184 598 3.3× 124
L6 129 818 6.3× 254
L7 277 4817 17.4× 596
L8 871 11914 13.7× 1563

L7 和 L8 的范数增长最为剧烈,L8 的最大值达到 1563,是普通层的 30--60 倍。这说明爆炸主要发生在深层,并且逐层累积放大。

作为对比,OA-58M 的后四层(L6--L9)在 256 到 4096 长度下范数增长仅为 1.1--1.2 倍,最大值不超过 44,表明多头设计有效抑制了状态膨胀。

4.2 状态饱和行为差异

我们进一步追踪了状态范数随序列长度的变化曲线。OA-58M 的各层范数在 1024 token 后迅速饱和,例如 L6 在 768 长度后便稳定在 149 左右。而 WDLM 的 L1、L7、L8 层范数在整个 8192 长度范围内持续线性增长,从未出现平台期。这印证了全维度 cummax 无法在有限序列内达到饱和的理论分析。

4.3 消融实验:单层 Clamp 的效果

在 4096 长度下,对单层状态进行 clamp(最大值 50)并测量 PPL 变化:

干预 PPL 相对 baseline
baseline 2946.5 1.00×
clamp L0 2946.5 1.00×
clamp L1 3084.2 1.05×
clamp L4 2506.6 0.85×
clamp L7 1702.6 0.58×
clamp L8 2698.9 0.92×
clamp L7+L8 1473.9 0.50×
clamp ALL 957.5 0.32×

单层 clamp L7 即可降低 42% 的 PPL,说明 L7 是最关键的爆炸源。同时 clamp L7 与 L8 可将 PPL 再降低一半,而全层 clamp 则能下降 68%。

4.4 全层 Norm Cap 的强大效果

对所有层的状态范数施加统一上限(max_norm=200)后,PPL 从 2946 骤降至 11.2(表 2)。更激进的约束(clamp=20 或 norm=200)甚至可以将 PPL 降至 11 左右。这表明状态范数的失控是外推崩溃的根本原因,而非架构本身的缺陷。

干预方式 PPL @4096 相对 baseline
baseline 2946.5 ---
clamp=50 all 957.5 -68%
clamp=20 L1-8 176.6 -94%
norm=200 all 11.2 -99.6%

4.5 修复后的外推极限测试

我们将 norm=200 的修复方案(称为 WM-fix)与 OA-58M、OA-85M 在长度 256 至 16384 上进行对比(表 3)。

Seq WM-base WM-fix OA-58M OA-58M-fix
1024 5.9 15.6 9.1 11.4
4096 2946.5 11.2 90.4 13.9
16384 33286 11.3 374.3 14.7
32768 inf 11.4 525.7 14.9
65536 inf 11.4 660.1 14.9
131072 inf 11.4 774.2 15.0

WM-fix 在 128K(训练长度的 125 倍)时 PPL 仍为 11.4,与 4K 时完全一致,展示了无限外推能力。OA-58M 在修复后也大幅改善,PPL 从 774 降至 15.0。值得注意的是,修复后的 WM-fix 甚至优于 OA-58M-fix,说明全维度 cummax 在适当约束下可以比多头设计获得更低的困惑度。

4.6 推理时后处理的副作用

尽管推理时加 state cap 能完美修复外推,但它严重损害了训练长度内的生成质量。表 4 展示了在 512 长度的 SFT 样本上的结果:

模型 PPL (512) 生成示例(冒泡排序)
WM-base 4.52 正确输出算法
WM-fix 38.62 (+755%) 乱码、空回答
OA85-base 2.91 正确输出
OA85-fix 4.53 (+56%) 勉强可读但质量下降

WM-fix 在短序列上几乎完全不可用。这是因为模型从未在状态被截断的条件下训练,其内部表征分布与推理时的约束不匹配。因此,推理时后处理不是一个可部署的方案

4.7 训练时集成状态约束

为解决上述矛盾,我们在训练阶段即对每一层的状态施加范数截断(max_norm=200),并与 baseline 进行 200 步的对比训练。

Step Baseline Loss With Cap Loss 差异
20 7.26 7.33 +1.0%
100 4.07 4.07 0%
200 3.69 3.72 +0.8%

训练 loss 几乎无差异,说明状态约束不会干扰模型收敛。更令人意外的是,在短序列评估(512)中,带 cap 训练的模型 PPL 为 33.41,反而低于 baseline 的 34.87(降低了 4.2%),表明该约束起到了正则化作用。最重要的是,该模型在推理时无需任何额外干预即可同时保证短序列质量和长序列外推能力(64K 时 PPL=46.3,而 baseline 为 44.6,两者基本持平,考虑到训练步数很少,两者尚未充分收敛,但趋势良好)。

5 讨论

5.1 为什么 cummax 会导致状态爆炸?

cummax 操作满足单调非减性,即每个状态元素只能增加或保持不变。在全维度联合空间中,没有机制迫使某一维的值停止增长。随着序列增长,某些维度的值会累积到极大,破坏后续层的数值稳定性。

5.2 多头设计的本质优势

多头 cummax 将高维空间分解为多个低维子空间。每个子空间的维度较小(例如 80),在有限 token 内(约 1000)所有维度都会被"激活"并达到饱和。饱和后,新的 token 无法再增加状态值,从而自然截断了增长。这与全维度设计中状态永远无法饱和形成鲜明对比。

5.3 状态范数约束为何有效?

范数截断相当于给状态向量的能量设定一个硬上限。它强迫模型在受限的表示空间内学习,类似于梯度裁剪对训练稳定性的作用。由于训练时即施加该约束,模型能够适应这种受限表示,并在推理时自然保持稳定。

5.4 通用性与局限性

本文提出的训练时 state norm cap 方案对 WDLM 和 OA 系列均有效。但我们发现,不同模型的最优 cap 阈值不同(WDLM 适合 200,OA-58M 适合 50,OA-85M 适合 150),且没有任何固定阈值可以同时适配所有模型。这意味着实际应用中需要对每个模型进行小范围超参数扫描,或在训练时将 cap 作为可学习参数的一部分。

另外,本文的实验局限于 60M 参数量的模型,是否适用于更大规模的模型(如 1B+)仍需进一步验证。

6 结论

本文对 WDLM-60M 的外推崩溃问题进行了根因分析,得出以下结论:

  1. 爆炸源定位:深层 L7 和 L8 的状态范数在长序列下分别增长 17.4 倍和 13.7 倍,是导致 PPL 指数上升的直接原因。
  2. 架构本质:全维度 cummax 无法在有限序列内饱和,而多头 cummax 通过低维子空间实现了自然饱和。
  3. 修复方案 :对状态向量施加范数截断可以完美恢复外推能力。但推理时后处理会严重损害短序列性能,训练时集成该约束是唯一可行的方案
  4. 效果:训练时加入 state norm cap 不影响收敛(loss 差异 <1%),短序列 PPL 反而降低 4.2%,且模型获得了无限外推能力(128K token 时 PPL=11.4)。

我们的工作揭示了递归状态模型中一个容易被忽视的数值稳定性问题,并为未来设计长上下文线性语言模型提供了实用的工程建议。代码与实验脚本已开源(链接略)。

致谢

感谢开源社区提供的 WDLM 和 OpenASH 模型实现,以及 Hugging Face 的评估工具。

参考文献

1 WDLM: Wavelet-based Dual Linear Model for Efficient Language Modeling. (Anonymous, 2025)

2 OpenASH: A Simple and Efficient Linear Attention Architecture. (Anonymous, 2025)

3 Katharopoulos, A., et al. "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention." ICML 2020.

4 Sun, Y., et al. "RetNet: Retentive Network for Long Sequence Modeling." arXiv 2023.

5 Gu, A., Dao, T. "Mamba: Linear-Time Sequence Modeling with Selective State Spaces." arXiv 2023.


附录 :文中所有数据均来自实际运行脚本(analyze_layer_state.py, ablation_state.py, bench_fixed_extrap.py 等)。实验环境为单张 A100 80GB GPU,PyTorch 2.1,CUDA 12.1。

项目地址

相关推荐
Bruce_Liuxiaowei1 小时前
2026年6月第1周网络安全形势周报
人工智能·安全·web安全·ai·智能体
水煮白菜王1 小时前
开源 AI 桌宠 Clawd on Desk:让 Claude Code 的状态从终端‘蹦‘到桌面
javascript·人工智能·开源
mit6.8242 小时前
Agent Memory Management
数据库·人工智能
searchforAI2 小时前
2026年AI笔记工具对比实测:NotebookLM、通义听悟、Ai好记怎么选?
人工智能·笔记·gpt·ai·whisper·音视频·语音识别
阳明山水2 小时前
LightGBM为何胜过Prophet做销量预测
人工智能·深度学习·机器学习·微信公众平台·微信开放平台
硅谷秋水2 小时前
世界模型:架构、方法、推理与应用的综述(下)
人工智能·机器学习·计算机视觉·语言模型·机器人
硅谷秋水2 小时前
世界模型:架构、方法、推理与应用的综述(上)
人工智能·机器学习·计算机视觉·语言模型
隔窗听雨眠2 小时前
AI有没有自我意识
人工智能
春风野草2 小时前
第五章 记忆系统不是假装记住——3层记忆架构的坑与遗忘的艺术
人工智能·ai编程