Cube MatMul:为什么矩阵乘法选了 Cube 而不是 Vector

本文基于昇腾CANN和昇腾NPU,围绕 Cube MatMul 矩阵乘法技术展开。

想象你在一个巨大的停车场里搬箱子。方案 A:一次搬一个箱子,走 100 趟------这是 Vector 的做法。方案 B:用叉车一次叉起 16×16 个箱子,一趟搞定------这是 Cube 的做法。

AI 计算的核心是矩阵乘法------95% 以上的浮点运算都花在 GEMM 上。一个 Attention 层里的 Q×K^T、Attention×V、FFN 的两次 MatMul------全是 GEMM。硬件优化的第一优先级就是让 GEMM 跑得尽可能快。这就是为什么达芬奇 Core 里专门塞了一个 Cube Unit------它生来只干一件事:16×16×16 的 FP16 乘累加。


Cube Unit 怎么算 GEMM

Cube Unit 每 cycle 做一次 C[16,16] += A[16,16] × B[16,16] 的矩阵乘累加。用 FP16 计算,FP32 累加------输入低精度省内存和带宽,中间累积用高精度保留数值稳定性。

一次 Attention 的 Q×K^T 计算:Q 是 [B, H, S, D],K^T 是 [B, H, D, S]。Cube Unit 把 M=S、K=D、N=S 的 GEMM 切成一堆 16×16 的小块,一块一块算。每个小块独立在 Cube Unit 上完成,结果累积到 C 矩阵。

复制代码
GEMM C[M,N] = A[M,K] × B[K,N] 在 Cube Unit 上的分解:

┌───────────────┐     ┌───────────┐     ┌───────────────┐
│   A (M×K)     │     │  B (K×N)  │     │   C (M×N)     │
│               │  ×  │           │  =  │               │
│ 切成 M_TILE×  │     │ 切成 K_TILE│     │ 切成 M_TILE×  │
│ K_TILE 小块   │     │ ×N_TILE   │     │ N_TILE 小块   │
└───────────────┘     └───────────┘     └───────────────┘

每个 M_TILE×N_TILE 的小块 C:
  for k_step in range(K / K_TILE):
    C_tile += A_tile[k_step] @ B_tile[k_step]  // 16×16×16 的 MAC

Tile 分块为什么是核心

Cube Unit 一次只算 16×16。一个 4096×4096 的 GEMM 要分解成 (4096/16)² = 65536 个小块------但不需要全存在 L1 上同时算。

Tiling 的关键在 K 维度上的循环累积。A 和 B 沿着 K 维度切段,每次载入一段到 L1、Cube 算完这段的乘累加、结果加到 C 上------然后载入下一段。C 一直在 L1 上不动,A 和 B 轮流上场。

这样 L1 只需要装下:一段 A [M_TILE × K_TILE]、一段 B [K_TILE × N_TILE]、和 C [M_TILE × N_TILE]。三个块加起来几十到一百多 KB------刚好塞进 192KB 的 L1。


昇腾NPU的 GEMM 融合

单纯的 MatMul 后面往往跟着 Bias Add、Activation、Residual Add。CANN 的算子库把这些操作融进 MatMul------不是"先算 GEMM 再调 Activation",而是在 GEMM 的最后一个 K-step 还没结束时,Vector Unit 就已经开始处理前面累积好的 C 元素做 Activation。

复制代码
融合 MatMul + Bias + GELU 的执行流:

Cube Unit:    [GEMM K-step 0] [GEMM K-step 1] ... [GEMM 最后 K-step]
Vector Unit:                                    [Bias Add] [GELU]
Scalar Unit:  [地址计算] [循环控制] ... [地址计算] [循环控制]

三个单元在不同 K-step 上并行------融合的 GEMM 比"先 GEMM 再 Activation"
快约 30%,因为省掉了 GEMM 输出写回 DDR 和 Activation 从 DDR 读入。
cpp 复制代码
// Ascend C 里的融合 MatMul 模板------简化版

class FusedMatMulGELU : public AscendC::Kernel {
    __aicore__ void Process() override {
        // Tiling 参数
        constexpr int M_TILE = 128, N_TILE = 256, K_TILE = 32;

        LocalTensor<fp16> a_tile, b_tile, c_tile;
        LocalAlloc(a_tile, M_TILE * K_TILE);
        LocalAlloc(b_tile, K_TILE * N_TILE);
        LocalAlloc(c_tile, M_TILE * N_TILE);

        for (int m = 0; m < M; m += M_TILE) {
         for (int n = 0; n < N; n += N_TILE) {
            SetZero(c_tile);
            for (int k = 0; k < K; k += K_TILE) {
                DataCopy(a_tile, gm_a[m][k], M_TILE * K_TILE);
                DataCopy(b_tile, gm_b[k][n], K_TILE * N_TILE);
                MatMul(c_tile, a_tile, b_tile, M_TILE, N_TILE, K_TILE);
                // 累加:C += A @ B(Cube Unit 原生支持 FP32 累加)
            }
            // GEMM 完成后,Vector Unit 直接对 C 做 GELU------不写回 DDR
            GELU(c_tile, c_tile, M_TILE * N_TILE);
            DataCopy(gm_c[m][n], c_tile, M_TILE * N_TILE);
         }}
    }
};

参考仓库

ops-blas 高性能 GEMM

ops-nn 神经网络算子

catlass 算子模板库

CANN 学习中心

相关推荐
weixin_448946637 小时前
安裝Hermes
python
都在酒里7 小时前
STM32矩阵按键详解——4×4行列扫描与非阻塞消抖(硬件总结六)
stm32·嵌入式硬件·矩阵
hef2887 小时前
SQL和Python怎么选?数据分析工具实战指南
python·sql·数据分析
徐安安ye7 小时前
FlashAttention长程依赖建模:局部+全局的Hybrid Spiral结构设计
python·深度学习·机器学习
IT策士8 小时前
Django 从 0 到 1 打造完整电商平台:商品排序与浏览量统计
后端·python·django
godspeed_lucip8 小时前
LLM和Agent——专题3: Agentic Workflow 入门(4)
人工智能·python
godspeed_lucip8 小时前
LLM和Agent——专题3: Agentic Workflow 入门(2)
网络·人工智能·python
mingshili8 小时前
[Python] Python中自带模块级的单例模式-不需要定义单例类
python·单例模式
alphaTao8 小时前
LeetCode 每日一题 2026/5/18-2026/5/24
python·leetcode