引言
cutlass是nvidia官方开源的一套用于通用矩阵乘法(GEMM)的C++模板库。相比于cuBlas、cuTensor等其他功能近似的cuda库,cutlass具有以下优点:
- 开源:方便进行自定义和客制化的开发;
- 模板化:API设计灵活,开发更高效;
- 高性能:通常能达到cuBlas 90+%的性能。
cutlass底层依赖tensor core和warp矩阵乘累加(Warp Matrix Multiply Accumulate, WMMA) API。在2017年Nvidia发布了Tesla V100 GPU中采用了Volta架构,引入了第一代tensor core。并在同一时期发布的cuda 9.0中新增了与tensor core配套的wmma API。本文将结合Volta架构对cutlass底层所使用的wmma API原理进行简单介绍。
cutlass GEMM hierarchy
CUDA中一次GEMM操作可以分为两个阶段,如图1所示:
-
main loop: m×k的矩阵A与k×n的矩阵B做矩阵乘,需要在k的维度进行分块,然后在k维度进行遍历,这个过程称之为main loop。在此阶段,cutlass支持batched gemm、splitK等优化。
-
epilogue: 做完矩阵乘后可以进行element wise操作,如add bias、activation等。在此阶段cutlass可以进行kernel fuse。
图1 Gemm计算流程图
cutlass的main loop阶段的层次结构如图2所示,执行过程如下:
- 申请一个二维的grid,网格中每个block负责一个小的矩阵块(tile),将数据从global memory读取到shared memory,这一步与使用cuda core的GEMM操作类似;
- 每个block中的每个warp负责从shared memory中读取一个矩阵片段(fragment)的数据到寄存器(register file),在register file上使用wmma指令调动cuda core执行矩阵乘法。
图2 cutlass main loop层次结构
在最内层warp级别使用tensor core进行矩阵乘累加计算,需要依赖cuda 9.0之后的wmma API,下面我们将借助Volta架构中的第一代tensor core结构,对底层的mma指令进行了解。
tensor core & wmma API
与cuda core相比,tensor core是一种SM级别的硬件结构,可编程的粒度为warp level ,在开发上不如cuda core的thread level灵活。Volta架构的第一代tensor core可以在一个时钟周期实现4(m)×4(n)×4(k)的矩阵乘累加,其中输入矩阵A(m×k: 4×4)、B(k×n: 4×4)数据类型为FP16,输出数据类型可以为FP16或FP32:
图3 tensor core在一个时钟周期内实现4(m)×4(n)×4(k)的矩阵乘累加
而在cuda core中做一次同样4×4×4的场景FP32的矩阵乘法,一个时钟周期只能实现1x4的运算,完成4x4x4矩阵乘法需要16个时钟周期:
图4 cuda core(Pascal) vs tensor core(Volta)
图4展示的X12倍是指tensor core相比于cuda core每秒浮点数运算理论峰值速度的差异,并不是指左边的cuda core完成1次运算完花费了16个时钟周期,右边的tensor core就要完成16次4×4×4算,毕竟cuda core和tensor core单个时钟周期所花费的时间也是不一样的。
cuda中使用tensor core来加速Gemm要借助wmma API来实现,wmma API位于cuda头文件mma.h中,有以下几个功能,结合使用可实现最底层的warp level gemm运算:
C++
// 定义warp所负责的矩阵片段(fragment)的数据布局
template<typename Use, int m, int n, int k, typename T, typename Layout = void>
class fragment;
template<> class fragment<matrix_a, 16, 16, 16, __half, row_major>
template<> class fragment<matrix_a, 16, 16, 16, __half, col_major>
template<> class fragment<matrix_b, 16, 16, 16, __half, row_major>
template<> class fragment<matrix_b, 16, 16, 16, __half, col_major>
template<> class fragment<accumulator, 16, 16, 16, __half>
template<> class fragment<accumulator, 16, 16, 16, float>
// 从内存中加载数据到warp负责的fragment
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm);
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm, layout_t,layout);
// 用定值v填充fragment
void fill_fragment(fragment<...> &a, const T& v);
// 执行mma运算
void mma_sync(fragment<...> &d, const fragment<...> &a, const fragment<...>
&b, const fragment<...> &c, bool satf = false);
// 将warp负责fragment数据储存到内存中
void store_matrix_sync(T* mptr, const fragment<...> &a, unsigned ldm, layout_t
layout);
wmma API属于指令级别的API,操作的数据是fragment级别的,运算发生在register file上。单个warp每次从shared memory中拿出一个fragment,在寄存器上实现整个fragment的运算,fragment层级的运算是由一个warp内32个线程借助tensor core共同完成的,Volta中一个fragment大小为16x16,即m、n、k尺寸都是16,下面将介绍Volta架构中为什么是16x16的fragment,以及单个warp上是怎么利用tensor core执行mma的。
Volta架构mma执行原理
根据Modeling Deep Learning Accelerator Enabled GPUs (arXiv:1811.08309)相关研究,在对16×16的FP16矩阵A、B进行加载时,一个warp内4个连续的thread分成一个组(thread group),32个thread共分为8组,编号0~7,每个group中的4个thread可以负责4行(A row major)或4列(B colum major的4x16子块(sub fragment)加载,具体加载方式如图5所示。
图5 Volta架构中每个warp加载16×16 fragment A B C的方式
图5左边展示了fragment A和B的加载方式,①中一个颜色块对应的4×16 sub fragment被2个thread group同时加载,A和B中子fragment与thread group映射关系如④所示(感觉论文影印版这张图有问题,后面会进行说明)。一行16个FP16的数据共需要256位,针对不同的layout,thread group 在对4×16 sub fragment加载方式如②和③所示。对于A(row major)和B(column major)的layout,1个线程会使用两个合并的128位的加载指令共256位加载16个FP16,对于A(column major)和B(row major),1个线程会使用4个间隔的64位加载指令完成。
图5右边展示了fragment C的加载方式,C中1个thread group中的4个线程加载1个4×8 sub fragment,thread group中线程与sub fragment中数据映射关系和布局无关,只与数据类型有关。
图6展示了threadGroup 0和threadGroup 4执行mma指令计算fragment C中1个4×8 sub fragment的全过程 :一个mma指令分为4个组(set 1 ~ 4),如图6(a)所示。其中每个set按照数据类型混精、FP16又分为4或2个step ,如图6(b)和6(c)所示。如果是混精运算,(a)中set 1将分为(b)中的4步;如果是FP16,将会分为(c)中的2步。因此完成一个4×8 sub fragment的计算共需4x4或4x2步,每1步都是一个4x4的结构,便可借助tensor core在1个时钟周期内完成。
图6 Volta架构中每个warp执行16×16 fragment运算方式
接下来我们回到图5中,说说为什么感觉影印版中的图有问题 ,回到图5(b),其实threadGroup 0和threadGroup 4作为一个group对,除了负责C中左上角的4×8的子片段[0:3,0:7],还负责靠下位置的4×8的子片段[4:7,0:7],共负责C中左上角[0:7,0:7]的1/4块。这也是为什么每个4×16的 sub fragment会被2个thread group同时加载的原因 。这样的2个thread group做为一组,称之为一个octet。表1展示了octet的组对方式和fragment A和B被加载方式。如果按照影印版原图(图5)会和表1对不上,对图5中matrixB两处进行对调后可以对上。
表1 octet加载A和B
在理清octet关系后,Volta架构中整个warp负责的16×16 fragment C = AxB+C计算过程如图7所示:每个octet负责一个4×8的A、8x4的B、4x4的C小片段做乘累加。threadGroup-0和threadGroup-4所组成的octet-0计算过程如图7(1-b)和7(2)表格所示。
图7 Volta架构中的octet与混精计算详细过程
至此应该可以理解在Volta架构中一个warp是怎么通过32个线程,利用tensor core恰好对16×16的fragment,完成一次mma运算。论文中还介绍了Turing架构的wmma执行原理,感兴趣的读者可以去原论文中进行了解。
小结
cutlass底层依靠tensor core和WMMA API,按照matrix(grid) => tile(block) => warp(fragment)的层次结构进行GEMM操作,了解在warp level如何利用tensor core对fragment执行mma运算的原理,将有助于我们进一步熟悉cutlass源码实现,从而更得心应手的使用cutlass进行GEMM开发。