PyTorch到C++再到 CUDA 的调用链(C++ ATen 层) :以torch._amp_update_scale_调用为例

今天在看pytorch源码,遇到的问题,记录一下 。

source:/lib/python3.10/site-packages/torch/amp/grad_scaler.py

go 复制代码
 torch._amp_update_scale_(
            _scale,
            _growth_tracker,
            found_inf_combined,
            self._growth_factor,
            self._backoff_factor,
            self._growth_interval,
        )

这段python代码是如何调用底层的cpp代码的

go 复制代码
Tensor& _amp_update_scale_cuda_(Tensor& current_scale,
                                Tensor& growth_tracker,
                                const Tensor& found_inf,
                                double growth_factor,
                                double backoff_factor,
                                int64_t growth_interval) {
  TORCH_CHECK(growth_tracker.is_cuda(), "growth_tracker must be a CUDA tensor.");
  TORCH_CHECK(current_scale.is_cuda(), "current_scale must be a CUDA tensor.");
  TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
  
  // 核函数调用
  amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
    current_scale.mutable_data_ptr<float>(),
    growth_tracker.mutable_data_ptr<int>(),
    found_inf.const_data_ptr<float>(),
    growth_factor,
    backoff_factor,
    growth_interval);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return current_scale;
}

source: https://github.com/pytorch/pytorch/blob/0431d47eaaaceadab0c0f2400e186c02ef66d96c/aten/src/ATen/native/cuda/AmpKernels.cu#L222

在 PyTorch 框架中,Python 和底层 C++/CUDA 实现之间通过 PyTorch ATen 底层 API 和 PyBind11 来完成。torch._amp_update_scale_ 是一个 Python 函数,它实际上通过 PyTorch 的 C++扩展机制 调用底层 C++ 实现(_amp_update_scale_cuda_)。以下是调用流程的完整解读:


1. Python 调用 C++ 的机制

torch._amp_update_scale_ 是一个绑定在 PyTorch ATen C++核心函数上的 Python 接口。其调用流程如下:

  1. Python 调用 torch._amp_update_scale_ 函数:

    • Python 参数 _scale, _growth_tracker, 和 found_inf_combinedtorch.Tensor 对象。
    • 这些对象在 Python 中使用,但底层通过 PyObject 引用 C++ 的 at::Tensor 对象(ATen 核心张量)。
  2. 通过绑定跳转到 C++ 函数 _amp_update_scale_cuda_

    • torch._amp_update_scale_ 被注册到一个 dispatcher(分发器),它会根据运行设备(CPU 或 GPU)选择合适的后端实现。
    • 在 GPU 上运行时,分发器会调用底层 CUDA 实现 _amp_update_scale_cuda_

2. Python 到 C++ 的具体流程

  1. ATen 和 PyTorch 的 Operator 注册系统

    PyTorch 使用 torch::RegisterOperators 注册 C++ 函数 _amp_update_scale_cuda_,并将其绑定到 Python 的 torch._amp_update_scale_

    注册流程示例

    cpp 复制代码
    TORCH_LIBRARY_IMPL(aten, CUDA, m) {
        m.impl("_amp_update_scale_", &_amp_update_scale_cuda_);
    }
    • TORCH_LIBRARY_IMPL 用于将 CUDA 实现 _amp_update_scale_cuda_ 注册到 ATen。
    • Python 代码调用 torch._amp_update_scale_ 时,会被自动映射到 C++ 实现 _amp_update_scale_cuda_
  2. Python 的 Tensor 转换为 C++ 的 at::Tensor

    torch._amp_update_scale_ 被调用时,Python 中的 Tensor 对象通过 PyBind11 自动转换为对应的 at::Tensor 对象。例如:

    python 复制代码
    torch._amp_update_scale_(
        _scale,             # Python Tensor -> at::Tensor
        _growth_tracker,    # Python Tensor -> at::Tensor
        found_inf_combined, # Python Tensor -> at::Tensor
        self._growth_factor, # Python float -> C++ double
        self._backoff_factor, # Python float -> C++ double
        self._growth_interval # Python int -> C++ int64_t
    )
  3. 调用 C++ 函数 _amp_update_scale_cuda_

    • 参数从 Python 传递到 _amp_update_scale_cuda_,对应 current_scale, growth_tracker, found_inf 等。
    • 在 C++ 中,_amp_update_scale_cuda_ 函数会调用底层 CUDA 核心函数 amp_update_scale_cuda_kernel,执行缩放更新逻辑。

3. C++ 到 CUDA 核心函数的调用流程

