catlass:昇腾NPU上的算子模板库

CANN 的矩阵计算能力来自三层结构:最顶层是 ops-blas 这样的标准算子库,中间是 GEMM 的优化实现层,最底层是 catlass------一套用模板元编程生成的算子 Kernel 工厂。

很多人第一次看到 catlass 会觉得它就是 CUTLASS 的昇腾移植版。这个印象对了一半。catlass 确实借鉴了 CUTLASS 的模板化思想,但针对昇腾达芬奇架构做了大量不同于 GPU 的设计取舍。


catlass 为什么存在

AI 计算的核心是矩阵乘法。一张计算图里 70-80% 的 FLOPs 来自各种形状的 GEMM------大的如 FFN 的 [B, 4096] × [4096, 11008],小的如 Attention 里的 [B, 40, n, 128] × [B, 40, 128, n]

这些 GEMM 的形状千差万别,但底层都是同一个操作:M×K 乘以 K×N。如果为每种形状手写一个 Kernel,工作量巨大且难以维护。

catlass 的答案是用 C++ 模板来生成 Kernel。开发者描述"我要一个 M=4096、K=4096、N=11008 的 GEMM"------catlass 的模板引擎在编译期展开成对应的 Tile 循环、搬运指令和计算指令。不同形状的 GEMM 共享同一套模板代码,只在模板参数上做区分。


算子模板的核心思想

catlass 把一个 GEMM Kernel 拆解成几个可组合的模板层次:

Tile 描述层。 定义每个分块的大小------M 维度分多少、N 维度分多少、K 维度做多少次循环。这个层的配置直接影响片上 L1 Buffer 的利用率和搬运次数。

数据搬运层。 定义数据从 GM(DDR)到片上 L1 的搬运策略------要不要做转置、数据是 FP16 还是 INT8、地址对齐方式。catlass 把搬运描述为 Iterator 模板,不同的 GEMM 形状对应不同的 Iterator 特化版本。

计算指令层。 定义 AI Core 上的计算指令序列------Cube Unit 做矩阵乘、Vector Unit 做偏置加和激活函数。这一层直接映射到达芬奇架构的硬件指令。

这三层模板在编译期通过模板参数组合,展开为特定形状的完整 Kernel 代码。开发者不需要手写任何循环或搬运指令。

cpp 复制代码
// catlass 模板组合示例(伪代码)
using GemmTraits = catlass::GemmTraits<
    M=4096, N=11008, K=4096,
    LayoutA=RowMajor, LayoutB=ColMajor,
    TypeA=float16, TypeB=float16, TypeC=float16,
    TileShape=<128, 128, 64>,
    IteratorA=IteratorContiguous,
    IteratorB=IteratorPacked
>;

// 编译期展开为完整 Kernel
catlass::Kernel<GemmTraits>::launch(ptrA, ptrB, ptrC, stream);

catlass 与 CUTLASS 的区别

硬件抽象不同。 CUTLASS 抽象的是 GPU 的 warp/thread 层次------开发者需要配置每个线程计算哪些元素、如何做 shared memory 的 bank conflict 避免。catlass 抽象的是昇腾的 Cube Unit 和 Vector Unit------开发者配置的是分块大小和数据搬运策略,不涉及线程模型。

编译期 vs 运行时。 CUTLASS 的模板展开走 CUDA 的 JIT 编译路径,运行时可能存在首次编译延迟。catlass 的模板在 ATC 编译模型时就全部展开,Runtime 加载 Kernel 时已经是编译好的二进制。推理时零编译开销。

融合能力不同。 CUTLASS 的模板化停在 GEMM Kernel 边界------GEMM 做完后的 bias add 和 activation 需要独立的 Kernel 调用。catlass 的模板允许在 Tile 循环内部嵌入 Vector 计算指令,比如 GEMM 的每个分块计算完后立即做 ReLU,结果不落 DDR。

cpp 复制代码
// catlass 的 GEMM + ReLU 融合模板
using GemmReluTraits = catlass::GemmTraits<
    M=4096, N=11008, K=4096,
    Epilogue=Relu<TypeC>  // 分块计算完后立即 ReLU
>;

这个融合能力在 Transformer 推理中非常关键------FFN 层的 GEMM → BiasAdd → ReLU → GEMM 可以融合成两个带 Epilogue 的 catlass Kernel,省掉中间 Tensor 的两次 DDR 读写。


CANN 如何调用 catlass

CANN 的算子调用链路中,catlass 不直接暴露给应用层。

复制代码
AscendCL 推理 → GE 图优化 → ops-blas(高层算子库)→ catlass(模板 Kernel 生成)

应用层调用的是 AscendCL 或 PyTorch 的标准推理接口。GE 在图优化阶段把计算图中的 GEMM 算子替换成 ops-blas 的优化实现。ops-blas 内部在初始化时通过 catlass 的模板引擎生成特定形状的 Kernel,注册到 Runtime 的算子表中。推理时 Runtime 直接查表加载对应 Kernel。

