pytorch中算子注册原理

注:新手文章,欢迎指正!以下内容基于pytorch2.0.0

pytorch的官方教程pytorch.org/tutorials/a... 中,写了注册算子的主要方式是:

cpp 复制代码
TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
  m.impl(<myadd_schema>, &myadd_autograd);
}

pytorch代码中,/home/pytorch/torch/library.h中定义了TORCH_LIBRARY_IMPL宏:

cpp 复制代码
#define TORCH_LIBRARY_IMPL(ns, k, m) _TORCH_LIBRARY_IMPL(ns, k, m, C10_UID)

_TORCH_LIBRARY_IMPL宏的定义如下:

cpp 复制代码
#define _TORCH_LIBRARY_IMPL(ns, k, m, uid)                             \
  static void C10_CONCATENATE(                                         \
      TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library&);    \
  static const torch::detail::TorchLibraryInit C10_CONCATENATE(        \
      TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)(              \
      torch::Library::IMPL,                                            \
      c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check( \
          c10::DispatchKey::k)>(                                       \
          []() {                                                       \
            return &C10_CONCATENATE(                                   \
                TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid);           \
          },                                                           \
          []() { return [](torch::Library&) -> void {}; }),            \
      #ns,                                                             \
      c10::make_optional(c10::DispatchKey::k),                         \
      __FILE__,                                                        \
      __LINE__);                                                       \
  void C10_CONCATENATE(                                                \
      TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library & m)

首先看C10_UID,其定义为:

cpp 复制代码
#define C10_UID __COUNTER__
#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __COUNTER__)

因此其实际上为一个全局唯一的ID号。

C10_CONCATENATE的定义如下:

cpp 复制代码
#define C10_CONCATENATE_IMPL(s1, s2) s1##s2
#define C10_CONCATENATE(s1, s2) C10_CONCATENATE_IMPL(s1, s2)

可见其就是连接了两个字符串,如果看不懂可以查一下##在C/C++预处理中的作用。

_TORCH_LIBRARY_IMPL的定义可以被分为以下三个部分:

  1. 声明一个静态函数:
cpp 复制代码
static void C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library&);

函数名为TORCH_LIBRARY_IMPL_init_+ns+k+uid,假设TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m)的UID为20,那么函数名为: TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20

  1. 定义一个cpp文件内部的常量:
cpp 复制代码
  static const torch::detail::TorchLibraryInit C10_CONCATENATE(        \
      TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)(              \
      torch::Library::IMPL,                                            \
      c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check( \
          c10::DispatchKey::k)>(                                       \
          []() {                                                       \
            return &C10_CONCATENATE(                                   \
                TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid);           \
          },                                                           \
          []() { return [](torch::Library&) -> void {}; }),            \
      #ns,                                                             \
      c10::make_optional(c10::DispatchKey::k),                         \
      __FILE__,                                                        \
      __LINE__);                                                       \

该常量类型为static const torch::detail::TorchLibraryInit,仍然以上面的例子为例,其名字为: TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20,其和上面定义的静态函数的名字的差别就是多了一个static字符串。宏展开后,整段代码为如下:

cpp 复制代码
  static const torch::detail::TorchLibraryInit                     //返回类型              
  TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20(          
      torch::Library::IMPL,                                        //参数1,Library::Kind类型    
      c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::AutogradPrivateUse1)>(                                       
          []() {return &TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20;},           
          []() { return [](torch::Library&) -> void {}; }
          ),                                                      //参数2,InitFn*类型
      "aten",                                                     //参数3,const char*类型 
      c10::make_optional(c10::DispatchKey::AutogradPrivateUse1),  //参数4,c10::optional<c10::DispatchKey>类型          
      __FILE__,                                                   //参数5,const char*类型
      __LINE__);                                                  //参数6,uint32_t类型

TorchLibraryInit的类定义如下:

cpp 复制代码
class TorchLibraryInit final {
 private:
  using InitFn = void(Library&);
  Library lib_;

 public:
  TorchLibraryInit(
      Library::Kind kind,
      InitFn* fn,
      const char* ns,
      c10::optional<c10::DispatchKey> k,
      const char* file,
      uint32_t line)
      : lib_(kind, ns, k, file, line) {
    fn(lib_);
  }
};

