ops-softmax:Transformer 推理中的概率归一化引擎

Transformer 里 Attention 的核心是 Softmax------它把注意力分数变成概率分布。没有 Softmax,注意力分数就只是一个数值没有归一化的矩阵,无法作为权重来聚合 Value。

CANN 的 ops-softmax 仓库专门管理 Softmax 及其变体的实现。Softmax 的计算量不大------就是 exp → sum → div 三步------但它的数据访问模式决定了它是 Memory Bound 算子,在昇腾NPU 上需要针对大序列长度做专门的优化。


Softmax 为什么是 Transformer 核心

Attention 的计算公式:Attention(Q,K,V) = Softmax(Q × K^T / √d) × V

Q × K^T 输出的注意力分数矩阵 S[n, n] 的矩阵。矩阵中的每个元素 S_ij 表示第 i 个 Token 对第 j 个 Token 的注意力强度。但这些分数是未归一化的------可能很大也可能很小。Softmax 把它们归一化成概率分布,让 sum(S_ij over j) = 1

Softmax 的步骤:

  1. exp(x_i)------指数化,把分数转为正数
  2. sum(exp(x_i))------求所有指数值的和
  3. exp(x_i) / sum------每个指数值除以总和,归一化为概率

Softmax 为什么会成为性能瓶颈

Softmax 的计算量很小------每个元素一次指数运算、一次除法。但它的数据访问模式很差:

  • 输入读取 [n, n] 矩阵的全部元素(从 DDR 搬到 L1)
  • 对所有元素做指数运算(Vector Unit 执行)
  • 在行方向做 sum(归约操作,需要对整行扫描)
  • 再读取一次,每个元素除以 sum(从 DDR 搬到 L1)

对于 n=4096 的序列,Score 矩阵 32MB。整个流程需要搬运约 64MB------两次读 S、一次写 S_softmax。计算/搬运比很低。

FlashAttention 中的 Softmax 优化

FlashAttention 对 Softmax 的优化是让它原地完成------Score 矩阵不落地 DDR。具体做法:Score 矩阵被切成 block×block 的子块,每次只搬运一个子块到 L1。在 L1 上做完 Softmax 后立即跟 Value 做矩阵乘,Softmax 的结果不需要写回 DDR。

这个过程需要 Online Softmax 算法------在不知道全局最大值的情况下分块计算:

复制代码
初始化:max_val = -inf, sum_val = 0
循环每个 K/V 块:
  当前块的最大值 local_max = max(S_ij)
  更新 max_val = max(max_val, local_max)
  缩放旧的 sum_val:sum_val *= exp(max_val - local_max)
  当前块的 exp 和:local_sum = sum(exp(S_ij - max_val))
  累积:sum_val += local_sum

Online Softmax 的计算精度跟标准 Softmax 完全一致,但避免了 Score 矩阵的整体搬运。在长序列场景中,Softmax 不再是性能瓶颈。

Online Softmax 的数值稳定性

Softmax 的朴素实现:exp(x_i)x_i 很大时(如 Attention Score 的值可能超过 30)会导致 float16 溢出。标准做法是减去最大值:exp(x_i - max(x)) / sum(exp(x_j - max(x)))

FlashAttention 的 Online Softmax 在分块计算时也保持了数值稳定性------每个分块独立减去自己的局部最大值,跨分块时用 running max 修正。这个修正的数值误差在 10^-5 级别------不影响推理精度。

ops-softmax 在 Vector Unit 上的实现

ops-softmax 在 Vector Unit 上的实现不是直接写一条 softmax 指令------Vector Unit 只有基本的数学指令。Softmax 被拆解为:

  1. vec_max(x) --- SIMD 找最大值
  2. vec_sub(x, max) --- 每个元素减最大值
  3. vec_exp(x) --- SIMD 指数运算(使用多项式近似)
  4. vec_sum(exp_x) --- SIMD 求和
  5. vec_div(exp_x, sum) --- 每个元素除以总和

这 5 条 Vector 指令在 L1 上执行,不需要写 DDR。对于 4096 个元素的 Softmax,Vector Unit 的执行时间约 1-2μs。

大序列长度(n > 4096)时,Score 矩阵 [n, n] 超出了一次 Kernel 可以处理的 L1 容量。ops-softmax 把 Score 矩阵按行分成多块------每块在 L1 上做完完整的 Softmax 后再写回 DDR。

参考仓库

ops-softmax 仓库

FlashAttention 融合优化

相关推荐
哦哦~921几秒前
AI赋能生物医学:从临床数据到药物分子性质预测实战培
人工智能·生物医学·药物分子
GIS数据转换器3 分钟前
城市排水生命线安全运行监测平台深度解析
java·运维·人工智能·python·安全·数据挖掘·无人机
虫无涯6 分钟前
本地离线大模型实战:Ollama + Llama 3.1 8B 全流程部署(适配VSCode Continue代码助手)
人工智能
Rocky Ding*21 分钟前
Latent Consistency Models:一篇读懂扩散模型的少步生成核心基础知识
人工智能·深度学习·机器学习·ai作画·stable diffusion·aigc·ai-native
大山佬23 分钟前
AI 边缘部署:MCU 上的轻量级目标检测,从 YOLO 到 TFLite Micro 的全链路优化
人工智能
数睿数据无代码开发24 分钟前
深度解析smardaten数据大屏:六大核心功能重塑可视化开发
人工智能·信息可视化
陈猪的杰咪24 分钟前
GitHub Copilot 2026计费新规:AI Credits消耗解析与节省策略
人工智能·ai·架构·github·copilot
学术头条33 分钟前
清华团队开源SCAIL-2:角色动画告别骨骼依赖,端到端还原视频中动作细节
人工智能·科技·机器学习·ai·开源·音视频·agi
لا معنى له33 分钟前
世界模型的功能分类法——Renderers, Simulators, Planners, and the Loop That Connects Them
人工智能
华如锦41 分钟前
面了很多 Java转AI Agent方向,一些面试题总结
java·开发语言·人工智能·python·ai