BASE_TRT: NVIDIA Tensor Core

NVIDIA Tensor Core 探索


📖 目录

  1. [Tensor Core 硬件架构](#Tensor Core 硬件架构)
  2. [Step 1: cuBLAS 快速上手](#Step 1: cuBLAS 快速上手)
  3. [Step 2: WMMA 基础 - 单个 16×16×16](#Step 2: WMMA 基础 - 单个 16×16×16)
  4. [Step 3: WMMA 扩展 - Tiled 多 Warp 协作](#Step 3: WMMA 扩展 - Tiled 多 Warp 协作)
  5. [Step 4: WMMA 优化 - Shared Memory](#Step 4: WMMA 优化 - Shared Memory)

1. Tensor Core 硬件架构

1.1 硬件层级结构

复制代码
GPU
└── 多个 SM (Streaming Multiprocessor)
    ├── CUDA Core (你熟悉的标量计算单元)
    └── Tensor Core (新增的矩阵计算单元) ★
        └── 每个时钟周期执行 4×4×4 矩阵乘加
           等效于 64 次标量乘加操作

1.2 Tensor Core 计算粒度

基本操作单元 : D = A × B + C

  • 输入: A (16×16), B (16×16), C (16×16)

  • 输出: D (16×16)

  • 实现: 将 16×16×16 分解为多个 4×4×4 小块,在多个时钟周期内完成

    ┌─────────────────────────────────────────────────────┐
    │ 16×16 结果矩阵 (D) │
    │ ┌────┬────┬────┬────┐ │
    │ │4×4 │4×4 │4×4 │4×4 │ ← 每个 4×4 小块 │
    │ ├────┼────┼────┼────┤ 需要 K/4 = 4 次 │
    │ │4×4 │... │... │... │ 时钟周期累加 │
    │ ├────┼────┼────┼────┤ │
    │ │... │... │... │... │ 总计: 4×4×4 = 64 时钟周期 │
    │ └────┴────┴────┴────┘ │
    └─────────────────────────────────────────────────────┘

关键点:

  • 最小粒度: 16×16×16 (不可分割)
  • 矩阵尺寸必须是 16 的倍数,否则需要 padding
  • 一个 Warp (32 线程) 协作完成一个 16×16×16 的计算

1.3 架构支持表

GPU 架构 Compute Capability 支持精度
Volta (V100) sm_70 FP16
Turing (T4, RTX 2080) sm_75 FP16, INT8
Ampere (A100) sm_80 FP16, BF16, TF32, FP64
Ada Lovelace (RTX 4090) sm_89 FP16, FP8
Hopper (H100) sm_90 FP16, FP8, FP64

2. Step 1: cuBLAS 快速上手

2.1 代码位置

src/cublas.cu - 使用 cuBLAS 库调用 Tensor Core

2.2 核心代码

cpp 复制代码
// 1. 创建 cuBLAS handle
cublasHandle_t handle;
cublasCreate(&handle);

// 2. ★ 启用 Tensor Core 模式
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);

// 3. 执行矩阵乘法 (混合精度: FP16 输入, FP32 输出)
cublasGemmEx(
    handle,
    CUBLAS_OP_N, CUBLAS_OP_N,
    N, M, K,
    &alpha,
    d_B, CUDA_R_16F, N,  // ← 注意: B 在前
    d_A, CUDA_R_16F, K,  // ← A 在后
    &beta,
    d_C, CUDA_R_32F, N,
    CUBLAS_COMPUTE_32F,
    CUBLAS_GEMM_DEFAULT_TENSOR_OP  // ★ Tensor Core 算法
);

2.3 列主序陷阱与解决方案

问题: cuBLAS 默认列主序 (Fortran 风格), C 语言是行主序

数学等价变换:

复制代码
行主序: C = A × B
等价于
列主序: C^T = B^T × A^T

实现技巧:

  • 把 A 和 B 的传递顺序对调
  • leading dimension 对应原矩阵的维度
  • 最终结果正确 (两次转置相互抵消)

详细推导:

复制代码
步骤1: 行主序内存 → 列主序视角 = 自动转置
  A_行主序 (M×K) → A^T_列主序 (K×M)
  B_行主序 (K×N) → B^T_列主序 (N×K)

步骤2: cuBLAS 列主序计算
  C^T = B^T (N×K) × A^T (K×M) = (A×B)^T

步骤3: 输出列主序 → 行主序读回 = 再次转置
  C^T_列主序 → C_行主序 = (C^T)^T = C ✓

2.4 优缺点

优点:

  • ✅ 简单易用,一行代码搞定
  • ✅ 自动优化,性能接近硬件峰值
  • ✅ 支持各种精度组合

缺点:

  • ❌ 黑盒操作,无法深度定制
  • ❌ 依赖 cuBLAS 库
  • ❌ 学不到底层原理

3. Step 2: WMMA 基础 - 单个 16×16×16

3.1 代码位置

src/wmma_simple.cu - 使用 WMMA API 手写 Tensor Core 代码

3.2 WMMA API 核心概念

WMMA = Warp-level Matrix Multiply-Accumulate

cpp 复制代码
namespace wmma = nvcuda::wmma;  // 命名空间别名 (重要!)

// 1️⃣ 声明 fragment (矩阵片段)
wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc;

// 2️⃣ 初始化累加器
wmma::fill_fragment(acc, 0.0f);

// 3️⃣ 从全局内存加载
wmma::load_matrix_sync(a_frag, A_ptr, 16);  // ld = 16
wmma::load_matrix_sync(b_frag, B_ptr, 16);

// 4️⃣ 执行矩阵乘加 (触发 Tensor Core)
wmma::mma_sync(acc, a_frag, b_frag, acc);  // acc = a×b + acc

// 5️⃣ 写回全局内存
wmma::store_matrix_sync(C_ptr, acc, 16, wmma::mem_row_major);

3.3 Fragment 深入理解

Fragment 是什么?

  • 不是真实的矩阵,而是"矩阵在 Warp 32 个线程间的分布式表示"
  • 每个线程持有 fragment 的一部分数据
  • 只能通过 WMMA API 操作,无法直接访问元素

模板参数解析:

cpp 复制代码
fragment<
    wmma::matrix_a,    // 角色: matrix_a / matrix_b / accumulator
    16, 16, 16,        // M, N, K (必须是 16)
    __half,            // 数据类型: __half / float
    wmma::row_major    // 内存布局 (accumulator 无此参数)
>

3.4 Leading Dimension (ld) 详解

cpp 复制代码
// 对于行主序矩阵 A[M][K]:
// A[i][j] 在内存中的位置 = A + i*ld + j
// ld = K (每行有 K 个元素)

wmma::load_matrix_sync(
    a_frag,
    A_ptr,   // 指向 A[i][j] 的指针
    ld       // = 矩阵的列数 (行主序)
);

常见错误:

cpp 复制代码
// ❌ 错误: ld 应该是矩阵列数,而非 tile 大小
wmma::load_matrix_sync(a_frag, A, 16);  // 如果 A 是 256×256

// ✅ 正确
wmma::load_matrix_sync(a_frag, A, 256);

3.5 线程配置要求

cpp 复制代码
// ★ 必须恰好 32 个线程 (1 个 warp)
wmma_kernel<<<1, 32>>>(A, B, C);

// ❌ 错误的配置
wmma_kernel<<<1, 64>>>();   // 多了
wmma_kernel<<<1, 16>>>();   // 少了
wmma_kernel<<<1, 31>>>();   // 不是 32 的倍数

4. Step 3: WMMA 扩展 - Tiled 多 Warp 协作

4.1 代码位置

src/wmma_tiled.cu - 处理大矩阵,多个 Warp 协作

4.2 Tiling 策略

问题: 一个 Warp 只能处理 16×16,如何计算 256×256?

解决: 把大矩阵切成多个 16×16 块,分配给不同 Warp

复制代码
C[256×256] 切分示意:
┌─────┬─────┬─────┬─────┐
│Warp │Warp │Warp │Warp │  ← 第 1 行 Block (4 个 Warp)
│ 0   │ 1   │ 2   │ 3   │
├─────┼─────┼─────┼─────┤
│Warp │Warp │Warp │Warp │  ← 第 2 行 Block
│ 4   │ 5   │ 6   │ 7   │
├─────┼─────┼─────┼─────┤
│     │ ... │     │     │
└─────┴─────┴─────┴─────┘

每个小方块 = 16×16 (1 个 Warp 负责)
每个 Block = 64×64 (16 个 Warp = 512 线程)

4.3 配置参数设计

cpp 复制代码
#define TC_M  16           // Tensor Core 固定 tile
#define TC_N  16
#define TC_K  16

#define WARPS_ROW  4       // Block 在行方向: 4 个 warp
#define WARPS_COL  4       // Block 在列方向: 4 个 warp

#define BLOCK_M  (WARPS_ROW * TC_M)   // = 64 行
#define BLOCK_N  (WARPS_COL * TC_N)   // = 64 列
#define BLOCK_THREADS  (WARPS_ROW * WARPS_COL * 32)  // = 512 线程

4.4 Warp 定位逻辑

cpp 复制代码
__global__ void wmma_tiled_kernel(...) {
    // 步骤1: 计算当前线程所属的 Warp
    int warp_id = threadIdx.x / 32;         // 0~15 (Block 内 16 个 warp)
    
    // 步骤2: Warp 在 Block 内的 2D 位置
    int warp_row = warp_id / WARPS_COL;     // 0~3
    int warp_col = warp_id % WARPS_COL;     // 0~3
    
    // 步骤3: 计算负责的 C 矩阵区域全局坐标
    int c_row = blockIdx.y * BLOCK_M + warp_row * TC_M;  // 第几行 (起始)
    int c_col = blockIdx.x * BLOCK_N + warp_col * TC_N;  // 第几列 (起始)
    
    // 步骤4: K 方向循环累加
    for (int k = 0; k < K; k += TC_K) {
        const __half *a_ptr = A + c_row * K + k;      // A[c_row][k]
        const __half *b_ptr = B + k * N + c_col;      // B[k][c_col]
        
        wmma::load_matrix_sync(a_frag, a_ptr, K);     // ld = K
        wmma::load_matrix_sync(b_frag, b_ptr, N);     // ld = N
        wmma::mma_sync(acc, a_frag, b_frag, acc);
    }
    
    // 步骤5: 写回结果
    float *c_ptr = C + c_row * N + c_col;
    wmma::store_matrix_sync(c_ptr, acc, N, wmma::mem_row_major);
}

4.5 Grid/Block 配置

cpp 复制代码
dim3 block(BLOCK_THREADS);  // 512 线程 = 16 warp
dim3 grid(
    (N + BLOCK_N - 1) / BLOCK_N,  // 列方向 Block 数 = ceil(N/64)
    (M + BLOCK_M - 1) / BLOCK_M   // 行方向 Block 数 = ceil(M/64)
);

wmma_tiled_kernel<<<grid, block>>>(A, B, C, M, N, K);

示例: 256×256 矩阵

  • Grid = (4, 4) → 16 个 Block
  • 每个 Block = 512 线程 = 16 Warp
  • 总计 16×16 = 256 个 Warp = 256 个 16×16 tile ✓

5. Step 4: WMMA 优化 - Shared Memory

5.1 代码位置

src/wmma_smem.cu - 使用 Shared Memory 减少全局内存访问

5.2 性能瓶颈分析

Step 3 的问题:

cpp 复制代码
for (int k = 0; k < K; k += 16) {
    // 每次循环都从全局内存读取
    wmma::load_matrix_sync(a_frag, A + ..., K);  // 慢! (400+ cycles)
    wmma::load_matrix_sync(b_frag, B + ..., N);  // 慢!
    wmma::mma_sync(...);  // 快 (< 10 cycles)
}

带宽瓶颈: Tensor Core 计算太快,全局内存跟不上

5.3 Shared Memory 优化策略

核心思想: Block 内所有线程协作,先把数据搬到 Shared Memory (更快)

复制代码
优化前:
  每个 Warp 独立从 Global Memory 读取
  Global Memory ← Warp 0
  Global Memory ← Warp 1
  ...

优化后:
  Block 内所有线程协作加载到 Shared Memory
  Global Memory → Shared Memory (1次, 多线程合作)
  Shared Memory → Warp 0
  Shared Memory → Warp 1
  ...

5.4 Shared Memory 布局设计

cpp 复制代码
__shared__ __half smem_A[BLOCK_M][SMEM_K];  // [64][16] = 2KB
__shared__ __half smem_B[SMEM_K][BLOCK_N];  // [16][64] = 2KB
// 总共 4KB per Block

为什么是 [64][16] 和 [16][64]?

  • BLOCK_M = 64: Block 负责 C 的 64 行
  • BLOCK_N = 64: Block 负责 C 的 64 列
  • SMEM_K = 16: K 方向每次只搬 16 列 (一个 TC tile 的宽度)

5.5 协作加载算法

cpp 复制代码
// 阶段1: 512 个线程协作加载 smem_A[64][16] = 1024 个元素
int total = BLOCK_M * SMEM_K;  // 1024
for (int idx = threadIdx.x; idx < total; idx += BLOCK_THREADS) {
    int r = idx / SMEM_K;   // 行号 (0~63)
    int c = idx % SMEM_K;   // 列号 (0~15)
    
    int global_row = blockIdx.y * BLOCK_M + r;
    int global_col = k + c;
    
    smem_A[r][c] = (global_row < M && global_col < K)
        ? A[global_row * K + global_col]
        : __float2half(0.0f);
}

// 阶段2: 同样协作加载 smem_B[16][64] = 1024 个元素
// ... (类似逻辑)

// ★ 关键: 同步屏障
__syncthreads();  // 确保所有线程加载完成

// 阶段3: 各 Warp 从 Shared Memory 读取
wmma::load_matrix_sync(
    a_frag,
    &smem_A[warp_row * TC_M][0],  // Shared Memory 地址
    SMEM_K                         // ld = 16
);

5.6 双重同步的重要性

cpp 复制代码
for (int k = 0; k < K; k += SMEM_K) {
    // 加载数据到 smem
    // ...
    
    __syncthreads();  // ★ 同步1: 确保加载完成后再读取
    
    // Warp 从 smem 读取并计算
    // ...
    
    __syncthreads();  // ★ 同步2: 确保所有 Warp 用完后再覆写
}

为什么需要两次同步?

  1. 同步1: 防止某些线程还没加载完,其他线程就开始读
  2. 同步2: 防止某些 Warp 还在用旧数据,新数据就覆盖进来

5.7 性能提升预期

优化阶段 内存访问 相对性能
Step 3 (Global only) 每个 Warp 独立读 Global
Step 4 (+ Shared Memory) Block 协作, 读 Shared 2-3×
极致优化 (Double Buffer) 计算和加载重叠 5-10×

6. 性能分析与验证

6.1 验证 Tensor Core 是否启用

bash 复制代码
# 使用 ncu 工具
sudo $(which ncu) \
  --metrics \
sm__inst_executed_pipe_tensor.avg,\
sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed,\
sm__throughput.avg.pct_of_peak_sustained_elapsed \
  --print-summary per-kernel \
  ./wmma_tiled

输出结果:

复制代码
wmma_tiled_kernel(const __half *, const __half *, float *, int, int, int)
  (4, 4, 1)x(512, 1, 1), Device 0, CC 8.6, Invocations 1
  
  Section: Command line profiler metrics
  --------------------------------------------------------------- ----------- -------
  Metric Name                                                     Metric Unit Average
  --------------------------------------------------------------- ----------- -------
  sm__inst_executed_pipe_tensor.avg                                   inst    120.47
  sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed        %      2.21
  sm__throughput.avg.pct_of_peak_sustained_elapsed                       %      6.86
  --------------------------------------------------------------- ----------- -------

6.2 Tensor Core 指令数分析

关键指标 : sm__inst_executed_pipe_tensor.avg = 120.47 inst

理论计算:

  • 矩阵尺寸: 256×256×256
  • 每个 Block 负责: 64×64 输出 (4×4 = 16 个 Warp)
  • 每个 Warp 负责: 16×16 输出
  • K 方向循环: 256/16 = 16 次
  • 每个 Warp 执行: 16 次 Tensor Core 指令
  • 每个 Block 执行: 16 Warp × 16 次 = 256 次

实际结果: 120.47 inst (平均每个 SM)

差异分析:

复制代码
理论: 256 次/Block
实际: 120.47 次/SM (平均)

原因:
1. Grid = (4, 4) = 16 个 Block
2. GPU 有多个 SM,Block 被分配到不同 SM 上
3. 负载不均: 有些 SM 处理多个 Block,有些处理得少
4. 平均值被"空闲"的 SM 拉低

实际有效 SM 数 ≈ (16 Block × 256 inst) / 120.47 = ~34 个 SM 在工作

关键洞察:

  • ✅ Tensor Core 已激活(指令数 > 0)
  • ⚠️ SM 占用率不理想(部分 SM 空闲)
  • 💡 优化方向:增大问题规模或调整 Block 配置

7. 关键要点速查

7.1 硬件限制

  • 最小粒度: 16×16×16
  • Warp 需求: 必须 32 线程整数倍
  • 架构要求: sm_70+ (Volta 及以上)
  • 数据类型: FP16 (Volta), FP8/BF16/TF32 (Ampere+)

7.2 API 规则

  • Fragment: 不可直接访问,只能通过 WMMA API
  • Leading Dim: 行主序 = 列数, 列主序 = 行数
  • 同步 : Global→Shared 和 Shared→Fragment 都要 __syncthreads()
  • 命名空间 : 用 namespace wmma = nvcuda::wmma; 避免问题

8. 参考资料


相关推荐
Rabitebla1 小时前
从零实现 C++ List:带头循环双向链表的每一个细节
数据结构·c++·算法·leetcode·链表·list
Byte不洛1 小时前
深入理解C++智能指针:从RAII到shared_ptr
c++·智能指针·raii·unique_ptr·shared_ptr·auto_ptr
云深麋鹿1 小时前
C++ | map&set的使用
开发语言·c++
allnlei1 小时前
gRPC C++ Callback API(Reactor 模式)介绍
开发语言·c++
菜_小_白1 小时前
高性能线程池
linux·c++·设计模式
minji...2 小时前
Linux 网络基础(三)HTTP的请求方法(GET/POST),HTTP表单、临时和永久重定向状态码、Cookie、查询参数、Web根目录
linux·运维·服务器·网络·c++·http
鱼很腾apoc2 小时前
【学习篇】第18期 C++模板
c语言·c++
郝学胜-神的一滴2 小时前
跨平台 C++ 静态库编译实战:Linux/Windows/macOS 三端统一实现
linux·开发语言·c++·windows·软件构建