Cube MatMul:为什么矩阵乘法选了 Cube 而不是 Vector

本文基于昇腾CANN和昇腾NPU,围绕 Cube MatMul 矩阵乘法技术展开。

想象你在一个巨大的停车场里搬箱子。方案 A:一次搬一个箱子,走 100 趟------这是 Vector 的做法。方案 B:用叉车一次叉起 16×16 个箱子,一趟搞定------这是 Cube 的做法。

AI 计算的核心是矩阵乘法------95% 以上的浮点运算都花在 GEMM 上。一个 Attention 层里的 Q×K^T、Attention×V、FFN 的两次 MatMul------全是 GEMM。硬件优化的第一优先级就是让 GEMM 跑得尽可能快。这就是为什么达芬奇 Core 里专门塞了一个 Cube Unit------它生来只干一件事:16×16×16 的 FP16 乘累加。


Cube Unit 怎么算 GEMM

Cube Unit 每 cycle 做一次 C16,16 += A16,16 × B16,16 的矩阵乘累加。用 FP16 计算,FP32 累加------输入低精度省内存和带宽,中间累积用高精度保留数值稳定性。

一次 Attention 的 Q×K^T 计算:Q 是 B, H, S, D,K^T 是 B, H, D, S。Cube Unit 把 M=S、K=D、N=S 的 GEMM 切成一堆 16×16 的小块,一块一块算。每个小块独立在 Cube Unit 上完成,结果累积到 C 矩阵。

复制代码
GEMM C[M,N] = A[M,K] × B[K,N] 在 Cube Unit 上的分解:

┌───────────────┐     ┌───────────┐     ┌───────────────┐
│   A (M×K)     │     │  B (K×N)  │     │   C (M×N)     │
│               │  ×  │           │  =  │               │
│ 切成 M_TILE×  │     │ 切成 K_TILE│     │ 切成 M_TILE×  │
│ K_TILE 小块   │     │ ×N_TILE   │     │ N_TILE 小块   │
└───────────────┘     └───────────┘     └───────────────┘

每个 M_TILE×N_TILE 的小块 C:
  for k_step in range(K / K_TILE):
    C_tile += A_tile[k_step] @ B_tile[k_step]  // 16×16×16 的 MAC

Tile 分块为什么是核心

Cube Unit 一次只算 16×16。一个 4096×4096 的 GEMM 要分解成 (4096/16)² = 65536 个小块------但不需要全存在 L1 上同时算。

Tiling 的关键在 K 维度上的循环累积。A 和 B 沿着 K 维度切段,每次载入一段到 L1、Cube 算完这段的乘累加、结果加到 C 上------然后载入下一段。C 一直在 L1 上不动,A 和 B 轮流上场。

这样 L1 只需要装下:一段 A M_TILE × K_TILE、一段 B K_TILE × N_TILE、和 C M_TILE × N_TILE。三个块加起来几十到一百多 KB------刚好塞进 192KB 的 L1。


昇腾NPU的 GEMM 融合

单纯的 MatMul 后面往往跟着 Bias Add、Activation、Residual Add。CANN 的算子库把这些操作融进 MatMul------不是"先算 GEMM 再调 Activation",而是在 GEMM 的最后一个 K-step 还没结束时,Vector Unit 就已经开始处理前面累积好的 C 元素做 Activation。

复制代码
融合 MatMul + Bias + GELU 的执行流:

Cube Unit:    [GEMM K-step 0] [GEMM K-step 1] ... [GEMM 最后 K-step]
Vector Unit:                                    [Bias Add] [GELU]
Scalar Unit:  [地址计算] [循环控制] ... [地址计算] [循环控制]

三个单元在不同 K-step 上并行------融合的 GEMM 比"先 GEMM 再 Activation"
快约 30%,因为省掉了 GEMM 输出写回 DDR 和 Activation 从 DDR 读入。
cpp 复制代码
// Ascend C 里的融合 MatMul 模板------简化版

class FusedMatMulGELU : public AscendC::Kernel {
    __aicore__ void Process() override {
        // Tiling 参数
        constexpr int M_TILE = 128, N_TILE = 256, K_TILE = 32;

        LocalTensor<fp16> a_tile, b_tile, c_tile;
        LocalAlloc(a_tile, M_TILE * K_TILE);
        LocalAlloc(b_tile, K_TILE * N_TILE);
        LocalAlloc(c_tile, M_TILE * N_TILE);

        for (int m = 0; m < M; m += M_TILE) {
         for (int n = 0; n < N; n += N_TILE) {
            SetZero(c_tile);
            for (int k = 0; k < K; k += K_TILE) {
                DataCopy(a_tile, gm_a[m][k], M_TILE * K_TILE);
                DataCopy(b_tile, gm_b[k][n], K_TILE * N_TILE);
                MatMul(c_tile, a_tile, b_tile, M_TILE, N_TILE, K_TILE);
                // 累加:C += A @ B(Cube Unit 原生支持 FP32 累加)
            }
            // GEMM 完成后,Vector Unit 直接对 C 做 GELU------不写回 DDR
            GELU(c_tile, c_tile, M_TILE * N_TILE);
            DataCopy(gm_c[m][n], c_tile, M_TILE * N_TILE);
         }}
    }
};

参考仓库

ops-blas 高性能 GEMM

ops-nn 神经网络算子

catlass 算子模板库

CANN 学习中心

相关推荐
极光代码工作室6 小时前
基于深度学习的手写数字识别系统
人工智能·python·深度学习·神经网络·机器学习
geovindu6 小时前
python: speech to text offline
开发语言·python·语音识别
AI创界者6 小时前
告别云端限制!Sulphur 2 本地文生视频/图生视频整合包,本地部署,解压即用,保姆级部署与工作流实战
人工智能·python·aigc·音视频
tsfy20037 小时前
Python批量调整Excel格式,并排版导出PDF
python·pdf·excel
木囧7 小时前
PyCharm手动创建虚拟环境
ide·python·pycharm
李可以量化7 小时前
QMT 量化实践:两种方式获取个股上市日期(内置 Python + 原生 Python 完整可运行代码)
python
是多巴胺不是尼古丁7 小时前
期末java复习--string
java·开发语言·python
garmin Chen7 小时前
从 Transformer 到 Agent:大模型技术全景解析
java·人工智能·python·深度学习·transformer
没有钱的钱仔7 小时前
pytorch_cuda安装
人工智能·pytorch·python
Full Stack Developme8 小时前
Apache Tika 教程
java·开发语言·python·apache