PaddlePaddle算子注册原理阅读记录

PaddlePaddle通过REGISTER_OPERATOR宏来进行算子注册,以paddle/fluid/operators/gru_unit_op.cc为例,其注册代码如下:

cpp 复制代码
REGISTER_OPERATOR(gru_unit,
                  ops::GRUUnitOp,           //在该文件中定义的class
                  ops::GRUUnitOpMaker,      //在该文件中定义的class
                  ops::GRUUnitGradOpMaker<paddle::framework::OpDesc>,   //在该文件中定义的class
                  ops::GRUUnitGradOpMaker<paddle::imperative::OpBase>);

其中ops::GRUUnitOpops::GRUUnitOpMakerops::GRUUnitGradOpMaker都是在该文件中定义的class。首先看一下REGISTER_OPERATOR宏,其定义如下:

cpp 复制代码
#define REGISTER_OPERATOR(op_type, op_class, ...)                        \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                        \
      __reg_op__##op_type,                                               \
      "REGISTER_OPERATOR must be called in global namespace");           \
  static ::paddle::framework::OperatorRegistrar<op_class, ##__VA_ARGS__> \
      __op_registrar_##op_type##__(#op_type);                            \
  int TouchOpRegistrar_##op_type() {                                     \
    __op_registrar_##op_type##__.Touch();                                \
    return 0;                                                            \
  }

注册gru_unit的宏展开后代码如下:

cpp 复制代码
//STATIC_ASSERT_GLOBAL_NAMESPACE检查REGISTER_OPERATOR宏是否在全局空间中被
//使用,如果不是在全局空间中使用则在编译期报错                              
struct __test_global_namespace___reg_op__gru_unit__ {};
static_assert(
    std ::is_same<::__test_global_namespace___reg_op__gru_unit__,
                  __test_global_namespace___reg_op__gru_unit__>::value,
    "REGISTER_OPERATOR must be called in global namespace");
static ::paddle ::framework ::OperatorRegistrar<
    ops ::GRUUnitOp,
    ops ::GRUUnitOpMaker,
    ops ::GRUUnitGradOpMaker<paddle ::framework ::OpDesc>,
    ops ::GRUUnitGradOpMaker<paddle ::imperative ::OpBase>>
    __op_registrar_gru_unit__("gru_unit");
int TouchOpRegistrar_gru_unit() {
  __op_registrar_gru_unit__.Touch();
  return 0;
};

::paddle ::framework ::OperatorRegistrar的定义如下:

cpp 复制代码
class Registrar {
 public:
  // In our design, various kinds of classes, e.g., operators and kernels,
  // have their corresponding registry and registrar. The action of
  // registration is in the constructor of a global registrar variable, which
  // are not used in the code that calls package framework, and would
  // be removed from the generated binary file by the linker. To avoid such
  // removal, we add Touch to all registrar classes and make USE_OP macros to
  // call this method. So, as long as the callee code calls USE_OP, the global
  // registrar variable won't be removed by the linker.
  void Touch() {}
};

template <typename... ARGS>
struct OperatorRegistrar : public Registrar {
  //在构造函数中完成注册
  explicit OperatorRegistrar(const char* op_type) {
    //OpInfoMap主要是一个key为std::string,value为OpInfo的字典,这里要求这个字典之前没有注册该算子
    PADDLE_ENFORCE_EQ(
        OpInfoMap::Instance().Has(op_type),
        false,
        platform::errors::AlreadyExists(
            "Operator '%s' is registered more than once.", op_type));
    static_assert(sizeof...(ARGS) != 0,
                  "OperatorRegistrar should be invoked at least by OpClass");
    OpInfo info;
    details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info); //算子信息写入OpInfo
    OpInfoMap::Instance().Insert(op_type, info); //把OpInfo注册到OpInfoMap
  }
};
相关推荐
冒泡的肥皂2 个月前
java表格识别PaddleOcr总结
java·后端·百度飞桨
飞桨PaddlePaddle7 个月前
一站式解读多模态——Transformer、Embedding、主流模型与通用任务实战(下)
百度·百度飞桨
飞桨PaddlePaddle7 个月前
一站式解读多模态——Transformer、Embedding、主流模型与通用任务实战(上)
百度·百度飞桨
飞桨PaddlePaddle8 个月前
文心一言赋能问卷生成,打造高效问卷调研工具
百度·百度飞桨
铁皮鸭子8 个月前
PaddleOCR 服务化部署(基于PaddleHub Serving)
机器学习·ocr·paddlepaddle·百度飞桨·表格识别·paddleocr
飞桨PaddlePaddle8 个月前
RAG一文读懂!概念、场景、优势、对比微调与项目代码示例
百度·百度飞桨
飞桨PaddlePaddle8 个月前
文心一言变身虚拟患者,助力医学生轻松开启「实践模式」
百度·文心一言·百度飞桨
飞桨PaddlePaddle9 个月前
AI Agent深入浅出——以ERNIE SDK和多工具智能编排为例
百度·llm·百度飞桨
软工菜鸡9 个月前
《零基础实践深度学习》1.4.1飞桨产业级深度学习开源开放平台介绍
人工智能·深度学习·机器学习·ai·百度飞桨