上一个 帖子主要分享了如何 去将 C++ 程序 打包成一个package。 我们最后的 目的实际上是想把 CUDA 的程序 打包成 一个 Package , C++ 程序只是起到了桥梁的作用:
首先:CUDA 程序 和 C++ 的程序一样, 都有一个 .cu 的源文件和 一个 .h 的头文件 。
我们的文件 包含 Cpp 文件组成,负责当作 CUDA 和 Python 的桥梁。 还有 对应的 CUDA 的源代码文件和 头文件。将这个cpp 文件命名成 ext.cpp.
cpp
#include <torch/extension.h>
#include "interpolation_kernel.h" ## CUDA 的头文件
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m){
m.def("trilinear_interpolation",&trilinear_interpolation);
}
cpp_properities.json 配置文件
bash
{
"configurations": [
{
"name": "Linux",
"includePath": [
"${workspaceFolder}/**",
"/home/smiao/anaconda3/envs/Gen_3DGS/lib/python3.8",
"/home/smiao/anaconda3/envs/Gen_3DGS/lib/python3.8/site-packages/torch/include/",
"/home/smiao/anaconda3/envs/Gen_3DGS/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/"
],
"defines": [],
"compilerPath": "/usr/bin/gcc",
"cStandard": "c17",
"cppStandard": "gnu++14",
"intelliSenseMode": "linux-gcc-x64"
}
],
"version": 4
CUDA 部分:
CUDA 的头文件 *** interpolation_kernel.h ***
cpp
#include <torch/extension.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor trilinear_interpolation(torch::Tensor feats, torch::Tensor point);
对应的 源代码 文件*** interpolation_kernel.cu ***
include 的 头文件 和源代码文件 尽量放在同一级的 目录
cpp
#include <torch/extension.h>
#include "interpolation_kernel.h"
torch::Tensor trilinear_interpolation(torch::Tensor feats, torch::Tensor points){
CHECK_CUDA(feats);
CHECK_CUDA(points);
return feats;
}
配置文件 setup.py 部分:
配置文件的 包含 ** *.cpp 文件 和 *.cu 文件 **
其他的部分应该 尽量不去改变。
python
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
import os
import glob
os.path.dirname(os.path.abspath(__file__))
setup(
name="cuda_tutorial",
version='1.0',
ext_modules=[
CUDAExtension(
name='cuda_tutorial',
sources=["interpolation_kernel.cu","ext.cpp"],
extra_compile_args={'cxx': ['-O2'],
'nvcc': ['-O2']}
)
],
cmdclass={
'build_ext': BuildExtension
}
)
最后是安装
pip install .