硬核深度解析:KimiDeltaAttention 源码逐行精读+公式推导+复杂度优化(完整可运行)
本文深度拆解 Kimi 大模型核心 KimiDeltaAttention(KDA) 源码,逐行解读、公式推导、设计思想、复杂度优化一网打尽,附带完整可直接使用代码,适合深度学习/大模型/注意力机制研究者学习。
前言
传统 Transformer 自注意力存在致命缺陷:时间/空间复杂度 O(N2)O(N^2)O(N2),长文本场景下显存爆炸、速度极慢。
而 KimiDeltaAttention(KDA) 是革命性的线性复杂度注意力 ,将复杂度降至 O(N)O(N)O(N),同时融合短卷积局部建模+状态空间全局记忆+门控机制,在长文本理解、逻辑推理、推理速度上全面超越传统注意力。
本文将从 源码逐行精读 → 核心数学公式推导 → 设计思想 → 复杂度与逻辑提升 → 完整代码 五个维度彻底吃透 KDA。
一、整体架构总览
KimiDeltaAttention 核心设计 = 短卷积(局部特征)+ 线性注意力(全局依赖)+ 状态空间模型(长程记忆)+ 多门控(信息筛选)
- 训练模式:
chunk分块并行计算 - 推理模式:
fused_recurrent循环流式计算 - 复杂度:O(N⋅H⋅d2)O(N \cdot H \cdot d^2)O(N⋅H⋅d2),NNN 为序列长度,HHH 为头数,ddd 为头维度
- 优势:无限长文本、极低显存、极高推理速度、强逻辑表达
二、依赖与模块导入(逐行解读)
python
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# 版权声明,代码归属作者
# 启用注解类型提示,提升代码可读性与类型检查
from __future__ import annotations
# 基础数学库,用于指数、对数、数值计算
import math
# 类型注解支持,标记仅用于类型检查的导入
from typing import TYPE_CHECKING
# PyTorch 核心库,深度学习张量计算与模块定义
import torch
import torch.nn as nn
# 张量维度重排神器,简化 Transformer 维度变换代码
from einops import rearrange, repeat
# PyTorch 函数式 API
from torch.nn import functional as F
# FLA 框架工具函数:缓存管理、序列填充/解填充、张量索引
from fla.layers.utils import get_layer_cache, get_unpad_data, index_first_axis, pad_input, update_layer_cache
# 自定义融合层:门控归一化、短卷积
from fla.modules import FusedRMSNormGated, ShortConvolution
# KDA 核心算子:分块计算、融合循环计算
from fla.ops.kda import chunk_kda, fused_recurrent_kda
# KDA 门控融合算子
from fla.ops.kda.gate import fused_kda_gate
# 类型检查阶段导入,运行时不加载,避免循环依赖
if TYPE_CHECKING:
from transformers.processing_utils import Unpack
from fla.models.utils import Cache
核心作用 :导入基础依赖 + FLA 加速算子,einops 简化维度变换,自定义 CUDA 算子是 KDA 速度的核心保障。
三、类定义与初始化参数
python
class KimiDeltaAttention(nn.Module):
"""
Kimi Delta Attention (KDA) 层实现
线性复杂度注意力,融合短卷积+状态空间模型,替代传统自注意力
Args:
hidden_size: 输入特征维度,默认 2048
expand_v: Value 维度扩张系数,默认 1.0
head_dim: 单注意力头维度,默认 128
num_heads: 注意力头数量,默认 16
num_v_heads: Value 头数,默认等于 num_heads,GVA 分组注意力使用
mode: 计算模式,chunk(训练)/fused_recurrent(推理)
use_short_conv: 是否使用短卷积,默认 True
allow_neg_eigval: 是否允许负特征值,提升长程记忆能力
conv_size: 短卷积核大小,默认 4
conv_bias: 卷积是否使用偏置
layer_idx: 模型层索引
norm_eps: 归一化层极小值,防止除 0
"""
设计思想:标准化注意力层接口,兼容 Transformer 生态,参数可配置,适配不同模型规模。
四、init 初始化函数(逐行精读)
这是 KDA 的参数定义与网络搭建核心,所有可学习参数、网络层都在这里初始化。
python
def __init__(
self,
hidden_size: int = 2048,
expand_v: float = 1,
head_dim: int = 128,
num_heads: int = 16,
num_v_heads: int = None,
mode: str = "chunk",
use_short_conv: bool = True,
allow_neg_eigval: bool = False,
conv_size: int = 4,
conv_bias: bool = False,
layer_idx: int = None,
norm_eps: float = 1e-5,
**kwargs,
) -> KimiDeltaAttention:
# 继承 nn.Module 初始化
super().__init__()
# ===================== 基础配置 =====================
# 计算模式(训练/推理)
self.mode = mode
# 是否启用负特征值(提升状态追踪能力)
self.allow_neg_eigval = allow_neg_eigval
# 输入特征维度
self.hidden_size = hidden_size
# Value 维度扩张系数
self.expand_v = expand_v
# ===================== 短卷积配置 =====================
# 是否使用短卷积提取局部特征
self.use_short_conv = use_short_conv
# 卷积核大小
self.conv_size = conv_size
# 卷积偏置
self.conv_bias = conv_bias
# ===================== 注意力头维度配置 =====================
# 单头维度
self.head_dim = head_dim
# Query/Key 头数
self.num_heads = num_heads
# Value 头数(GVA 分组注意力,默认等于 Q/K 头数)
self.num_v_heads = num_v_heads if num_v_heads is not None else num_heads
# 计算 Key/Value 最终维度
self.head_k_dim = head_dim
self.head_v_dim = int(self.head_dim * self.expand_v)
self.key_dim = int(self.num_heads * self.head_k_dim)
self.value_dim = int(self.num_v_heads * self.head_v_dim)
# 层索引
self.layer_idx = layer_idx
# ===================== 参数合法性校验 =====================
# 校验 Value 维度为整数,避免线性层报错
if not math.isclose(self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5):
raise ValueError(f"expand_v={expand_v} 计算得到非法 Value 维度")
# GVA 校验:Value 头数必须是 Q/K 头数的整数倍
if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
raise ValueError(f"num_v_heads 必须能被 num_heads 整除")
# 校验单头 Value 维度合法性
if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5):
raise ValueError(f"expand_v={expand_v} 计算得到非法头维度")
# 仅支持两种计算模式
assert mode in ["chunk", "fused_recurrent"], f"不支持模式 {mode}"
# ===================== Q/K/V 线性投影层 =====================
# Query 投影:无偏置,输入→Key 总维度
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
# Key 投影
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
# Value 投影:输入→Value 总维度
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
# ===================== 短卷积层(局部特征提取) =====================
if use_short_conv:
# Query 分支短卷积:SiLU 激活,提取局部时序特征
self.q_conv1d = ShortConvolution(
hidden_size=self.key_dim,
kernel_size=conv_size,
bias=conv_bias,
activation="silu",
)
# Key 分支短卷积
self.k_conv1d = ShortConvolution(
hidden_size=self.key_dim,
kernel_size=conv_size,
bias=conv_bias,
activation="silu",
)
# Value 分支短卷积
self.v_conv1d = ShortConvolution(
hidden_size=self.value_dim,
kernel_size=conv_size,
bias=conv_bias,
activation="silu",
)
# ===================== 门控投影层 =====================
# F 门控:控制状态更新
self.f_proj = nn.Sequential(
nn.Linear(hidden_size, self.head_v_dim, bias=False),
nn.Linear(self.head_v_dim, self.key_dim, bias=False),
)
# Beta 投影:状态缩放系数,Sigmoid 约束在 0~1
self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
# ===================== 状态动力学可学习参数 =====================
# 衰减系数 A:控制历史信息遗忘速度,无权重衰减
self.A_log = nn.Parameter(torch.log(torch.empty(self.num_heads, dtype=torch.float32).uniform_(1, 16)))
self.A_log._no_weight_decay = True
# 时间步长偏置 dt:控制状态更新速率
dt = torch.exp(
torch.rand(self.key_dim, dtype=torch.float32) * (math.log(0.1) - math.log(0.001)) + math.log(0.001)
).clamp(min=1e-4)
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
self.dt_bias._no_weight_decay = True
# ===================== 输出门控与归一化 =====================
# 输出门控投影
self.g_proj = nn.Sequential(
nn.Linear(hidden_size, self.head_v_dim, bias=False),
nn.Linear(self.head_v_dim, self.value_dim, bias=True),
)
# 融合门控 RMS 归一化:数值稳定,训练加速
self.o_norm = FusedRMSNormGated(self.head_v_dim, activation="sigmoid", eps=norm_eps)
# 最终输出投影:还原回输入维度
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
初始化核心设计点
- GVA 分组注意力:减少 Value 参数量,降低计算量,不损失效果
- 短卷积 :先提取局部语法/短语特征,再做全局注意力
- 动力学参数 :
A_log(衰减)+dt_bias(步长),模拟 RNN 长程记忆 - 无权重衰减:保护动力学参数,不被正则化破坏
- 多门控机制:动态筛选信息,过滤噪声,提升逻辑表达
五、forward 前向传播(逐行精读+公式)
前向传播是 KDA 计算的核心流程,包含:卷积→投影→门控→状态更新→输出归一化。
python
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
past_key_values: Cache | None = None,
use_cache: bool | None = False,
output_attentions: bool | None = False,
**kwargs: Unpack[dict],
) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
# 校验注意力掩码:仅支持 2D padding 掩码,不支持全注意力矩阵
if attention_mask is not None:
assert len(attention_mask.shape) == 2, (
"attention_mask 必须是 [batch_size, seq_len] 的 0-1 矩阵"
)
# 获取批次大小与序列长度
batch_size, q_len, _ = hidden_states.shape
# ===================== 模式自动切换 =====================
# 推理+短序列:使用 fused_recurrent 极致加速
# 训练:必须使用 chunk 分块并行
mode = "fused_recurrent" if (q_len <= 64 and not self.training) else self.mode
if self.training:
assert mode == "chunk", "训练仅支持 chunk 模式"
# ===================== 加载历史缓存 =====================
# 获取上一时刻的卷积缓存+状态缓存(推理加速)
last_state = get_layer_cache(self, past_key_values)
# ===================== 处理 padding 序列 =====================
cu_seqlens = kwargs.get("cu_seqlens")
if attention_mask is not None:
# 解填充:去掉 padding token,减少无效计算
indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0)
# ===================== Q/K/V 投影 + 短卷积 =====================
if self.use_short_conv:
# 初始化卷积缓存
conv_state_q, conv_state_k, conv_state_v = None, None, None
if last_state is not None:
conv_state_q, conv_state_k, conv_state_v = last_state["conv_state"]
# Q 分支:线性投影 → 短卷积 → SiLU 激活
q, conv_state_q = self.q_conv1d(
x=self.q_proj(hidden_states),
cache=conv_state_q,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
)
# K 分支
k, conv_state_k = self.k_conv1d(
x=self.k_proj(hidden_states),
cache=conv_state_k,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
)
# V 分支
v, conv_state_v = self.v_conv1d(
x=self.v_proj(hidden_states),
cache=conv_state_v,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
)
else:
# 无卷积:直接投影+SiLU 激活
q = F.silu(self.q_proj(hidden_states))
k = F.silu(self.k_proj(hidden_states))
v = F.silu(self.v_proj(hidden_states))
# ===================== 门控计算 =====================
# F 门控:状态更新门
g = self.f_proj(hidden_states)
# Beta 缩放:Sigmoid 约束到 0~1
beta = self.b_proj(hidden_states).sigmoid()
# ===================== 维度重排 =====================
# [B, L, D] → [B, L, H, d_h],适配多头计算
q, k, g = (rearrange(x, "... (h d) -> ... h d", d=self.head_k_dim) for x in (q, k, g))
v = rearrange(v, "... (h d) -> ... h d", d=self.head_v_dim)
# ===================== GVA 分组适配 =====================
if self.num_v_heads > self.num_heads:
# 重复 Q/K 适配多 Value 头
q, k, g = (repeat(x, "... h d -> ... (h g) d", g=self.num_v_heads // self.num_heads) for x in (q, k, g))
beta = repeat(beta, "... h -> ... (h g)", g=self.num_v_heads // self.num_heads)
# 负特征值:Beta 系数翻倍
if self.allow_neg_eigval:
beta = beta * 2.0
# ===================== 加载状态缓存 =====================
recurrent_state = last_state["recurrent_state"] if last_state is not None else None
# ===================== KDA 核心计算(公式落地) =====================
if mode == "chunk":
# 分块模式:训练使用,并行计算,线性复杂度
o, recurrent_state = chunk_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
A_log=self.A_log,
dt_bias=self.dt_bias,
initial_state=recurrent_state,
output_final_state=use_cache,
use_qk_l2norm_in_kernel=True,
use_gate_in_kernel=True,
cu_seqlens=cu_seqlens,
)
elif mode == "fused_recurrent":
# 融合循环模式:推理使用,流式计算,极低延迟
g = fused_kda_gate(g=g, A_log=self.A_log, dt_bias=self.dt_bias)
o, recurrent_state = fused_recurrent_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=recurrent_state,
output_final_state=use_cache,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cu_seqlens,
)
else:
raise NotImplementedError(f"不支持模式 {mode}")
# ===================== 更新缓存 =====================
update_layer_cache(
self,
past_key_values,
recurrent_state=recurrent_state,
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
offset=q_len,
)
# ===================== 输出归一化 + 投影 =====================
# 门控归一化:稳定数值,提升表达能力
o = self.o_norm(o, rearrange(self.g_proj(hidden_states), "... (h d) -> ... h d", d=self.head_v_dim))
# 维度合并:多头拼接
o = rearrange(o, "b t h d -> b t (h d)")
# 最终投影:还原输入维度
o = self.o_proj(o)
# ===================== 还原 padding =====================
if attention_mask is not None:
o = pad_input(o.squeeze(0), indices, batch_size, q_len)
# 返回:输出特征、注意力权重(None)、缓存
return o, None, past_key_values
六、KimiDeltaAttention 核心数学公式
KDA 是线性注意力 + 状态空间模型(SSM) 的融合,彻底抛弃 QKTQK^TQKT 矩阵乘法。
1. 状态更新公式(长程记忆核心)
St=St−1⋅A+βt⋅kt⊗gt \boldsymbol{S}t = \boldsymbol{S}{t-1} \cdot \boldsymbol{A} + \beta_t \cdot \boldsymbol{k}_t \otimes \boldsymbol{g}_t St=St−1⋅A+βt⋅kt⊗gt
- St\boldsymbol{S}_tSt:第 ttt 步隐状态,存储所有历史信息
- A\boldsymbol{A}A:可学习衰减系数,控制历史信息遗忘速度
- βt\beta_tβt:缩放门控,筛选当前信息权重
- kt\boldsymbol{k}_tkt:Key 向量
- gt\boldsymbol{g}_tgt:状态更新门控
2. 输出计算公式(全局注意力)
Ot=qt⋅St⋅vt \boldsymbol{O}_t = \boldsymbol{q}_t \cdot \boldsymbol{S}_t \cdot \boldsymbol{v}_t Ot=qt⋅St⋅vt
- 无 Softmax,无 O(N2)O(N^2)O(N2) 计算
- 直接用当前查询 × 历史状态 × 值得到输出
3. 最终输出门控
Ot=Norm(Ot,Gate(ht)) \boldsymbol{O}_t = \text{Norm}(\boldsymbol{O}_t, \text{Gate}(\boldsymbol{h}_t)) Ot=Norm(Ot,Gate(ht))
七、设计思想:如何提升逻辑+降低复杂度?
1. 降低计算复杂度(革命性优化)
| 优化点 | 效果 |
|---|---|
| 抛弃 O(N2)O(N^2)O(N2) 矩阵乘法 | 复杂度从 O(N2)O(N^2)O(N2) → O(N)O(N)O(N) |
| 分块并行训练 | 训练速度提升 10~100 倍 |
| GVA 分组注意力 | 参数量↓,计算量↓ |
| CUDA 算子融合 | 显存读写减少,速度翻倍 |
| 状态缓存替代 KV Cache | 显存占用↓ 90% |
2. 提升模型逻辑表达能力
- 短卷积 :先提取局部语法/短语依赖,再建模全局关系
- 可学习状态衰减:模型自动学习保留/遗忘历史信息,逻辑更精准
- 多层门控:动态过滤噪声,保留关键逻辑信息
- 负特征值 :极大提升长程依赖追踪能力
- 门控归一化:深层模型不梯度消失,逻辑表达更深
八、完整可运行代码
python
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from __future__ import annotations
import math
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from einops import rearrange, repeat
from torch.nn import functional as F
from fla.layers.utils import get_layer_cache, get_unpad_data, index_first_axis, pad_input, update_layer_cache
from fla.modules import FusedRMSNormGated, ShortConvolution
from fla.ops.kda import chunk_kda, fused_recurrent_kda
from fla.ops.kda.gate import fused_kda_gate
if TYPE_CHECKING:
from transformers.processing_utils import Unpack
from fla.models.utils import Cache
class KimiDeltaAttention(nn.Module):
"""
Kimi Delta Attention (KDA) layer implementation.
Args:
hidden_size (int, Optional):
The hidden size of the input. Default: 2048.
expand_v (float, Optional):
The expansion ratio for the value dimension. Default: 1.0.
head_dim (int, Optional):
The dimension of each head. Default: 128.
num_heads (int, Optional):
The number of heads. Default: 16.
num_v_heads (int, Optional):
The number of heads for the value projection, equal to `num_heads` if `None`.
GVA (Grouped Value Attention) is applied if `num_v_heads` > `num_heads`. Default: `None`.
mode (str, Optional):
Which Kimi Delta Attention kernel to use.
Currently available: `chunk` and `fused_recurrent`.
Default: `chunk`.
use_short_conv (bool, Optional):
Whether to use short convolutions. Default: `True`.
allow_neg_eigval (bool, Optional):
Allow negative eigenvalues. Default: `False`. If set to `True`, the beta will be multiplied by 2.
See reference:
[Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues](https://arxiv.org/abs/2411.12537)
conv_size (int, Optional):
The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
conv_bias (bool, Optional):
Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
layer_idx (int, Optional):
The index of the layer. Default: None.
norm_eps (float, Optional):
The epsilon value for the normalization layer. Default: 1e-5.
"""
def __init__(
self,
hidden_size: int = 2048,
expand_v: float = 1,
head_dim: int = 128,
num_heads: int = 16,
num_v_heads: int = None,
mode: str = "chunk",
use_short_conv: bool = True,
allow_neg_eigval: bool = False,
conv_size: int = 4,
conv_bias: bool = False,
layer_idx: int = None,
norm_eps: float = 1e-5,
**kwargs,
) -> KimiDeltaAttention:
super().__init__()
self.mode = mode
self.allow_neg_eigval = allow_neg_eigval
self.hidden_size = hidden_size
self.expand_v = expand_v
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.head_dim = head_dim
self.num_heads = num_heads
self.num_v_heads = num_v_heads if num_v_heads is not None else num_heads
self.head_k_dim = head_dim
self.head_v_dim = int(self.head_dim * self.expand_v)
self.key_dim = int(self.num_heads * self.head_k_dim)
self.value_dim = int(self.num_v_heads * self.head_v_dim)
self.layer_idx = layer_idx
if not math.isclose(self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5):
raise ValueError(
f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
f"Resulting value_dim would be {self.num_v_heads * self.head_dim * expand_v}, which is invalid for nn.Linear.",
)
if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
raise ValueError(
f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.",
)
if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5):
raise ValueError(
f"expand_v={expand_v} does not produce an integer value when multiplied by head_dim={head_dim}. "
f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated.",
)
assert mode in ["chunk", "fused_recurrent"], f"Not supported mode `{mode}`."
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
if use_short_conv:
self.q_conv1d = ShortConvolution(
hidden_size=self.key_dim,
kernel_size=conv_size,
bias=conv_bias,
activation="silu",
)
self.k_conv1d = ShortConvolution(
hidden_size=self.key_dim,
kernel_size=conv_size,
bias=conv_bias,
activation="silu",
)
self.v_conv1d = ShortConvolution(
hidden_size=self.value_dim,
kernel_size=conv_size,
bias=conv_bias,
activation="silu",
)
self.f_proj = nn.Sequential(
nn.Linear(hidden_size, self.head_v_dim, bias=False),
nn.Linear(self.head_v_dim, self.key_dim, bias=False),
)
self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
self.A_log = nn.Parameter(torch.log(torch.empty(self.num_heads, dtype=torch.float32).uniform_(1, 16)))
self.A_log._no_weight_decay = True
dt = torch.exp(
torch.rand(self.key_dim, dtype=torch.float32) * (math.log(0.1) - math.log(0.001)) + math.log(0.001)
).clamp(min=1e-4)
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
self.dt_bias._no_weight_decay = True
self.g_proj = nn.Sequential(
nn.Linear(hidden_size, self.head_v_dim, bias=False),
nn.Linear(self.head_v_dim, self.value_dim, bias=True),
)
self.o_norm = FusedRMSNormGated(self.head_v_dim, activation="sigmoid", eps=norm_eps)
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
past_key_values: Cache | None = None,
use_cache: bool | None = False,
output_attentions: bool | None = False,
**kwargs: Unpack[dict],
) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
if attention_mask is not None:
assert len(attention_mask.shape) == 2, (
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
"for padding purposes (0 indicating padding). "
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
)
batch_size, q_len, _ = hidden_states.shape
mode = "fused_recurrent" if (q_len <= 64 and not self.training) else self.mode
if self.training:
assert mode == "chunk", "Only chunk mode is supported in training."
last_state = get_layer_cache(self, past_key_values)
cu_seqlens = kwargs.get("cu_seqlens")
if attention_mask is not None:
indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0)
if self.use_short_conv:
conv_state_q, conv_state_k, conv_state_v = None, None, None
if last_state is not None:
conv_state_q, conv_state_k, conv_state_v = last_state["conv_state"]
q, conv_state_q = self.q_conv1d(
x=self.q_proj(hidden_states),
cache=conv_state_q,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
)
k, conv_state_k = self.k_conv1d(
x=self.k_proj(hidden_states),
cache=conv_state_k,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
)
v, conv_state_v = self.v_conv1d(
x=self.v_proj(hidden_states),
cache=conv_state_v,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
)
else:
q = F.silu(self.q_proj(hidden_states))
k = F.silu(self.k_proj(hidden_states))
v = F.silu(self.v_proj(hidden_states))
g = self.f_proj(hidden_states)
beta = self.b_proj(hidden_states).sigmoid()
q, k, g = (rearrange(x, "... (h d) -> ... h d", d=self.head_k_dim) for x in (q, k, g))
v = rearrange(v, "... (h d) -> ... h d", d=self.head_v_dim)
if self.num_v_heads > self.num_heads:
q, k, g = (repeat(x, "... h d -> ... (h g) d", g=self.num_v_heads // self.num_heads) for x in (q, k, g))
beta = repeat(beta, "... h -> ... (h g)", g=self.num_v_heads // self.num_heads)
if self.allow_neg_eigval:
beta = beta * 2.0
recurrent_state = last_state["recurrent_state"] if last_state is not None else None
if mode == "chunk":
o, recurrent_state = chunk_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
A_log=self.A_log,
dt_bias=self.dt_bias,
initial_state=recurrent_state,
output_final_state=use_cache,
use_qk_l2norm_in_kernel=True,
use_gate_in_kernel=True,
cu_seqlens=cu_seqlens,
)
elif mode == "fused_recurrent":
g = fused_kda_gate(g=g, A_log=self.A_log, dt_bias=self.dt_bias)
o, recurrent_state = fused_recurrent_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=recurrent_state,
output_final_state=use_cache,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cu_seqlens,
)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
update_layer_cache(
self,
past_key_values,
recurrent_state=recurrent_state,
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
offset=q_len,
)
o = self.o_norm(o, rearrange(self.g_proj(hidden_states), "... (h d) -> ... h d", d=self.head_v_dim))
o = rearrange(o, "b t h d -> b t (h d)")
o = self.o_proj(o)
if attention_mask is not None:
o = pad_input(o.squeeze(0), indices, batch_size, q_len)
return o, None, past_key_values
总结
- KDA 是下一代注意力标准 :O(N)O(N)O(N) 复杂度,长文本、速度、效果全面领先
- 核心设计:短卷积(局部)+ 线性注意力(全局)+ 状态空间(记忆)
- 工程优势:算子融合、缓存优化、自动模式切换,训练推理双高效
- 适用场景:大模型、长文本、低算力部署、实时推理