RWKV-7 架构理解

阅读之前你可以前往 RWKV wiki 了解一些关于 RWKV 的基本知识,不过他们的 wiki 似乎没有对模型架构的详细介绍,于是便有了这篇文章。

RWKV-7 的核心:动态状态演化机制

RWKV-V7 的动态状态演化机制可以通俗理解为 "在线学习上下文关系的动态记忆更新" 。它的核心思想是:通过实时计算和更新一个内部状态(state)来动态捕捉上下文中 key 和 value 的关联关系,并利用这个状态处理当前输入的 query(在 RWKV 中是 r)以生成输出

1. 状态(state)的本质

在 RWKV 中,state 是一个三维张量(B, H, N, N),其中:

  • B:批量大小(batch size)
  • H:注意力头数量(head count)
  • N:每个头的维度(head size)

state 的作用是 维护一个动态的"知识库",记录历史输入的 key(k)和 value(v)之间的关系。它类似于传统 RNN 的隐藏状态,但更复杂,因为它显式建模了 key-value 的交互。

2. 状态更新公式与代码实现

官方公式:

St=St−1⋅diag(wt)+(St−1at⊤bt+vt⊤kt)

在代码中(RWKV7_OP 函数),这个公式被拆解为:
state = state * w + state @ a @ b + v @ k

该公式对应 RWKV-V7 的 动态状态演化机制 ,其核心思想是通过 在线学习 维护一个低秩状态矩阵 S_t,以捕捉上下文中的 key-value 关系。公式分为三部分:

  1. 权重衰减项(遗忘旧信息)

    state * w

    • 作用:控制历史信息的遗忘速率。
    • 数学形式S_t = S_{t-1} ⋅ diag(w_t),其中 w_t 是负值(通过 exp(-exp(...)) 生成)。
    • 思想 :类似 RNN 的遗忘门,w 越小(更负),遗忘越多,确保模型专注于当前上下文。
  2. 学习率调整项(动态修正知识)

    state @ aa @ bb

    • 作用:根据当前输入调整已有知识的权重。
    • 数学形式S_t = S_{t-1} ⋅ a_t^⊤ b_t,其中 ab 是动态学习率参数(a = -k, b = k ⋅ η)。
    • 思想:模拟梯度下降中的学习率乘以梯度方向,使状态更新更适应当前输入的特征分布。
  3. 新信息注入项(更新知识库)

    vv @ kk

    • 作用 :将当前 token 的 key-value 对(k_t, v_t)外积加入状态。
    • 数学形式S_t += v_t^⊤ ⋅ k_t
    • 思想:直接注入新信息,确保模型能实时捕捉上下文中的新关联(如实体关系、语义依赖)

GPT 与 RNN 模式

在进一步了解之前需要知道的是,RWKV 分为两种模式,一种是 GPT 模式,一种是 RNN 模式,区别在于:

GPT 模式rwkv_v7_demo.py

  • 计算方式 :采用 Transformer 的 GPT 模式,一次性处理整个输入序列(类似标准 Transformer 的前向传播)。
  • 状态管理:无需显式维护隐藏状态(state),直接通过注意力机制处理上下文。
  • 适用场景 :适合 **预填充(prefill)**阶段(即处理初始提示词),但自回归生成时效率较低。
  • 性能特点
    • 预填充阶段可以并行计算所有 token 的表示。
    • 自回归生成时需重新计算所有 token 的表示(无状态缓存),导致速度较慢。

RNN 模式rwkv_v7_demo_rnn.py

  • 计算方式 :采用 RNN 模式,逐个 token 处理,显式维护隐藏状态(state)。
  • 状态管理 :通过 state 变量保存每个时间步的中间状态(如 att_kvffn_x_prev 等),避免重复计算。
  • 适用场景 :适合 自回归生成(逐 token 生成),效率更高。
  • 性能特点
    • 预填充阶段效率低(需逐 token 处理)。
    • 生成阶段只需计算当前 token 的表示,速度显著优于 GPT 模式。

此外,官方还提供了结合两种模式的 fast 模式: rwkv_v7_demo_fast.py

GPT 模式

Time mix(RWKV_Tmix_x070)

Time mix 是整个 RWKV-7 的核心,也是最复杂的部分。