_amp_update_scale_cuda_ 中,C++ 调用 CUDA 核心代码的主要流程是:

  1. 参数检查

    使用 TORCH_CHECK 确保 current_scale, growth_tracker, 和 found_inf 都是 CUDA 张量:

    cpp 复制代码
    TORCH_CHECK(growth_tracker.is_cuda(), "growth_tracker must be a CUDA tensor.");
    TORCH_CHECK(current_scale.is_cuda(), "current_scale must be a CUDA tensor.");
    TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
  2. 启动 CUDA 核函数

    使用 CUDA 的核函数调用机制 <<<...>>> 启动 CUDA 内核函数:

    cpp 复制代码
    amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
        current_scale.mutable_data_ptr<float>(), // 当前缩放因子的指针
        growth_tracker.mutable_data_ptr<int>(), // 成功步数计数器的指针
        found_inf.const_data_ptr<float>(),      // 梯度溢出的标志
        growth_factor,                          // 增长因子
        backoff_factor,                         // 回退因子
        growth_interval                         // 增长间隔
    );

    调用过程

    • CUDA 核函数 amp_update_scale_cuda_kernel 被调度到当前 GPU 的流(at::cuda::getCurrentCUDAStream())。
    • 各个张量(如 current_scale, growth_tracker)通过 .data_ptr<T>() 方法获取指针,传递给 CUDA 核函数。
  3. CUDA 核函数执行

    核函数 amp_update_scale_cuda_kernel 在 GPU 上执行,完成缩放因子的动态调整。逻辑详见问题中的 CUDA 实现。

  4. 内核启动检查

    启动内核后,通过 C10_CUDA_KERNEL_LAUNCH_CHECK() 检查 CUDA 内核是否成功运行。

    cpp 复制代码
    C10_CUDA_KERNEL_LAUNCH_CHECK();

4. 总结调用链路

完整调用链如下:

  1. Python 层

    python 复制代码
    torch._amp_update_scale_(
        _scale, _growth_tracker, found_inf_combined,
        self._growth_factor, self._backoff_factor, self._growth_interval
    )
    • Python 张量(torch.Tensor)通过 PyBind11 转换为 C++ 张量(at::Tensor)。
  2. C++ 层

    cpp 复制代码
    Tensor& _amp_update_scale_cuda_(
        Tensor& current_scale, Tensor& growth_tracker, const Tensor& found_inf,
        double growth_factor, double backoff_factor, int64_t growth_interval
    ) {
        // 调用 CUDA 核函数
        amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
            current_scale.mutable_data_ptr<float>(),
            growth_tracker.mutable_data_ptr<int>(),
            found_inf.const_data_ptr<float>(),
            growth_factor, backoff_factor, growth_interval
        );
        C10_CUDA_KERNEL_LAUNCH_CHECK();
        return current_scale;
    }
  3. CUDA 层

go 复制代码
// amp_update_scale_cuda_kernel is launched with a single thread to compute the new scale.
// The scale factor is maintained and updated on the GPU to avoid synchronization.
__global__ void amp_update_scale_cuda_kernel(float* current_scale,
                                             int* growth_tracker,
                                             const float* found_inf,
                                             double growth_factor,
                                             double backoff_factor,
                                             int growth_interval)
{
 // 核函数逻辑:根据是否溢出动态调整 current_scale 和 growth_tracker
  if (*found_inf) {
    *current_scale = (*current_scale)*backoff_factor;
    *growth_tracker = 0;
  } else {
    // Entering this branch means we just carried out a successful step,
    // so growth_tracker is incremented before comparing to growth_interval.
    auto successful = (*growth_tracker) + 1;
    if (successful == growth_interval) {
      auto new_scale = static_cast<float>((*current_scale)*growth_factor);
      // Do not grow the scale past fp32 bounds to inf.
      if (isfinite_ensure_cuda_math(new_scale)) {
          *current_scale = new_scale;
      }
      *growth_tracker = 0;
    } else {
      *growth_tracker = successful;
    }
  }
}

5. 补充说明

这种从 Python 到 C++ 再到 CUDA 的调用链是 PyTorch 的通用设计模式:

  • Python API 层:提供高层易用接口。
  • C++ ATen 层:实现设备无关的核心逻辑。
  • CUDA 内核层:实现高性能的设备特定操作。

后记

2025年1月2日15点22分于上海, 在GPT4o大模型辅助下完成。

相关推荐
power-辰南3 分钟前
Pytorch 三小时极限入门教程
人工智能·pytorch·深度学习
JINGWHALE118 分钟前
设计模式 结构型 代理模式(Proxy Pattern)与 常见技术框架应用 解析
前端·人工智能·后端·设计模式·性能优化·系统架构·代理模式
youcans_25 分钟前
【YOLO 项目实战】(12)红外/可见光多模态目标检测
人工智能·yolo·目标检测·计算机视觉·多模态
深蓝学院26 分钟前
Visual CoT:解锁视觉链式思维推理的潜能
人工智能·计算机视觉·目标跟踪
AI追随者27 分钟前
超越YOLO11!DEIM:先进的实时DETR目标检测
人工智能·深度学习·算法·目标检测·计算机视觉
卧式纯绿30 分钟前
自动驾驶3D目标检测综述(六)
人工智能·算法·目标检测·计算机视觉·3d·目标跟踪·自动驾驶
KeyPan1 小时前
【机器学习:一、机器学习简介】
人工智能·数码相机·算法·机器学习·计算机视觉
deardao1 小时前
【顶刊TPAMI 2025】多头编码(MHE)之极限分类 Part 1:背景动机
人工智能·深度学习·神经网络·数据挖掘·极限标签分类
沐欣工作室_lvyiyi1 小时前
基于单片机的家庭智能垃圾桶(论文+源码)
人工智能·stm32·单片机·嵌入式硬件·单片机毕业设计·垃圾桶
湫ccc1 小时前
《Opencv》基础操作详解(4)
人工智能·opencv·计算机视觉