发表博客之:gemm/threadblock/threadblock_swizzle.h 文件夹讲解,cutlass深入讲解

文章目录

发表博客之:gemm/threadblock/threadblock_swizzle.h 文件夹讲解,cutlass深入讲解

  • 我们知道,cuda 处理问题都是将一个很大规模的问题分成很多个小问题,每个小问题由一个ThreadBlock来处理,而ThreadblockSwizzle就是负责将逻辑上的小问题映射到cuda上的ThreadBlock上。
  • 或者直接引用这个文件上的注释吧!
  • Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems.

先来看一下最简单的struct GemmIdentityThreadblockSwizzle结构体

  • 这个结构体有一个默认参数是1。
cpp 复制代码
template <int N = 1>
struct GemmIdentityThreadblockSwizzle {

  CUTLASS_HOST_DEVICE
  GemmIdentityThreadblockSwizzle() { }

  /// Returns the shape of the problem in units of logical tiles
  /// *Gemm* problem size: gemm(M, N, K)
  /// 这个函数的作用是简单的。
  /// 就是以tile_size为逻辑单元,整个问题的逻辑shape!
  CUTLASS_HOST_DEVICE
  static GemmCoord get_tiled_shape(
    GemmCoord problem_size,
    GemmCoord tile_size,
    int split_k_slices) {

    return GemmCoord(
      (problem_size.m() + tile_size.m() - 1) / tile_size.m(),
      (problem_size.n() + tile_size.n() - 1) / tile_size.n(),
      split_k_slices);
  }

  /// Returns the shape of the problem in units of logical tiles
  /// *ImplicitGemm* Conv2d problem size: conv_operator(NPQK, NHWC, KRSC)
  CUTLASS_HOST_DEVICE
  static GemmCoord get_tiled_shape(
    cutlass::conv::Operator conv_operator,
    cutlass::conv::Conv2dProblemSize const &problem_size,
    GemmCoord tile_size,
    int split_k_slices) {

    gemm::GemmCoord implicit_gemm_problem_size = 
    cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size);

    return get_tiled_shape(
      implicit_gemm_problem_size, tile_size, split_k_slices);
  }

  /// Returns the shape of the problem in units of logical tiles
  /// *ImplicitGemm* Conv3d problem size: conv_operator(NZPQK, NDHWC, KTRSC)
  CUTLASS_HOST_DEVICE
  static GemmCoord get_tiled_shape(
    cutlass::conv::Operator conv_operator,
    cutlass::conv::Conv3dProblemSize const &problem_size,
    GemmCoord tile_size,
    int split_k_slices) {

    gemm::GemmCoord implicit_gemm_problem_size = 
    cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size);

    return get_tiled_shape(
      implicit_gemm_problem_size, tile_size, split_k_slices);
  }

  /// 这个函数是获得物理shape!也就是三对三对<<<>>>下的grid_shape!
  /// Computes CUDA grid dimensions given a size in units of logical tiles
  CUTLASS_HOST_DEVICE
  static dim3 get_grid_shape(GemmCoord tiled_shape) {
    int tile = 1 << get_log_tile(tiled_shape);
    return dim3(tiled_shape.m() * tile, (tiled_shape.n() + tile - 1) / tile, tiled_shape.k());
  }
  • 下面的这个函数来获得最好的get_log_tile!
cpp 复制代码
  /// 这个是防止函数是防止逻辑shape上的n过大,导致grid的第2维过大!
  /// Calculates optimal swizzle width
  CUTLASS_HOST_DEVICE
  static int get_log_tile(GemmCoord tiled_shape) {
    auto n = tiled_shape.n();
    // Thresholds picked so that it doesn't cause too many no-op CTAs
    if (N >= 8 && n >= 6)
      return 3;
    else if (N >= 4 && n >= 3)
      return 2;
    else if (N >= 2 && n >= 2)
      return 1;
    else
      return 0;
  }
  • 下面两个函数是同一个名字,get_tile_offset,但是参数不同。
    • 他们的共同作用根据物理id是获取 逻辑上Tile的偏移量!
  • 但是第二个函数好像很少用到的样子!
cpp 复制代码
  /// Obtains the threadblock offset (in units of threadblock-scoped tiles)
  CUTLASS_DEVICE
  static GemmCoord get_tile_offset(int log_tile) {
    int block_idx_x = RematerializeBlockIdxX();
    int block_idx_y = RematerializeBlockIdxY();
    int block_idx_z = RematerializeBlockIdxZ();

    return GemmCoord{(block_idx_x >> log_tile),  //
                     (block_idx_y << log_tile) + ((block_idx_x) & ((1 << (log_tile)) - 1)),
                     block_idx_z};
  }
  
  /// Obtains the threadblock offset (in units of threadblock-scoped tiles)
  CUTLASS_DEVICE
  static GemmCoord get_tile_offset(GemmCoord tiled_shape) {

    int const kTile = N;
    int block_idx_x = RematerializeBlockIdxX();
    int block_idx_y = RematerializeBlockIdxY();

    if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile))
      return GemmCoord{block_idx_x, block_idx_y, RematerializeBlockIdxZ()};

    return GemmCoord{
      (block_idx_x / kTile),
      (block_idx_y * kTile) + (block_idx_x % kTile),
      RematerializeBlockIdxZ()
    };
  }
};
  • 举个例子,假设N=1,并且 C C C输出矩阵被分成下面这样的逻辑shape,
  • 那么三对<<<>>>发射的grid就是(4,4,1)!
  • 那么每个Tile被映射到的ThreadBlock id如下图所示。
  • 如果 N = 2 N=2 N=2,

  • 那么三对<<<>>>发射的grid就是(8,2,1)!

  • 那么每个Tile被映射到的ThreadBlock id如下图所示。

相关推荐
m0_603888715 小时前
FineInstructions Scaling Synthetic Instructions to Pre-Training Scale
人工智能·深度学习·机器学习·ai·论文速览
爬台阶的蚂蚁6 小时前
RAG概念和使用
ai·rag
undsky_6 小时前
【RuoYi-SpringBoot3-Pro】:将 AI 编程融入传统 java 开发
java·人工智能·spring boot·ai·ai编程
AI应用开发实战派6 小时前
AI人工智能中Bard的智能电子商务优化
人工智能·ai·bard
AI原生应用开发6 小时前
AIGC领域Bard在通信领域的内容创作
ai·aigc·bard
唐诺6 小时前
深入了解AI
人工智能·ai
ZEGO即构开发者7 小时前
如何用一句话让AI集成 ZEGO 产品
ai·实时互动·实时音视频·rtc
阿杰学AI7 小时前
AI核心知识76——大语言模型之RAG 2.0(简洁且通俗易懂版)
人工智能·ai·语言模型·自然语言处理·rag·检索增强生成·rag2.0
GuoDongOrange7 小时前
智能体来了从 0 到 1:工作流在智能体系统中的真实作用
ai·智能体·从0到1·智能体来了·智能体来了从0到1
爱吃涮肉8 小时前
# 第二章:ClaudeCode核心功能(详细版)
ai