ncrelu.cpp
cpp
#include <torch/extension.h> // 头文件引用部分
namespace py = pybind11;
torch::Tensor ncrelu_forward(torch::Tensor input) {
auto pos = input.clamp_min(0); // 具体实现部分
auto neg = input.clamp_max(0);
return torch::cat({pos, neg}, 1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // 绑定部分
m.def("forward", &ncrelu_forward, py::arg("input"), "NCReLU forward");
}
python
from setuptools import setup
from torch.utils import cpp_extension
setup(
name='ncrelu_cpp',
version='1.0',# 编译后的链接库名称
py_modules=['ncrelu_cpp'],
ext_modules=[
cpp_extension.CppExtension(
'ncrelu_cpp', ['ncrelu.cpp'],
extra_compile_args={'cxx': ['-O2']}
# 待编译文件,及编译函数
)
],
cmdclass={ # 执行编译命令设置
'build_ext': cpp_extension.BuildExtension
}
)
python
import torch
import ncrelu_cpp
import sys
print(sys.path)
a = torch.randn(4,3)
print(a)
b = ncrelu_cpp.forward(a)
python setup.py install
或pip install .
但是在Windows平台下不知道为什么会报错找不到包,或者找不到函数,很奇怪,但是正常运行没有任何问题