Transformer Attention 加速器

算法原理说明 --- Transformer Attention 加速器

本文档说明加速器实现的数学原理、数值处理策略、以及从算法到硬件流水线的映射关系。

面向读者:需要理解"为什么这么算""硬件为什么这么排"的验证/移植工程师。


1. Attention 数学定义

单头 Self-Attention 的标准公式:

Attention ( Q , K , V ) = softmax  ⁣ ( Q K T d k ) V \text{Attention}(Q,K,V) = \text{softmax}\!\left(\frac{QK^{T}}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

其中输入序列 X ∈ R R O W × d m o d e l X \in \mathbb{R}^{ROW \times d_{model}} X∈RROW×dmodel,经三个投影矩阵得到:

Q = X W Q , K = X W K , V = X W V Q = X W_Q,\quad K = X W_K,\quad V = X W_V Q=XWQ,K=XWK,V=XWV

本项目参数:ROW = 512(序列长度),DK = 64(每头维度)。

整条计算链路分五步:

步骤 运算 输出维度
① 投影 Q , K , V = X W Q / K / V Q,K,V = X W_{Q/K/V} Q,K,V=XWQ/K/V 512×64(×3)
② 打分 S = Q K T S = QK^{T} S=QKT 512×512
③ 缩放 S ′ = S ⋅ 1 d k S' = S \cdot \frac{1}{\sqrt{d_k}} S′=S⋅dk 1 512×512
④ 归一 P = softmax ( S ′ ) P = \text{softmax}(S') P=softmax(S′)(按行) 512×512
⑤ 加权 O = P V O = PV O=PV 512×64

2. 为什么用 FP16

  • 动态范围: ± 6.55 × 10 4 \pm 6.55\times10^4 ±6.55×104,指数 5 位覆盖 attention 分数常见范围。
  • 存储/带宽:相比 FP32 减半,512×512 的 Score 矩阵占用减半,直接影响 BRAM/URAM 用量。
  • DSP 友好:11×11 尾数乘法可映射进单个 DSP48,无需拆分。
  • 代价:尾数仅 10 位,累加误差会累积------这正是第 4、5 节要重点处理的地方。

3. 数值稳定 Softmax 原理

3.1 为什么不能直接算 exp

softmax 定义:

p i = e x i ∑ j e x j p_i = \frac{e^{x_i}}{\sum_j e^{x_j}} pi=∑jexjexi

若 x i x_i xi 较大(例如 x i > 11 x_i > 11 xi>11), e x i e^{x_i} exi 在 FP16 下直接溢出为 Inf。因此必须做平移不变变换

p i = e x i − x max ⁡ ∑ j e x j − x max ⁡ , x max ⁡ = max ⁡ j x j p_i = \frac{e^{x_i - x_{\max}}}{\sum_j e^{x_j - x_{\max}}},\quad x_{\max} = \max_j x_j pi=∑jexj−xmaxexi−xmax,xmax=jmaxxj

减去行最大值后,所有指数参数 ≤ 0 \le 0 ≤0, e ( ⋅ ) ∈ ( 0 , 1 ] e^{(\cdot)} \in (0, 1] e(⋅)∈(0,1],永不溢出。这是硬件 softmax 三阶段的第一阶段存在的根本原因。

3.2 三阶段时序

阶段 运算 硬件行为
阶段 1 求 x max ⁡ x_{\max} xmax 逐元素收数并比较,存入 xbuf
阶段 2 e i = exp ⁡ ( x i − x max ⁡ ) e_i = \exp(x_i - x_{\max}) ei=exp(xi−xmax) 并累加 ∑ e i \sum e_i ∑ei PWL LUT 查表 + fp16_add 累加
阶段 3 p i = e i ⋅ 1 ∑ e i p_i = e_i \cdot \frac{1}{\sum e_i} pi=ei⋅∑ei1 fp16_recip 求倒数 + 逐元素乘

3.3 exp 的分段线性(PWL)近似

硬件不做真正的 e x e^x ex,而用分段线性查找表逼近:

exp ⁡ ( x ) ≈ a k ⋅ x + b k , x ∈ [ seg k , seg k + 1 ) \exp(x) \approx a_k \cdot x + b_k,\quad x \in [\text{seg}k, \text{seg}{k+1}) exp(x)≈ak⋅x+bk,x∈[segk,segk+1)

  • 各段斜率 a k a_k ak、截距 b k b_k bk 由 NumPy 对 e x e^x ex 在目标区间(主要是 x ≤ 0 x \le 0 x≤0)拟合。
  • 误差目标: < 2 − 8 < 2^{-8} <2−8(约 0.4%),对 attention 权重排序影响可忽略。
  • 当前状态 :代码中 exp_pwl 是占位常数,系数需按上式标定后填入 case

3.4 用倒数代替除法

FPGA 上浮点除法代价高。归一化改为:

p i = e i × 1 ∑ e i p_i = e_i \times \frac{1}{\sum e_i} pi=ei×∑ei1

倒数用 magic-seed + 一次 Newton 迭代 求解。设要求 y = 1 / x y = 1/x y=1/x,迭代式:

y n + 1 = y n ( 2 − x   y n ) y_{n+1} = y_n(2 - x\,y_n) yn+1=yn(2−xyn)

  • 初值 y 0 y_0 y0 由位模式 magic 常数近似(fp16_recipMAGIC = 0x7A69)。
  • 因 ∑ e i > 0 \sum e_i > 0 ∑ei>0 恒成立,只需处理正数区间,无需通用倒数逻辑。
  • 一次迭代即可将相对误差压到 FP16 尾数精度量级。

4. 点积的 loop-carried 依赖问题

4.1 问题本质

Q K T QK^T QKT 和 P V PV PV 都是长度 64 的点积累加:

s = ∑ k = 0 63 a k b k s = \sum_{k=0}^{63} a_k b_k s=k=0∑63akbk

朴素写法是单累加器串行相加:acc = acc + a_k*b_k。但 FP16 加法器有 3 拍流水延迟acc 的新值 3 拍后才可用,导致下一次加法必须等待------每 3 拍才能吃一个数,吞吐骤降。这就是 loop-carried dependency(循环携带依赖)

4.2 交织累加 + 树形归约

解法是开 4 路独立累加器 ,按 lane = 0,1,2,3 轮流分配乘积:

acc l = ∑ k   m o d   4 = l a k b k , l = 0 , 1 , 2 , 3 \text{acc}l = \sum_{k \bmod 4 = l} a_k b_k,\quad l = 0,1,2,3 accl=kmod4=l∑akbk,l=0,1,2,3

  • 每路累加器相邻两次加法间隔 4 拍 > 3 拍延迟,依赖被并行度吸收,可每拍喂一个新数
  • 全部 64 个乘积处理完后,做 3 次树形加归约为标量:

s = ( acc 0 + acc 1 ) + ( acc 2 + acc 3 ) s = (\text{acc}0+\text{acc}1) + (\text{acc}2+\text{acc}3) s=(acc0+acc1)+(acc2+acc3)

浮点加法不满足结合律,交织+树形归约的舍入路径与串行不同,但误差量级相当且可预测------这一点需在单元 TB 里与 NumPy 参考值比对确认。

5. FP16 加法器规格化原理(易错核心)

FP16 加法的三个阶段及其数学依据:

阶段 1:对阶(Exponent Alignment)

两操作数指数不同时,小指数尾数右移对齐到大指数:

shift = ∣ e a − e b ∣ \text{shift} = |e_a - e_b| shift=∣ea−eb∣

右移丢出的位要用 guard / round / sticky 三位保留,否则舍入误差偏大。sticky 位是所有移出位的"或"。

阶段 2:有符号加减

同号相加、异号相减;异号时以绝对值大者符号为结果符号。

阶段 3:规格化(Normalization,前几版硬伤所在)

相减可能产生大量前导零(浮点相消),需:

  1. 前导零计数(LZC) :数出最高有效位前的 0 的个数 lz
  2. 在完整位宽上左移 lz 位,同时指数减 lz

e result = e max ⁡ − l z + 1 e_{\text{result}} = e_{\max} - lz + 1 eresult=emax−lz+1

⚠️ 关键教训:LZC 移位必须作用在未截断的完整尾数 上,先规格化再取高 10 位。若先截断到 10 位再左移,有效位会被移出丢失------这是前几版计算错误的根因。同时 LZC 的输入必须与它要规格化的数据同拍,否则错拍。

6. 算法到硬件流水线的映射

算法运算 硬件模块 延迟(拍)
乘法 a × b a \times b a×b fp16_mult 2
加法 a + b a + b a+b fp16_add 3
倒数 1 / x 1/x 1/x fp16_recip(mult→add→mult) 7
点积 ∑ a b \sum ab ∑ab fp16_dot(4 路交织 + 3 级树) 交织无阻塞
逐行 softmax softmax_fp16(3 阶段状态机) 按行

流水线对齐原则:所有控制信号(清零、valid)必须按对应算术延迟打拍,才能保证"最后一个乘积到达累加器时清零信号恰好释放",否则丢数或多累加。这是《设计规范检查SKILL.md》流水线检查项的重点。

7. 数值误差预算

误差源 量级 处理
FP16 尾数舍入 2 − 11 2^{-11} 2−11/次 guard/round/sticky 就近舍入
exp PWL 近似 < 2 − 8 < 2^{-8} <2−8 NumPy 拟合分段系数
倒数 Newton 1 次 ~尾数精度 需实测校准 magic 种子
交织累加结合律差异 与串行相当 单元 TB 对拍确认

端到端建议验收阈值:输出逐元素 FP16 相对误差 < 2 − 8 < 2^{-8} <2−8。

8. 待标定项(诚实清单)

以下算法参数无法凭空写对,必须结合数据实测

  1. exp_pwl 分段系数 --- 用 NumPy 拟合后填入。
  2. fp16_recip magic 种子 --- 按 ∑ e i \sum e_i ∑ei 数值范围校准 Newton 收敛。
  3. 交织累加/softmax 各阶段 valid 计数对齐 --- 波形逐拍确认。

小结:本加速器的算法核心是三件事------用"减最大值"保证 exp 数值稳定、用"交织累加"破除浮点 loop-carried 依赖、用"完整位宽 LZC 规格化"保证加法正确。三者对了,attention 的数值结果才能对齐软件参考。