pytorch使用c++/cuda扩展

1、编写:c++/cuda拓展源文件

pybind11_demo/

├── setup.py

├── example.cpp

└── test.py

example.cpp

复制代码
#include <torch/extension.h>
#include <vector>

// Forward declaration of the function
torch::Tensor custom_add(torch::Tensor a, torch::Tensor b);

// The actual implementation
torch::Tensor custom_add(torch::Tensor a, torch::Tensor b) {
    // Simple element-wise addition
    return a + b;
}

// Pybind11 module definition
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("custom_add", &custom_add, "A function that adds two tensors");
}

PyTorch中的PYBIND11_MODULE

PYBIND11_MODULE是Pybind11库中的一个宏,它用于定义一个Python模块,并将C++类、函数或其他对象绑定到该模块。这使得Python可以直接调用C++编写的函数和类,极大地提高了Python的性能,尤其是当计算密集型任务需要底层C++实现时。

2、编译:setuptools指导c++/cuda拓展的编译

setup.py

复制代码
from setuptools import setup, Extension
from torch.utils.cpp_extension import CppExtension, BuildExtension,CUDAExtension



setup(
    name='python_demo', # python包的名称
    ext_modules=[
        CppExtension(
            name='demo', # 扩展模块名称,后面import使用
            sources=['example.cpp'],
            extra_compile_args={'CXX': ['-w', '-std=c++14']}
        )
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)


# python setup.py install
# or for development:
# python setup.py develop

指定构建命令

复制代码
cmdclass={  
    'build_ext': BuildExtension  
}

cmdclass是一个字典,用于指定自定义的构建命令。

'build_ext'是setuptools中的一个标准构建命令,用于构建扩展模块。

BuildExtension是PyTorch提供的BuildExtension类,它扩展了setuptools的build_ext命令,以支持C++和CUDA扩展的编译。

3、python调用编译完成的库

test.py

复制代码
import torch
import demo  # The name you specified in setup.py

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])

result = demo.custom_add(a, b)
print(result)  # Should output tensor([5., 7., 9.])


# python test.py

参考

https://zhuanlan.zhihu.com/p/459955492

深入解析PyTorch中的PYBIND11_MODULE:功能与实现_pytorch pybind11-CSDN博客

相关推荐
wan9yu5 分钟前
为什么你需要给 LLM 的数据"加密"而不是"脱敏"?我写了一个开源工具
python
摇滚侠14 分钟前
你是一名 java 程序员,总结定义数组的方式
java·开发语言·python
这个名有人用不32 分钟前
解决 uv 虚拟环境使用 pip 命令提示command not found的办法
python·pip·uv·claude code
Oueii1 小时前
掌握Python魔法方法(Magic Methods)
jvm·数据库·python
2501_908329851 小时前
使用Python自动收发邮件
jvm·数据库·python
2501_908329852 小时前
NumPy入门:高性能科学计算的基础
jvm·数据库·python
2401_874732532 小时前
Python Web爬虫入门:使用Requests和BeautifulSoup
jvm·数据库·python
平常心cyk3 小时前
Python基础快速复习——集合和字典
开发语言·数据结构·python
阿钱真强道3 小时前
34 Python 离群点检测:什么是离群点?为什么要做异常检测?
python·sklearn·异常检测·异常·离群点检测