CANN 的矩阵计算能力来自三层结构:最顶层是 ops-blas 这样的标准算子库,中间是 GEMM 的优化实现层,最底层是 catlass------一套用模板元编程生成的算子 Kernel 工厂。
很多人第一次看到 catlass 会觉得它就是 CUTLASS 的昇腾移植版。这个印象对了一半。catlass 确实借鉴了 CUTLASS 的模板化思想,但针对昇腾达芬奇架构做了大量不同于 GPU 的设计取舍。
catlass 为什么存在
AI 计算的核心是矩阵乘法。一张计算图里 70-80% 的 FLOPs 来自各种形状的 GEMM------大的如 FFN 的 [B, 4096] × [4096, 11008],小的如 Attention 里的 [B, 40, n, 128] × [B, 40, 128, n]。
这些 GEMM 的形状千差万别,但底层都是同一个操作:M×K 乘以 K×N。如果为每种形状手写一个 Kernel,工作量巨大且难以维护。
catlass 的答案是用 C++ 模板来生成 Kernel。开发者描述"我要一个 M=4096、K=4096、N=11008 的 GEMM"------catlass 的模板引擎在编译期展开成对应的 Tile 循环、搬运指令和计算指令。不同形状的 GEMM 共享同一套模板代码,只在模板参数上做区分。
算子模板的核心思想
catlass 把一个 GEMM Kernel 拆解成几个可组合的模板层次:
Tile 描述层。 定义每个分块的大小------M 维度分多少、N 维度分多少、K 维度做多少次循环。这个层的配置直接影响片上 L1 Buffer 的利用率和搬运次数。
数据搬运层。 定义数据从 GM(DDR)到片上 L1 的搬运策略------要不要做转置、数据是 FP16 还是 INT8、地址对齐方式。catlass 把搬运描述为 Iterator 模板,不同的 GEMM 形状对应不同的 Iterator 特化版本。
计算指令层。 定义 AI Core 上的计算指令序列------Cube Unit 做矩阵乘、Vector Unit 做偏置加和激活函数。这一层直接映射到达芬奇架构的硬件指令。
这三层模板在编译期通过模板参数组合,展开为特定形状的完整 Kernel 代码。开发者不需要手写任何循环或搬运指令。
cpp
// catlass 模板组合示例(伪代码)
using GemmTraits = catlass::GemmTraits<
M=4096, N=11008, K=4096,
LayoutA=RowMajor, LayoutB=ColMajor,
TypeA=float16, TypeB=float16, TypeC=float16,
TileShape=<128, 128, 64>,
IteratorA=IteratorContiguous,
IteratorB=IteratorPacked
>;
// 编译期展开为完整 Kernel
catlass::Kernel<GemmTraits>::launch(ptrA, ptrB, ptrC, stream);
catlass 与 CUTLASS 的区别
硬件抽象不同。 CUTLASS 抽象的是 GPU 的 warp/thread 层次------开发者需要配置每个线程计算哪些元素、如何做 shared memory 的 bank conflict 避免。catlass 抽象的是昇腾的 Cube Unit 和 Vector Unit------开发者配置的是分块大小和数据搬运策略,不涉及线程模型。
编译期 vs 运行时。 CUTLASS 的模板展开走 CUDA 的 JIT 编译路径,运行时可能存在首次编译延迟。catlass 的模板在 ATC 编译模型时就全部展开,Runtime 加载 Kernel 时已经是编译好的二进制。推理时零编译开销。
融合能力不同。 CUTLASS 的模板化停在 GEMM Kernel 边界------GEMM 做完后的 bias add 和 activation 需要独立的 Kernel 调用。catlass 的模板允许在 Tile 循环内部嵌入 Vector 计算指令,比如 GEMM 的每个分块计算完后立即做 ReLU,结果不落 DDR。
cpp
// catlass 的 GEMM + ReLU 融合模板
using GemmReluTraits = catlass::GemmTraits<
M=4096, N=11008, K=4096,
Epilogue=Relu<TypeC> // 分块计算完后立即 ReLU
>;
这个融合能力在 Transformer 推理中非常关键------FFN 层的 GEMM → BiasAdd → ReLU → GEMM 可以融合成两个带 Epilogue 的 catlass Kernel,省掉中间 Tensor 的两次 DDR 读写。
CANN 如何调用 catlass
CANN 的算子调用链路中,catlass 不直接暴露给应用层。
AscendCL 推理 → GE 图优化 → ops-blas(高层算子库)→ catlass(模板 Kernel 生成)
应用层调用的是 AscendCL 或 PyTorch 的标准推理接口。GE 在图优化阶段把计算图中的 GEMM 算子替换成 ops-blas 的优化实现。ops-blas 内部在初始化时通过 catlass 的模板引擎生成特定形状的 Kernel,注册到 Runtime 的算子表中。推理时 Runtime 直接查表加载对应 Kernel。
开发者不需要直接接触 catlass,但理解它的模板化思想对于手工优化 GEMM 形状很关键------当你发现某个 GEMM 形状的推理性能异常时,大概率是 catlass 为该形状生成的 Tile 参数没有匹配到最优配置。
Transformer 中的 GEMM 优化
以 LLaMA-7B 为例,模型运行时涉及的 GEMM 主要是三种形状:
FFN 前向 GEMM: [B, 4096] × [4096, 11008]。这是一个典型的 M 小、K 和 N 大的矩形 GEMM。catlass 为这种形状生成的 Tile 策略是 N 维度大分块(减少 Tile 循环次数)、M 维度小分块(匹配 Batch 的动态变化)。
Attention 投影 GEMM: Q = X @ W_Q,形状 [B, n, 4096] × [4096, 4096]。三个投影(Q/K/V)形状相同。catlass 会让三个投影共享同一套 Tile 模板,在 Stream 上流水线执行。
Decoder Block 的残差 GEMM: 形状小且不规则。catlass 为这类小 GEMM 专门设计了 Tile 模板,避免在小矩阵上跑大 Tile 导致的搬运浪费。
实测中,catlass 的模板化 GEMM 比手写固定 Kernel 在各种形状上的平均性能差距在 5% 以内,但在开发效率和代码维护上有数量级的优势。
Tile 配置的艺术
catlass 的模板化核心在 Tile 参数的配置。不同的 GEMM 形状对 Tile 的要求完全不同。
大 GEMM(M=4096, K=4096, N=11008):
Tile(M=128, N=128, K=64) --- 让 Cube Unit 保持满载
循环次数:M=32, N=86, K=64 → 总启动 176,128 次 Tile 计算
小 GEMM(M=1, K=4096, N=4096):
Tile(M=1, N=128, K=4096) --- M 维度不切,N 平铺
循环次数:M=1, N=32, K=1 → 32 次
不规则 GEMM(M=7, K=1024, N=256):
Tile(M=7, N=64, K=1024) --- 非 2 的幂次,K 全量
循环次数:M=1, N=4, K=1 → 4 次
catlass 的模板引擎内置了一套启发式 Tile 选择逻辑:根据 M、N、K 的比值和金丝雀测试数据,自动选择理论最优的 Tile 配置。如果自动选择不是最优,开发者可以通过模板参数覆盖。
Kernel 生成的编译期流程
catlass 的模板展开不是运行时发生的,而是在 ATC 编译模型时通过编译器完成:
模型 ONNX → ATC 解析 GEMM 形状
→ catlass 模板实例化(编译期)
→ 模板参数推导 Tile 配置
→ 生成 /tmp/.catlass_kernels/<shape_hash>.o
→ 链接进 OM 模型
→ Runtime 加载时已经是编译好的二进制
每个 GEMM 形状的 Kernel 编译一次后缓存到文件系统。下次遇到相同形状的 GEMM 直接复用。这个缓存机制在模型部署时非常关键------模型加载时不需要重新编译 Kernel,OM 文件内已经包含了所有 GEMM Kernel 的二进制。
继续学习
catlass 解决的是"GEMM Kernel 怎么写"的问题。上一层 ops-blas 解决的是"GEMM 形状太多怎么管理"的问题。理解了 catlass 的模板化思想后,顺着 ops-blas 的仓库可以看到 CANN 是如何为数百种 GEMM 形状提供统一优化接口的。
模板化 vs 手写 Kernel 的取舍
模板化的代价是灵活性受限。手写 Kernel 可以为特定形状做极致优化------比如在 M=1 的场景省略掉 M 维度的所有循环。catlass 的通用模板无法覆盖所有极端形状的特殊优化,但是覆盖了 95% 的常用场景。剩余 5% 的极端形状可以通过手写 Ascend C Kernel 来补充,catlass 提供了与手写 Kernel 的互操作接口。