自定义实现C++拓展pytorch功能

ncrelu.cpp

cpp 复制代码
#include <torch/extension.h>					// 头文件引用部分

namespace py = pybind11;

torch::Tensor ncrelu_forward(torch::Tensor input) {
    auto pos = input.clamp_min(0);				       // 具体实现部分
    auto neg = input.clamp_max(0);
    return torch::cat({pos, neg}, 1);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {	// 绑定部分
    m.def("forward", &ncrelu_forward, py::arg("input"), "NCReLU forward");
}

setup.py

python 复制代码
from setuptools import setup
from torch.utils import cpp_extension


setup(
    name='ncrelu_cpp',
    version='1.0',# 编译后的链接库名称
    py_modules=['ncrelu_cpp'],
    ext_modules=[
        cpp_extension.CppExtension(
            'ncrelu_cpp', ['ncrelu.cpp'],
            extra_compile_args={'cxx': ['-O2']}
            # 待编译文件,及编译函数
        )
    ],
    cmdclass={						       # 执行编译命令设置
        'build_ext': cpp_extension.BuildExtension
    }
)

test.py

python 复制代码
import torch
import ncrelu_cpp
import sys
print(sys.path)
a = torch.randn(4,3)
print(a)
b = ncrelu_cpp.forward(a)

python setup.py install

或pip install .

但是在Windows平台下不知道为什么会报错找不到包,或者找不到函数,很奇怪,但是正常运行没有任何问题

相关推荐
vyuvyucd4 分钟前
MPPI算法实战:机器人避障与仿真
python
计算机徐师兄4 分钟前
Python基于Flask的广东旅游数据分析系统(附源码,文档说明)
python·flask·旅游数据分析·广东旅游数据分析系统·python广东数据分析系统·python广东旅游数据分析·python旅游数据分析系统
YxVoyager5 分钟前
Qt C++ :QRegularExpression 正则表达式使用详解
c++·qt·正则表达式
jarreyer6 分钟前
数据项目分析标准化流程
开发语言·python·机器学习
闻缺陷则喜何志丹7 分钟前
【回文 字符串】3677 统计二进制回文数字的数目|2223
c++·算法·字符串·力扣·回文
GZKPeng9 分钟前
pytorch +cuda成功安装后, torch.cuda.is_available 是False
人工智能·pytorch·python
李余博睿(新疆)9 分钟前
c++分治算法
c++
我的xiaodoujiao10 分钟前
使用 Python 语言 从 0 到 1 搭建完整 Web UI自动化测试学习系列 39--生成 Allure测试报告
python·学习·测试工具·pytest
陈小桔14 分钟前
logging模块-python
开发语言·python
oioihoii14 分钟前
Protocol Buffers 编码原理深度解析
c++