本文基于昇腾CANN和昇腾NPU,围绕 Cube MatMul 矩阵乘法技术展开。
想象你在一个巨大的停车场里搬箱子。方案 A:一次搬一个箱子,走 100 趟------这是 Vector 的做法。方案 B:用叉车一次叉起 16×16 个箱子,一趟搞定------这是 Cube 的做法。
AI 计算的核心是矩阵乘法------95% 以上的浮点运算都花在 GEMM 上。一个 Attention 层里的 Q×K^T、Attention×V、FFN 的两次 MatMul------全是 GEMM。硬件优化的第一优先级就是让 GEMM 跑得尽可能快。这就是为什么达芬奇 Core 里专门塞了一个 Cube Unit------它生来只干一件事:16×16×16 的 FP16 乘累加。
Cube Unit 怎么算 GEMM
Cube Unit 每 cycle 做一次 C[16,16] += A[16,16] × B[16,16] 的矩阵乘累加。用 FP16 计算,FP32 累加------输入低精度省内存和带宽,中间累积用高精度保留数值稳定性。
一次 Attention 的 Q×K^T 计算:Q 是 [B, H, S, D],K^T 是 [B, H, D, S]。Cube Unit 把 M=S、K=D、N=S 的 GEMM 切成一堆 16×16 的小块,一块一块算。每个小块独立在 Cube Unit 上完成,结果累积到 C 矩阵。
GEMM C[M,N] = A[M,K] × B[K,N] 在 Cube Unit 上的分解:
┌───────────────┐ ┌───────────┐ ┌───────────────┐
│ A (M×K) │ │ B (K×N) │ │ C (M×N) │
│ │ × │ │ = │ │
│ 切成 M_TILE× │ │ 切成 K_TILE│ │ 切成 M_TILE× │
│ K_TILE 小块 │ │ ×N_TILE │ │ N_TILE 小块 │
└───────────────┘ └───────────┘ └───────────────┘
每个 M_TILE×N_TILE 的小块 C:
for k_step in range(K / K_TILE):
C_tile += A_tile[k_step] @ B_tile[k_step] // 16×16×16 的 MAC
Tile 分块为什么是核心
Cube Unit 一次只算 16×16。一个 4096×4096 的 GEMM 要分解成 (4096/16)² = 65536 个小块------但不需要全存在 L1 上同时算。
Tiling 的关键在 K 维度上的循环累积。A 和 B 沿着 K 维度切段,每次载入一段到 L1、Cube 算完这段的乘累加、结果加到 C 上------然后载入下一段。C 一直在 L1 上不动,A 和 B 轮流上场。
这样 L1 只需要装下:一段 A [M_TILE × K_TILE]、一段 B [K_TILE × N_TILE]、和 C [M_TILE × N_TILE]。三个块加起来几十到一百多 KB------刚好塞进 192KB 的 L1。
昇腾NPU的 GEMM 融合
单纯的 MatMul 后面往往跟着 Bias Add、Activation、Residual Add。CANN 的算子库把这些操作融进 MatMul------不是"先算 GEMM 再调 Activation",而是在 GEMM 的最后一个 K-step 还没结束时,Vector Unit 就已经开始处理前面累积好的 C 元素做 Activation。
融合 MatMul + Bias + GELU 的执行流:
Cube Unit: [GEMM K-step 0] [GEMM K-step 1] ... [GEMM 最后 K-step]
Vector Unit: [Bias Add] [GELU]
Scalar Unit: [地址计算] [循环控制] ... [地址计算] [循环控制]
三个单元在不同 K-step 上并行------融合的 GEMM 比"先 GEMM 再 Activation"
快约 30%,因为省掉了 GEMM 输出写回 DDR 和 Activation 从 DDR 读入。
cpp
// Ascend C 里的融合 MatMul 模板------简化版
class FusedMatMulGELU : public AscendC::Kernel {
__aicore__ void Process() override {
// Tiling 参数
constexpr int M_TILE = 128, N_TILE = 256, K_TILE = 32;
LocalTensor<fp16> a_tile, b_tile, c_tile;
LocalAlloc(a_tile, M_TILE * K_TILE);
LocalAlloc(b_tile, K_TILE * N_TILE);
LocalAlloc(c_tile, M_TILE * N_TILE);
for (int m = 0; m < M; m += M_TILE) {
for (int n = 0; n < N; n += N_TILE) {
SetZero(c_tile);
for (int k = 0; k < K; k += K_TILE) {
DataCopy(a_tile, gm_a[m][k], M_TILE * K_TILE);
DataCopy(b_tile, gm_b[k][n], K_TILE * N_TILE);
MatMul(c_tile, a_tile, b_tile, M_TILE, N_TILE, K_TILE);
// 累加:C += A @ B(Cube Unit 原生支持 FP32 累加)
}
// GEMM 完成后,Vector Unit 直接对 C 做 GELU------不写回 DDR
GELU(c_tile, c_tile, M_TILE * N_TILE);
DataCopy(gm_c[m][n], c_tile, M_TILE * N_TILE);
}}
}
};