RWKV_Tmix_x070 是 RWKV 模型中实现 Time Mixing 的核心模块,用于处理序列数据中的时序依赖关系。它是 RWKV-V7 架构中负责模拟注意力机制和动态状态更新的组件,类似于 Transformer 中的自注意力模块,但采用了更高效的线性复杂度设计。

复制代码
@MyFunction
def forward(self, x, v_first):
    # 获取输入张量的维度:批量大小 B、序列长度 T、特征维度 C
    B, T, C = x.size()
    
    # 获取注意力头数量 H 和每个头的维度 N(固定为 64)
    H = self.n_head
    N = self.head_size

    # 使用 time_shift 操作获取前一个时间步的信息,并与当前输入做差值
    xx = self.time_shift(x) - x  # 得到时间差分项,用于生成不同方向的输入

    # 根据时间差分项和可学习参数,生成不同方向的输入
    xr = x + xx * self.x_r   # receptance 输入方向
    xw = x + xx * self.x_w   # decay weight 输入方向
    xk = x + xx * self.x_k   # key 输入方向
    xv = x + xx * self.x_v   # value 输入方向
    xa = x + xx * self.x_a   # learning rate 输入方向
    xg = x + xx * self.x_g   # gate 输入方向

    # 通过线性层生成 r(receptance),类似于传统 attention 中的 query
    r = self.receptance(xr)

    # 构造 w(权重衰减项):控制 state 的遗忘速率
    # 使用 tanh 和 softplus 确保 w 值在 (-inf, -0.5) 范围内
    w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5

    # 通过线性层生成 key 向量
    k = self.key(xk)

    # 通过线性层生成 value 向量
    v = self.value(xv)

    # 如果是第一层,则保存当前 v 作为 v_first(初始 value)
    if self.layer_id == 0:
        v_first = v
    else:
        # 否则使用门控机制对 v 进行残差连接,保留一部分初始信息
        v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2)

    # 生成动态学习率 a,控制 state 更新的学习率(类似在线梯度下降中的学习率)
    a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2)

    # 生成门控参数 g,控制最终输出的激活强度
    g = torch.sigmoid(xg @ self.g1) @ self.g2

    # 对 key 进行缩放和 L2 归一化,增强数值稳定性
    kk = k * self.k_k
    kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C)

    # 调整 key 的尺度,结合学习率 a 和可学习参数 self.k_a
    k = k * (1 + (a - 1) * self.k_a)

    # 调用 RWKV7_OP 函数进行状态更新和 attention 输出计算
    x = RWKV7_OP(r, w, k, v, -kk, kk * a)

    # 使用 GroupNorm 对输出进行归一化
    x = self.ln_x(x.view(B * T, C)).view(B, T, C)

    # 添加一个额外的残差连接,融合 r, k, v 的乘积项
    x = x + ((r.view(B, T, H, -1) *
              k.view(B, T, H, -1) *
              self.r_k).sum(dim=-1, keepdim=True) *
             v.view(B, T, H, -1)).view(B, T, C)

    # 最终输出经过门控 g 控制,并通过线性层输出
    x = self.output(x * g)

    # 返回输出和 v_first,供下一层使用
    return x, v_first

模块基于以下两个核心机制:

一. 时间差分与方向调整

首先,Time Shift 通过 time_shift(x) - x 得到当前 token 与前一 token 的差异,然后结合可学习参数生成多个"方向"的输入(如 xr, xw, xk, xv, xa, xg),分别用于构建不同的注意力相关参数。

具体来说,Time Shift 的作用是将输入张量在时间维度上向前移动一个单位, 这样在处理当前时间步 xt 时,模型可以获取到到上一个时间步 xt−1 的信息。

举个例子:

复制代码
# 假设原始输入是:
x = [
 [x_0_0, x_0_1, ..., x_0_C],   # 时间步 t=0
 [x_1_0, x_1_1, ..., x_1_C],   # 时间步 t=1
 ...
 [x_T_0, x_T_1, ..., x_T_C]    # 时间步 t=T
]

# Time Shift 后得到的是:
x_shifted = [
 [0, 0, ..., 0],               # 第一个位置用 0 填充(表示没有前一步)
 [x_0_0, x_0_1, ..., x_0_C],   # 第二个位置是 t=0 的值(相当于延迟了一个时间步)
 ...
 [x_{T-1}_0, ..., x_{T-1}_C]   # 最后一个位置是 t=T-1 的值
]

接着利用 Time Shift 的结果 xx 我们就可以计算不同的注意力相关参数。