开发者不需要直接接触 catlass,但理解它的模板化思想对于手工优化 GEMM 形状很关键------当你发现某个 GEMM 形状的推理性能异常时,大概率是 catlass 为该形状生成的 Tile 参数没有匹配到最优配置。


Transformer 中的 GEMM 优化

以 LLaMA-7B 为例,模型运行时涉及的 GEMM 主要是三种形状:

FFN 前向 GEMM: [B, 4096] × [4096, 11008]。这是一个典型的 M 小、K 和 N 大的矩形 GEMM。catlass 为这种形状生成的 Tile 策略是 N 维度大分块(减少 Tile 循环次数)、M 维度小分块(匹配 Batch 的动态变化)。

Attention 投影 GEMM: Q = X @ W_Q,形状 [B, n, 4096] × [4096, 4096]。三个投影(Q/K/V)形状相同。catlass 会让三个投影共享同一套 Tile 模板,在 Stream 上流水线执行。

Decoder Block 的残差 GEMM: 形状小且不规则。catlass 为这类小 GEMM 专门设计了 Tile 模板,避免在小矩阵上跑大 Tile 导致的搬运浪费。

实测中,catlass 的模板化 GEMM 比手写固定 Kernel 在各种形状上的平均性能差距在 5% 以内,但在开发效率和代码维护上有数量级的优势。

CANN catlass 算子模板仓库

ops-blas 线性代数算子库


Tile 配置的艺术

catlass 的模板化核心在 Tile 参数的配置。不同的 GEMM 形状对 Tile 的要求完全不同。

复制代码
大 GEMM(M=4096, K=4096, N=11008):
  Tile(M=128, N=128, K=64) --- 让 Cube Unit 保持满载
  循环次数:M=32, N=86, K=64 → 总启动 176,128 次 Tile 计算

小 GEMM(M=1, K=4096, N=4096):
  Tile(M=1, N=128, K=4096) --- M 维度不切,N 平铺
  循环次数:M=1, N=32, K=1 → 32 次

不规则 GEMM(M=7, K=1024, N=256):
  Tile(M=7, N=64, K=1024) --- 非 2 的幂次,K 全量
  循环次数:M=1, N=4, K=1 → 4 次

catlass 的模板引擎内置了一套启发式 Tile 选择逻辑:根据 M、N、K 的比值和金丝雀测试数据,自动选择理论最优的 Tile 配置。如果自动选择不是最优,开发者可以通过模板参数覆盖。


Kernel 生成的编译期流程

catlass 的模板展开不是运行时发生的,而是在 ATC 编译模型时通过编译器完成:

复制代码
模型 ONNX → ATC 解析 GEMM 形状
  → catlass 模板实例化(编译期)
  → 模板参数推导 Tile 配置
  → 生成 /tmp/.catlass_kernels/<shape_hash>.o
  → 链接进 OM 模型
  → Runtime 加载时已经是编译好的二进制

每个 GEMM 形状的 Kernel 编译一次后缓存到文件系统。下次遇到相同形状的 GEMM 直接复用。这个缓存机制在模型部署时非常关键------模型加载时不需要重新编译 Kernel,OM 文件内已经包含了所有 GEMM Kernel 的二进制。

继续学习

catlass 解决的是"GEMM Kernel 怎么写"的问题。上一层 ops-blas 解决的是"GEMM 形状太多怎么管理"的问题。理解了 catlass 的模板化思想后,顺着 ops-blas 的仓库可以看到 CANN 是如何为数百种 GEMM 形状提供统一优化接口的。


模板化 vs 手写 Kernel 的取舍

模板化的代价是灵活性受限。手写 Kernel 可以为特定形状做极致优化------比如在 M=1 的场景省略掉 M 维度的所有循环。catlass 的通用模板无法覆盖所有极端形状的特殊优化,但是覆盖了 95% 的常用场景。剩余 5% 的极端形状可以通过手写 Ascend C Kernel 来补充,catlass 提供了与手写 Kernel 的互操作接口。

参考仓库

catlass 算子模板库
ops-blas 线性代数算子库

相关推荐
桜吹雪5 小时前
所有智能体架构(2):ReAct(推理 + 行动)
人工智能
埃菲尔铁塔_CV算法5 小时前
YOLO11 与传统纹理特征融合目标检测 完整实现教程
人工智能·神经网络·yolo·计算机视觉
快乐的哈士奇5 小时前
LangFuse 自托管实战:选型理由、Docker 部署与常用配置全解析
运维·人工智能·docker·容器
数智化管理手记5 小时前
精益生产3步实操,让现场从混乱变标杆
大数据·运维·网络·人工智能·精益工程
百度Geek说5 小时前
PRD → Goal → After-Goal:AI 主导全流程研发实践
人工智能
山西茄子5 小时前
DeepStream9.0 在DeepStream中使用VLM
人工智能
小小测试开发5 小时前
AI 水印攻防战:OpenAI 引入 SynthID 认证,GitHub 同步出现去水印工具
人工智能·github
larance5 小时前
[菜鸟教程] 机器学习教程第六课-机器学习基础术语
人工智能·机器学习
多年小白5 小时前
2026年5月半导体板块深度分析
大数据·人工智能·科技·区块链