PyTorch到C++再到 CUDA 的调用链(C++ ATen 层) :以torch._amp_update_scale_调用为例

今天在看pytorch源码,遇到的问题,记录一下 。

source:/lib/python3.10/site-packages/torch/amp/grad_scaler.py

go 复制代码
 torch._amp_update_scale_(
            _scale,
            _growth_tracker,
            found_inf_combined,
            self._growth_factor,
            self._backoff_factor,
            self._growth_interval,
        )

这段python代码是如何调用底层的cpp代码的

go 复制代码
Tensor& _amp_update_scale_cuda_(Tensor& current_scale,
                                Tensor& growth_tracker,
                                const Tensor& found_inf,
                                double growth_factor,
                                double backoff_factor,
                                int64_t growth_interval) {
  TORCH_CHECK(growth_tracker.is_cuda(), "growth_tracker must be a CUDA tensor.");
  TORCH_CHECK(current_scale.is_cuda(), "current_scale must be a CUDA tensor.");
  TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
  
  // 核函数调用
  amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
    current_scale.mutable_data_ptr<float>(),
    growth_tracker.mutable_data_ptr<int>(),
    found_inf.const_data_ptr<float>(),
    growth_factor,
    backoff_factor,
    growth_interval);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return current_scale;
}

source: https://github.com/pytorch/pytorch/blob/0431d47eaaaceadab0c0f2400e186c02ef66d96c/aten/src/ATen/native/cuda/AmpKernels.cu#L222

在 PyTorch 框架中,Python 和底层 C++/CUDA 实现之间通过 PyTorch ATen 底层 API 和 PyBind11 来完成。torch._amp_update_scale_ 是一个 Python 函数,它实际上通过 PyTorch 的 C++扩展机制 调用底层 C++ 实现(_amp_update_scale_cuda_)。以下是调用流程的完整解读:


1. Python 调用 C++ 的机制

torch._amp_update_scale_ 是一个绑定在 PyTorch ATen C++核心函数上的 Python 接口。其调用流程如下:

  1. Python 调用 torch._amp_update_scale_ 函数:

    • Python 参数 _scale, _growth_tracker, 和 found_inf_combinedtorch.Tensor 对象。
    • 这些对象在 Python 中使用,但底层通过 PyObject 引用 C++ 的 at::Tensor 对象(ATen 核心张量)。
  2. 通过绑定跳转到 C++ 函数 _amp_update_scale_cuda_

    • torch._amp_update_scale_ 被注册到一个 dispatcher(分发器),它会根据运行设备(CPU 或 GPU)选择合适的后端实现。
    • 在 GPU 上运行时,分发器会调用底层 CUDA 实现 _amp_update_scale_cuda_

2. Python 到 C++ 的具体流程

  1. ATen 和 PyTorch 的 Operator 注册系统

    PyTorch 使用 torch::RegisterOperators 注册 C++ 函数 _amp_update_scale_cuda_,并将其绑定到 Python 的 torch._amp_update_scale_

    注册流程示例

    cpp 复制代码
    TORCH_LIBRARY_IMPL(aten, CUDA, m) {
        m.impl("_amp_update_scale_", &_amp_update_scale_cuda_);
    }
    • TORCH_LIBRARY_IMPL 用于将 CUDA 实现 _amp_update_scale_cuda_ 注册到 ATen。
    • Python 代码调用 torch._amp_update_scale_ 时,会被自动映射到 C++ 实现 _amp_update_scale_cuda_
  2. Python 的 Tensor 转换为 C++ 的 at::Tensor

    torch._amp_update_scale_ 被调用时,Python 中的 Tensor 对象通过 PyBind11 自动转换为对应的 at::Tensor 对象。例如:

    python 复制代码
    torch._amp_update_scale_(
        _scale,             # Python Tensor -> at::Tensor
        _growth_tracker,    # Python Tensor -> at::Tensor
        found_inf_combined, # Python Tensor -> at::Tensor
        self._growth_factor, # Python float -> C++ double
        self._backoff_factor, # Python float -> C++ double
        self._growth_interval # Python int -> C++ int64_t
    )
  3. 调用 C++ 函数 _amp_update_scale_cuda_

    • 参数从 Python 传递到 _amp_update_scale_cuda_,对应 current_scale, growth_tracker, found_inf 等。
    • 在 C++ 中,_amp_update_scale_cuda_ 函数会调用底层 CUDA 核心函数 amp_update_scale_cuda_kernel,执行缩放更新逻辑。

3. C++ 到 CUDA 核心函数的调用流程

