mmcv算子注册和算子分发原理

mmcv模仿pytorch,通过dispatcher根据算子的信息将算子分发到不同的函数,从而实现代码的复用,下面以upfirdn2d为例,介绍mmcv的算子注册及算子分发原理。

1. python和C++的bind

mmcv使用PYBIND11将C++后端代码绑定到python前端,位置位于/home/mmcv/mmcv/ops/csrc/pytorch/pybind.cpp,下面为部分代码节选:

cpp 复制代码
  Tensor upfirdn2d(torch::Tensor input, torch::Tensor filter, int upx, int upy,
                 int downx, int downy, int padx0, int padx1, int pady0,
                 int pady1, bool flip, float gain);
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
        py::arg("filter"), py::arg("upx"), py::arg("upy"), py::arg("downx"),
        py::arg("downy"), py::arg("padx0"), py::arg("padx1"), py::arg("pady0"),
        py::arg("pady1"), py::arg("flip"), py::arg("gain")); //将upfirdn2d注册到python前端。
  m.def("fused_bias_leakyrelu", &fused_bias_leakyrelu,
        "fused_bias_leakyrelu (CUDA)", py::arg("input"), py::arg("bias"),
        py::arg("empty"), py::arg("act"), py::arg("grad"), py::arg("alpha"),
        py::arg("scale"));
    ........

2. upfirdn2d的实现

upfirdn2d的实现如下所示:

cpp 复制代码
torch::Tensor upfirdn2d_op_impl(torch::Tensor input, torch::Tensor filter,
                                int upx, int upy, int downx, int downy,
                                int padx0, int padx1, int pady0, int pady1,
                                bool flip, float gain) {
  return DISPATCH_DEVICE_IMPL(upfirdn2d_op_impl, input, filter, upx, upy, downx,
                              downy, padx0, padx1, pady0, pady1, flip, gain);
}

torch::Tensor upfirdn2d(torch::Tensor input, torch::Tensor filter, int upx,
                        int upy, int downx, int downy, int padx0, int padx1,
                        int pady0, int pady1, bool flip, float gain) {
  return upfirdn2d_op_impl(input, filter, upx, upy, downx, downy, padx0, padx1,
                           pady0, pady1, flip, gain);
}

可见其通过DISPATCH_DEVICE_IMPL返回结果。

3.DISPATCH_DEVICE_IMPL宏------算子分发

和这个宏相关的代码如下所示:

