HybridFRSM 顺序头对比实验报告

在 DSpark-RWKV 的 Stage 1 合成任务框架下,将 HybridFRSM(快线性递推 + 慢内容门控) 适配为

DSpark 顺序头,与 NoHead / GRU / RWKV-7 进行等参数量等架构公平对比。

1. 实验目标

验证 HybridFRSM 作为 DSpark 半自回归推测解码的顺序头(Sequential Head),在 block 内多 token 预测任务上:

  1. 能否达到与 GRU/RWKV-7 同等或更高的 top-1 接受率;
  2. 参数量严格对齐的前提下,优势是否仍然成立(排除"靠堆参数取胜"的疑问);
  3. 在需要同时记忆多个 key 的更难任务上,FRSM 的快慢尺度分离架构是否有结构性优势。

2. 参考代码与来源

组件 来源 说明
DSpark-RWKV Stage 1 框架 https://github.com/cgisky1980/dspark-rwkv stage1_experiment.py:提供 NoHead / GruHead / Rwkv7HeadV2、数据生成、训练与评估 run()
原 DSpark (DeepSpec) https://github.com/deepseek-ai/DeepSpec DSpark 半自回归推测解码原始论文实现(DeepSeek × 北京大学)
RWKV-7 参考 https://github.com/BlinkDL/RWKV-LM RWKV-7 Delta Rule / DPLR 公式参考
HybridFRSM 实现 本地F:\OpenASH2605\frsm_linear.py 快尺度 parallel scan + 慢尺度内容门控的分形递归状态机
本实验测试代码 本地F:\dspark-rwkv\stage1_frsm_compare.py FrsmHead 适配器 + 等参对比脚本(完整源码见文末附录)

3. 实验环境

  • Python 解释器F:\OpenASH\.venv\Scripts\python.exe(Python 3.13.3)
  • PyTorch :2.12.0 + cu130,CUDA 可用(GPU 运行)
  • 操作系统 :Windows / PowerShell(运行时设 PYTHONIOENCODING=utf-8python -u 实时输出)
  • 随机种子 :每个变体训练前 torch.manual_seed(42)

4. 方法

4.1 顺序头接口

所有头实现统一接口,嵌入 DSpark 的 Draft 框架:

python 复制代码
def forward_block(self, base_logits, prev_token_ids, hidden_states) -> logits
# base_logits: (B, T, V)    目标模型 lm_head 输出
# prev_token_ids: (B, T)    teacher forcing 的上一 token(位置 k 用 tokens[k-1])
# hidden_states:  (B, T, H) 目标模型隐状态(含 key 的弱信号 + 强噪声)
# 返回: base_logits + bias

4.2 FrsmHead 适配器设计

HybridFRSM 原为序列处理器(输入 (B,T,D) 特征 → 输出 (B,T,D)),通过 FrsmHead 适配为顺序头:

  • 输入融合:in_proj(hidden_states) + token_emb(prev_token_ids)(与 GRU/RWKV 同样利用上一 token 信息);
  • 内部:FRSM 快尺度 parallel scan(全局线性递推,O(log T) 训练)+ 慢尺度内容门控(每 K=1 步更新一次,选择性地将候选值写入长期状态);
  • 输出:w_out(feat) → vocab 维 bias,叠加到 base_logits

4.3 等参数量对齐策略(关键)

GRU/RWKV 的 embedding 与 w_out 均走 rank=32 维(参数 VOCAB*rank*2 ≈ 16K)。

若 FRSM 直接用 DHID=64 做 embedding,参数量会翻倍,对比不公平。

解决:FrsmHead 引入独立 frsm_dim 控制 FRSM 内部维度与 embedding 维度,并加 in_proj: DHID→frsm_dim

配置 frsm_dim 参数量 对标
GRU rank=32 28,768 ---
RWKV-7 (shift+LN) rank=32 41,632 ---
HybridFRSM d32 32 43,153 ≈ RWKV(差 1.5K,等参)
HybridFRSM d64 64 134,433 上限参照

frsm_dim=32 时 FRSM 与 RWKV-7 参数量几乎相同(43.2K vs 41.6K),构成严格等参数量对比

4.4 任务与超参

  • VOCAB=256, DHID=64, RANK=32, BLOCK=8N_TRAIN=N_EVAL=2000 步,LR=3e-3batch=256
  • 单 keyblock[k] = (key + k) % V(只需记忆 1 个 key)
  • 双 keyblock[0]=a, block[1]=b, block[k]=(a+b+k) % V(需同时记忆 2 个 key,更难)
  • 主干 hidden 含强噪声(NOISE=0.8 单key / 1.2 双key),lm_head 无法直接精确定位 key,必须靠顺序头的递归状态。

