【Pytorch】cumsum的实现逻辑

本文只记录cumsum的实现逻辑的CUDA部分,也即底层调用了CUDA的什么实现算子。

cpp 复制代码
void launch_cumsum_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) {
  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
      ScalarType::Half, ScalarType::BFloat16,
      self.scalar_type(), "cumsum_cuda",
      [&]() {
        scalar_t init = 0;
        scan_dim<scalar_t>(
            self,
            result,
            dim,
            init,
            std::plus<scalar_t>());
      });
}

通过定位源码,找到了执行kernel的关键代码,可以看到,此代码内部调用了Pytorch定义的宏,核心调用是pytorch定义的名为scan_dim的模板函数。

该模板函数的定义位于:aten/src/ATen/native/cuda/ScanUtils.cuh

代码如下:

cpp 复制代码
template<typename scalar_t, typename BinaryFunction>
void scan_dim(const TensorBase& self, const TensorBase& result,
     int64_t dim, scalar_t init, BinaryFunction binary_op) {
  int ndim = self.dim();
  auto self_ = self.expect_contiguous();
  TORCH_INTERNAL_ASSERT(result.is_contiguous());

  if (self.numel() == self.size(dim)) {
    cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
  } else if (dim == ndim - 1) {
    scan_innermost_dim<scalar_t>(*self_, result, init, binary_op);
  } else {
    scan_outer_dim<scalar_t>(*self_, result, dim, init, binary_op);
  }
}

该函数内部最重要的是后面的条件结构,首先如果元素的总数和当前维度的元素个数相同,也即tensor是一维的,直接利用cub的前缀扫描方法,如果元素的总数和当前维度的元素个数不同,又分为最内层的维度,也即最后一维,以及其他情况。

cpp 复制代码
template<typename scalar_t, class BinaryFunction>
__host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
                             int dim, scalar_t init, BinaryFunction binary_op) {
  const int64_t row_size = self.size(dim);
  auto sizes = self.sizes();

  // Treat all outer dimensions (i.e. dim_ < dim) as one.
  const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);

  // Treat all inner dimensions (i.e. dim > dimension) as one.
  const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());

  dim3 threads(std::min(512, int(num_irows)));
  int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
  dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));

  check_fits_in_unsigned(num_irows, "num_irows");
  check_fits_in_unsigned(num_orows, "num_orows");
  check_fits_in_unsigned(row_size, "row_size");

  tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
    result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
    num_orows, num_irows, row_size, init, binary_op);
  C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template <typename scalar_t, class BinaryFunction>
void scan_innermost_dim(const TensorBase& self, const TensorBase& result,
                        scalar_t init, BinaryFunction binary_op) {
  int64_t ndim = self.dim();
  // Treat all outer dimensions as a single dimension.
  int64_t row_size = self.size(ndim - 1);
  int64_t num_rows = self.numel() / row_size;

  // assuming max_num_threads per block is 512
  const uint32_t num_threads = 512;
  const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan<uint32_t>(num_rows, row_size);
  const uint32_t num_threads_x = (1 << log_num_threads_x);
  const uint32_t num_threads_y = num_threads / num_threads_x;
  dim3 threads(num_threads_x, num_threads_y);
  int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
  dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y})));

  check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))");
  check_fits_in_unsigned(row_size, "row_size");

  tensor_kernel_scan_innermost_dim<scalar_t><<<grid, threads, num_threads * 2 * sizeof(scalar_t),
                                               at::cuda::getCurrentCUDAStream()>>>(
    result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
    num_rows, row_size, log_num_threads_x, init, binary_op);
  C10_CUDA_KERNEL_LAUNCH_CHECK();
}

可以看到Pytorch针对上述两种情况进行了自定义,因为cub的inclusive_scan针对的是一维张量而非多维张量。

在调用核函数前,首先要定义调用核函数的网络结构和线程块结构,pytorch默认的线程块大小是512的,那么如何将512个线程块进行二维切分以满足合适的比例呢,pytorch中的做法是像下面这样:

cpp 复制代码
template <typename integer>
constexpr inline integer get_log_num_threads_x_inner_scan(integer num_rows, integer row_size) {
  integer log_num_threads_x = 0;
  integer log_num_threads_y = 0;
  while (((integer)1 << log_num_threads_x) < row_size) {
    ++log_num_threads_x;
  }
  while (((integer)1 << log_num_threads_y) < num_rows) {
    ++log_num_threads_y;
  }
  // we want to keep the ratio between the x-threads and y-threads about the same as
  // the ratio between the row_size and num_rows, but the total number of threads in
  // a block should be about 512
  integer diff = log_num_threads_x - log_num_threads_y;
  // 9 is from log2(512)
  log_num_threads_x = ((integer)9 + diff) / (integer)2;
  // I found that in having larger log_num_threads_x can give significant speed up in some cases,
  // but detrimental in another case, so just keep the lower bound to be log2(16) == 4 to make it
  // similar to the previous implementation
  // Keeping the upper bound to be log2(512) == 9 as the maximum number of threads in a block.
  log_num_threads_x = std::min(std::max((integer)4, log_num_threads_x), (integer)9);
  return log_num_threads_x;
}

使用对数进行计算是便于计算出的x的结果可以整除,关键点在于最后平衡二者的比例的那行代码。可以预见,在某些情况下由于待处理数据的大小超过512造成线程块不能够完全分配的情况,此时就需要顾及线程块的比例,那么如果两个维度上线程块的对数值分别为x和y,对应的线程数分别为X,Y,也即 X = 2 x X=2^x X=2x。此时X与Y的比例 X / Y X / Y X/Y 的结果也即 2 x − y 2^{x - y} 2x−y ,其实也就是 2 d i f f 2 ^ {diff} 2diff。那么如果将x变为(diff+9) / 2, y也就是 (9 - diff) / 2,二者相减也就是diff,因此保证了变换前后的比例。

相关推荐
泰迪智能科技0134 分钟前
高校深度学习视觉应用平台产品介绍
人工智能·深度学习
盛派网络小助手1 小时前
微信 SDK 更新 Sample,NCF 文档和模板更新,更多更新日志,欢迎解锁
开发语言·人工智能·后端·架构·c#
算法小白(真小白)1 小时前
低代码软件搭建自学第二天——构建拖拽功能
python·低代码·pyqt
唐小旭1 小时前
服务器建立-错误:pyenv环境建立后python版本不对
运维·服务器·python
007php0071 小时前
Go语言zero项目部署后启动失败问题分析与解决
java·服务器·网络·python·golang·php·ai编程
Eric.Lee20211 小时前
Paddle OCR 中英文检测识别 - python 实现
人工智能·opencv·计算机视觉·ocr检测
cd_farsight1 小时前
nlp初学者怎么入门?需要学习哪些?
人工智能·自然语言处理
AI明说1 小时前
评估大语言模型在药物基因组学问答任务中的表现:PGxQA
人工智能·语言模型·自然语言处理·数智药师·数智药学
Chinese Red Guest2 小时前
python
开发语言·python·pygame
Focus_Liu2 小时前
NLP-UIE(Universal Information Extraction)
人工智能·自然语言处理