其有只包含一个Library类型的私有成员变量,注意其初始构造函数中,会先用kind, ns, k, file, line初始化lib_,再用传入的InitFn类型,也就是void(Library&)类型的函数初始化这个私有成员变量lib_

在定义TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20的时候,第一个参数Library::Kind kindtorch::Library::IMPL,第二个参数为

rust 复制代码
c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::AutogradPrivateUse1)>(                                       
          []() {return &TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20;},           
          []() { return [](torch::Library&) -> void {}; }
          ),                                                      //参数2,InitFn*类型

首先看模板参数c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::AutogradPrivateUse1),其定义为:

cpp 复制代码
constexpr bool dispatch_key_allowlist_check(DispatchKey /*k*/) {
#ifdef C10_MOBILE
 return true;
 // Disabled for now: to be enabled later!
 // return k == DispatchKey::CPU || k == DispatchKey::Vulkan || k == DispatchKey::QuantizedCPU || k == DispatchKey::BackendSelect || k == DispatchKey::CatchAll;
#else
 return true;
#endif
} 

可见其目前无脑返回true,因此第二个参数变成:

cpp 复制代码
c10::guts::if_constexpr<true>(                                       
         []() {return &TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20;},           
         []() { return [](torch::Library&) -> void {}; }
         ),                                                      //参数2,InitFn*类型

if_constexpr的定义如下:

cpp 复制代码
template <bool Condition, class ThenCallback, class ElseCallback>
decltype(auto) if_constexpr(
   ThenCallback&& thenCallback,
   ElseCallback&& elseCallback) {
#if defined(__cpp_if_constexpr)
 // If we have C++17, just use it's "if constexpr" feature instead of wrapping
 // it. This will give us better error messages.
 if constexpr (Condition) {
   if constexpr (detail::function_takes_identity_argument<
                     ThenCallback>::value) {
     // Note that we use static_cast<T&&>(t) instead of std::forward (or
     // ::std::forward) because using the latter produces some compilation
     // errors about ambiguous `std` on MSVC when using C++17. This static_cast
     // is just what std::forward is doing under the hood, and is equivalent.
     return static_cast<ThenCallback&&>(thenCallback)(detail::_identity());
   } else {
     return static_cast<ThenCallback&&>(thenCallback)();
   }
 } else {
   if constexpr (detail::function_takes_identity_argument<
                     ElseCallback>::value) {
     return static_cast<ElseCallback&&>(elseCallback)(detail::_identity());
   } else {
     return static_cast<ElseCallback&&>(elseCallback)();
   }
 }
#else
 // C++14 implementation of if constexpr
 return detail::_if_constexpr<Condition>::call(
     static_cast<ThenCallback&&>(thenCallback),
     static_cast<ElseCallback&&>(elseCallback));
#endif
}

这里有点炫技的味道了,直接看注释:

cpp 复制代码
Example 1: simple constexpr if/then/else
 template<int arg> int increment_absolute_value() {
   int result = arg;
   if_constexpr<(arg > 0)>(
     [&] { ++result; }  // then-case
     [&] { --result; }  // else-case
   );
   return result;
 }

所以这就是一个简单的模板编译期if else,由于其模板参数为true,因此第二个参数就是第一部分定义的静态函数TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20,之后的参数就不再赘述了,值得注意的是,第四个参数c10::make_optional(c10::DispatchKey::AutogradPrivateUse1)颇为复杂。

  1. 正式定义第一步声明的静态函数,宏展开后为:

    javascript 复制代码
    void TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20(torch::Library & m){ 
        m.impl(<myadd_schema>, &myadd_autograd); 
    }

整个代码简化之前为:

cpp 复制代码
TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
  m.impl(<myadd_schema>, &myadd_autograd);
}

宏展开+简化后为:

cpp 复制代码
static void TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20(torch::Library & m);

static const torch::detail::TorchLibraryInit TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20(          
  torch::Library::IMPL,                                       //参数1,Library::Kind类型    
  &TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20,       //参数2,InitFn*类型
  "aten",                                                     //参数3,const char*类型 
  c10::make_optional(c10::DispatchKey::AutogradPrivateUse1),  //参数4,c10::optional<c10::DispatchKey>类型          
  __FILE__,                                                   //参数5,const char*类型
  __LINE__);                                                  //参数6,uint32_t类型
  
void TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20(torch::Library & m){ 
   m.impl(<myadd_schema>, &myadd_autograd); 
}

//TorchLibraryInit的定义,在library.h中定义
class TorchLibraryInit final {
 private:
  using InitFn = void(Library&);
  Library lib_;

 public:
  TorchLibraryInit(
      Library::Kind kind,
      InitFn* fn,
      const char* ns,
      c10::optional<c10::DispatchKey> k,
      const char* file,
      uint32_t line)
      : lib_(kind, ns, k, file, line) {
    fn(lib_);
  }
};

到这里总结一下:

① 第一部分声明了一个静态函数TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20

② 第二部分声明了一个torch::detail::TorchLibraryInit类型的静态常量TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20,在有一个Library类型的成员变量,通过传入的参数和第一部分声明的静态函数来初始化这个成员变量。

③ 第三部分则是实现了第一部分声明的函数。

注意这个函数通过调用torch::Library类型参数的impl成员函数来实现算子注册,而传入的实参实际上第二部分声明的静态常量的私有成员变量,而第二部分的静态常量名称为TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_##uid,也就是取决于命名空间(namespace)、设备(cpu or cuda or XXX)以及UID。

TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20的初始构造函数利用TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20来初始化其私有成员变量lib_,初始化方法为调用其私有成员变量lib_impl方法。

下面讲解torch::Library类的impl方法,其定义如下:

cpp 复制代码
  /// Register an implementation for an operator.  You may register multiple
  /// implementations for a single operator at different dispatch keys
  /// (see torch::dispatch()).  Implementations must have a corresponding
  /// declaration (from def()), otherwise they are invalid.  If you plan
  /// to register multiple implementations, DO NOT provide a function
  /// implementation when you def() the operator.
  ///
  /// \param name The name of the operator to implement.  Do NOT provide
  ///   schema here.
  /// \param raw_f The C++ function that implements this operator.  Any
  ///   valid constructor of torch::CppFunction is accepted here;
  ///   typically you provide a function pointer or lambda.
  ///
  /// ```
  /// // Example:
  /// TORCH_LIBRARY_IMPL(myops, CUDA, m) {
  ///   m.impl("add", add_cuda);
  /// }
  /// ```
  template <typename Name, typename Func>
  Library& impl(Name name, Func&& raw_f, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
    // TODO: need to raise an error when you impl a function that has a
    // catch all def
#if defined C10_MOBILE
    CppFunction f(std::forward<Func>(raw_f), NoInferSchemaTag());
#else
    CppFunction f(std::forward<Func>(raw_f));
#endif
    return _impl(name, std::move(f), rv);
  }

...............健身去了,未完待续

相关推荐
sp_fyf_20246 分钟前
计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-11-01
人工智能·深度学习·神经网络·算法·机器学习·语言模型·数据挖掘
多吃轻食10 分钟前
大模型微调技术 --> 脉络
人工智能·深度学习·神经网络·自然语言处理·embedding
北京搜维尔科技有限公司1 小时前
搜维尔科技:【应用】Xsens在荷兰车辆管理局人体工程学评估中的应用
人工智能·安全
说私域1 小时前
基于开源 AI 智能名片 S2B2C 商城小程序的视频号交易小程序优化研究
人工智能·小程序·零售
YRr YRr1 小时前
深度学习:Transformer Decoder详解
人工智能·深度学习·transformer
知来者逆1 小时前
研究大语言模型在心理保健智能顾问的有效性和挑战
人工智能·神经网络·机器学习·语言模型·自然语言处理
云起无垠1 小时前
技术分享 | 大语言模型赋能软件测试:开启智能软件安全新时代
人工智能·安全·语言模型
老艾的AI世界1 小时前
新一代AI换脸更自然,DeepLiveCam下载介绍(可直播)
图像处理·人工智能·深度学习·神经网络·目标检测·机器学习·ai换脸·视频换脸·直播换脸·图片换脸
翔云API2 小时前
PHP静默活体识别API接口应用场景与集成方案
人工智能
浊酒南街2 小时前
吴恩达深度学习笔记:卷积神经网络(Foundations of Convolutional Neural Networks)4.9-4.10
人工智能·深度学习·神经网络·cnn