注:新手文章,欢迎指正!以下内容基于pytorch2.0.0
pytorch的官方教程pytorch.org/tutorials/a... 中,写了注册算子的主要方式是:
cpp
TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
m.impl(<myadd_schema>, &myadd_autograd);
}
pytorch代码中,/home/pytorch/torch/library.h中定义了TORCH_LIBRARY_IMPL
宏:
cpp
#define TORCH_LIBRARY_IMPL(ns, k, m) _TORCH_LIBRARY_IMPL(ns, k, m, C10_UID)
_TORCH_LIBRARY_IMPL
宏的定义如下:
cpp
#define _TORCH_LIBRARY_IMPL(ns, k, m, uid) \
static void C10_CONCATENATE( \
TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library&); \
static const torch::detail::TorchLibraryInit C10_CONCATENATE( \
TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \
torch::Library::IMPL, \
c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check( \
c10::DispatchKey::k)>( \
[]() { \
return &C10_CONCATENATE( \
TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid); \
}, \
[]() { return [](torch::Library&) -> void {}; }), \
#ns, \
c10::make_optional(c10::DispatchKey::k), \
__FILE__, \
__LINE__); \
void C10_CONCATENATE( \
TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library & m)
首先看C10_UID
,其定义为:
cpp
#define C10_UID __COUNTER__
#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __COUNTER__)
因此其实际上为一个全局唯一的ID号。
C10_CONCATENATE
的定义如下:
cpp
#define C10_CONCATENATE_IMPL(s1, s2) s1##s2
#define C10_CONCATENATE(s1, s2) C10_CONCATENATE_IMPL(s1, s2)
可见其就是连接了两个字符串,如果看不懂可以查一下##
在C/C++预处理中的作用。
_TORCH_LIBRARY_IMPL
的定义可以被分为以下三个部分:
- 声明一个静态函数:
cpp
static void C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library&);
函数名为TORCH_LIBRARY_IMPL_init_+ns+k+uid
,假设TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m)
的UID为20,那么函数名为: TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20
- 定义一个cpp文件内部的常量:
cpp
static const torch::detail::TorchLibraryInit C10_CONCATENATE( \
TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \
torch::Library::IMPL, \
c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check( \
c10::DispatchKey::k)>( \
[]() { \
return &C10_CONCATENATE( \
TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid); \
}, \
[]() { return [](torch::Library&) -> void {}; }), \
#ns, \
c10::make_optional(c10::DispatchKey::k), \
__FILE__, \
__LINE__); \
该常量类型为static const torch::detail::TorchLibraryInit
,仍然以上面的例子为例,其名字为: TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20
,其和上面定义的静态函数的名字的差别就是多了一个static字符串。宏展开后,整段代码为如下:
cpp
static const torch::detail::TorchLibraryInit //返回类型
TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20(
torch::Library::IMPL, //参数1,Library::Kind类型
c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::AutogradPrivateUse1)>(
[]() {return &TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20;},
[]() { return [](torch::Library&) -> void {}; }
), //参数2,InitFn*类型
"aten", //参数3,const char*类型
c10::make_optional(c10::DispatchKey::AutogradPrivateUse1), //参数4,c10::optional<c10::DispatchKey>类型
__FILE__, //参数5,const char*类型
__LINE__); //参数6,uint32_t类型
TorchLibraryInit
的类定义如下:
cpp
class TorchLibraryInit final {
private:
using InitFn = void(Library&);
Library lib_;
public:
TorchLibraryInit(
Library::Kind kind,
InitFn* fn,
const char* ns,
c10::optional<c10::DispatchKey> k,
const char* file,
uint32_t line)
: lib_(kind, ns, k, file, line) {
fn(lib_);
}
};
其有只包含一个Library
类型的私有成员变量,注意其初始构造函数中,会先用kind, ns, k, file, line
初始化lib_
,再用传入的InitFn
类型,也就是void(Library&)
类型的函数初始化这个私有成员变量lib_
。
在定义TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20
的时候,第一个参数Library::Kind kind
为torch::Library::IMPL
,第二个参数为
rust
c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::AutogradPrivateUse1)>(
[]() {return &TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20;},
[]() { return [](torch::Library&) -> void {}; }
), //参数2,InitFn*类型
首先看模板参数c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::AutogradPrivateUse1)
,其定义为:
cpp
constexpr bool dispatch_key_allowlist_check(DispatchKey /*k*/) {
#ifdef C10_MOBILE
return true;
// Disabled for now: to be enabled later!
// return k == DispatchKey::CPU || k == DispatchKey::Vulkan || k == DispatchKey::QuantizedCPU || k == DispatchKey::BackendSelect || k == DispatchKey::CatchAll;
#else
return true;
#endif
}
可见其目前无脑返回true,因此第二个参数变成:
cpp
c10::guts::if_constexpr<true>(
[]() {return &TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20;},
[]() { return [](torch::Library&) -> void {}; }
), //参数2,InitFn*类型
if_constexpr
的定义如下:
cpp
template <bool Condition, class ThenCallback, class ElseCallback>
decltype(auto) if_constexpr(
ThenCallback&& thenCallback,
ElseCallback&& elseCallback) {
#if defined(__cpp_if_constexpr)
// If we have C++17, just use it's "if constexpr" feature instead of wrapping
// it. This will give us better error messages.
if constexpr (Condition) {
if constexpr (detail::function_takes_identity_argument<
ThenCallback>::value) {
// Note that we use static_cast<T&&>(t) instead of std::forward (or
// ::std::forward) because using the latter produces some compilation
// errors about ambiguous `std` on MSVC when using C++17. This static_cast
// is just what std::forward is doing under the hood, and is equivalent.
return static_cast<ThenCallback&&>(thenCallback)(detail::_identity());
} else {
return static_cast<ThenCallback&&>(thenCallback)();
}
} else {
if constexpr (detail::function_takes_identity_argument<
ElseCallback>::value) {
return static_cast<ElseCallback&&>(elseCallback)(detail::_identity());
} else {
return static_cast<ElseCallback&&>(elseCallback)();
}
}
#else
// C++14 implementation of if constexpr
return detail::_if_constexpr<Condition>::call(
static_cast<ThenCallback&&>(thenCallback),
static_cast<ElseCallback&&>(elseCallback));
#endif
}
这里有点炫技的味道了,直接看注释:
cpp
Example 1: simple constexpr if/then/else
template<int arg> int increment_absolute_value() {
int result = arg;
if_constexpr<(arg > 0)>(
[&] { ++result; } // then-case
[&] { --result; } // else-case
);
return result;
}
所以这就是一个简单的模板编译期if else,由于其模板参数为true
,因此第二个参数就是第一部分定义的静态函数TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20
,之后的参数就不再赘述了,值得注意的是,第四个参数c10::make_optional(c10::DispatchKey::AutogradPrivateUse1)
颇为复杂。
-
正式定义第一步声明的静态函数,宏展开后为:
javascriptvoid TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20(torch::Library & m){ m.impl(<myadd_schema>, &myadd_autograd); }
整个代码简化之前为:
cpp
TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
m.impl(<myadd_schema>, &myadd_autograd);
}
宏展开+简化后为:
cpp
static void TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20(torch::Library & m);
static const torch::detail::TorchLibraryInit TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20(
torch::Library::IMPL, //参数1,Library::Kind类型
&TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20, //参数2,InitFn*类型
"aten", //参数3,const char*类型
c10::make_optional(c10::DispatchKey::AutogradPrivateUse1), //参数4,c10::optional<c10::DispatchKey>类型
__FILE__, //参数5,const char*类型
__LINE__); //参数6,uint32_t类型
void TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20(torch::Library & m){
m.impl(<myadd_schema>, &myadd_autograd);
}
//TorchLibraryInit的定义,在library.h中定义
class TorchLibraryInit final {
private:
using InitFn = void(Library&);
Library lib_;
public:
TorchLibraryInit(
Library::Kind kind,
InitFn* fn,
const char* ns,
c10::optional<c10::DispatchKey> k,
const char* file,
uint32_t line)
: lib_(kind, ns, k, file, line) {
fn(lib_);
}
};
到这里总结一下:
① 第一部分声明了一个静态函数TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20
。
② 第二部分声明了一个torch::detail::TorchLibraryInit
类型的静态常量TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20
,在有一个Library
类型的成员变量,通过传入的参数和第一部分声明的静态函数来初始化这个成员变量。
③ 第三部分则是实现了第一部分声明的函数。
注意这个函数通过调用torch::Library
类型参数的impl
成员函数来实现算子注册,而传入的实参实际上第二部分声明的静态常量的私有成员变量,而第二部分的静态常量名称为TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_##uid
,也就是取决于命名空间(namespace)、设备(cpu or cuda or XXX)以及UID。
TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20
的初始构造函数利用TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20
来初始化其私有成员变量lib_
,初始化方法为调用其私有成员变量lib_
的impl
方法。
下面讲解torch::Library
类的impl
方法,其定义如下:
cpp
/// Register an implementation for an operator. You may register multiple
/// implementations for a single operator at different dispatch keys
/// (see torch::dispatch()). Implementations must have a corresponding
/// declaration (from def()), otherwise they are invalid. If you plan
/// to register multiple implementations, DO NOT provide a function
/// implementation when you def() the operator.
///
/// \param name The name of the operator to implement. Do NOT provide
/// schema here.
/// \param raw_f The C++ function that implements this operator. Any
/// valid constructor of torch::CppFunction is accepted here;
/// typically you provide a function pointer or lambda.
///
/// ```
/// // Example:
/// TORCH_LIBRARY_IMPL(myops, CUDA, m) {
/// m.impl("add", add_cuda);
/// }
/// ```
template <typename Name, typename Func>
Library& impl(Name name, Func&& raw_f, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
// TODO: need to raise an error when you impl a function that has a
// catch all def
#if defined C10_MOBILE
CppFunction f(std::forward<Func>(raw_f), NoInferSchemaTag());
#else
CppFunction f(std::forward<Func>(raw_f));
#endif
return _impl(name, std::move(f), rv);
}
...............健身去了,未完待续