pytorch使用c++/cuda扩展

1、编写:c++/cuda拓展源文件

pybind11_demo/

├── setup.py

├── example.cpp

└── test.py

example.cpp

复制代码
#include <torch/extension.h>
#include <vector>

// Forward declaration of the function
torch::Tensor custom_add(torch::Tensor a, torch::Tensor b);

// The actual implementation
torch::Tensor custom_add(torch::Tensor a, torch::Tensor b) {
    // Simple element-wise addition
    return a + b;
}

// Pybind11 module definition
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("custom_add", &custom_add, "A function that adds two tensors");
}

PyTorch中的PYBIND11_MODULE

PYBIND11_MODULE是Pybind11库中的一个宏,它用于定义一个Python模块,并将C++类、函数或其他对象绑定到该模块。这使得Python可以直接调用C++编写的函数和类,极大地提高了Python的性能,尤其是当计算密集型任务需要底层C++实现时。

2、编译:setuptools指导c++/cuda拓展的编译

setup.py

复制代码
from setuptools import setup, Extension
from torch.utils.cpp_extension import CppExtension, BuildExtension,CUDAExtension



setup(
    name='python_demo', # python包的名称
    ext_modules=[
        CppExtension(
            name='demo', # 扩展模块名称,后面import使用
            sources=['example.cpp'],
            extra_compile_args={'CXX': ['-w', '-std=c++14']}
        )
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)


# python setup.py install
# or for development:
# python setup.py develop

指定构建命令

复制代码
cmdclass={  
    'build_ext': BuildExtension  
}

cmdclass是一个字典,用于指定自定义的构建命令。

'build_ext'是setuptools中的一个标准构建命令,用于构建扩展模块。

BuildExtension是PyTorch提供的BuildExtension类,它扩展了setuptools的build_ext命令,以支持C++和CUDA扩展的编译。

3、python调用编译完成的库

test.py

复制代码
import torch
import demo  # The name you specified in setup.py

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])

result = demo.custom_add(a, b)
print(result)  # Should output tensor([5., 7., 9.])


# python test.py

参考

https://zhuanlan.zhihu.com/p/459955492

深入解析PyTorch中的PYBIND11_MODULE:功能与实现_pytorch pybind11-CSDN博客

相关推荐
专注VB编程开发20年5 分钟前
python图片验证码识别selenium爬虫--超级鹰实现自动登录,滑块,点击
数据库·python·mysql
iFeng的小屋13 分钟前
【2026最新当当网爬虫分享】用Python爬取千本日本相关图书,自动分析价格分布!
开发语言·爬虫·python
民乐团扒谱机16 分钟前
【微科普】3D 演奏蠕虫分析图:解码音乐表演情感的 “可视化语言”
python·可视化·音乐·3d图·3d蠕虫
Network_Engineer18 分钟前
从零手写LSTM:从门控原理到PyTorch源码级实现
人工智能·pytorch·lstm
芝士爱知识a22 分钟前
AlphaGBM 深度解析:下一代基于 AI 与蒙特卡洛的智能期权分析平台
数据结构·人工智能·python·股票·alphagbm·ai 驱动的智能期权分析·期权
52Hz1181 小时前
力扣230.二叉搜索树中第k小的元素、199.二叉树的右视图、114.二叉树展开为链表
python·算法·leetcode
喵手1 小时前
Python爬虫实战:网页截图归档完全指南 - 构建生产级页面存证与历史回溯系统!
爬虫·python·爬虫实战·零基础python爬虫教学·网页截图归档·历史回溯·生产级方案
张3蜂1 小时前
Python 四大 Web 框架对比解析:FastAPI、Django、Flask 与 Tornado
前端·python·fastapi
2601_948374571 小时前
商用电子秤怎么选
大数据·python