cpp 复制代码
//注意,传入给该宏的第一个参数直接是函数名,剩下的都是函数参数
#define DISPATCH_DEVICE_IMPL(key, ...) \
  Dispatch(DEVICE_REGISTRY(key), #key, __VA_ARGS__)

因此

cpp 复制代码
DISPATCH_DEVICE_IMPL(upfirdn2d_op_impl, input, filter, \
    upx, upy, downx,downy, padx0, padx1, pady0, pady1, flip, gain)

可以展开为

cpp 复制代码
Dispatch(DEVICE_REGISTRY(upfirdn2d_op_impl), "upfirdn2d_op_impl", input,  \
    filter, upx, upy, downx,downy, padx0, padx1, pady0, pady1, flip, gain)

DEVICE_REGISTRY的实现如下:

cpp 复制代码
#define DEVICE_REGISTRY(key) DeviceRegistry<decltype(&(key)), key>::instance()

//可以展开成如下代码
DeviceRegistry<decltype(&(upfirdn2d_op_impl)), upfirdn2d_op_impl>::instance()

DeviceRegistry的部分定义如下:

cpp 复制代码
// Registry
template <typename F, F f>
class DeviceRegistry;

template <typename Ret, typename... Args, Ret (*f)(Args...)>
class DeviceRegistry<Ret (*)(Args...), f> {
 public:
  using FunctionType = Ret (*)(Args...);  //获取函数类型
  static DeviceRegistry& instance() {
    static DeviceRegistry inst;
    return inst;
  }
  
  //根据设备类型获取相应的后端函数
  FunctionType Find(at::DeviceType device) const {
    return funcs_[int8_t(device)];
  }
  private:
   FunctionType funcs_[MAX_DEVICE_TYPES]; //函数数组,保存了针对不同的device实现的函数
};

因此可以得出结论,

cpp 复制代码
Dispatch(DEVICE_REGISTRY(upfirdn2d_op_impl), "upfirdn2d_op_impl", input,  \
    filter, upx, upy, downx,downy, padx0, padx1, pady0, pady1, flip, gain)
//这个里面的 DEVICE_REGISTRY(upfirdn2d_op_impl) 实际上就是
//获取一个针对upfirdn2d_op_impl函数类型的DeviceRegistry的单例

(mmcv中在头文件实现了单例,这是非常不安全的做法,详情可见:www.zhihu.com/question/42...

Dispatch的实现如下:

cpp 复制代码
template <typename R, typename... Args>
auto Dispatch(const R& registry, const char* name, Args&&... args) {
  auto device = GetFirstTensorDevice(std::forward<Args>(args)...); //获取tensor所在device
  auto inconsist =
      CheckDeviceConsistency(device, 0, std::forward<Args>(args)...);//确保输入的tensor都在一个device上
  TORCH_CHECK(inconsist.first >= int(sizeof...(Args)), name, ": at param ",
              inconsist.first,
              ", inconsistent device: ", GetDeviceStr(inconsist.second).c_str(),
              " vs ", GetDeviceStr(device).c_str(), "\n")
  auto f_ptr = registry.Find(device.type());//获得和tensor device所对应的函数指针
  TORCH_CHECK(f_ptr != nullptr, name, ": implementation for device ",
              GetDeviceStr(device).c_str(), " not found.\n")
  return f_ptr(std::forward<Args>(args)...);//通过函数指针计算结果并返回结果
}

至此总结一下,首先mmcv通过pybind11将python前端的upfirdn2d绑定到c++后端的upfirdn2d,然后c++中upfirdn2d会调用upfirdn2d_op_impl,而宏展开后的upfirdn2d_op_impl定义如下:

cpp 复制代码
torch::Tensor upfirdn2d_op_impl(torch::Tensor input, torch::Tensor filter,
                                int upx, int upy, int downx, int downy,
                                int padx0, int padx1, int pady0, int pady1,
                                bool flip, float gain) {
  //Dispatch函数通过DeviceRegistry单例找到对应设备的后端函数并调用之,再返回结果。
  return Dispatch(
                   //根据函数和函数类型产生DeviceRegistry的单例
                   DeviceRegistry<decltype(&(upfirdn2d_op_impl)),
                                 upfirdn2d_op_impl>::instance(), 
                  "upfirdn2d_op_impl", input, filter, upx, upy, downx, downy,
                  padx0, padx1, pady0, pady1, flip, gain);  
}

4.REGISTER_DEVICE_IMPL------算子注册

/home/mmcv/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp通过REGISTER_DEVICE_IMPL将对应的算子注册到upfirdn2d_op_impl的cuda后端,如下代码所示:

cpp 复制代码
torch::Tensor upfirdn2d_op(torch::Tensor input, torch::Tensor filter, int upx,
                           int upy, int downx, int downy, int padx0, int padx1,
                           int pady0, int pady1, bool flip, float gain);

torch::Tensor upfirdn2d_op_impl(torch::Tensor input, torch::Tensor filter,
                                int upx, int upy, int downx, int downy,
                                int padx0, int padx1, int pady0, int pady1,
                                bool flip, float gain);
//upfirdn2d_op_impl如上节所述,为被调用的c++后端函数,upfirdn2d_op则是实现了相应功能的CUDA函数。                              
REGISTER_DEVICE_IMPL(upfirdn2d_op_impl, CUDA, upfirdn2d_op);

REGISTER_DEVICE_IMPL的定义如下所示:

cpp 复制代码
#define REGISTER_DEVICE_IMPL(key, device, value)           \
  struct key##_##device##_registerer {                     \
    key##_##device##_registerer() {                        \
      DEVICE_REGISTRY(key).Register(at::k##device, value); \
    }                                                      \
  };                                                       \
  static key##_##device##_registerer _##key##_##device##_registerer;

可见REGISTER_DEVICE_IMPL根据要实现的函数和device生成一个结构体,该结构体的初始构造函数获取和要实现的函数所对应的单例并调用Register方法将具体实现函数注册进去。最后创建一个该结构体类型的静态局部变量,从而在库加载的时候就进行注册。宏展开后的代码如下所示:

cpp 复制代码
//宏展开前的代码
REGISTER_DEVICE_IMPL(upfirdn2d_op_impl, CUDA, upfirdn2d_op);

//宏展开后的代码
struct upfirdn2d_op_impl_CUDA_registerer {
  upfirdn2d_op_impl_CUDA_registerer() {
    DeviceRegistry<decltype(&(upfirdn2d_op_impl)),   //获取单例
                   upfirdn2d_op_impl>::instance()
        .Register(at ::kCUDA, upfirdn2d_op);         //注册
  }
};
static upfirdn2d_op_impl_CUDA_registerer _upfirdn2d_op_impl_CUDA_registerer;

//以下是DeviceRegistry类型的Register函数的实现,可见就是往数组里填函数地址,很简单。
void Register(at::DeviceType device, FunctionType function) {
    funcs_[int8_t(device)] = function;
  }
相关推荐
AAD55588899几秒前
【电力设备检测】YOLO11-LQEHead绝缘子缺陷检测与分类系统实现
人工智能·分类·数据挖掘
想唱rap5 分钟前
MySQL内置函数
linux·运维·服务器·数据库·c++·mysql
玖釉-6 分钟前
[Vulkan 学习之路] 10 - 掌握 SPIR-V:编写你的第一个着色器 (Shader Modules)
c++·windows·图形渲染
renhongxia18 分钟前
学习基于数字孪生的质量预测与控制
人工智能·深度学习·学习·语言模型·自然语言处理·制造
Ulyanov10 分钟前
高级可视化技术——让PyVista数据展示更专业
开发语言·前端·人工智能·python·tkinter·gui开发
xiaoye-duck10 分钟前
吃透C++类和对象(中):详解 Date 类的设计与实现
c++
昨夜见军贴061613 分钟前
IACheck × AI审核赋能5G远程检测:实时视频传输质量
人工智能·5g
玖釉-17 分钟前
[Vulkan 学习之路] 03 - 你的守护天使:校验层 (Validation Layers)
c++·windows·图形渲染
冰暮流星17 分钟前
c语言如何实现字符串复制替换
c语言·c++·算法
txinyu的博客17 分钟前
C++内存池的内存对齐问题
c++