复制代码
xr = x + xx * self.x_r
xw = x + xx * self.x_w
xk = x + xx * self.x_k
xv = x + xx * self.x_v
xa = x + xx * self.x_a
xg = x + xx * self.x_g

这些新变量 xr, xw, xk, xv, xa, xg 分别用于生成注意力机制中的不同角色:

变量 对应角色 作用
xr receptance (r) 类似 query,在状态更新中起调节作用
xw decayweight (w) 控制遗忘速率
xk key (k) 用于构建上下文关联
xv value (v) 上下文信息载体
xa learning rate (a) 动态调整学习率
xg gate (g) 非线性激活门控

二.线性变换生成核心参数

复制代码
r = self.receptance(xr)
w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # soft-clamp to (-inf, -0.5)
k = self.key(xk)
v = self.value(xv)
  • receptance(xr) : 生成 r(类似传统注意力中的 query,但用于控制当前状态如何与新的输入进行交互,决定哪些信息会被"接收"并用于输出计算)。
  • w : 权重衰减参数(控制遗忘速率),类似于 RNN 中的遗忘门。
    • 公式w = -softplus(-(...)) - 0.5,确保 w 为负值。
    • 作用:模拟指数衰减,遗忘旧信息。
  • key(xk) : 生成 k(key 向量,决定新信息如何与当前 state 进行交互)。
  • value(xv) : 生成 v(value 向量,表示输入的"值",包含实际的内容信息)。

三. 残差连接机制

复制代码
if self.layer_id == 0:
    v_first = v # store the v of the first layer
else:
    v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual

一、什么是残差连接?

✅ 简单定义:

残差连接(Residual Connection) 是指将某一层的输入直接加到该层的输出上。

数学表达为:

output=input+F(input)

其中 F(input) 是网络中某个函数或模块对输入的变换。


🎯 二、为什么需要残差连接?

残差连接最早出现在 ResNet中,用于解决深度神经网络中的两个关键问题:

  1. 梯度消失/爆炸
  • 在深层网络中,反向传播时梯度可能会指数级衰减或爆炸。
  • 残差连接可以缓解这个问题,使得梯度更容易流动。
  1. 信息丢失 / 语义漂移
  • 随着层数加深,原始输入的信息可能被逐渐稀释。
  • 残差连接可以让模型在处理过程中始终保留原始输入的信息。
  1. 训练稳定性提升

实验表明,加入残差连接后,深层网络更容易优化和收敛。

在长序列建模中,随着序列变长,中间 token 可能会逐渐丢失最初的上下文信息,通过 v_first 引入初始 token 的 value,可以让后续层仍然保有最初的信息。(v_first - v) 是一个残差项,使得当前值 v 可以根据初始值进行调整,这种设计有助于缓解梯度消失问题,提升训练稳定性。

主要公式可以拆解为三个主要部分:
v = v +
(v_first - v) *
torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2)

1. v: 当前时间步的 value 向量

  • 是通过 self.value(xv) 得到的当前 token 的 value 表示。
  • 它表示模型对当前 token 的理解或编码。

2. v_first: 第一个时间步的 value

  • 来自第一个层的第一个 token 的 value(在 layer_id == 0 时被赋值)。
  • 目的是在整个序列处理过程中保留初始信息(类似 RNN 中的"初始状态"或 Transformer 中的 [CLS] token)。

3. (v_first - v): 初始值与当前值之间的差异

  • 这是一个残差项,表示当前 token 与初始 token 的语义差距。
  • 如果当前 token 与初始 token 差异较大,则这个值会更大,可能意味着需要更多的初始信息来修正当前值。

4. torch.sigmoid(...):门控函数

  • 将线性变换的结果映射到 [0, 1] 区间。
  • 控制 (v_first - v) 对当前 v 的影响程度。
  • 类似于 LSTM 中的遗忘门或 GRU 中的更新门。

5. self.v0 + (xv @ self.v1) @ self.v2:可学习的门控参数

  • xv: 经过 time_shift 处理后的输入,用于生成 value 的原始输入。
  • self.v1, self.v2: 两个低秩矩阵,构成 LoRA 结构(Low-Rank Adaptation),用于减少参数量。
  • 整个表达式是一个门控信号,根据当前输入动态决定是否引入初始信息。

四. 动态学习率参数 a 和门控参数 g

复制代码
a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate"
g = torch.sigmoid(xg @ self.g1) @ self.g2

