PaddlePaddle通过REGISTER_OPERATOR
宏来进行算子注册,以paddle/fluid/operators/gru_unit_op.cc
为例,其注册代码如下:
cpp
REGISTER_OPERATOR(gru_unit,
ops::GRUUnitOp, //在该文件中定义的class
ops::GRUUnitOpMaker, //在该文件中定义的class
ops::GRUUnitGradOpMaker<paddle::framework::OpDesc>, //在该文件中定义的class
ops::GRUUnitGradOpMaker<paddle::imperative::OpBase>);
其中ops::GRUUnitOp
、ops::GRUUnitOpMaker
和ops::GRUUnitGradOpMaker
都是在该文件中定义的class。首先看一下REGISTER_OPERATOR
宏,其定义如下:
cpp
#define REGISTER_OPERATOR(op_type, op_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, \
"REGISTER_OPERATOR must be called in global namespace"); \
static ::paddle::framework::OperatorRegistrar<op_class, ##__VA_ARGS__> \
__op_registrar_##op_type##__(#op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
}
注册gru_unit
的宏展开后代码如下:
cpp
//STATIC_ASSERT_GLOBAL_NAMESPACE检查REGISTER_OPERATOR宏是否在全局空间中被
//使用,如果不是在全局空间中使用则在编译期报错
struct __test_global_namespace___reg_op__gru_unit__ {};
static_assert(
std ::is_same<::__test_global_namespace___reg_op__gru_unit__,
__test_global_namespace___reg_op__gru_unit__>::value,
"REGISTER_OPERATOR must be called in global namespace");
static ::paddle ::framework ::OperatorRegistrar<
ops ::GRUUnitOp,
ops ::GRUUnitOpMaker,
ops ::GRUUnitGradOpMaker<paddle ::framework ::OpDesc>,
ops ::GRUUnitGradOpMaker<paddle ::imperative ::OpBase>>
__op_registrar_gru_unit__("gru_unit");
int TouchOpRegistrar_gru_unit() {
__op_registrar_gru_unit__.Touch();
return 0;
};
::paddle ::framework ::OperatorRegistrar
的定义如下:
cpp
class Registrar {
public:
// In our design, various kinds of classes, e.g., operators and kernels,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which
// are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_OP macros to
// call this method. So, as long as the callee code calls USE_OP, the global
// registrar variable won't be removed by the linker.
void Touch() {}
};
template <typename... ARGS>
struct OperatorRegistrar : public Registrar {
//在构造函数中完成注册
explicit OperatorRegistrar(const char* op_type) {
//OpInfoMap主要是一个key为std::string,value为OpInfo的字典,这里要求这个字典之前没有注册该算子
PADDLE_ENFORCE_EQ(
OpInfoMap::Instance().Has(op_type),
false,
platform::errors::AlreadyExists(
"Operator '%s' is registered more than once.", op_type));
static_assert(sizeof...(ARGS) != 0,
"OperatorRegistrar should be invoked at least by OpClass");
OpInfo info;
details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info); //算子信息写入OpInfo
OpInfoMap::Instance().Insert(op_type, info); //把OpInfo注册到OpInfoMap
}
};