【Pytorch】cumsum的实现逻辑

本文只记录cumsum的实现逻辑的CUDA部分,也即底层调用了CUDA的什么实现算子。

cpp 复制代码
void launch_cumsum_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) {
  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
      ScalarType::Half, ScalarType::BFloat16,
      self.scalar_type(), "cumsum_cuda",
      [&]() {
        scalar_t init = 0;
        scan_dim<scalar_t>(
            self,
            result,
            dim,
            init,
            std::plus<scalar_t>());
      });
}

通过定位源码,找到了执行kernel的关键代码,可以看到,此代码内部调用了Pytorch定义的宏,核心调用是pytorch定义的名为scan_dim的模板函数。

该模板函数的定义位于:aten/src/ATen/native/cuda/ScanUtils.cuh

代码如下:

cpp 复制代码
template<typename scalar_t, typename BinaryFunction>
void scan_dim(const TensorBase& self, const TensorBase& result,
     int64_t dim, scalar_t init, BinaryFunction binary_op) {
  int ndim = self.dim();
  auto self_ = self.expect_contiguous();
  TORCH_INTERNAL_ASSERT(result.is_contiguous());

  if (self.numel() == self.size(dim)) {
    cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
  } else if (dim == ndim - 1) {
    scan_innermost_dim<scalar_t>(*self_, result, init, binary_op);
  } else {
    scan_outer_dim<scalar_t>(*self_, result, dim, init, binary_op);
  }
}

该函数内部最重要的是后面的条件结构,首先如果元素的总数和当前维度的元素个数相同,也即tensor是一维的,直接利用cub的前缀扫描方法,如果元素的总数和当前维度的元素个数不同,又分为最内层的维度,也即最后一维,以及其他情况。

cpp 复制代码
template<typename scalar_t, class BinaryFunction>
__host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
                             int dim, scalar_t init, BinaryFunction binary_op) {
  const int64_t row_size = self.size(dim);
  auto sizes = self.sizes();

  // Treat all outer dimensions (i.e. dim_ < dim) as one.
  const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);

  // Treat all inner dimensions (i.e. dim > dimension) as one.
  const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());

  dim3 threads(std::min(512, int(num_irows)));
  int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
  dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));

  check_fits_in_unsigned(num_irows, "num_irows");
  check_fits_in_unsigned(num_orows, "num_orows");
  check_fits_in_unsigned(row_size, "row_size");

  tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
    result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
    num_orows, num_irows, row_size, init, binary_op);
  C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template <typename scalar_t, class BinaryFunction>
void scan_innermost_dim(const TensorBase& self, const TensorBase& result,
                        scalar_t init, BinaryFunction binary_op) {
  int64_t ndim = self.dim();
  // Treat all outer dimensions as a single dimension.
  int64_t row_size = self.size(ndim - 1);
  int64_t num_rows = self.numel() / row_size;

  // assuming max_num_threads per block is 512
  const uint32_t num_threads = 512;
  const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan<uint32_t>(num_rows, row_size);
  const uint32_t num_threads_x = (1 << log_num_threads_x);
  const uint32_t num_threads_y = num_threads / num_threads_x;
  dim3 threads(num_threads_x, num_threads_y);
  int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
  dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y})));

  check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))");
  check_fits_in_unsigned(row_size, "row_size");

  tensor_kernel_scan_innermost_dim<scalar_t><<<grid, threads, num_threads * 2 * sizeof(scalar_t),
                                               at::cuda::getCurrentCUDAStream()>>>(
    result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
    num_rows, row_size, log_num_threads_x, init, binary_op);
  C10_CUDA_KERNEL_LAUNCH_CHECK();
}

可以看到Pytorch针对上述两种情况进行了自定义,因为cub的inclusive_scan针对的是一维张量而非多维张量。

在调用核函数前,首先要定义调用核函数的网络结构和线程块结构,pytorch默认的线程块大小是512的,那么如何将512个线程块进行二维切分以满足合适的比例呢,pytorch中的做法是像下面这样:

cpp 复制代码
template <typename integer>
constexpr inline integer get_log_num_threads_x_inner_scan(integer num_rows, integer row_size) {
  integer log_num_threads_x = 0;
  integer log_num_threads_y = 0;
  while (((integer)1 << log_num_threads_x) < row_size) {
    ++log_num_threads_x;
  }
  while (((integer)1 << log_num_threads_y) < num_rows) {
    ++log_num_threads_y;
  }
  // we want to keep the ratio between the x-threads and y-threads about the same as
  // the ratio between the row_size and num_rows, but the total number of threads in
  // a block should be about 512
  integer diff = log_num_threads_x - log_num_threads_y;
  // 9 is from log2(512)
  log_num_threads_x = ((integer)9 + diff) / (integer)2;
  // I found that in having larger log_num_threads_x can give significant speed up in some cases,
  // but detrimental in another case, so just keep the lower bound to be log2(16) == 4 to make it
  // similar to the previous implementation
  // Keeping the upper bound to be log2(512) == 9 as the maximum number of threads in a block.
  log_num_threads_x = std::min(std::max((integer)4, log_num_threads_x), (integer)9);
  return log_num_threads_x;
}

使用对数进行计算是便于计算出的x的结果可以整除,关键点在于最后平衡二者的比例的那行代码。可以预见,在某些情况下由于待处理数据的大小超过512造成线程块不能够完全分配的情况,此时就需要顾及线程块的比例,那么如果两个维度上线程块的对数值分别为x和y,对应的线程数分别为X,Y,也即 X = 2 x X=2^x X=2x。此时X与Y的比例 X / Y X / Y X/Y 的结果也即 2 x − y 2^{x - y} 2x−y ,其实也就是 2 d i f f 2 ^ {diff} 2diff。那么如果将x变为(diff+9) / 2, y也就是 (9 - diff) / 2,二者相减也就是diff,因此保证了变换前后的比例。

相关推荐
十有久诚7 分钟前
E2VPT: An Effective and Efficient Approach for Visual Prompt Tuning
人工智能·深度学习·提示学习·视觉语言模型
Unicorn建模25 分钟前
2024“华为杯”中国研究生数学建模竞赛(E题)深度剖析|数学建模完整过程+详细思路+代码全解析
python·算法·数学建模
卓_尔_不_凡30 分钟前
Pytorch学习---基于经典网络架构ResNet训练花卉图像分类模型
人工智能·分类·数据挖掘
神奇夜光杯39 分钟前
Python酷库之旅-第三方库Pandas(123)
开发语言·人工智能·python·excel·pandas·标准库及第三方库·学习与成长
SEU-WYL43 分钟前
基于神经网络的光线追踪
人工智能·神经网络·计算机视觉
Bill661 小时前
OpenCV GUI常用函数详解
人工智能·opencv·计算机视觉
DisonTangor1 小时前
OpenAI面向开发者继续提高o1系列模型的调用速率 最高每分钟可调用1000次
人工智能
zhangbin_2371 小时前
【Python机器学习】NLP信息提取——提取人物/事物关系
开发语言·人工智能·python·机器学习·自然语言处理
王豫翔1 小时前
OpenAl o1论文:Let’s Verify Step by Step 快速解读
人工智能·深度学习·机器学习·chatgpt
xuehaikj1 小时前
婴儿接触危险物品检测系统源码分享
人工智能·计算机视觉·目标跟踪