_amp_update_scale_cuda_ 中,C++ 调用 CUDA 核心代码的主要流程是:

  1. 参数检查

    使用 TORCH_CHECK 确保 current_scale, growth_tracker, 和 found_inf 都是 CUDA 张量:

    cpp 复制代码
    TORCH_CHECK(growth_tracker.is_cuda(), "growth_tracker must be a CUDA tensor.");
    TORCH_CHECK(current_scale.is_cuda(), "current_scale must be a CUDA tensor.");
    TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
  2. 启动 CUDA 核函数

    使用 CUDA 的核函数调用机制 <<<...>>> 启动 CUDA 内核函数:

    cpp 复制代码
    amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
        current_scale.mutable_data_ptr<float>(), // 当前缩放因子的指针
        growth_tracker.mutable_data_ptr<int>(), // 成功步数计数器的指针
        found_inf.const_data_ptr<float>(),      // 梯度溢出的标志
        growth_factor,                          // 增长因子
        backoff_factor,                         // 回退因子
        growth_interval                         // 增长间隔
    );

    调用过程

    • CUDA 核函数 amp_update_scale_cuda_kernel 被调度到当前 GPU 的流(at::cuda::getCurrentCUDAStream())。
    • 各个张量(如 current_scale, growth_tracker)通过 .data_ptr<T>() 方法获取指针,传递给 CUDA 核函数。
  3. CUDA 核函数执行

    核函数 amp_update_scale_cuda_kernel 在 GPU 上执行,完成缩放因子的动态调整。逻辑详见问题中的 CUDA 实现。

  4. 内核启动检查

    启动内核后,通过 C10_CUDA_KERNEL_LAUNCH_CHECK() 检查 CUDA 内核是否成功运行。

    cpp 复制代码
    C10_CUDA_KERNEL_LAUNCH_CHECK();

4. 总结调用链路

完整调用链如下:

  1. Python 层

    python 复制代码
    torch._amp_update_scale_(
        _scale, _growth_tracker, found_inf_combined,
        self._growth_factor, self._backoff_factor, self._growth_interval
    )
    • Python 张量(torch.Tensor)通过 PyBind11 转换为 C++ 张量(at::Tensor)。
  2. C++ 层

    cpp 复制代码
    Tensor& _amp_update_scale_cuda_(
        Tensor& current_scale, Tensor& growth_tracker, const Tensor& found_inf,
        double growth_factor, double backoff_factor, int64_t growth_interval
    ) {
        // 调用 CUDA 核函数
        amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
            current_scale.mutable_data_ptr<float>(),
            growth_tracker.mutable_data_ptr<int>(),
            found_inf.const_data_ptr<float>(),
            growth_factor, backoff_factor, growth_interval
        );
        C10_CUDA_KERNEL_LAUNCH_CHECK();
        return current_scale;
    }
  3. CUDA 层

go 复制代码
// amp_update_scale_cuda_kernel is launched with a single thread to compute the new scale.
// The scale factor is maintained and updated on the GPU to avoid synchronization.
__global__ void amp_update_scale_cuda_kernel(float* current_scale,
                                             int* growth_tracker,
                                             const float* found_inf,
                                             double growth_factor,
                                             double backoff_factor,
                                             int growth_interval)
{
 // 核函数逻辑:根据是否溢出动态调整 current_scale 和 growth_tracker
  if (*found_inf) {
    *current_scale = (*current_scale)*backoff_factor;
    *growth_tracker = 0;
  } else {
    // Entering this branch means we just carried out a successful step,
    // so growth_tracker is incremented before comparing to growth_interval.
    auto successful = (*growth_tracker) + 1;
    if (successful == growth_interval) {
      auto new_scale = static_cast<float>((*current_scale)*growth_factor);
      // Do not grow the scale past fp32 bounds to inf.
      if (isfinite_ensure_cuda_math(new_scale)) {
          *current_scale = new_scale;
      }
      *growth_tracker = 0;
    } else {
      *growth_tracker = successful;
    }
  }
}

5. 补充说明

这种从 Python 到 C++ 再到 CUDA 的调用链是 PyTorch 的通用设计模式:

  • Python API 层:提供高层易用接口。
  • C++ ATen 层:实现设备无关的核心逻辑。
  • CUDA 内核层:实现高性能的设备特定操作。

后记

2025年1月2日15点22分于上海, 在GPT4o大模型辅助下完成。

相关推荐
运器12313 分钟前
【一起来学AI大模型】PyTorch DataLoader 实战指南
大数据·人工智能·pytorch·python·深度学习·ai·ai编程
超龄超能程序猿27 分钟前
(5)机器学习小白入门 YOLOv:数据需求与图像不足应对策略
人工智能·python·机器学习·numpy·pandas·scipy
卷福同学28 分钟前
【AI编程】AI+高德MCP不到10分钟搞定上海三日游
人工智能·算法·程序员
帅次36 分钟前
系统分析师-计算机系统-输入输出系统
人工智能·分布式·深度学习·神经网络·架构·系统架构·硬件架构
AndrewHZ1 小时前
【图像处理基石】如何入门大规模三维重建?
人工智能·深度学习·大模型·llm·三维重建·立体视觉·大规模三维重建
5G行业应用1 小时前
【赠书福利,回馈公号读者】《智慧城市与智能网联汽车,融合创新发展之路》
人工智能·汽车·智慧城市
Tanecious.1 小时前
C++--红黑树
开发语言·c++
悟空胆好小1 小时前
分音塔科技(BABEL Technology) 的公司背景、股权构成、产品类型及技术能力的全方位解读
网络·人工智能·科技·嵌入式硬件
探讨探讨AGV1 小时前
以科技赋能未来,科聪持续支持青年创新实践 —— 第七届“科聪杯”浙江省大学生智能机器人创意竞赛圆满落幕
人工智能·科技·机器人
cwn_2 小时前
回归(多项式回归)
人工智能·机器学习·数据挖掘·回归