a 是动态学习率参数,用于控制状态(state)更新的强度,使模型能够根据上下文自适应地调整记忆更新的幅度;

g 是门控参数,用于调节最终输出的激活程度,通过可学习的权重决定哪些信息应该被强调或抑制,从而增强模型的表达能力和非线性建模能力。两者共同提升了模型对长序列的建模效率和准确性。

五. 调用 RWKV7_OP 状态更新函数

复制代码
x = RWKV7_OP(r, w, k, v, -kk, kk*a)

六. 组合输出

复制代码
x = self.ln_x(x.view(B * T, C)).view(B, T, C)
x = x + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)
x = self.output(x * g)
  • self.ln_x : 对 x 进行分组归一化(GroupNorm),提升训练稳定性。
  • 残差连接 :将 r, k, v 的乘积加入输出。
    • 公式x += (r * k * r_k) · v,增强非线性表达。
  • 门控输出x = self.output(x * g),通过门控参数 g 调制输出。

通道混合 ChannelMix (RWKV_CMix_x070

复制代码
class RWKV_CMix_x070(MyModule):
    def __init__(self, args, layer_id):
        super().__init__()
        self.args = args
        self.layer_id = layer_id
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))  # 时间维度的位移操作
        with torch.no_grad():
            self.x_k = nn.Parameter(torch.empty(1, 1, args.n_embd))  # 可学习的权重参数
        self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)  # 第一阶段线性变换
        self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)  # 第二阶段线性变换

    @MyFunction
    def forward(self, x):
        xx = self.time_shift(x) - x  # 计算时间差分
        k = x + xx * self.x_k  # 调整输入特征
        k = torch.relu(self.key(k)) ** 2  # 非线性激活
        return self.value(k)  # 输出变换后的特征

关键组件解析

(1) 时间位移(Time Shift)

  • 使用 nn.ZeroPad2d((0, 0, 1, -1)) 对输入张量进行时间维度的位移(前移一位),生成 time_shift(x)
  • 计算差值 xx = time_shift(x) - x,捕捉序列中相邻位置的局部变化特征。

(2) 输入调整(Input Gating)

  • 引入可学习参数 self.x_k,通过 k = x + xx * self.x_k 调整输入特征,类似门控机制,控制历史信息的保留比例。

(3) 非线性变换

  • 第一阶段 :通过 key 线性层将输入从 n_embd 维度映射到更高维度 dim_ffn(通常为 4 * n_embd)。
  • 激活函数 :使用 ReLU 激活后平方(torch.relu(...) ** 2),增强非线性表达能力。
  • 第二阶段 :通过 value 线性层将特征映射回 n_embd 维度,形成最终输出。

与传统 FFN 的对比

特性 RWKV_CMix_x070 传统 Transformer FFN
激活函数 ReLU(x)^2 GELUSwish
线性层 两层(key+value 两层(linear + activation
时间依赖性 显式引入时间差分(time_shift
参数量 较低(依赖dim_ffn 较高(通常为
相关推荐
weixin_437497777 小时前
读书笔记:Context Engineering 2.0 (上)
人工智能·nlp
喝拿铁写前端7 小时前
前端开发者使用 AI 的能力层级——从表面使用到工程化能力的真正分水岭
前端·人工智能·程序员
goodfat7 小时前
Win11如何关闭自动更新 Win11暂停系统更新的设置方法【教程】
人工智能·禁止windows更新·win11优化工具
北京领雁科技8 小时前
领雁科技反洗钱案例白皮书暨人工智能在反洗钱系统中的深度应用
人工智能·科技·安全
落叶,听雪8 小时前
河南建站系统哪个好
大数据·人工智能·python
清月电子8 小时前
杰理AC109N系列AC1082 AC1074 AC1090 芯片停产替代及资料说明
人工智能·单片机·嵌入式硬件·物联网
Dev7z8 小时前
非线性MPC在自动驾驶路径跟踪与避障控制中的应用及Matlab实现
人工智能·matlab·自动驾驶
七月shi人8 小时前
AI浪潮下,前端路在何方
前端·人工智能·ai编程
橙汁味的风8 小时前
1隐马尔科夫模型HMM与条件随机场CRF
人工智能·深度学习·机器学习
itwangyang5209 小时前
AIDD-人工智能药物设计-AI 制药编码之战:预测癌症反应,选对方法是关键
人工智能