5. 实验结果

5.1 单 key 任务(简单,全模型收敛)

参数量 平均接受率
None(baseline) 0 0.0382
GRU(DSpark) 28,768 1.0000
RWKV-7(shift+LN) 41,632 1.0000
HybridFRSM(d32) 43,153 1.0000
HybridFRSM(d64) 134,433 1.0000

任务过简单,所有递归头均达 100%,无区分度。

5.2 双 key 任务(核心区分)

参数量 平均 位置0 位置1 位置2 位置3 位置4 位置5 位置6 位置7
None 0 0.0395 0.117 0.113 0.015 0.013 0.013 0.015 0.013 0.015
GRU 28,768 0.8349 0.980 0.449 0.262 0.989 1.000 1.000 1.000 1.000
RWKV-7 41,632 0.8373 0.946 0.315 0.477 0.966 1.000 0.998 0.999 0.999
HybridFRSM(d32) 43,153 0.9709 0.998 0.855 0.917 0.999 1.000 1.000 1.000 1.000
HybridFRSM(d64) 134,433 1.0000 1.000 1.000 1.000 1.000 1.000 1.000 1.000 1.000

5.3 收敛动态(双 key,每 200 步)

步数 GRU RWKV-7 FRSM d32 FRSM d64
200 0.662 0.075 0.764 0.832
600 0.759 0.678 0.847 0.998
1000 0.784 0.747 0.896 1.000
2000 0.838 0.838 0.974(仍上升) 1.000

6. 关键发现

6.1 等参数量下 FRSM 仍显著领先

HybridFRSM(d32, 43.2K) ≈ RWKV-7(41.6K),参数量几乎相同,但双 key 平均接受率 0.971 vs 0.837(+13.4%)

这排除了"FRSM 靠参数多取胜"的疑问,证明优势来自架构本身。

6.2 记忆瓶颈位置的结构性碾压

需同时持有两个 key 的位置 1 / 位置 2 是真正考验顺序头记忆能力的瓶颈:

位置 GRU RWKV-7 FRSM d32
位置1 0.449 0.315 0.855
位置2 0.262 0.477 0.917

FRSM 在这两个位置大幅领先(+0.4~0.6 绝对值)。这直接体现了慢尺度内容门控 的选择性长期记忆能力------

GRU 的标量门控和 RWKV-7 的 Delta Rule 矩阵状态在面对"多 key 并行保持"时均出现明显遗忘,而 FRSM 的

内容相关门控 MLP 能主动学习"何时写入",把第二个 key 稳定压入长期状态。

6.3 收敛更快更稳

  • FRSM d32 在 200 步即达 0.764(RWKV 同期仅 0.075,几乎未启动);
  • FRSM d64 在 800 步即收敛到 100%;
  • FRSM d32 在 2000 步末仍在上升(0.974),给足训练步数预计可逼近 1.0------潜力尚未榨干

6.4 快慢分离的收益确认

FRSM 的设计哲学是"快尺度负责即时预测(无门控开销),慢尺度负责选择性记忆(只需 1 个)"。

本实验中单慢尺度(num_slow=1)即足够在双 key 上击败 GRU/RWKV,验证了 frsm_linear.py 注释里

"快慢分离比纯门控快 4.7×、参数少 24%"的设计动机在顺序头场景同样成立。

7. 结论

  1. HybridFRSM 作为 DSpark 顺序头架构可行且领先:在等参数量(≈43K)下,双 key 接受率较 RWKV-7 高 13.4 个百分点。
  2. 优势来源于快慢尺度分离的架构设计,而非参数量:慢尺度的内容门控是处理"多 key 并行记忆"的关键机制。
  3. 收敛速度与稳定性均优于 GRU/RWKV-7,在 200 步内即可建立显著领先。
  4. 单 key 任务对所有递归头均过简单,建议后续在更长 block(BLOCK=16/32)或多 key(3+)任务上进一步拉开差距。

8. 复现方式

powershell 复制代码
$env:PYTHONIOENCODING="utf-8"
& "F:\OpenASH\.venv\Scripts\python.exe" -u "F:\dspark-rwkv\stage1_frsm_compare.py"

依赖:F:\OpenASH2605\frsm_linear.py(脚本通过 sys.path 自动导入,无需复制)。


附录:完整测试代码 stage1_frsm_compare.py

