自定义实现C++拓展pytorch功能

ncrelu.cpp

cpp 复制代码
#include <torch/extension.h>					// 头文件引用部分

namespace py = pybind11;

torch::Tensor ncrelu_forward(torch::Tensor input) {
    auto pos = input.clamp_min(0);				       // 具体实现部分
    auto neg = input.clamp_max(0);
    return torch::cat({pos, neg}, 1);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {	// 绑定部分
    m.def("forward", &ncrelu_forward, py::arg("input"), "NCReLU forward");
}

setup.py

python 复制代码
from setuptools import setup
from torch.utils import cpp_extension


setup(
    name='ncrelu_cpp',
    version='1.0',# 编译后的链接库名称
    py_modules=['ncrelu_cpp'],
    ext_modules=[
        cpp_extension.CppExtension(
            'ncrelu_cpp', ['ncrelu.cpp'],
            extra_compile_args={'cxx': ['-O2']}
            # 待编译文件,及编译函数
        )
    ],
    cmdclass={						       # 执行编译命令设置
        'build_ext': cpp_extension.BuildExtension
    }
)

test.py

python 复制代码
import torch
import ncrelu_cpp
import sys
print(sys.path)
a = torch.randn(4,3)
print(a)
b = ncrelu_cpp.forward(a)

python setup.py install

或pip install .

但是在Windows平台下不知道为什么会报错找不到包,或者找不到函数,很奇怪,但是正常运行没有任何问题

相关推荐
克里普crirp3 分钟前
电离层TEC地图中添加晨昏线/昼夜转换线
python
Dxy12393102163 分钟前
Python使用PyEnchant详解:打造高效拼写检查工具
开发语言·python
架构师老Y12 分钟前
011、消息队列应用:RabbitMQ、Kafka与Celery
python·架构·kafka·rabbitmq·ruby
枫叶林FYL17 分钟前
【Python高级工程与架构实战】项目四:生产级LLM Agent框架:基于PydanticAI的类型安全企业级实现
人工智能·python·自然语言处理
龙腾AI白云19 分钟前
多模大模型应用实战:智能问答系统开发
python·机器学习·数据分析·django·tornado
Hommy8829 分钟前
【开源剪映小助手】配置与部署
python·开源·aigc·剪映小助手
chh56330 分钟前
C++--内存管理
java·c语言·c++·windows·学习·面试
Yungoal31 分钟前
C++ 标准模板库STL(Standard Template Library)
c++·哈希算法·散列表
我真不是小鱼33 分钟前
cpp刷题打卡记录27——无重复字符的最长子串 & 找到字符串中所有字母的异位词
数据结构·c++·算法·leetcode
V搜xhliang024637 分钟前
基于¹⁸F-FDG PET/CT的深度学习-影像组学-临床模型预测非小细胞肺癌脉管侵犯的价值
大数据·人工智能·python·深度学习·机器学习·机器人