PaddlePaddle算子注册原理阅读记录

PaddlePaddle通过REGISTER_OPERATOR宏来进行算子注册,以paddle/fluid/operators/gru_unit_op.cc为例,其注册代码如下:

cpp 复制代码
REGISTER_OPERATOR(gru_unit,
                  ops::GRUUnitOp,           //在该文件中定义的class
                  ops::GRUUnitOpMaker,      //在该文件中定义的class
                  ops::GRUUnitGradOpMaker<paddle::framework::OpDesc>,   //在该文件中定义的class
                  ops::GRUUnitGradOpMaker<paddle::imperative::OpBase>);

其中ops::GRUUnitOpops::GRUUnitOpMakerops::GRUUnitGradOpMaker都是在该文件中定义的class。首先看一下REGISTER_OPERATOR宏,其定义如下:

cpp 复制代码
#define REGISTER_OPERATOR(op_type, op_class, ...)                        \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                        \
      __reg_op__##op_type,                                               \
      "REGISTER_OPERATOR must be called in global namespace");           \
  static ::paddle::framework::OperatorRegistrar<op_class, ##__VA_ARGS__> \
      __op_registrar_##op_type##__(#op_type);                            \
  int TouchOpRegistrar_##op_type() {                                     \
    __op_registrar_##op_type##__.Touch();                                \
    return 0;                                                            \
  }

注册gru_unit的宏展开后代码如下:

cpp 复制代码
//STATIC_ASSERT_GLOBAL_NAMESPACE检查REGISTER_OPERATOR宏是否在全局空间中被
//使用,如果不是在全局空间中使用则在编译期报错                              
struct __test_global_namespace___reg_op__gru_unit__ {};
static_assert(
    std ::is_same<::__test_global_namespace___reg_op__gru_unit__,
                  __test_global_namespace___reg_op__gru_unit__>::value,
    "REGISTER_OPERATOR must be called in global namespace");
static ::paddle ::framework ::OperatorRegistrar<
    ops ::GRUUnitOp,
    ops ::GRUUnitOpMaker,
    ops ::GRUUnitGradOpMaker<paddle ::framework ::OpDesc>,
    ops ::GRUUnitGradOpMaker<paddle ::imperative ::OpBase>>
    __op_registrar_gru_unit__("gru_unit");
int TouchOpRegistrar_gru_unit() {
  __op_registrar_gru_unit__.Touch();
  return 0;
};

::paddle ::framework ::OperatorRegistrar的定义如下:

cpp 复制代码
class Registrar {
 public:
  // In our design, various kinds of classes, e.g., operators and kernels,
  // have their corresponding registry and registrar. The action of
  // registration is in the constructor of a global registrar variable, which
  // are not used in the code that calls package framework, and would
  // be removed from the generated binary file by the linker. To avoid such
  // removal, we add Touch to all registrar classes and make USE_OP macros to
  // call this method. So, as long as the callee code calls USE_OP, the global
  // registrar variable won't be removed by the linker.
  void Touch() {}
};

template <typename... ARGS>
struct OperatorRegistrar : public Registrar {
  //在构造函数中完成注册
  explicit OperatorRegistrar(const char* op_type) {
    //OpInfoMap主要是一个key为std::string,value为OpInfo的字典,这里要求这个字典之前没有注册该算子
    PADDLE_ENFORCE_EQ(
        OpInfoMap::Instance().Has(op_type),
        false,
        platform::errors::AlreadyExists(
            "Operator '%s' is registered more than once.", op_type));
    static_assert(sizeof...(ARGS) != 0,
                  "OperatorRegistrar should be invoked at least by OpClass");
    OpInfo info;
    details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info); //算子信息写入OpInfo
    OpInfoMap::Instance().Insert(op_type, info); //把OpInfo注册到OpInfoMap
  }
};
相关推荐
上官胡闹2 个月前
基于vLLM的PaddleOCR-VL部署指南
人工智能·百度飞桨
战场小包2 个月前
PaddleOCR-VL,超强文字识别能力,PDF的拯救者
人工智能·百度飞桨
百度Geek说8 个月前
飞桨新一代框架3.0正式发布:加速大模型时代的技术创新与产业应用
百度·百度飞桨
Harry技术9 个月前
基于PaddleNLP使用DeepSeek-R1搭建智能体
百度飞桨·deepseek
冒泡的肥皂1 年前
java表格识别PaddleOcr总结
java·后端·百度飞桨
飞桨PaddlePaddle2 年前
一站式解读多模态——Transformer、Embedding、主流模型与通用任务实战(下)
百度·百度飞桨
飞桨PaddlePaddle2 年前
一站式解读多模态——Transformer、Embedding、主流模型与通用任务实战(上)
百度·百度飞桨
飞桨PaddlePaddle2 年前
文心一言赋能问卷生成,打造高效问卷调研工具
百度·百度飞桨
铁皮鸭子2 年前
PaddleOCR 服务化部署(基于PaddleHub Serving)
机器学习·ocr·paddlepaddle·百度飞桨·表格识别·paddleocr