python 复制代码
"""DSpark-RWKV Stage1 · HybridFRSM vs GRU / RWKV-7 顺序头对比

把 frsm_linear.py 的 HybridFRSM(快线性递推 + 慢内容门控)适配为 DSpark 顺序头,
在 stage1 合成任务(单 key / 双 key)上与 NoHead / GRU / RWKV-7 对比接受率与参数量。

顺序头接口: forward_block(base_logits, prev_token_ids, hidden_states) -> logits
"""
import os
import sys

import torch
import torch.nn as nn

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, r"F:\OpenASH2605")

from stage1_experiment import (
    DEVICE, DHID, RANK, VOCAB, BLOCK,
    NoHead, GruHead, Rwkv7HeadV2, run,
    make_data_single, make_data_double,
)
from frsm_linear import HybridFRSM


class FrsmHead(nn.Module):
    """HybridFRSM 顺序头: 用 FRSM 状态机处理 hidden_states 序列, 输出 vocab 维 bias。

    输入融合: in_proj(hidden_states) + prev_token_emb (与 GRU/RWKV 同样利用上一 token)。
    FRSM 内部: 快尺度 parallel scan (全局线性递推) + 慢尺度内容门控 (每 K 步更新)。

    参数量对齐说明: GRU/RWKV 的 embedding 与 w_out 均走 rank 维 (VOCAB*rank*2)。
    本头用独立的 frsm_dim 控制 FRSM 内部维度与 embedding 维度,
    frsm_dim=rank 时参数量与 RWKV 同级, 实现等参数量公平对比。
    """
    def __init__(self, vocab_size, rank, hidden_size,
                 num_fast=3, num_slow=1, slow_update_freq=1, frsm_dim=None):
        super().__init__()
        frsm_dim = rank if frsm_dim is None else frsm_dim
        self.token_emb = nn.Embedding(vocab_size, frsm_dim)
        self.in_proj = nn.Linear(hidden_size, frsm_dim)
        self.frsm = HybridFRSM(
            d_model=frsm_dim,
            num_fast=num_fast,
            num_slow=num_slow,
            slow_update_freq=slow_update_freq,
        )
        self.w_out = nn.Linear(frsm_dim, vocab_size, bias=False)

    def forward_block(self, base_logits, prev_token_ids, hidden_states):
        feat = self.in_proj(hidden_states) + self.token_emb(prev_token_ids)
        out = self.frsm(feat)
        bias = self.w_out(out)
        return base_logits + bias


def run_task(task_name, make_data_fn):
    print(f"\n{'=' * 60}")
    print(f"任务: {task_name}")
    print(f"{'=' * 60}")
    train_data = make_data_fn(4096)
    variants = [
        ("None(baseline)",        lambda: NoHead(VOCAB, RANK, DHID)),
        ("GRU(DSpark)",           lambda: GruHead(VOCAB, RANK, DHID)),
        ("RWKV-7(shift+LN)",      lambda: Rwkv7HeadV2(VOCAB, RANK, DHID, use_shift=True, use_layernorm=True)),
        ("HybridFRSM(d32,~43K)",  lambda: FrsmHead(VOCAB, RANK, DHID, num_fast=3, num_slow=1, slow_update_freq=1, frsm_dim=32)),
        ("HybridFRSM(d64,~130K)", lambda: FrsmHead(VOCAB, RANK, DHID, num_fast=3, num_slow=1, slow_update_freq=1, frsm_dim=64)),
    ]
    results = {}
    for name, factory in variants:
        torch.manual_seed(42)
        pos_rate, avg = run(name, factory, train_data)
        n_params = sum(p.numel() for p in factory().parameters())
        results[name] = (pos_rate, avg, n_params)

    print(f"\n--- {task_name} 汇总 ---")
    print(f"{'头':<24} {'平均':<8} {'参数':<10} 各位置")
    for name, (pos_rate, avg, n_params) in results.items():
        print(f"{name:<24} {avg:.4f}   {n_params:>8,}  {[f'{r:.3f}' for r in pos_rate]}")
    return results


def main():
    print(f"VOCAB={VOCAB} DHID={DHID} RANK={RANK} BLOCK={BLOCK}  device={DEVICE}")
    print("对比: None / GRU / RWKV-7 / HybridFRSM(d32,等参) / HybridFRSM(d64,上限)")
    run_task("单 key: block[k]=(key+k)%V", lambda n: make_data_single(n))
    run_task("双 key: block[0]=a,block[1]=b,block[k]=(a+b+k)%V", lambda n: make_data_double(n))


if __name__ == "__main__":
    main()