你以为 GE 做完融合决策,交给 Runtime 执行就行了?其实它们是一个协同系统------GE 决定"融什么",Runtime 决定"怎么跑",但 GE 的融合决策必须考虑 Runtime 的调度约束,Runtime 的调度策略也必须参考 GE 的融合结果。
这一篇把 GE 和 Runtime 的协同工作机制拆开来,说说四个被误解的设计决策。
GE 不是"图优化器",是"融合决策引擎"
很多人以为 GE(Graph Engine)就是做图优化的------算子融合、内存优化、计算图重写。这些都没错,但不是 GE 的核心。
GE 的核心是融合决策------根据算子的 shape、dtype、tiling 参数,决定哪些算子可以融合、以什么顺序融合、融合后的算子怎么调度。这个决策过程不是固定的,是可学习的------你可以读 GE 的融合规则,甚至写自定义的融合 pass。
python
# GE 的融合决策过程(简化版)
# 来源:ge/frontend/fusion_pass/flash_attention_fusion_pass.cc
# 决策1:输入 dtype 检查(必须是 float16)
# if (input_dtype != DT_FLOAT16) return false;
# 决策2:seq_len 检查(必须是 2 的幂次方)
# int seq_len = input_shape[2];
# if ((seq_len & (seq_len - 1)) != 0) return false;
# 决策3:Q、K、V 的 seq_len 必须相同
# if (q_shape[2] != k_shape[2] || q_shape[2] != v_shape[2]) return false;
# 决策4:必须开启 causal mask(训练场景)
# if (!is_causal) return false;
# 验证:查看 GE 的融合决策日志
import os
os.environ["ASCEND_GLOBAL_LOG_LEVEL"] = "3"
os.environ["GE_LOG_TO_STDOUT"] = "1"
import torch
Q = torch.randn(4, 32, 2048, 64, dtype=torch.float16).npu()
K = torch.randn(4, 32, 2048, 64, dtype=torch.float16).npu()
V = torch.randn(4, 32, 2048, 64, dtype=torch.float16).npu()
output = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.npu.synchronize()
# 在日志输出中搜索 "FlashAttentionFusionPass"
# 如果看到 "FlashAttentionFusionPass: success",说明 GE 的融合决策成功了
误解 :GE 是图优化器,做的事情就是算子融合、内存优化。
纠正:GE 的核心是融合决策------根据算子的 shape/dtype/tiling,决定哪些算子可以融合、以什么顺序融合。这个决策过程是可学习的。
Runtime 不是"任务调度器",是"overlap 协调器"
很多人以为 Runtime 就是调度算子执行的------哪个算子先执行、哪个后执行、哪些可以并行。这些都没错,但不是 Runtime 的核心。
Runtime 的核心是 overlap------让数据搬运和计算重叠起来。具体来说:当前 tile 的计算在进行的时候,Runtime 已经把下一个 tile 的数据从 HBM 搬到 UB 上了。这样计算单元就不会停下来等数据。
python
# Runtime 的 overlap 机制(简化版)
# 来源:runtime/core/mem_manager/overlap_manager.cc
# overlap 的核心逻辑:
# 1. 把算子按 tile 切分
# 2. 当前 tile 在计算的时候,预取下一个 tile 的数据
# 3. 计算完当前 tile,立刻开始计算下一个 tile(数据已经就位)
# 验证:用 Profiler 抓 trace,看计算 kernel 和数据搬运 kernel 的时间轴
from torch_npu.profiler import profile, ProfilerActivity
Q = torch.randn(4, 32, 4096, 64, dtype=torch.float16).npu()
K = torch.randn(4, 32, 4096, 64, dtype=torch.float16).npu()
V = torch.randn(4, 32, 4096, 64, dtype=torch.float16).npu()
with profile(activities=[ProfilerActivity.NPU], export_name="runtime_overlap.json"):
output = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.npu.synchronize()
# 分析 runtime_overlap.json:
# - 如果计算 kernel(FlashAttentionKernel)和数据搬运 kernel(MemcpyH2D)有重叠
# → Runtime 的 overlap 生效了 ✅
# - 如果计算 kernel 和数据搬运 kernel 完全串行
# → Runtime 的 overlap 未生效 ❌
# 对比 overlap 开启/关闭的性能
os.environ["ASCEND_OVERLAP_DISABLE"] = "1" # 关闭 overlap
torch.npu.synchronize()
start = time.time()
for _ in range(50):
output = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.npu.synchronize()
end = time.time()
print(f"overlap 关闭后 50 次耗时: {end-start:.2f}s")
os.environ["ASCEND_OVERLAP_DISABLE"] = "0" # 开启 overlap
torch.npu.synchronize()
start = time.time()
for _ in range(50):
output = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.npu.synchronize()
end = time.time()
print(f"overlap 开启后 50 次耗时: {end-start:.2f}s")
误解 :Runtime 是任务调度器,决定算子执行顺序。
纠正:Runtime 的核心是 overlap------让数据搬运和计算并行,减少计算单元等待数据的时间。
GE 和 Runtime 不是"上下游",是"协同决策"
很多人以为 GE 做完融合决策,生成融合后的算子,交给 Runtime 执行就行了。这个理解太浅了。
GE 的融合决策必须考虑 Runtime 的调度约束。比如:如果一个算子融合后太大(tile 数太多),Runtime 可能无法有效地做 overlap(因为内存不够存两个 tile 的数据)。GE 在决策融合的时候,必须参考 Runtime 的 overlap 可行性。
反过来,Runtime 的调度策略也必须参考 GE 的融合结果。比如:融合后的算子更适合 tile 级 pipeline(因为一个融合算子内部可以切分 tile),Runtime 会根据融合算子的特性调整调度策略。
python
# GE 和 Runtime 的协同决策(简化版)
# 场景:GE 在决策是否为 FlashAttention 做融合时,会参考 Runtime 的 overlap 可行性
# GE 的考虑:
# 1. 融合后的 FlashAttentionKernel 有多大?(tile 数 × 每个 tile 的 UB 占用)
# 2. Runtime 能不能有效地做 overlap?(UB 够不够存两个 tile 的数据)
# 3. 如果 UB 不够,GE 可能不会触发融合,或者选择一个更小的 tile 大小
# Runtime 的考虑:
# 1. 这个算子是不是融合算子?(融合算子更适合 tile 级 pipeline)
# 2. 融合算子的 tile 大小是多少?(决定预取策略)
# 3. 根据融合算子的特性,调整调度策略(比如更多的 pipeline 级)
# 验证:对比不同 tile 大小下,GE 是否触发融合 + Runtime 的 overlap 效率
import torch
import time
for tile_size in [64, 128, 256, 512]:
Q = torch.randn(4, 32, 4096, tile_size, dtype=torch.float16).npu()
K = torch.randn(4, 32, 4096, tile_size, dtype=torch.float16).npu()
V = torch.randn(4, 32, 4096, tile_size, dtype=torch.float16).npu()
# 查看 GE 日志,看这个 tile_size 下是否触发了 FlashAttentionFusion
output = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.npu.synchronize()
# 计时
start = time.time()
for _ in range(100):
output = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.npu.synchronize()
end = time.time()
print(f"tile_size={tile_size}, 100次耗时: {end-start:.2f}s")
# 用 npu-smi 查看 Compute Cube 利用率
# 如果利用率 > 80%,说明 Runtime 的 overlap 做得好(计算单元没怎么等数据)
# 如果利用率 < 50%,说明计算单元经常在等数据(overlap 没生效或 tile 大小不合适)
误解 :GE 和 Runtime 是上下游关系------GE 做完决策,Runtime 执行就行。
纠正:GE 和 Runtime 是协同决策关系------GE 的融合决策要考虑 Runtime 的调度约束,Runtime 的调度策略要参考 GE 的融合结果。
ops-transformer 不是"被动适应",是"主动配合"
很多人以为 ops-transformer 的算子只要实现功能就行,GE 和 Runtime 会自动优化。这个理解是错的。
ops-transformer 的算子设计必须主动配合 GE 的融合规则和 Runtime 的调度策略。具体来说:
- 暴露 tiling 参数:让 GE 在决策融合的时候,知道这个算子支持什么样的 tile 大小
- 支持 causal mask:让 GE 在匹配 FlashAttentionFusionPass 的时候,知道这个算子支持 causal mask
- 优化 UB 使用:让 Runtime 在做 overlap 的时候,有足够的内存预取下一个 tile 的数据
python
# ops-transformer 的算子设计(主动配合 GE 和 Runtime)
# 来源:ops-transformer/src/ops_transformer/flash_attention/flash_attention_kernel.cpp
# 主动配合1:暴露 tiling 参数(让 GE 知道这个算子支持什么 tile 大小)
# void FlashAttentionKernel(..., int tiling) { ... }
# 主动配合2:支持 causal mask(让 GE 匹配 FlashAttentionFusionPass)
# if (causal) { ... // 在 softmax 之前把 mask 位置设为 -inf ... }
# 主动配合3:优化 UB 使用(让 Runtime 有足够的空间做 overlap)
# - 每个 tile 的 UB 占用尽量小
# - 预留足够的 UB 空间给下一个 tile 的数据预取
# 验证:读 ops-transformer 的源码,看它是怎么主动配合 GE 和 Runtime 的
import subprocess
# 查看 tiling 参数的定义
result = subprocess.run(
["grep", "-r", "tiling", "ops-transformer/src/ops_transformer/flash_attention/"],
capture_output=True,
text=True
)
print("tiling 参数定义:")
print(result.stdout)
# 查看 causal mask 的实现
result = subprocess.run(
["grep", "-r", "causal", "ops-transformer/src/ops_transformer/flash_attention/"],
capture_output=True,
text=True
)
print("causal mask 实现:")
print(result.stdout)
# 查看 UB 使用的优化
result = subprocess.run(
["grep", "-r", "UB", "ops-transformer/src/ops_transformer/flash_attention/"],
capture_output=True,
text=True
)
print("UB 使用优化:")
print(result.stdout)
误解 :ops-transformer 的算子只要实现功能就行,GE 和 Runtime 会自动优化。
纠正:ops-transformer 的算子设计必须主动配合 GE 和 Runtime------暴露 tiling 参数、支持 causal mask、优化 UB 使用。
相关仓库: