Vulkan Cooperative Matrix 简明教程

1. 引言

本教程介绍 Vulkan 中的 Cooperative Matrix(协作矩阵)技术,包括工作原理、Vulkan 扩展 API、Shader 函数详解,以及在 NCNN 深度学习框架中的实际应用。

在阅读本文前您需要一定的GPU编程知识,以及Vulkan Compute基本知识,这里推荐白牛大佬的:如何火急火燎地上手 Vulkan Compute - 知乎

本文参考了以下资料:

vulkan tensorcore 笔记 - 知乎

Tencent/ncnn: ncnn is a high-performance neural network inference framework optimized for the mobile platform

为以上内容作者nihui大佬献上赞歌~

本文作者也只是初学者,如有不到之处欢迎批评指正。

本文配套代码:futz12/cm_gemm_example

2. 矩阵分块基础

在深入学习 Cooperative Matrix 之前,理解**矩阵分块(Matrix Tiling/Blocking)**是必要的。矩阵分块是高性能矩阵计算的核心技术,也是 Cooperative Matrix 的工作基础。

2.1 为什么需要矩阵分块

大矩阵计算的挑战

考虑两个大矩阵相乘:\(C = A \times B\),其中 \(A\) 是 \(4096 \times 4096\),\(B\) 是 \(4096 \times 4096\)。

计算量

\\\text{FLOP} = 2 \\times 4096 \\times 4096 \\times 4096 \\approx 137 \\text{ GFLOP} \\

内存访问量(朴素实现):

\\\text{读取 A} = 4096 \\times 4096 \\times 4 \\text{ bytes} \\times 4096 = 256 \\text{ GB} \\

\\\text{读取 B} = 4096 \\times 4096 \\times 4 \\text{ bytes} \\times 4096 = 256 \\text{ GB} \\

\\\text{写入 C} = 4096 \\times 4096 \\times 4 \\text{ bytes} = 64 \\text{ MB} \\

朴素实现中,每个输出元素 \(C_{ij}\) 需要读取 A 的一整行和 B 的一整列,导致大量重复的内存访问。

分块解决方案

将大矩阵划分为小块,每次只处理一个小块:

\A = \\begin{bmatrix} A_{00} \& A_{01} \& \\cdots \& A_{0,k-1} \\\\ A_{10} \& A_{11} \& \\cdots \& A_{1,k-1} \\\\ \\vdots \& \\vdots \& \\ddots \& \\vdots \\\\ A_{m-1,0} \& A_{m-1,1} \& \\cdots \& A_{m-1,k-1} \\end{bmatrix}, \\quad B = \\begin{bmatrix} B_{00} \& B_{01} \& \\cdots \& B_{0,n-1} \\\\ B_{10} \& B_{11} \& \\cdots \& B_{1,n-1} \\\\ \\vdots \& \\vdots \& \\ddots \& \\vdots \\\\ B_{k-1,0} \& B_{k-1,1} \& \\cdots \& B_{k-1,n-1} \\end{bmatrix} \\

分块矩阵乘法

\C_{ij} = \\sum_{l=0}\^{k-1} A_{il} \\times B_{lj} \\

分块示意 (以 \(128 \times 128\) 矩阵,\(16 \times 16\) 分块为例):

\A = \\begin{bmatrix} \\boxed{A_{00}} \& \\boxed{A_{01}} \& \\boxed{A_{02}} \& \\boxed{A_{03}} \\\\ \\boxed{A_{10}} \& \\boxed{A_{11}} \& \\boxed{A_{12}} \& \\boxed{A_{13}} \\\\ \\boxed{A_{20}} \& \\boxed{A_{21}} \& \\boxed{A_{22}} \& \\boxed{A_{23}} \\\\ \\boxed{A_{30}} \& \\boxed{A_{31}} \& \\boxed{A_{32}} \& \\boxed{A_{33}} \\end{bmatrix}_{128 \\times 128} \\

每个 \(\boxed{A_{ij}}\) 是 \(32 \times 32\) 的子矩阵。

2.2 分块的优势

数据复用

分块矩阵乘法的核心优势是数据复用:在计算单个输出块时,A 的行数据和 B 的列数据被多次复用。

以计算 \(C_{00}\) 为例 (假设 \(C_{00}\) 是 \(32 \times 32\) 的块,共 1024 个输出元素):

\C_{00} = A_{00} \\times B_{00} + A_{01} \\times B_{10} + A_{02} \\times B_{20} + A_{03} \\times B_{30} \\

数据需求分析

计算 \(32 \times 32 = 1024\) 个输出元素需要:

\\\begin{array}{c\|c\|c\|c} \\text{数据} \& \\text{范围} \& \\text{元素数量} \& \\text{说明} \\\\ \\hline \\text{A 的行} \& \\text{第 0 到 31 行} \& 32 \\times 128 = 4096 \& \\text{每行 128 个元素} \\\\ \\text{B 的列} \& \\text{第 0 到 31 列} \& 128 \\times 32 = 4096 \& \\text{每列 128 个元素} \\end{array} \\

关键观察

  • A 的每个元素 \(a_{ik}\) 会参与该行所有 32 个输出元素的计算:\(c_{i0}, c_{i1}, \ldots, c_{i31}\)
  • B 的每个元素 \(b_{kj}\) 会参与该列所有 32 个输出元素的计算:\(c_{0j}, c_{1j}, \ldots, c_{31j}\)
  • 即每个 A 元素被复用 32 次,每个 B 元素也被复用 32 次

内存访问对比(计算 1024 个输出元素):

\\\begin{array}{c\|c\|c\|c\|c} \\text{方式} \& \\text{A 的访问次数} \& \\text{B 的访问次数} \& \\text{总访问次数} \& \\text{说明} \\\\ \\hline \\text{不分块} \& 1024 \\times 128 = 131072 \& 1024 \\times 128 = 131072 \& 262144 \& \\text{每个输出元素独立读取 A 行和 B 列} \\\\ \\text{分块} \& 4096 \& 4096 \& 8192 \& \\text{A 和 B 的每个元素只加载 1 次到 Shared Memory} \\end{array} \\

复用倍数

\\\text{内存访问减少倍数} = \\frac{262144}{8192} = 32 \\text{ 倍} \\

分块计算通过将 A 和 B 的数据缓存到 Shared Memory,使得每个数据元素只需加载一次,就能被 32 个输出元素复用,内存访问量减少 32 倍

缓存友好

内存层次结构

\\\begin{array}{c\|c\|c} \\text{层次} \& \\text{容量} \& \\text{延迟} \\\\ \\hline \\text{寄存器} \& \\sim 256 \\text{ KB} \& 1 \\text{ cycle} \\\\ \\text{L1 缓存} \& \\sim 32 \\text{ KB} \& \\sim 4 \\text{ cycles} \\\\ \\text{L2 缓存} \& \\sim 2 \\text{ MB} \& \\sim 40 \\text{ cycles} \\\\ \\text{全局内存} \& \\sim 8 \\text{ GB} \& \\sim 400 \\text{ cycles} \\end{array} \\

分块策略:选择块大小使数据能放入高速缓存

\\\text{块大小} \\leq \\text{缓存大小} \\

例如:L1 缓存 32KB,FP32 数据:

\\\text{最大块元素数} = \\frac{32 \\times 1024}{4} = 8192 \\text{ 元素} \\

\\\text{可选块大小} = 64 \\times 64 = 4096 \\text{ 元素(安全)} \\

2.3 分块矩阵乘法示例

具体数值示例

计算 \(C = A \times B\),其中 \(A\) 是 \(4 \times 4\),\(B\) 是 \(4 \times 4\),使用 \(2 \times 2\) 分块:

矩阵 A 分块

\A = \\begin{bmatrix} \\boxed{\\begin{matrix} 1 \& 2 \\\\ 3 \& 4 \\end{matrix}} \& \\boxed{\\begin{matrix} 5 \& 6 \\\\ 7 \& 8 \\end{matrix}} \\\\ \\boxed{\\begin{matrix} 9 \& 10 \\\\ 11 \& 12 \\end{matrix}} \& \\boxed{\\begin{matrix} 13 \& 14 \\\\ 15 \& 16 \\end{matrix}} \\end{bmatrix} = \\begin{bmatrix} A_{00} \& A_{01} \\\\ A_{10} \& A_{11} \\end{bmatrix} \\

矩阵 B 分块

\B = \\begin{bmatrix} \\boxed{\\begin{matrix} 1 \& 0 \\\\ 0 \& 1 \\end{matrix}} \& \\boxed{\\begin{matrix} 2 \& 0 \\\\ 0 \& 2 \\end{matrix}} \\\\ \\boxed{\\begin{matrix} 3 \& 0 \\\\ 0 \& 3 \\end{matrix}} \& \\boxed{\\begin{matrix} 4 \& 0 \\\\ 0 \& 4 \\end{matrix}} \\end{bmatrix} = \\begin{bmatrix} B_{00} \& B_{01} \\\\ B_{10} \& B_{11} \\end{bmatrix} \\

计算 \(C_{00}\)

\C_{00} = A_{00} \\times B_{00} + A_{01} \\times B_{10} \\

\= \\begin{bmatrix} 1 \& 2 \\\\ 3 \& 4 \\end{bmatrix} \\times \\begin{bmatrix} 1 \& 0 \\\\ 0 \& 1 \\end{bmatrix} + \\begin{bmatrix} 5 \& 6 \\\\ 7 \& 8 \\end{bmatrix} \\times \\begin{bmatrix} 3 \& 0 \\\\ 0 \& 3 \\end{bmatrix} \\

\= \\begin{bmatrix} 1 \& 2 \\\\ 3 \& 4 \\end{bmatrix} + \\begin{bmatrix} 15 \& 18 \\\\ 21 \& 24 \\end{bmatrix} = \\begin{bmatrix} 16 \& 20 \\\\ 24 \& 28 \\end{bmatrix} \\

分块计算流程图

\\\begin{array}{c} \\text{计算 } C_{00} \\\\ \\downarrow \\\\ \\begin{array}{\|c\|c\|} \\hline \\text{Step 1} \& \\text{加载 } A_{00}, B_{00} \\\\ \\hline \\text{Step 2} \& C_{00}\^{(1)} = A_{00} \\times B_{00} \\\\ \\hline \\text{Step 3} \& \\text{加载 } A_{01}, B_{10} \\\\ \\hline \\text{Step 4} \& C_{00}\^{(2)} = A_{01} \\times B_{10} \\\\ \\hline \\text{Step 5} \& C_{00} = C_{00}\^{(1)} + C_{00}\^{(2)} \\\\ \\hline \\end{array} \\end{array} \\

2.4 GPU 上的分块实现

线程块映射

在 GPU 上,每个线程块负责计算一个或多个输出块:

\\\begin{array}{c\|c} \\text{GPU 概念} \& \\text{矩阵分块对应} \\\\ \\hline \\text{Workgroup / Thread Block} \& \\text{计算一个或多个输出块 } C_{ij} \\\\ \\text{Subgroup / Warp} \& \\text{协作计算一个子块} \\\\ \\text{Thread} \& \\text{持有子块中的部分元素} \\end{array} \\

映射示意 (\(4 \times 4\) 输出块,\(2 \times 2\) 子块):

\C = \\begin{bmatrix} \\color{red}{\\boxed{C_{00}}} \& \\color{blue}{\\boxed{C_{01}}} \\\\ \\color{green}{\\boxed{C_{10}}} \& \\color{orange}{\\boxed{C_{11}}} \\end{bmatrix} \\

  • 线程块 0(红色):计算 \(C_{00}\)
  • 线程块 1(蓝色):计算 \(C_{01}\)
  • 线程块 2(绿色):计算 \(C_{10}\)
  • 线程块 3(橙色):计算 \(C_{11}\)

Shared Memory 使用

每个线程块使用 Shared Memory 缓存当前计算所需的数据块:

glsl 复制代码
// Shared Memory 声明
shared float A_tile[TILE_M][TILE_K];
shared float B_tile[TILE_K][TILE_N];

// 分块计算
for (int k = 0; k < K; k += TILE_K) {
    // 协作加载 A 和 B 的当前块到 Shared Memory
    A_tile[ty][tx] = A[row * K + k + tx];
    B_tile[ty][tx] = B[(k + ty) * N + col];
    
    barrier();  // 同步
    
    // 计算当前块的贡献
    for (int i = 0; i < TILE_K; i++) {
        sum += A_tile[ty][i] * B_tile[i][tx];
    }
    
    barrier();  // 同步
}

2.5 分块大小与硬件的关系

分块大小的选择不是任意的,而是由硬件特性决定的。

Tensor Core 的固定尺寸

现代 GPU 配备了专门的矩阵计算单元(NVIDIA 称为 Tensor Core,Intel 称为 XMX),这些硬件单元在单周期内完成固定尺寸的矩阵乘加运算:

\\\text{矩阵计算单元操作} = A_{M \\times K} \\times B_{K \\times N} + C_{M \\times N} \\

常见矩阵计算单元尺寸

\\\begin{array}{c\|c\|c\|c\|c} M \\times N \\times K \& \\text{A 尺寸} \& \\text{B 尺寸} \& \\text{C/D 尺寸} \& \\text{硬件} \\\\ \\hline 16 \\times 8 \\times 16 \& 16\\times16 \& 16\\times8 \& 16\\times8 \& \\text{NVIDIA Volta/Turing/Ampere} \\\\ 16 \\times 16 \\times 16 \& 16\\times16 \& 16\\times16 \& 16\\times16 \& \\text{NVIDIA Ampere} \\\\ 8 \\times 8 \\times 16 \& 8\\times16 \& 16\\times8 \& 8\\times8 \& \\text{Intel XMX} \\end{array} \\

分块大小的约束

由于矩阵计算单元的尺寸是固定的,分块大小必须满足:

\\\text{分块大小} = n \\times \\text{硬件单元尺寸} \\

例如,如果硬件支持 16×8×16 的 Tensor Core 操作,那么分块大小应该是 16、8、16 的整数倍。

分块大小选择的核心原则:匹配硬件矩阵计算单元的固定尺寸。这为后续 Cooperative Matrix 的设计奠定了基础------它封装了这些硬件细节,让开发者无需手动处理。

2.6 从分块到 Cooperative Matrix

理解矩阵分块后,Cooperative Matrix 的概念就自然清晰了:

\\\begin{array}{c\|c\|c} \\text{概念} \& \\text{传统分块} \& \\text{Cooperative Matrix} \\\\ \\hline \\text{分块单位} \& \\text{手动管理} \& \\text{固定尺寸 (16×8 等)} \\\\ \\text{数据存储} \& \\text{Shared Memory} \& \\text{线程寄存器(分布式)} \\\\ \\text{计算单元} \& \\text{CUDA Core} \& \\text{Tensor Core} \\\\ \\text{线程协作} \& \\text{手动同步} \& \\text{Subgroup 隐式同步} \\\\ \\text{编程模型} \& \\text{显式管理} \& \\text{声明式 API} \\end{array} \\

Cooperative Matrix 本质上是将矩阵分块技术与 Tensor Core 硬件结合,通过标准化的 API 让开发者无需关心底层细节。

3. Cooperative Matrix 原理

3.1 什么是 Cooperative Matrix

Cooperative Matrix(协作矩阵) 是 Vulkan 提供的一种特殊矩阵类型,其存储和计算分布在某个作用域(通常是 Subgroup)内的所有调用之间。这些调用协同工作,高效地执行矩阵乘法运算。

Cooperative Matrix 的核心特点:

  • 分布式存储:矩阵数据分布在 Subgroup 的所有线程中,每个线程只持有部分数据
  • 协同计算:所有线程共同参与矩阵运算,充分利用 Tensor Core 硬件加速
  • 中等尺寸:矩阵尺寸通常为 8×8、16×8、16×16 等,适合 Tensor Core 的计算单元
  • 透明优化:矩阵乘加操作由硬件(Tensor Core)执行,开发者无需手动管理线程间数据分配

3.1.1 Cooperative Matrix 的意义

问题背景:传统 GPU 矩阵乘法的困境

在传统 GPU 编程中,矩阵乘法的实现面临以下挑战:

  1. 内存访问效率低:每个线程独立计算输出元素,导致大量重复的内存访问
  2. 无法利用专用硬件:现代 GPU 配备了 Tensor Core 等矩阵计算专用单元,但传统编程模型无法直接使用
  3. 优化复杂度高:需要精心设计 tiling、共享内存管理、bank conflict 避免等,代码复杂且难以移植

Cooperative Matrix 的解决方案

Cooperative Matrix 通过以下方式解决这些问题:

\\\begin{array}{c\|c\|c} \\text{问题} \& \\text{传统方法} \& \\text{Cooperative Matrix} \\\\ \\hline \\text{内存访问} \& \\text{每线程独立访问,重复读取} \& \\text{协作加载,数据复用} \\\\ \\text{硬件利用} \& \\text{仅使用 CUDA/Shader Core} \& \\text{直接使用 Tensor Core} \\\\ \\text{编程复杂度} \& \\text{手动 tiling、共享内存} \& \\text{声明式编程,驱动优化} \\\\ \\text{可移植性} \& \\text{需针对不同 GPU 优化} \& \\text{驱动处理硬件差异} \\end{array} \\

性能提升:相比传统实现,Cooperative Matrix 可获得显著性能提升。根据 NVIDIA 的分析,Tensor Core 矩阵乘法可在 1 个周期内完成,而传统 FFMA 需要多个周期,具体加速比取决于矩阵尺寸和内存访问模式。

3.1.2 矩阵乘法的数学定义

设矩阵乘法运算 \(D = A \times B + C\),其中:

\A \\in \\mathbb{R}\^{M \\times K}, \\quad B \\in \\mathbb{R}\^{K \\times N}, \\quad C, D \\in \\mathbb{R}\^{M \\times N} \\

元素级计算:

\D_{ij} = \\sum_{k=0}\^{K-1} A_{ik} \\cdot B_{kj} + C_{ij}, \\quad \\forall i \\in \[0, M), j \\in \[0, N) \\

计算复杂度分析

\\\text{乘法次数} = M \\times N \\times K \\

\\\text{加法次数} = M \\times N \\times K \\

\\\text{总 FLOP} = 2 \\times M \\times N \\times K \\

3.2 Tensor Core 加速原理

现代 GPU(如 NVIDIA RTX 系列、Intel Arc 系列)配备了专门的矩阵计算单元,称为 Tensor Core (NVIDIA)或 XMX(Intel)。这些硬件单元能够在单个时钟周期内完成小矩阵的乘加运算。

3.2.1 Tensor Core 的工作方式

Tensor Core 在单周期内完成小矩阵乘加运算。典型的 16×8×16 运算:

\\\begin{bmatrix} d_{0,0} \& d_{0,1} \& \\cdots \& d_{0,7} \\\\ d_{1,0} \& d_{1,1} \& \\cdots \& d_{1,7} \\\\ \\vdots \& \\vdots \& \\ddots \& \\vdots \\\\ d_{15,0} \& d_{15,1} \& \\cdots \& d_{15,7} \\end{bmatrix} = \\begin{bmatrix} a_{0,0} \& a_{0,1} \& \\cdots \& a_{0,15} \\\\ a_{1,0} \& a_{1,1} \& \\cdots \& a_{1,15} \\\\ \\vdots \& \\vdots \& \\ddots \& \\vdots \\\\ a_{15,0} \& a_{15,1} \& \\cdots \& a_{15,15} \\end{bmatrix} \\times \\begin{bmatrix} b_{0,0} \& b_{0,1} \& \\cdots \& b_{0,7} \\\\ b_{1,0} \& b_{1,1} \& \\cdots \& b_{1,7} \\\\ \\vdots \& \\vdots \& \\ddots \& \\vdots \\\\ b_{15,0} \& b_{15,1} \& \\cdots \& b_{15,7} \\end{bmatrix} + \\begin{bmatrix} c_{0,0} \& c_{0,1} \& \\cdots \& c_{0,7} \\\\ c_{1,0} \& c_{1,1} \& \\cdots \& c_{1,7} \\\\ \\vdots \& \\vdots \& \\ddots \& \\vdots \\\\ c_{15,0} \& c_{15,1} \& \\cdots \& c_{15,7} \\end{bmatrix} \\

计算复杂度分析

  • 乘法次数:\(M \times N \times K = 16 \times 8 \times 16 = 2048\) 次
  • 加法次数:\(M \times N \times K = 2048\) 次
  • 吞吐量:4096 FLOP/周期(Ampere 架构数据)

注意:Tensor Core 的"1周期"指的是吞吐量而非延迟。实际执行中,Tensor Core 操作有流水线延迟(通常 4-16 周期),但由于流水线设计,可以每个周期发射新的 Tensor Core 指令。

3.2.2 支持的矩阵尺寸

\\\begin{array}{c\|c\|c\|c\|c} M \\times N \\times K \& \\text{A 类型} \& \\text{B 类型} \& \\text{C/D 类型} \& \\text{说明} \\\\ \\hline 16 \\times 8 \\times 16 \& \\text{FP16} \& \\text{FP16} \& \\text{FP16/FP32} \& \\text{最常用,NVIDIA GPU} \\\\ 16 \\times 8 \\times 8 \& \\text{FP16} \& \\text{FP16} \& \\text{FP16/FP32} \& \\text{较小尺寸} \\\\ 8 \\times 8 \\times 16 \& \\text{FP16} \& \\text{FP16} \& \\text{FP16/FP32} \& \\text{通用支持} \\\\ 16 \\times 16 \\times 16 \& \\text{FP16} \& \\text{FP16} \& \\text{FP16/FP32} \& \\text{较大尺寸,高性能} \\\\ 16 \\times 8 \\times 8 \& \\text{BF16} \& \\text{BF16} \& \\text{FP32} \& \\text{NVIDIA Ampere+} \\end{array} \\

3.2.3 Tensor Core 性能数据

以下是 NVIDIA Ampere 架构的典型延迟数据(来源:NVIDIA 官方分析):

\\\begin{array}{c\|c\|c} \\text{操作类型} \& \\text{周期数} \& \\text{说明} \\\\ \\hline \\text{全局内存访问} \& \\sim 380 \\text{ 周期} \& \\text{访问 HBM/GDDR 显存} \\\\ \\text{L2 缓存访问} \& \\sim 200 \\text{ 周期} \& \\text{片上二级缓存} \\\\ \\text{L1 缓存/共享内存} \& \\sim 34 \\text{ 周期} \& \\text{每个 SM 的本地存储} \\\\ \\text{FFMA(标量乘加)} \& 4 \\text{ 周期} \& \\text{CUDA Core 执行} \\\\ \\text{Tensor Core 操作} \& 1 \\text{ 周期吞吐} \& \\text{流水线延迟 4-16 周期} \\end{array} \\

性能对比示例(32×32 矩阵乘法):

\\\begin{array}{c\|c\|c} \\text{实现方式} \& \\text{周期数} \& \\text{计算方式} \\\\ \\hline \\text{传统 CUDA Core} \& \\sim 504 \& \\text{全局内存 + 8次共享内存 + 8次FFMA} \\\\ \\text{Tensor Core} \& \\sim 235 \& \\text{全局内存 + 1次共享内存 + 1次Tensor Core} \\end{array} \\

\\\text{加速比} = \\frac{504}{235} \\approx 2.1\\times \\

实际加速比取决于矩阵尺寸、内存访问模式和 GPU 架构。大矩阵(如 GEMM)可获得更高加速比,因为 Tensor Core 的优势在计算密集型操作中更明显。

3.3 与传统实现的对比

3.3.1 传统标量实现

每个线程独立计算一个或多个输出元素:

\D_{ij} = \\sum_{k=0}\^{K-1} A_{ik} \\cdot B_{kj} + C_{ij} \\

线程 \(t\) 计算的输出元素集合:\(\{(i, j) \mid t = f(i, j)\}\)

传统线程计算示意(以 4×4 输出矩阵,2×2 微块为例):

\D = \\begin{bmatrix} \\color{red}{d_{0,0}} \& \\color{red}{d_{0,1}} \& \\color{blue}{d_{0,2}} \& \\color{blue}{d_{0,3}} \\\\ \\color{red}{d_{1,0}} \& \\color{red}{d_{1,1}} \& \\color{blue}{d_{1,2}} \& \\color{blue}{d_{1,3}} \\\\ \\color{green}{d_{2,0}} \& \\color{green}{d_{2,1}} \& \\color{orange}{d_{2,2}} \& \\color{orange}{d_{2,3}} \\\\ \\color{green}{d_{3,0}} \& \\color{green}{d_{3,1}} \& \\color{orange}{d_{3,2}} \& \\color{orange}{d_{3,3}} \\end{bmatrix} \\

不同颜色代表不同线程计算的 2×2 微块,每个线程独立完成:

  • 线程 0(红色):\(d_{0,0}, d_{0,1}, d_{1,0}, d_{1,1}\)(左上角 2×2 块)
  • 线程 1(蓝色):\(d_{0,2}, d_{0,3}, d_{1,2}, d_{1,3}\)(右上角 2×2 块)
  • 线程 2(绿色):\(d_{2,0}, d_{2,1}, d_{3,0}, d_{3,1}\)(左下角 2×2 块)
  • 线程 3(橙色):\(d_{2,2}, d_{2,3}, d_{3,2}, d_{3,3}\)(右下角 2×2 块)

传统实现的问题

\\\begin{array}{c\|c} \\text{问题} \& \\text{具体表现} \\\\ \\hline \\text{内存访问效率低} \& \\text{每个线程独立读取 A 的行和 B 的列,大量重复访问} \\\\ \\text{无法利用 Tensor Core} \& \\text{CUDA Core 一次只能执行一次乘加} \\\\ \\text{计算密度低} \& \\text{内存带宽成为瓶颈,计算单元利用率低} \\\\ \\text{优化复杂} \& \\text{需要手动实现 tiling、共享内存、避免 bank conflict} \\end{array} \\

3.3.2 Cooperative Matrix 实现

Subgroup 内所有线程协作完成矩阵块计算:

\\\mathbf{D}_{M \\times N} = \\mathbf{A}_{M \\times K} \\times \\mathbf{B}_{K \\times N} + \\mathbf{C}_{M \\times N} \\

每个线程持有矩阵的部分元素,协同完成计算。

Cooperative Matrix 协作计算示意(以 16×8 输出矩阵为例):

\D = \\begin{bmatrix} d_{0,0} \& d_{0,1} \& d_{0,2} \& d_{0,3} \& d_{0,4} \& d_{0,5} \& d_{0,6} \& d_{0,7} \\\\ d_{1,0} \& d_{1,1} \& d_{1,2} \& d_{1,3} \& d_{1,4} \& d_{1,5} \& d_{1,6} \& d_{1,7} \\\\ \\vdots \& \\vdots \& \\vdots \& \\vdots \& \\vdots \& \\vdots \& \\vdots \& \\vdots \\\\ d_{15,0} \& d_{15,1} \& d_{15,2} \& d_{15,3} \& d_{15,4} \& d_{15,5} \& d_{15,6} \& d_{15,7} \\end{bmatrix}_{16 \\times 8} \\

所有 32 线程协作计算此矩阵块

数据分布(每个线程持有 4 个元素):

\\\begin{array}{c\|c} \\text{线程} \& \\text{持有元素} \\\\ \\hline T_0 \& d_{0,0}, d_{0,1}, d_{0,2}, d_{0,3} \\\\ T_1 \& d_{0,4}, d_{0,5}, d_{0,6}, d_{0,7} \\\\ \\vdots \& \\vdots \\\\ T_{31} \& d_{12,4}, d_{12,5}, d_{12,6}, d_{12,7} \\end{array} \\

Cooperative Matrix 的优势

\\\begin{array}{c\|c} \\text{优势} \& \\text{说明} \\\\ \\hline \\text{内存合并访问} \& \\text{所有线程协作加载,内存访问模式优化} \\\\ \\text{Tensor Core 加速} \& \\text{直接利用专用矩阵计算单元} \\\\ \\text{高计算密度} \& \\text{单次操作完成大量计算} \\\\ \\text{代码简洁} \& \\text{声明式编程,无需手动优化} \\end{array} \\

3.3.3 性能对比

\\\begin{array}{c\|c\|c} \\text{特性} \& \\text{传统实现} \& \\text{Cooperative Matrix} \\\\ \\hline \\text{计算单元} \& \\text{CUDA Core / Shader Core} \& \\text{Tensor Core / XMX} \\\\ \\text{线程模型} \& \\text{每个线程独立计算输出元素} \& \\text{线程协作计算矩阵块} \\\\ \\text{内存访问} \& \\text{需精心设计避免 bank conflict} \& \\text{数据布局由硬件处理,但仍需合理设计} \\\\ \\text{代码复杂度} \& \\text{高,需手动 tiling 和共享内存管理} \& \\text{低,声明式编程} \\\\ \\text{性能可移植性} \& \\text{低,需针对不同 GPU 优化} \& \\text{高,驱动处理硬件差异} \\\\ \\text{计算吞吐量} \& \\text{基准} \& \\text{显著提升(取决于矩阵尺寸)} \\end{array} \\

3.4 分布式存储模型

3.4.1 数据分布原理

Cooperative Matrix 的核心特性是数据分布在 Subgroup 的所有线程中

设 Subgroup 大小为 \(S\),协作矩阵尺寸为 \(M \times N\)。每个线程 \(t \in [0, S)\) 持有的元素数量:

\\\text{elements}_t = \\frac{M \\times N}{S} \\

典型配置

\\\begin{array}{c\|c\|c\|c\|c} \\text{矩阵尺寸} \& \\text{元素总数} \& \\text{Subgroup 大小} \& \\text{每线程元素数} \& \\text{典型 GPU} \\\\ \\hline 16 \\times 8 \& 128 \& 32 \& 4 \& \\text{NVIDIA} \\\\ 16 \\times 16 \& 256 \& 32 \& 8 \& \\text{NVIDIA} \\\\ 8 \\times 8 \& 64 \& 16 \& 4 \& \\text{Intel} \\\\ 8 \\times 16 \& 128 \& 64 \& 2 \& \\text{AMD} \\end{array} \\

2.4.2 数据分布示意

以 16×8 矩阵、32 线程 Subgroup 为例:

\\\text{总元素数} = 16 \\times 8 = 128 \\

\\\text{每线程元素数} = \\frac{128}{32} = 4 \\

矩阵元素分布示意

\\\begin{bmatrix} \\color{red}{e_0} \& \\color{red}{e_1} \& \\color{red}{e_2} \& \\color{red}{e_3} \& \\color{blue}{e_4} \& \\color{blue}{e_5} \& \\color{blue}{e_6} \& \\color{blue}{e_7} \\\\ \\color{green}{e_8} \& \\color{green}{e_9} \& \\color{green}{e_{10}} \& \\color{green}{e_{11}} \& \\color{orange}{e_{12}} \& \\color{orange}{e_{13}} \& \\color{orange}{e_{14}} \& \\color{orange}{e_{15}} \\\\ \\vdots \& \\vdots \& \\vdots \& \\vdots \& \\vdots \& \\vdots \& \\vdots \& \\vdots \\\\ \\color{purple}{e_{120}} \& \\color{purple}{e_{121}} \& \\color{purple}{e_{122}} \& \\color{purple}{e_{123}} \& \\color{cyan}{e_{124}} \& \\color{cyan}{e_{125}} \& \\color{cyan}{e_{126}} \& \\color{cyan}{e_{127}} \\end{bmatrix} \\

线程与元素对应关系

\\\begin{array}{c\|c\|c} \\text{线程 ID} \& \\text{颜色} \& \\text{持有元素} \\\\ \\hline T_0 \& \\color{red}{\\text{红}} \& e_0, e_1, e_2, e_3 \\\\ T_1 \& \\color{blue}{\\text{蓝}} \& e_4, e_5, e_6, e_7 \\\\ T_2 \& \\color{green}{\\text{绿}} \& e_8, e_9, e_{10}, e_{11} \\\\ T_3 \& \\color{orange}{\\text{橙}} \& e_{12}, e_{13}, e_{14}, e_{15} \\\\ \\vdots \& \\vdots \& \\vdots \\\\ T_{30} \& \\color{purple}{\\text{紫}} \& e_{120}, e_{121}, e_{122}, e_{123} \\\\ T_{31} \& \\color{cyan}{\\text{青}} \& e_{124}, e_{125}, e_{126}, e_{127} \\end{array} \\

元素分配公式

\\\begin{aligned} \\text{Thread 0} \&\\rightarrow \\text{elem}\[0, 1, 2, 3 \\ \text{Thread 1} &\rightarrow \text{elem}4, 5, 6, 7 \\ &\vdots \\ \text{Thread 31} &\rightarrow \text{elem}124, 125, 126, 127 \end{aligned} \]

Vulkan 规范没有规定具体的元素分配方式,这是由硬件实现决定的。程序员不需要关心具体分配,硬件会自动处理。

2.4.3 尺寸匹配要求

矩阵尺寸必须与 Subgroup 大小匹配:\((M \times N) \mod S = 0\)

硬件会验证矩阵元素总数能被 Subgroup 大小整除。不匹配的组合会导致运行时错误或未定义行为。

3.5 计算过程详解

3.5.1 矩阵乘加的执行流程

Cooperative Matrix 执行 \(D = A \times B + C\) 的过程:

步骤 1:数据加载

所有线程调用 coopMatLoad,从全局内存加载矩阵数据:

\\\text{coopMatLoad}(M, \\text{buffer}, \\text{offset}, \\text{stride}, \\text{layout}) \\

硬件根据 Tensor Core 的数据布局要求将数据分配到各线程的寄存器中。

数据加载示意

\\\begin{array}{c} \\text{全局内存中的大矩阵 A (M × K)} \\\\\[5pt \left\\begin{array}{ccccc} \\cdot \& \\cdot \& \\cdot \& \\cdot \& \\cdot \\\\ \\cdot \& {\\boxed{\\begin{array}{ccc} a_{0,0} \& a_{0,1} \& \\cdots \\\\ a_{1,0} \& a_{1,1} \& \\cdots \\\\ \\vdots \& \\vdots \& \\ddots \\end{array}}} \& \\cdot \\\\ \\cdot \& \\cdot \& \\cdot \& \\cdot \& \\cdot \\end{array}\\right \\10pt \big\downarrow \\ext{coopMatLoad(offset, stride)} \\5pt \big\downarrow \\ \underbrace{\begin{bmatrix} a_{0,0} & a_{0,1} & \cdots \\ a_{1,0} & a_{1,1} & \cdots \\ \vdots & \vdots & \ddots \end{bmatrix}}{\text{加载的子矩阵块}} \\10pt \big\downarrow \quad \text{硬件按 Tensor Core 布局分配} \\5pt \underbrace{\begin{array}{|c|c|c|c|} \hline T_0 & T_1 & \cdots & T{31} \\ \hline e_0..e_3 & e_4..e_7 & \cdots & e_{124}..e_{127} \\ \hline \end{array}}_{\text{数据分布到各线程寄存器}} \end{array} \]

步骤 2:矩阵乘加

所有线程调用 coopMatMulAdd,Tensor Core 执行计算:

\D = \\text{coopMatMulAdd}(A, B, C) \\

等价于:

\d_{ij} = \\sum_{k=0}\^{K-1} a_{ik} \\cdot b_{kj} + c_{ij} \\

矩阵乘加示意

\\\begin{array}{ccc} \\text{线程寄存器中的数据} \& \\rightarrow \& \\text{Tensor Core 计算} \\\\ \\begin{array}{\|c\|} \\hline T_0: a_0..a_3, b_0..b_3, c_0..c_3 \\\\ T_1: a_4..a_7, b_4..b_7, c_4..c_7 \\\\ \\vdots \\\\ T_{31}: a_{124}..a_{127}, b_{124}..b_{127}, c_{124}..c_{127} \\\\ \\hline \\end{array} \& \\rightarrow \& \\begin{array}{\|c\|} \\hline \\text{Tensor Core} \\\\ D = A \\times B + C \\\\ \\text{流水线执行} \\\\ \\hline \\end{array} \\end{array} \\

步骤 3:结果存储

所有线程调用 coopMatStore,将结果写回全局内存:

\\\text{coopMatStore}(D, \\text{buffer}, \\text{offset}, \\text{stride}, \\text{layout}) \\

结果存储示意

\\\begin{array}{c} \\underbrace{\\begin{array}{\|c\|c\|c\|c\|} \\hline T_0 \& T_1 \& \\cdots \& T_{31} \\\\ \\hline d_0..d_3 \& d_4..d_7 \& \\cdots \& d_{124}..d_{127} \\\\ \\hline \\end{array}}_{\\text{各线程持有结果的一部分}} \\\\\[10pt \big\downarrow \\ \text{coopMatStore(offset, stride)} \\ \big\downarrow \\ \underbrace{\begin{bmatrix} d_{0,0} & d_{0,1} & \cdots \\ d_{1,0} & d_{1,1} & \cdots \\ \vdots & \vdots & \ddots \end{bmatrix}}_{\text{写入的子矩阵块}} \\10pt \big\downarrow \\5pt \text{全局内存中的大矩阵 D (M × N)} \\5pt \left\\begin{array}{ccccc} \\cdot \& \\cdot \& \\cdot \& \\cdot \& \\cdot \\\\ \\cdot \& {\\boxed{\\begin{array}{ccc} d_{0,0} \& d_{0,1} \& \\cdots \\\\ d_{1,0} \& d_{1,1} \& \\cdots \\\\ \\vdots \& \\vdots \& \\ddots \\end{array}}} \& \\cdot \\\\ \\cdot \& \\cdot \& \\cdot \& \\cdot \& \\cdot \\end{array}\\right \end{array} \]

3.5.2 Uniform Control Flow 要求

Subgroup 内的所有线程必须同时调用 CM 函数,否则行为未定义!

\\\forall t \\in \[0, S): \\text{thread}_t \\text{ 调用 } \\texttt{coopMatMulAdd}(A, B, C) \\

为什么需要 Uniform Control Flow?

  1. 矩阵存储分布在所有线程:数据分散在各线程中,需要所有线程参与
  2. 协作计算:Tensor Core 需要所有线程协同工作
  3. 硬件同步:这些函数在硬件层面隐式执行同步操作

正确示例

glsl 复制代码
// ✅ 正确:所有线程都调用
coopmat<float16_t, gl_ScopeSubgroup, 16, 8, gl_MatrixUseA> matA;
coopMatLoad(matA, buffer, offset, stride, layout);  // 所有 32 线程都调用

错误示例

glsl 复制代码
// ❌ 错误:只有部分线程调用
if (gl_SubgroupInvocationID < 16) {
    coopMatLoad(matA, buffer, offset, stride, layout);  // 未定义行为!
}

3.5.3 与普通变量赋值的对比

\\\begin{array}{c\|c\|c} \\text{操作类型} \& \\text{普通变量} \& \\text{Cooperative Matrix} \\\\ \\hline \\text{声明} \& \\texttt{float x;} \& \\texttt{coopmat\<\> m;} \\\\ \& \\text{每个线程有独立的 x} \& \\text{所有线程共享一个逻辑矩阵 m} \\\\ \\text{赋值} \& \\texttt{x = 1.0;} \& \\texttt{m = coopmat\<\>(1.0);} \\\\ \& \\text{每个线程的 x 都变成 1.0} \& \\text{矩阵所有元素都变成 1.0} \\\\ \\text{加载} \& \\texttt{x = buffer\[i;} & \texttt{coopMatLoad(m, buffer, ...);} \\ & \text{每个线程加载一个值} & \text{所有线程协作加载整个矩阵} \\ \text{计算} & \texttt{y = a * b + c;} & \texttt{m = coopMatMulAdd(A, B, C);} \\ & \text{每个线程独立计算} & \text{所有线程协作计算(使用 Tensor Core)} \end{array} \]

4. Vulkan 扩展介绍

4.1 VK_KHR_cooperative_matrix

VK_KHR_cooperative_matrix 是 Khronos 标准化的协作矩阵扩展,提供了跨厂商的标准接口。

扩展启用方式:

glsl 复制代码
// 在 GLSL shader 中启用扩展
#extension GL_KHR_cooperative_matrix : require

扩展特性:

  • 支持查询设备支持的矩阵尺寸和类型组合
  • 支持 Subgroup 和 Workgroup 作用域
  • 支持多种数据类型:FP16、FP32、BF16、INT8、INT16、INT32
  • 提供 Robust Buffer Access 选项

4.2 VK_NV_cooperative_matrix

VK_NV_cooperative_matrix 是 NVIDIA 的厂商扩展,是 KHR 扩展的前身。

\\\begin{array}{c\|c\|c} \\text{特性} \& \\text{VK\\_KHR\\_cooperative\\_matrix} \& \\text{VK\\_NV\\_cooperative\\_matrix} \\\\ \\hline \\text{状态} \& \\text{Khronos 标准} \& \\text{NVIDIA 厂商扩展} \\\\ \\text{矩阵类型} \& \\texttt{coopmat\<\>} \& \\texttt{fcoopmatNV\<\>} \\\\ \\text{Use 参数} \& \\text{显式指定(gl\\_MatrixUseA/B/Accumulator)} \& \\text{隐式(通过尺寸推断)} \\\\ \\text{加载函数} \& \\texttt{coopMatLoad()} \& \\texttt{coopMatLoadNV()} \\\\ \\text{存储函数} \& \\texttt{coopMatStore()} \& \\texttt{coopMatStoreNV()} \\\\ \\text{乘加函数} \& \\texttt{coopMatMulAdd()} \& \\texttt{coopMatMulAddNV()} \\end{array} \\

推荐:优先使用 VK_KHR_cooperative_matrix,除非需要兼容旧版驱动。

4.3 API 查询方法

在使用 Cooperative Matrix 之前,必须查询设备支持的矩阵尺寸和类型组合:

cpp 复制代码
// 查询支持的 cooperative matrix 属性数量
uint32_t propertyCount = 0;
vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(physicalDevice, &propertyCount, nullptr);

// 分配空间并获取属性
std::vector<VkCooperativeMatrixPropertiesKHR> properties(propertyCount);
vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(physicalDevice, &propertyCount, properties.data());

// 遍历支持的矩阵配置
for (const auto& prop : properties) {
    // prop.MSize, prop.NSize, prop.KSize - 矩阵尺寸
    // prop.AType, prop.BType, prop.CType, prop.ResultType - 数据类型
    // prop.scope - 作用域(Subgroup 或 Workgroup)
}

VkCooperativeMatrixPropertiesKHR 结构体

cpp 复制代码
typedef struct VkCooperativeMatrixPropertiesKHR {
    VkStructureType    sType;           // 结构体类型
    void*              pNext;           // 扩展指针
    uint32_t           MSize;           // 矩阵 A 的行数,矩阵 C 的行数
    uint32_t           NSize;           // 矩阵 B 的列数,矩阵 C 的列数
    uint32_t           KSize;           // 矩阵 A 的列数,矩阵 B 的行数
    VkComponentTypeKHR AType;           // 矩阵 A 的元素类型
    VkComponentTypeKHR BType;           // 矩阵 B 的元素类型
    VkComponentTypeKHR CType;           // 矩阵 C 的元素类型
    VkComponentTypeKHR ResultType;      // 结果矩阵的元素类型
    VkBool32           saturatingAccumulation; // 是否饱和累加
    VkScopeKHR         scope;           // 作用域
} VkCooperativeMatrixPropertiesKHR;

VkComponentTypeKHR 支持的类型

\\\begin{array}{c\|c\|c} \\text{枚举值} \& \\text{对应类型} \& \\text{说明} \\\\ \\hline \\texttt{VK\\_COMPONENT\\_TYPE\\_FLOAT16\\_KHR} \& \\text{FP16} \& \\text{半精度浮点数} \\\\ \\texttt{VK\\_COMPONENT\\_TYPE\\_FLOAT32\\_KHR} \& \\text{FP32} \& \\text{单精度浮点数} \\\\ \\texttt{VK\\_COMPONENT\\_TYPE\\_BFLOAT16\\_KHR} \& \\text{BF16} \& \\text{Brain Float16} \\\\ \\texttt{VK\\_COMPONENT\\_TYPE\\_SINT8\\_KHR} \& \\text{int8} \& \\text{8位有符号整数} \\\\ \\texttt{VK\\_COMPONENT\\_TYPE\\_UINT8\\_KHR} \& \\text{uint8} \& \\text{8位无符号整数} \\\\ \\texttt{VK\\_COMPONENT\\_TYPE\\_SINT16\\_KHR} \& \\text{int16} \& \\text{16位有符号整数} \\\\ \\texttt{VK\\_COMPONENT\\_TYPE\\_UINT16\\_KHR} \& \\text{uint16} \& \\text{16位无符号整数} \\\\ \\texttt{VK\\_COMPONENT\\_TYPE\\_SINT32\\_KHR} \& \\text{int32} \& \\text{32位有符号整数} \\\\ \\texttt{VK\\_COMPONENT\\_TYPE\\_UINT32\\_KHR} \& \\text{uint32} \& \\text{32位无符号整数} \\end{array} \\

VkScopeKHR 作用域

\\\begin{array}{c\|c\|c} \\text{枚举值} \& \\text{说明} \& \\text{使用场景} \\\\ \\hline \\texttt{VK\\_SCOPE\\_SUBGROUP\\_KHR} \& \\text{Subgroup 作用域} \& \\text{最常用,矩阵分布在单个 Subgroup 内} \\\\ \\texttt{VK\\_SCOPE\\_WORKGROUP\\_KHR} \& \\text{Workgroup 作用域} \& \\text{较大矩阵,需要跨 Subgroup 协作} \\\\ \\texttt{VK\\_SCOPE\\_DEVICE\\_KHR} \& \\text{设备作用域} \& \\text{跨 Workgroup(较少使用)} \\\\ \\texttt{VK\\_SCOPE\\_QUEUE\\_FAMILY\\_KHR} \& \\text{队列族作用域} \& \\text{跨队列(较少使用)} \\end{array} \\

5. Shader 函数详解

5.1 矩阵类型声明

KHR 扩展矩阵类型

glsl 复制代码
// KHR 扩展的矩阵类型声明
coopmat<ElementType, Scope, Rows, Cols, Use> matrixName;

// 示例:声明一个 16×8 的矩阵 A(用于乘法左操作数)
coopmat<float16_t, gl_ScopeSubgroup, 16, 8, gl_MatrixUseA> matA;

// 示例:声明一个 8×8 的矩阵 B(用于乘法右操作数)
coopmat<float16_t, gl_ScopeSubgroup, 8, 8, gl_MatrixUseB> matB;

// 示例:声明一个 16×8 的累加器矩阵 C
coopmat<float, gl_ScopeSubgroup, 16, 8, gl_MatrixUseAccumulator> matC;

模板参数说明:

\\\begin{array}{c\|c\|c} \\text{参数} \& \\text{类型} \& \\text{说明} \\\\ \\hline \\text{ElementType} \& \\text{类型} \& \\text{矩阵元素类型:float16\\_t, float, bfloat16\\_t, int8\\_t, uint8\\_t, int16\\_t, uint16\\_t, int32\\_t, uint32\\_t} \\\\ \\text{Scope} \& \\text{常量} \& \\text{作用域:gl\\_ScopeSubgroup, gl\\_ScopeWorkgroup} \\\\ \\text{Rows} \& \\text{常量整数} \& \\text{矩阵行数(必须是设备支持的尺寸)} \\\\ \\text{Cols} \& \\text{常量整数} \& \\text{矩阵列数(必须是设备支持的尺寸)} \\\\ \\text{Use} \& \\text{常量} \& \\text{矩阵用途:gl\\_MatrixUseA, gl\\_MatrixUseB, gl\\_MatrixUseAccumulator} \\end{array} \\

NV 扩展矩阵类型

glsl 复制代码
// NV 扩展的矩阵类型声明
fcoopmatNV<BitWidth, Scope, Rows, Cols> matrixName;

// 示例:FP16 矩阵(16位精度)
fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> matA;

// 示例:FP32 矩阵(32位精度,用于累加器)
fcoopmatNV<32, gl_ScopeSubgroup, 16, 8> matC;

BitWidth 精度说明

  • 16 - FP16(半精度浮点)
  • 32 - FP32(单精度浮点)

5.2 Matrix Use 类型

Matrix Use 参数决定了矩阵在乘加运算中的角色:

\\\underbrace{A_{M \\times K}}_{\\text{MatrixA}} \\times \\underbrace{B_{K \\times N}}_{\\text{MatrixB}} + \\underbrace{C_{M \\times N}}_{\\text{Accumulator}} = \\underbrace{D_{M \\times N}}_{\\text{Accumulator}} \\

矩阵乘加示意 (以 \(M=16, K=16, N=8\) 为例):

\\\underbrace{\\begin{bmatrix} a_{0,0} \& \\cdots \& a_{0,15} \\\\ \\vdots \& \\ddots \& \\vdots \\\\ a_{15,0} \& \\cdots \& a_{15,15} \\end{bmatrix}}_{A: 16 \\times 16} \\times \\underbrace{\\begin{bmatrix} b_{0,0} \& \\cdots \& b_{0,7} \\\\ \\vdots \& \\ddots \& \\vdots \\\\ b_{15,0} \& \\cdots \& b_{15,7} \\end{bmatrix}}_{B: 16 \\times 8} + \\underbrace{\\begin{bmatrix} c_{0,0} \& \\cdots \& c_{0,7} \\\\ \\vdots \& \\ddots \& \\vdots \\\\ c_{15,0} \& \\cdots \& c_{15,7} \\end{bmatrix}}_{C: 16 \\times 8} = \\underbrace{\\begin{bmatrix} d_{0,0} \& \\cdots \& d_{0,7} \\\\ \\vdots \& \\ddots \& \\vdots \\\\ d_{15,0} \& \\cdots \& d_{15,7} \\end{bmatrix}}_{D: 16 \\times 8} \\

单个元素计算示意

\d_{ij} = \\sum_{k=0}\^{15} a_{ik} \\cdot b_{kj} + c_{ij} \\

例如 \(d_{0,0}\) 的计算:

\d_{0,0} = a_{0,0} \\cdot b_{0,0} + a_{0,1} \\cdot b_{1,0} + \\cdots + a_{0,15} \\cdot b_{15,0} + c_{0,0} \\

\\\begin{array}{c\|c\|c\|c} \\text{Use 类型} \& \\text{角色} \& \\text{尺寸约束} \& \\text{典型数据类型} \\\\ \\hline \\texttt{gl\\_MatrixUseA} \& \\text{乘法左操作数} \& M \\times K \& \\text{FP16, BF16, INT8} \\\\ \\texttt{gl\\_MatrixUseB} \& \\text{乘法右操作数} \& K \\times N \& \\text{FP16, BF16, INT8} \\\\ \\texttt{gl\\_MatrixUseAccumulator} \& \\text{累加器(加法操作数和结果)} \& M \\times N \& \\text{FP16, FP32} \\end{array} \\

5.3 Scope 参数

Scope 参数定义了矩阵数据分布的范围:

gl_ScopeSubgroup:矩阵分布在单个 Subgroup 内(最常用)

gl_ScopeWorkgroup:矩阵分布在多个 Subgroup 间

5.4 加载和存储函数

coopMatLoad / coopMatLoadNV

从内存加载数据到 Cooperative Matrix:

glsl 复制代码
// KHR 扩展语法
void coopMatLoad(
    out coopmat<ElementType, Scope, M, N, Use> matrix,  // 输出矩阵
    in uint[] buffer,                                   // 源缓冲区
    uint offset,                                        // 起始偏移(元素数)
    uint stride,                                        // 行/列步长
    uint layout                                         // 布局方式
);

// 示例:从 buffer 加载矩阵
coopmat<float16_t, gl_ScopeSubgroup, 16, 8, gl_MatrixUseA> matA;
coopMatLoad(matA, inputData, 0, K, gl_CooperativeMatrixLayoutRowMajor);

coopMatStore / coopMatStoreNV

将 Cooperative Matrix 数据存储到内存:

glsl 复制代码
// KHR 扩展语法
void coopMatStore(
    in coopmat<ElementType, Scope, M, N, Use> matrix,  // 输入矩阵
    out uint[] buffer,                                 // 目标缓冲区
    uint offset,                                       // 起始偏移
    uint stride,                                       // 行/列步长
    uint layout                                        // 布局方式
);

// 示例:存储结果矩阵
coopmat<float, gl_ScopeSubgroup, 16, 8, gl_MatrixUseAccumulator> result;
coopMatStore(result, outputData, 0, N, gl_CooperativeMatrixLayoutRowMajor);

5.5 矩阵乘加函数

执行矩阵乘加运算:\(D = A \times B + C\)

glsl 复制代码
// KHR 扩展语法
coopmat<ResultType, Scope, M, N, gl_MatrixUseAccumulator> coopMatMulAdd(
    coopmat<AType, Scope, M, K, gl_MatrixUseA> A,
    coopmat<BType, Scope, K, N, gl_MatrixUseB> B,
    coopmat<CType, Scope, M, N, gl_MatrixUseAccumulator> C
);

// 示例:完整的矩阵乘加
coopmat<float16_t, gl_ScopeSubgroup, 16, 8, gl_MatrixUseA> matA;
coopmat<float16_t, gl_ScopeSubgroup, 8, 8, gl_MatrixUseB> matB;
coopmat<float, gl_ScopeSubgroup, 16, 8, gl_MatrixUseAccumulator> matC;
coopmat<float, gl_ScopeSubgroup, 16, 8, gl_MatrixUseAccumulator> result;

// 初始化累加器
matC = coopmat<float, gl_ScopeSubgroup, 16, 8, gl_MatrixUseAccumulator>(0.0f);

// 执行乘加
result = coopMatMulAdd(matA, matB, matC);

5.6 布局参数

内存布局

行主序(gl_CooperativeMatrixLayoutRowMajor

\\\text{addr}(i, j) = \\text{base} + i \\times \\text{stride} + j \\

内存排列示意(以 3×4 矩阵为例):

\\\begin{bmatrix} a_{0,0} \& a_{0,1} \& a_{0,2} \& a_{0,3} \\\\ a_{1,0} \& a_{1,1} \& a_{1,2} \& a_{1,3} \\\\ a_{2,0} \& a_{2,1} \& a_{2,2} \& a_{2,3} \\end{bmatrix} \\

内存中的存储顺序

\\[a_{0,0}, a_{0,1}, a_{0,2}, a_{0,3}, \\; a_{1,0}, a_{1,1}, a_{1,2}, a_{1,3}, \\; a_{2,0}, a_{2,1}, a_{2,2}, a_{2,3} \]

即按行依次存储:第 0 行 → 第 1 行 → 第 2 行

列主序(gl_CooperativeMatrixLayoutColumnMajor

\\\text{addr}(i, j) = \\text{base} + j \\times \\text{stride} + i \\

内存排列示意(以 3×4 矩阵为例):

\\\begin{bmatrix} a_{0,0} \& a_{0,1} \& a_{0,2} \& a_{0,3} \\\\ a_{1,0} \& a_{1,1} \& a_{1,2} \& a_{1,3} \\\\ a_{2,0} \& a_{2,1} \& a_{2,2} \& a_{2,3} \\end{bmatrix} \\

内存中的存储顺序

\\[a_{0,0}, a_{1,0}, a_{2,0}, \\; a_{0,1}, a_{1,1}, a_{2,1}, \\; a_{0,2}, a_{1,2}, a_{2,2}, \\; a_{0,3}, a_{1,3}, a_{2,3} \]

即按列依次存储:第 0 列 → 第 1 列 → 第 2 列 → 第 3 列

5.7 其他相关函数

矩阵初始化

glsl 复制代码
// 使用标量初始化所有元素
coopmat<float, gl_ScopeSubgroup, 16, 8, gl_MatrixUseAccumulator> mat(0.0f);

// 拷贝构造
coopmat<float, gl_ScopeSubgroup, 16, 8, gl_MatrixUseAccumulator> mat2 = mat1;

矩阵的行数 M 和列数 N 在类型声明时已确定,编译时已知,无需运行时查询。

6. GEMM 完整实现示例

本章通过一个完整的 GEMM 示例项目 cm_gemm_example,讲解 Cooperative Matrix 的实际应用。该示例包含主机端代码和 Shader 代码,实现了带有 Double Buffer 优化的矩阵乘法。

6.1 项目结构

复制代码
cm_gemm_example/
├── main.cpp          # 主机端代码
└── CMakeLists.txt    # 构建配置

主要功能模块

\\\begin{array}{c\|c\|c} \\text{模块} \& \\text{文件} \& \\text{功能} \\\\ \\hline \\text{设备初始化} \& \\text{main.cpp} \& \\text{创建 Instance、Device、Queue} \\\\ \\text{CM 属性查询} \& \\text{main.cpp} \& \\text{查询支持的矩阵尺寸} \\\\ \\text{权重重排} \& \\text{main.cpp} \& \\text{预处理 B 矩阵} \\\\ \\text{Pipeline 创建} \& \\text{main.cpp} \& \\text{编译 Shader、创建 Pipeline} \\\\ \\text{GEMM 计算} \& \\text{Shader} \& \\text{执行矩阵乘法} \\end{array} \\

6.2 主机端:查询 Cooperative Matrix 属性

在使用 Cooperative Matrix 之前,必须查询设备支持的矩阵尺寸和类型组合:

cpp 复制代码
void queryCooperativeMatrixProperties() {
    PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR = 
        (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)
        vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR");
    
    uint32_t propertyCount = 0;
    VK_CHECK(vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(physicalDevice, &propertyCount, nullptr));
    
    std::vector<VkCooperativeMatrixPropertiesKHR> properties(propertyCount);
    for (auto& prop : properties) {
        prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;
        prop.pNext = nullptr;
    }
    
    VK_CHECK(vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(physicalDevice, &propertyCount, properties.data()));
    
    for (const auto& prop : properties) {
        if (prop.scope == VK_SCOPE_SUBGROUP_KHR) {
            std::cout << "M=" << prop.MSize << ", N=" << prop.NSize << ", K=" << prop.KSize
                      << " (A: " << prop.AType << ", B: " << prop.BType << ", C: " << prop.CType << ")" << std::endl;
            
            if (prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
                prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
                prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
                selectedCM.M = prop.MSize;
                selectedCM.N = prop.NSize;
                selectedCM.K = prop.KSize;
            }
        }
    }
}

输出示例

复制代码
Supported Cooperative Matrix configurations:
----------------------------------------
M=16, N=8, K=16 (A: FP16, B: FP16, C: FP16)
M=16, N=8, K=16 (A: FP16, B: FP16, C: FP32)
M=16, N=16, K=16 (A: FP16, B: FP16, C: FP16)
M=16, N=16, K=16 (A: FP16, B: FP16, C: FP32)

Selected configuration: M=16, N=16, K=16 (FP16 input, FP32 accumulate)

6.3 主机端:Tile 尺寸计算

根据矩阵尺寸和 CM 尺寸,计算最优的 Tile 参数:

cpp 复制代码
void calculateTileSizes() {
    TILE_M = std::min((M + selectedCM.M - 1) / selectedCM.M, 2u);
    TILE_N = std::min((N + selectedCM.N - 1) / selectedCM.N, 2u);
    TILE_K = std::min((K + selectedCM.K - 1) / selectedCM.K, 2u);
    
    std::cout << "Tile sizes (for double buffer optimization):" << std::endl;
    std::cout << "  TILE_M = " << TILE_M << std::endl;
    std::cout << "  TILE_N = " << TILE_N << std::endl;
    std::cout << "  TILE_K = " << TILE_K << std::endl;
}

Tile 参数含义

\\\begin{aligned} \\text{TILE\\_M} \&= \\min\\left(\\left\\lceil \\frac{M}{M_{cm}} \\right\\rceil, 2\\right) \\\\ \\text{TILE\\_N} \&= \\min\\left(\\left\\lceil \\frac{N}{N_{cm}} \\right\\rceil, 2\\right) \\\\ \\text{TILE\\_K} \&= \\min\\left(\\left\\lceil \\frac{K}{K_{cm}} \\right\\rceil, 2\\right) \\end{aligned} \\

每个 Subgroup 处理 \(\text{TILE\M} \times \text{TILE\N}\) 个 CM 块,每个 CM 块尺寸为 \(M{cm} \times N{cm}\)。

6.4 主机端:B 矩阵权重重排

为了优化内存访问,B 矩阵需要预先重排:

cpp 复制代码
void reorderBWeights() {
    uint32_t numWG_N = (N + TILE_N * selectedCM.N - 1) / (TILE_N * selectedCM.N);
    uint32_t numKTiles = (K + TILE_K * selectedCM.K - 1) / (TILE_K * selectedCM.K);
    uint32_t tileSizeB = TILE_K * selectedCM.K * TILE_N * selectedCM.N;
    
    hostB_reordered.resize(numWG_N * numKTiles * tileSizeB, 0.0f);
    
    for (uint32_t wgCol = 0; wgCol < numWG_N; wgCol++) {
        for (uint32_t kt = 0; kt < numKTiles; kt++) {
            uint32_t tileOffset = (wgCol * numKTiles + kt) * tileSizeB;
            
            for (uint32_t localK = 0; localK < TILE_K * selectedCM.K; localK++) {
                for (uint32_t localN = 0; localN < TILE_N * selectedCM.N; localN++) {
                    uint32_t globalK = kt * TILE_K * selectedCM.K + localK;
                    uint32_t globalN = wgCol * TILE_N * selectedCM.N + localN;
                    
                    uint32_t reorderedIdx = tileOffset + localK * (TILE_N * selectedCM.N) + localN;
                    
                    if (globalK < K && globalN < N) {
                        hostB_reordered[reorderedIdx] = hostB[globalK * N + globalN];
                    }
                }
            }
        }
    }
}

重排原理示意 (以 \(K=4, N=8\),Tile 尺寸 \(2 \times 4\) 为例):

\\\begin{array}{ccc} \\text{原始布局} \& \& \\text{重排后布局} \\\\ \\begin{bmatrix} \\color{red}{\\boxed{b_{0,0} \\cdots b_{0,3}}} \& \\color{blue}{\\boxed{b_{0,4} \\cdots b_{0,7}}} \\\\ \\color{red}{\\boxed{b_{1,0} \\cdots b_{1,3}}} \& \\color{blue}{\\boxed{b_{1,4} \\cdots b_{1,7}}} \\\\ \\color{green}{\\boxed{b_{2,0} \\cdots b_{2,3}}} \& \\color{orange}{\\boxed{b_{2,4} \\cdots b_{2,7}}} \\\\ \\color{green}{\\boxed{b_{3,0} \\cdots b_{3,3}}} \& \\color{orange}{\\boxed{b_{3,4} \\cdots b_{3,7}}} \\end{bmatrix} \& \\rightarrow \& \\underbrace{\\begin{bmatrix} \\boxed{\\color{red}{\\text{Tile}_{0,0}}} \\\\ \\boxed{\\color{green}{\\text{Tile}_{1,0}}} \\\\ \\boxed{\\color{blue}{\\text{Tile}_{0,1}}} \\\\ \\boxed{\\color{orange}{\\text{Tile}_{1,1}}} \\end{bmatrix}}_{\\text{按 K 方向连续存储}} \\end{array} \\

重排过程

  1. 将原始矩阵划分为 \(K_{cm} \times N_{cm}\) 的 tile
  2. 每个 tile 内的数据保持原有顺序
  3. 按 (WG, K方向) 重新排列 tile 的存储顺序

重排的优势

\\\begin{array}{c\|c\|c} \\text{特性} \& \\text{原始布局} \& \\text{重排后布局} \\\\ \\hline \\text{内存访问模式} \& \\text{跨行跳跃访问} \& \\text{连续块访问} \\\\ \\text{缓存利用率} \& \\text{低} \& \\text{高} \\\\ \\text{coopMatLoad 效率} \& \\text{需要分散读取} \& \\text{单次连续读取} \\end{array} \\

6.5 主机端:Pipeline 创建与 Specialization Constants

使用 Specialization Constants 将运行时参数传递给 Shader:

cpp 复制代码
void createPipeline() {
    std::string shaderSource = loadShaderSource();
    std::vector<uint32_t> spirv = compileShader(shaderSource);
    
    uint32_t numKTiles = (K + TILE_K * selectedCM.K - 1) / (TILE_K * selectedCM.K);
    uint32_t specData[8] = {
        selectedCM.M, selectedCM.N, selectedCM.K,
        subgroupSize,
        TILE_M, TILE_N, TILE_K,
        numKTiles
    };
    
    VkSpecializationMapEntry mapEntries[8] = {};
    for (int i = 0; i < 8; i++) {
        mapEntries[i].constantID = i;
        mapEntries[i].offset = i * sizeof(uint32_t);
        mapEntries[i].size = sizeof(uint32_t);
    }
    
    VkSpecializationInfo specInfo = {};
    specInfo.mapEntryCount = 8;
    specInfo.pMapEntries = mapEntries;
    specInfo.dataSize = sizeof(specData);
    specInfo.pData = specData;
    
    pipelineInfo.stage.pSpecializationInfo = &specInfo;
    pipelineInfo.stage.flags = VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT;
    
    VK_CHECK(vkCreateComputePipelines(device, VK_NULL_HANDLE, 1, &pipelineInfo, nullptr, &pipeline));
}

Specialization Constants 映射

\\\begin{array}{c\|c\|c} \\text{ID} \& \\text{Shader 常量} \& \\text{主机端值} \\\\ \\hline 0 \& \\texttt{CM\\_M} \& \\texttt{selectedCM.M} \\\\ 1 \& \\texttt{CM\\_N} \& \\texttt{selectedCM.N} \\\\ 2 \& \\texttt{CM\\_K} \& \\texttt{selectedCM.K} \\\\ 3 \& \\texttt{subgroup\\_size} \& \\texttt{subgroupSize} \\\\ 4 \& \\texttt{TILE\\_M} \& \\texttt{TILE\\_M} \\\\ 5 \& \\texttt{TILE\\_N} \& \\texttt{TILE\\_N} \\\\ 6 \& \\texttt{TILE\\_K} \& \\texttt{TILE\\_K} \\\\ 7 \& \\texttt{numKTiles} \& \\texttt{numKTiles} \\end{array} \\

6.6 Shader:常量定义与 Shared Memory

glsl 复制代码
#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_ballot : require
#extension GL_KHR_memory_scope_semantics : require
#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_cooperative_matrix : require

layout(constant_id = 0) const uint CM_M = 16;
layout(constant_id = 1) const uint CM_N = 16;
layout(constant_id = 2) const uint CM_K = 16;
layout(constant_id = 3) const uint subgroup_size = 32;
layout(constant_id = 4) const uint TILE_M = 2;
layout(constant_id = 5) const uint TILE_N = 2;
layout(constant_id = 6) const uint TILE_K = 2;
layout(constant_id = 7) const uint numKTiles = 1;

layout(local_size_x_id = 3) in;

layout(binding = 0) readonly buffer A_buffer { float16_t A_data[]; };
layout(binding = 1) readonly buffer B_buffer { float16_t B_data[]; };
layout(binding = 2) writeonly buffer C_buffer { float C_data[]; };

layout(push_constant) uniform PushConstants { uint M; uint N; uint K; } pc;

shared float16_t sharedA[2][TILE_M * CM_M * TILE_K * CM_K];
shared float16_t sharedB[2][TILE_K * CM_K * TILE_N * CM_N];
shared float sharedC[TILE_M * CM_M * TILE_N * CM_N];

Shared Memory 结构

\\\begin{array}{c\|c\|c} \\text{变量} \& \\text{尺寸} \& \\text{用途} \\\\ \\hline \\text{sharedA\[2} & 2 \times \text{TILE\M} \times M{cm} \times \text{TILE\K} \times K{cm} & \text{双缓冲 A 矩阵} \\ \text{sharedB2} & 2 \times \text{TILE\K} \times K{cm} \times \text{TILE\N} \times N{cm} & \text{双缓冲 B 矩阵} \\ \text{sharedC} & \text{TILE\M} \times M{cm} \times \text{TILE\N} \times N{cm} & \text{结果矩阵} \end{array} \]

6.7 Shader:数据加载函数

glsl 复制代码
void loadTileA(uint tileIdx, uint wgRow, uint kStart) {
    const uint lane = gl_SubgroupInvocationID;
    const uint totalElements = TILE_M * CM_M * TILE_K * CM_K;
    for (uint i = lane; i < totalElements; i += subgroup_size) {
        uint localM = i / (TILE_K * CM_K);
        uint localK = i % (TILE_K * CM_K);
        uint globalM = wgRow * TILE_M * CM_M + localM;
        uint globalK = kStart + localK;
        if (globalM < pc.M && globalK < pc.K) {
            sharedA[tileIdx][i] = A_data[globalM * pc.K + globalK];
        } else {
            sharedA[tileIdx][i] = float16_t(0.0);
        }
    }
}

void loadTileB(uint tileIdx, uint wgCol, uint kt) {
    const uint lane = gl_SubgroupInvocationID;
    const uint tileSizeB = TILE_K * CM_K * TILE_N * CM_N;
    uint tileOffset = (wgCol * numKTiles + kt) * tileSizeB;
    for (uint i = lane; i < tileSizeB; i += subgroup_size) {
        sharedB[tileIdx][i] = B_data[tileOffset + i];
    }
}

加载策略

  • 所有线程协作加载,每个线程加载 totalElements / subgroupSize 个元素
  • A 矩阵从全局内存按行主序加载
  • B 矩阵从预重排的缓冲区直接复制

6.8 Shader:主计算循环(Double Buffer)

glsl 复制代码
void main() {
    const uint wgRow = gl_WorkGroupID.x;
    const uint wgCol = gl_WorkGroupID.y;
    const uint lane = gl_SubgroupInvocationID;

    if (wgRow * TILE_M * CM_M >= pc.M || wgCol * TILE_N * CM_N >= pc.N) return;

    coopmat<float, gl_ScopeSubgroup, CM_M, CM_N, gl_MatrixUseAccumulator> sum[TILE_M][TILE_N];
    [[unroll]] for (uint tm = 0; tm < TILE_M; tm++) {
        [[unroll]] for (uint tn = 0; tn < TILE_N; tn++) {
            sum[tm][tn] = coopmat<float, gl_ScopeSubgroup, CM_M, CM_N, gl_MatrixUseAccumulator>(0.f);
        }
    }

    const uint KTileSize = TILE_K * CM_K;

    loadTileA(0, wgRow, 0);
    loadTileB(0, wgCol, 0);
    barrier();

    for (uint kt = 0; kt < numKTiles; kt++) {
        uint currentBuf = kt % 2;
        uint nextBuf = (kt + 1) % 2;

        [[unroll]] for (uint tk = 0; tk < TILE_K; tk++) {
            [[unroll]] for (uint tm = 0; tm < TILE_M; tm++) {
                [[unroll]] for (uint tn = 0; tn < TILE_N; tn++) {
                    coopmat<float16_t, gl_ScopeSubgroup, CM_M, CM_K, gl_MatrixUseA> matA;
                    coopmat<float16_t, gl_ScopeSubgroup, CM_K, CM_N, gl_MatrixUseB> matB;

                    uint offsetA = (tm * CM_M) * (TILE_K * CM_K) + tk * CM_K;
                    uint offsetB = (tk * CM_K) * (TILE_N * CM_N) + tn * CM_N;
                    uint strideA = TILE_K * CM_K;
                    uint strideB = TILE_N * CM_N;

                    coopMatLoad(matA, sharedA[currentBuf], offsetA, strideA, gl_CooperativeMatrixLayoutRowMajor);
                    coopMatLoad(matB, sharedB[currentBuf], offsetB, strideB, gl_CooperativeMatrixLayoutRowMajor);

                    sum[tm][tn] = coopMatMulAdd(matA, matB, sum[tm][tn]);
                }
            }
        }

        barrier();

        uint nextKt = kt + 1;
        if (nextKt < numKTiles) {
            uint nextKStart = nextKt * KTileSize;
            loadTileA(nextBuf, wgRow, nextKStart);
            loadTileB(nextBuf, wgCol, nextKt);
        }

        barrier();
    }

    [[unroll]] for (uint tm = 0; tm < TILE_M; tm++) {
        [[unroll]] for (uint tn = 0; tn < TILE_N; tn++) {
            uint offsetC = (tm * CM_M) * (TILE_N * CM_N) + tn * CM_N;
            uint strideC = TILE_N * CM_N;
            coopMatStore(sum[tm][tn], sharedC, offsetC, strideC, gl_CooperativeMatrixLayoutRowMajor);
        }
    }
    barrier();

    const uint totalC = TILE_M * CM_M * TILE_N * CM_N;
    for (uint i = lane; i < totalC; i += subgroup_size) {
        uint localM = i / (TILE_N * CM_N);
        uint localN = i % (TILE_N * CM_N);
        uint globalM = wgRow * TILE_M * CM_M + localM;
        uint globalN = wgCol * TILE_N * CM_N + localN;
        if (globalM < pc.M && globalN < pc.N) {
            C_data[globalM * pc.N + globalN] = sharedC[i];
        }
    }
}

6.9 Double Buffer 优化详解

问题背景

传统实现中,数据加载和计算是串行执行的:

\\\begin{array}{c\|cccccc} \\text{时间} \& t_0 \& t_1 \& t_2 \& t_3 \& t_4 \& t_5 \\\\ \\hline \\text{加载} \& \\boxed{L_0} \& \& \\boxed{L_1} \& \& \\boxed{L_2} \& \\\\ \\text{计算} \& \& \\boxed{C_0} \& \& \\boxed{C_1} \& \& \\boxed{C_2} \\end{array} \\

总时间:

\T_{\\text{serial}} = \\sum_{i=0}\^{n-1} (T_{\\text{load},i} + T_{\\text{compute},i}) \\

Double Buffer 原理

使用两个缓冲区交替工作,实现加载与计算的重叠:

\\\begin{array}{c\|cccccc} \\text{时间} \& t_0 \& t_1 \& t_2 \& t_3 \& t_4 \& t_5 \\\\ \\hline \\text{Buffer 0} \& \\boxed{L_0} \& \\boxed{C_0} \& \\boxed{L_2} \& \\boxed{C_2} \& \& \\\\ \\text{Buffer 1} \& \& \\boxed{L_1} \& \\boxed{C_1} \& \\boxed{L_3} \& \\boxed{C_3} \& \\end{array} \\

加速比

\\\text{Speedup} = \\frac{T_{\\text{serial}}}{T_{\\text{pipelined}}} \\approx \\frac{2(T_L + T_C)}{T_L + T_C} = 2 \\

代码实现要点

glsl 复制代码
for (uint kt = 0; kt < numKTiles; kt++) {
    uint currentBuf = kt % 2;      // 当前计算使用的缓冲区
    uint nextBuf = (kt + 1) % 2;   // 下一块数据加载的缓冲区

    // 1. 从当前缓冲区计算
    coopMatLoad(matA, sharedA[currentBuf], ...);
    coopMatLoad(matB, sharedB[currentBuf], ...);
    sum[tm][tn] = coopMatMulAdd(matA, matB, sum[tm][tn]);

    barrier();

    // 2. 预加载下一块到另一个缓冲区
    if (nextKt < numKTiles) {
        loadTileA(nextBuf, wgRow, nextKStart);
        loadTileB(nextBuf, wgCol, nextKt);
    }

    barrier();
}

关键点

  1. 使用 currentBufnextBuf 交替访问两个缓冲区
  2. barrier() 确保计算完成后再加载下一块
  3. 预加载与当前计算在时间上重叠

6.10 性能分析

计算复杂度

\\\text{FLOP} = 2 \\times M \\times N \\times K \\

内存访问量

\\\text{Bytes} = M \\times K \\times 2 + K \\times N \\times 2 + M \\times N \\times 4 \\

(A 和 B 为 FP16,C 为 FP32)

计算强度

\\\text{Arithmetic Intensity} = \\frac{2 \\times M \\times N \\times K}{M \\times K \\times 2 + K \\times N \\times 2 + M \\times N \\times 4} \\quad \\text{FLOP/Byte} \\

示例 (\(M=N=K=256\)):

\\\text{FLOP} = 2 \\times 256\^3 = 33,554,432 \\approx 33.6 \\text{ MFLOP} \\

\\\text{Bytes} = 256\^2 \\times (2 + 2 + 4) = 524,288 \\approx 0.5 \\text{ MB} \\

\\\text{Intensity} = \\frac{33.6 \\text{ MFLOP}}{0.5 \\text{ MB}} = 64 \\text{ FLOP/Byte} \\

6.11 运行示例

复制代码
============================================
Vulkan Cooperative Matrix GEMM - Optimized
With Double Buffer Optimization
============================================
Validation layers enabled
Vulkan instance created

Available GPUs:
---------------
0: Intel(R) RaptorLake-S Mobile Graphics Controller [Cooperative Matrix: NO]
1: NVIDIA GeForce RTX 4060 Laptop GPU [Cooperative Matrix: YES]
2: Intel(R) RaptorLake-S Mobile Graphics Controller [Cooperative Matrix: NO]

Selected GPU: NVIDIA GeForce RTX 4060 Laptop GPU
Subgroup size: 32
VK_KHR_cooperative_matrix supported
Logical device created

Supported Cooperative Matrix configurations:
----------------------------------------
M=16, N=16, K=16 (A: FP16, B: FP16, C: FP16)
M=16, N=8, K=16 (A: FP16, B: FP16, C: FP16)
M=16, N=8, K=8 (A: FP16, B: FP16, C: FP16)
M=16, N=16, K=16 (A: FP16, B: FP16, C: FP32)
M=16, N=8, K=16 (A: FP16, B: FP16, C: FP32)
M=16, N=8, K=8 (A: FP16, B: FP16, C: FP32)

Selected configuration: M=16, N=16, K=16 (FP16 input, FP32 accumulate)

Tile sizes (for double buffer optimization):
  TILE_M = 2
  TILE_N = 2
  TILE_K = 2

Buffer sizes:
  A: 131072 bytes (256x256 FP16)
  B (reordered): 131072 bytes
  C: 262144 bytes (256x256 FP32)
Buffers created successfully
Pipeline created

=== Initializing Test Data ===

=== B Weight Reordering ===
B weight reordered for optimal memory access pattern
  Original size: 65536 elements
  Reordered size: 65536 elements
  numWG_N: 8, numKTiles: 8
Test data uploaded to GPU

=== Running GPU GEMM ===

Dispatching 8 x 8 workgroups

=== GPU Results ===
Execution time: 0.886 ms
Performance: 37.89 GFLOPS

=== Running CPU GEMM for comparison ===

=== CPU Results ===
Execution time (avg of 5 runs): 7.990 ms
Performance: 4.20 GFLOPS

=== Result Verification ===
Max difference: 1.64e-03
Avg difference: 2.80e-04
Errors (diff > 1.00e-02): 0/65536
Status: PASSED

Done!

7. 总结

核心公式汇总

\\\begin{array}{c\|c} \\text{概念} \& \\text{公式} \\\\ \\hline \\text{矩阵乘法} \& D = A \\times B + C \\\\ \\text{元素计算} \& D_{ij} = \\sum_k A_{ik} B_{kj} + C_{ij} \\\\ \\text{每线程元素数} \& E = \\frac{M \\times N}{S} \\\\ \\text{Workgroup 数量} \& WG = \\lceil M/(M_{cm} Z_M W_M) \\rceil \\times \\lceil N/(N_{cm} Z_N W_N) \\rceil \\\\ \\text{内存大小} \& \\text{size} = B \\times M_{cm} \\times K_{cm} \\times Z \\times W \\times K_K \\\\ \\text{计算复杂度} \& \\text{FLOP} = 2 \\times M \\times N \\times K \\end{array} \\

性能要点

  1. 数据布局优化:预处理权重,使内存访问连续
  2. 计算密度:每个 Subgroup 处理足够大的矩阵块
  3. 内存层次:利用 Local Memory 减少全局内存访问
  4. 流水线:计算与内存传输重叠
  5. Uniform Control Flow:确保所有线程同时调用 CM 函数
相关推荐
FreakStudio1 天前
大话电容传感器和电容SOC芯片,看这一篇就够了
python·单片机·嵌入式·面向对象·并行计算·电子diy·电子计算机
神工坊5 天前
性能测试︱Abaqus 非线性瞬态热固耦合仿真——增材制造求解效率实测
abaqus·hpc·并行计算·cae·结构仿真·非线性·热固耦合
FreakStudio21 天前
硬件版【Cursor】?aily blockly IDE尝鲜封神,实战硬伤尽显
python·单片机·嵌入式·大学生·面向对象·并行计算·电子diy·电子计算机
FreakStudio23 天前
开源分享|用MicroPython 做了个 AI 小鸡,它会长大,还记得我所有的情绪
python·单片机·嵌入式·面向对象·并行计算·电子diy·电子计算机
FreakStudio25 天前
WIZnet-EVB-Pico2开始,用MicroPython玩转以太网开发
python·单片机·嵌入式·大学生·面向对象·技术栈·并行计算·电子diy·电子计算机
FreakStudio1 个月前
工控开发板从开箱到点亮 LED-恩智浦MCXE31B 实测:3 路 CAN + 以太网+自带调试器
python·单片机·嵌入式·大学生·面向对象·技术栈·并行计算·电子diy·电子计算机
FreakStudio1 个月前
MicroPython 内核开发者直接狂喜!这个 Claude 插件市场,把开发全流程做成了「对话式外挂」
python·单片机·嵌入式·面向对象·并行计算·电子diy
FreakStudio1 个月前
亲测可用!可本地部署的 MicroPython 开源仿真器
python·单片机·嵌入式·面向对象·并行计算·电子diy·电子计算机
FreakStudio1 个月前
和做工厂系统的印尼老哥,复刻了一套属于 MicroPython 的包管理系统
python·单片机·嵌入式·大学生·面向对象·并行计算·电子diy·电子计算机
FreakStudio1 个月前
做了个Claude Code CLI 电子宠物:程序员的实体监工代码搭子
python·单片机·嵌入式·面向对象·并行计算·电子diy·电子计算机