Paadle Custom device 代码阅读记录

Paddle Custom device的仓库位于:github.com/PaddlePaddl... 其示例代码为在backends/custom_cpu,本文主要讲解一下custom_cpu使用的paddle相关的代码,介绍paddle是如何实现custom device的。

1 Custom cpu算子注册代码

以add算子为例,其注册代码为于backends/custom_cpu/kernels/elementwise_kernel.cc,代码如下:

cpp 复制代码
template <typename T>
void AddKernel(const phi::Context& dev_ctx,
               const phi::DenseTensor& x,
               const phi::DenseTensor& y,
               phi::DenseTensor* out) {
  int axis = -1;
  custom_kernel::AddRawKernel<T>(dev_ctx, x, y, axis, out);
}

PD_BUILD_PHI_KERNEL(add,
                    custom_cpu,
                    ALL_LAYOUT,
                    custom_kernel::AddKernel,
                    int32_t,
                    int64_t,
                    float,
                    double) {}

注意,该文件include的头文件中,和paddle相关的只有paddle/phi/capi/all.h,该文件内容如下:

h 复制代码
#pragma once

#if !defined(_WIN32)

#include "paddle/phi/capi/include/c_data_type.h"
#include "paddle/phi/capi/include/c_device_context.h"
#include "paddle/phi/capi/include/c_infer_meta_context.h"
#include "paddle/phi/capi/include/c_int_array.h"
#include "paddle/phi/capi/include/c_kernel_context.h"
#include "paddle/phi/capi/include/c_kernel_factory.h"
#include "paddle/phi/capi/include/c_kernel_registry.h"
#include "paddle/phi/capi/include/c_meta_tensor.h"
#include "paddle/phi/capi/include/c_place.h"
#include "paddle/phi/capi/include/c_scalar.h"
#include "paddle/phi/capi/include/c_tensor.h"
#include "paddle/phi/capi/include/data_type.h"
#include "paddle/phi/capi/include/kernel_registry.h"

#endif

因此paddle为custom device提供的头文件全部都在paddle/phi/capi文件夹中,由于只include了capi文件夹中的头文件,因此AddKernel函数参数中的phi::DenseTensor数据类型实际上并不是paddle/phi/core/dense_tensor.h中定义的phi::DenseTensor,而是paddle/phi/capi/include/kernel_registry.h中如下定义的数据类型:

cpp 复制代码
namespace phi {
.....
using DenseTensor = capi::DenseTensor;
.....
}

因此,custom_cpu中算子使用的phi::DenseTensor实际上是paddle/phi/capi/include/wrapper_base.h中定义的phi::capi::DenseTensor,同理,custom_cpu中算子使用的phi::Context实际上是paddle/phi/capi/include/wrapper_base.h中定义的phi::capi::DeviceContext

2 phi::capi::DenseTensor讲解

该class的定义位于paddle/phi/capi/include/wrapper_base.h,定义如下:

cpp 复制代码
template <typename T>
class WrapperBase {
 public:
  explicit WrapperBase(T* ptr, bool own = false) : data_(ptr), own_(own) {}

  inline T* raw_data() const { return data_; }

  inline bool own_data() const { return own_; }

  inline void reset(const T* ptr) { data_ = ptr; }

 private:
  T* data_;
  bool own_;
};


// PD_Tensor定义于paddle/phi/capi/include/c_tensor.h
// 定义内容为:typedef struct PD_Tensor PD_Tensor;
// 即PD_Tensor为一个空结构体
class DenseTensor : public WrapperBase<PD_Tensor> { 
//继承WrapperBase<PD_Tensor>,因此DenseTensor有一个PD_Tensor类型的指针作为私有成员变量
 public:
  DenseTensor() : WrapperBase(PD_NewTensor(), true) {}

  explicit DenseTensor(PD_Tensor* tensor) : WrapperBase(tensor) {}

  ~DenseTensor() {
    if (own_data()) {
      PD_DeleteTensor(raw_data());
    }
  }

  size_t offset() const {
    C_Status status;
    //PD_TensorGetOffset函数传入的第一个参数是PD_Tensor类型的指针
    auto offset = PD_TensorGetOffset(raw_data(), &status);
    PD_CHECK_STATUS(status);
    return offset;
  }
  
............................................
}


//以下是定义于paddle/phi/capi/lib/c_tensor.cc的PD_TensorGetOffset函数
size_t PD_TensorGetOffset(const PD_Tensor* tensor, PD_Status* status) {
  if (status) {
    if (!tensor) {
      *status = C_FAILED;
      return 0;
    }
    *status = C_SUCCESS;
  }

  auto cc_tensor = reinterpret_cast<const phi::DenseTensor*>(tensor); 将指针转为phi::DenseTensor类型并调用其方法
  return cc_tensor->offset();
}
相关推荐
羑悻的小杀马特8 分钟前
从混沌到秩序:数据科学的热力学第二定律破局——线性回归的熵减模型 × 最小二乘的能量最小化 × 梯度下降的负反馈控制系统,用物理定律重构智能算法的统一场论
人工智能·算法·机器学习
仟濹24 分钟前
【数据结构】「栈」(顺序栈、共享栈、链栈)
c语言·数据结构·算法
青梅主码27 分钟前
2025年7月全球大模型最新排名发布!企业与个人该如何选择最适合你的 AI 得力助手?
人工智能·算法
许愿与你永世安宁37 分钟前
强化学习 (11)随机近似
人工智能·算法·强化学习·梯度下降·随机近似
凤年徐1 小时前
【数据结构】栈和队列-----数据结构中的双生花
c语言·开发语言·数据结构·c++·笔记·算法·链表
KoiHeng4 小时前
部分排序算法的Java模拟实现(复习向,非0基础)
java·算法·排序算法
艾莉丝努力练剑10 小时前
【数据结构与算法】数据结构初阶:详解顺序表和链表(四)——单链表(下)
c语言·开发语言·数据结构·学习·算法·链表
yngsqq12 小时前
移动碰撞法 ——套料排版算法——CAD c#
算法
秋说13 小时前
【PTA数据结构 | C语言版】根据层序序列重构二叉树
c语言·数据结构·算法
秋说14 小时前
【PTA数据结构 | C语言版】前序遍历二叉树
c语言·数据结构·算法