Lss-bev系列-2-部署插件IndexPut
总结
在导出该项目onnx时候,会产生不支持算子的报错。这里主要分析如下这段话进行导出自定义插件。这部分主要做的操作是按索引对final中的数据赋值成x的一行值。
python
final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = x
上面的写法等价下面的写法
python
for i in range(len(geom_feats)):
b = geom_feats[i, 3] # batch索引
z = geom_feats[i, 2] # Z索引
xi = geom_feats[i, 0] # X索引
yi = geom_feats[i, 1] # Y索引
final[b, :, z, xi, yi] = x[i] # 把第i个点的特征写进去
自定义插件-pytorch
注意:在这部分的时候我这里给geom_feats从long类型转换成了int32的类型,因为8.多的版本自定义插件中对long是不支持的,所以转换到32位上进行计算
python
class IndexPutFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, geom_feats, final):
final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = x
return final
@staticmethod
def symbolic(g, x, geom_feats, final):
output = g.op("xyz.onnx.contrib::IndexPut", x, geom_feats, final)
# 明确设置输出类型与 final 相同
output.setType(final.type())
return output
在pytorch中使用此计算规则,替换不可直接支持的计算语句
python
# final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = x
final = IndexPutFunction.apply(x, geom_feats, final)
自定义插件-c++
cu文件
cpp
#include <cuda_runtime.h>
#include <stdio.h>
// ============================================================
// x = inputs[0] FLOAT [N, C]
// geom_feats = inputs[1] INT32 [N, 4] → (xi, yi, zi, bi)
// final = inputs[2] FLOAT [B,C,Z,X,Y]
// output = outputs[0] FLOAT [B,C,Z,X,Y]
// ============================================================
__global__ void indexPutKernel(
const float* x,
const int32_t* geom_feats,
const float* final_in,
float* output,
int32_t N, int32_t C,
int32_t B, int32_t Z,
int32_t X, int32_t Y)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int totalSize = B * C * Z * X * Y;
for (int i = idx; i < totalSize; i += blockDim.x * gridDim.x) {
output[i] = final_in[i];
}
}
__global__ void indexPutScatterKernel(
const float* x,
const int32_t* geom_feats,
float* output,
int32_t N, int32_t C,
int32_t B, int32_t Z,
int32_t X, int32_t Y)
{
// 总索引
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) return;
// 读取体素坐标: geom_feats[idx] = (xi, yi, zi, bi)
int32_t xi = geom_feats[idx * 4 + 0];
int32_t yi = geom_feats[idx * 4 + 1];
int32_t zi = geom_feats[idx * 4 + 2];
int32_t bi = geom_feats[idx * 4 + 3];
// 边界检查
if (xi < 0 || xi >= X ||
yi < 0 || yi >= Y ||
zi < 0 || zi >= Z ||
bi < 0 || bi >= B) {
return;
}
// 写入所有通道: output[bi, c, zi, xi, yi] = x[idx, c]
for (int32_t c = 0; c < C; ++c) {
int32_t out_idx = bi * (C * Z * X * Y)
+ c * (Z * X * Y)
+ zi * (X * Y)
+ xi * Y
+ yi;
int32_t x_idx = idx * C + c;
// 直接赋值
output[out_idx] = x[x_idx];
}
}
// ── 主入口函数 ───────────────────────────────────────────────
extern "C" void indexPutCUDA(
const float* x,
const int32_t* geom_feats,
const float* final_in,
float* output,
int32_t N, int32_t C,
int32_t B, int32_t Z,
int32_t X, int32_t Y,
cudaStream_t stream)
{
int totalSize = B * C * Z * X * Y;
int threadsPerBlock = 256;// 每个block设置256个线程
// ── Step1: 复制 final_in → output ───────────────────────
cudaMemcpyAsync(
output,
final_in,
sizeof(float) * totalSize,
cudaMemcpyDeviceToDevice,
stream
);
// ── Step2: Scatter x → output ────
if (N <= 0) return;
int blocksPerGrid = (N + threadsPerBlock - 1) / threadsPerBlock;
indexPutScatterKernel<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
x, geom_feats, output,
N, C, B, Z, X, Y
);
// 错误检查
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("[indexPutCUDA] CUDA Error: %s\n", cudaGetErrorString(err));
}
}
cpp文件
cpp
#include "IndexPutPlugin.h"
#include <cassert>
#include <cstring>
#include <iostream>
#include <vector>
using namespace nvinfer1;
using namespace nvinfer1::plugin;
namespace {
const char* PLUGIN_VERSION{"1"};
const char* PLUGIN_NAME{"IndexPut"};
} // namespace
// ── CUDA 核函数声明 ─────────────────────────────────────────
extern "C" void indexPutCUDA(
const float* x,
const int32_t* geom_feats,
const float* final_in,
float* output,
int32_t N, int32_t C,
int32_t B, int32_t Z,
int32_t X, int32_t Y,
cudaStream_t stream);
// ── 序列化工具 ───────────────────────────────────────────────
template<typename T>
static void writeBuf(char*& buf, const T& val) {
std::memcpy(buf, &val, sizeof(T));
buf += sizeof(T);
}
template<typename T>
static T readBuf(const char*& buf) {
T val;
std::memcpy(&val, buf, sizeof(T));
buf += sizeof(T);
return val;
}
// ── 构造函数 ────────────────────────────────────────────────
IndexPutPlugin::IndexPutPlugin()
: mNamespace("")
, mDataType(DataType::kFLOAT)
, mN(0), mC(0), mB(0), mZ(0), mX(0), mY(0)
{}
IndexPutPlugin::IndexPutPlugin(const void* data, size_t length)
{
const char* d = static_cast<const char*>(data);
const char* const start = d;
mDataType = readBuf<DataType>(d);
mN = readBuf<int32_t>(d);
mC = readBuf<int32_t>(d);
mB = readBuf<int32_t>(d);
mZ = readBuf<int32_t>(d);
mX = readBuf<int32_t>(d);
mY = readBuf<int32_t>(d);
if (static_cast<size_t>(d - start) != length) {
std::cerr << "[IndexPutPlugin] deserialize size mismatch: "
<< (d - start) << " vs " << length << "\n";
}
}
IndexPutPlugin::~IndexPutPlugin() { terminate(); }
IPluginV2DynamicExt* IndexPutPlugin::clone() const noexcept
{
size_t sz = getSerializationSize();
std::vector<char> buf(sz);
serialize(buf.data());
auto* p = new IndexPutPlugin(buf.data(), sz);
p->setPluginNamespace(mNamespace.c_str());
return p;
}
DimsExprs IndexPutPlugin::getOutputDimensions(
int32_t outputIndex,
const DimsExprs* inputs,
int32_t nbInputs,
IExprBuilder& exprBuilder) noexcept
{
// inputs[0] = x [N, C]
// inputs[1] = geom_feats [N, 4]
// inputs[2] = final [B, C, Z, X, Y] ← 输出形状与此一致
assert(outputIndex == 0 && nbInputs == 3);
return inputs[2];
}
bool IndexPutPlugin::supportsFormatCombination(
int32_t pos,
const PluginTensorDesc* inOut,
int32_t nbInputs,
int32_t nbOutputs) noexcept
{
// pos 0 → x [N, C] FLOAT LINEAR
// pos 1 → geom_feats [N, 4] INT32 LINEAR
// pos 2 → final [B,C,Z,X,Y] FLOAT LINEAR
// pos 3 → output [B,C,Z,X,Y] FLOAT LINEAR
assert(nbInputs == 3 && nbOutputs == 1);
const auto& desc = inOut[pos];
// 所有张量必须是 LINEAR 格式
if (desc.format != TensorFormat::kLINEAR) {
return false;
}
switch (pos) {
case 0: // x: FLOAT
return desc.type == DataType::kFLOAT;
case 1: // geom_feats: INT32
return desc.type == DataType::kINT32;
case 2: // final: FLOAT
return desc.type == DataType::kFLOAT;
case 3: // output: FLOAT
return desc.type == DataType::kFLOAT;
default:
return false;
}
}
void IndexPutPlugin::configurePlugin(
const DynamicPluginTensorDesc* in,
int32_t nbInputs,
const DynamicPluginTensorDesc* out,
int32_t nbOutputs) noexcept
{
assert(nbInputs == 3 && nbOutputs == 1);
// inputs[0] = x [N, C] FLOAT ← 从这里取 N, C
// inputs[1] = geom_feats [N, 4] INT32
// inputs[2] = final [B,C,Z,X,Y] FLOAT ← 从这里取 B,Z,X,Y
// 从 x (inputs[0]) 提取 [N, C]
const Dims& xdims = in[0].desc.dims;
if (xdims.nbDims >= 2) {
mN = xdims.d[0]; // N: 有效点数(可能是 -1 动态)
mC = xdims.d[1]; // C: 通道数
}
// 从 final (inputs[2]) 提取 [B, C, Z, X, Y]
const Dims& fd = in[2].desc.dims;
if (fd.nbDims == 5) {
mB = fd.d[0]; // B
mZ = fd.d[2]; // Z ≈ 1
mX = fd.d[3]; // X ≈ 200
mY = fd.d[4]; // Y ≈ 200
}
mDataType = in[0].desc.type;
// 确认输出格式
if (out[0].desc.type != DataType::kFLOAT ||
out[0].desc.format != TensorFormat::kLINEAR) {
}
}
size_t IndexPutPlugin::getWorkspaceSize(
const PluginTensorDesc*, int32_t,
const PluginTensorDesc*, int32_t) const noexcept
{
return 0;
}
int32_t IndexPutPlugin::enqueue(
const PluginTensorDesc* inputDesc,
const PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* /*workspace*/,
cudaStream_t stream) noexcept
{
try {
// inputs[0] = x FLOAT [N, C]
// inputs[1] = geom_feats INT32 [N, 4]
// inputs[2] = final FLOAT [B,C,Z,X,Y]
const float* x = static_cast<const float*> (inputs[0]);
const int32_t* geom_feats = static_cast<const int32_t*>(inputs[1]);
const float* final_in = static_cast<const float*> (inputs[2]);
float* output = static_cast<float*> (outputs[0]);
// 从 x (inputDesc[0]) 读取 [N, C]
const int32_t N = inputDesc[0].dims.d[0];
const int32_t C = inputDesc[0].dims.d[1];
// 从 output (outputDesc[0]) 读取 [B, C, Z, X, Y]
const Dims& od = outputDesc[0].dims;
assert(od.nbDims == 5);
const int32_t B = od.d[0];
const int32_t Z = od.d[2];
const int32_t X = od.d[3];
const int32_t Y = od.d[4];
// 调用 CUDA kernel
indexPutCUDA(
x, geom_feats, final_in, output,
N, C, B, Z, X, Y,
stream
);
return 0;
} catch (const std::exception& e) {
std::cerr << "[IndexPutPlugin::enqueue] " << e.what() << "\n";
return -1;
}
}
// ── IPluginV2 基础 ───────────────────────────────────────────
const char* IndexPutPlugin::getPluginType() const noexcept { return PLUGIN_NAME; }
const char* IndexPutPlugin::getPluginVersion() const noexcept { return PLUGIN_VERSION; }
int32_t IndexPutPlugin::getNbOutputs() const noexcept { return 1; }
int32_t IndexPutPlugin::initialize() noexcept { return 0; }
void IndexPutPlugin::terminate() noexcept {}
void IndexPutPlugin::destroy() noexcept { delete this; }
void IndexPutPlugin::setPluginNamespace(const char* ns) noexcept {
mNamespace = (ns != nullptr) ? ns : "";
}
const char* IndexPutPlugin::getPluginNamespace() const noexcept {
return mNamespace.c_str();
}
// ── 序列化 ───────────────────────────────────────────────────
size_t IndexPutPlugin::getSerializationSize() const noexcept {
// DataType(4) + N(4) + C(4) + B(4) + Z(4) + X(4) + Y(4) = 28 bytes
return sizeof(DataType) + 6 * sizeof(int32_t);
}
void IndexPutPlugin::serialize(void* buffer) const noexcept {
char* d = static_cast<char*>(buffer);
const char* const start = d;
writeBuf(d, mDataType);
writeBuf(d, mN);
writeBuf(d, mC);
writeBuf(d, mB);
writeBuf(d, mZ);
writeBuf(d, mX);
writeBuf(d, mY);
assert(d == start + getSerializationSize());
}
// ── IPluginV2Ext ─────────────────────────────────────────────
DataType IndexPutPlugin::getOutputDataType(
int32_t index,
const nvinfer1::DataType* /*inputTypes*/,
int32_t /*nbInputs*/) const noexcept
{
assert(index == 0);
// 输出永远是 FLOAT,不受 INT32 输入影响
return DataType::kFLOAT;
}
// ============================================================
// IndexPutPluginCreator
// ============================================================
PluginFieldCollection IndexPutPluginCreator::mFC{};
std::vector<PluginField> IndexPutPluginCreator::mPluginAttributes;
IndexPutPluginCreator::IndexPutPluginCreator() {
mPluginAttributes.clear();
mFC.nbFields = 0;
mFC.fields = nullptr;
}
const char* IndexPutPluginCreator::getPluginName() const noexcept { return PLUGIN_NAME; }
const char* IndexPutPluginCreator::getPluginVersion() const noexcept { return PLUGIN_VERSION; }
const PluginFieldCollection* IndexPutPluginCreator::getFieldNames() noexcept { return &mFC; }
IPluginV2* IndexPutPluginCreator::createPlugin(
const char* /*name*/,
const PluginFieldCollection* /*fc*/) noexcept
{
auto* p = new IndexPutPlugin();
p->setPluginNamespace(mNamespace.c_str());
return p;
}
IPluginV2* IndexPutPluginCreator::deserializePlugin(
const char* /*name*/,
const void* serialData,
size_t serialLength) noexcept
{
auto* p = new IndexPutPlugin(serialData, serialLength);
p->setPluginNamespace(mNamespace.c_str());
return p;
}
void IndexPutPluginCreator::setPluginNamespace(const char* ns) noexcept {
mNamespace = (ns != nullptr) ? ns : "";
}
const char* IndexPutPluginCreator::getPluginNamespace() const noexcept {
return mNamespace.c_str();
}
REGISTER_TENSORRT_PLUGIN(IndexPutPluginCreator);
h文件
cpp
#ifndef INDEX_PUT_PLUGIN_H
#define INDEX_PUT_PLUGIN_H
#include <NvInfer.h>
#include <vector>
#include <string>
#include <cuda_runtime.h>
namespace nvinfer1 {
namespace plugin {
// ============================================================
// inputs[0]: x [N, C] kFLOAT 有效特征点
// inputs[1]: geom_feats [N, 4] kINT32 体素坐标(xi,yi,zi,bi)
// inputs[2]: final [B,C,Z,X,Y] kFLOAT 零初始化体素
//
// output[0]: [B,C,Z,X,Y] kFLOAT 填充后体素
// ============================================================
class IndexPutPlugin : public IPluginV2DynamicExt {
public:
IndexPutPlugin();
explicit IndexPutPlugin(const void* data, size_t length);
~IndexPutPlugin() override;
// ── IPluginV2DynamicExt ──────────────────────────────────
IPluginV2DynamicExt* clone() const noexcept override;
DimsExprs getOutputDimensions(
int32_t outputIndex,
const DimsExprs* inputs,
int32_t nbInputs,
IExprBuilder& exprBuilder) noexcept override;
bool supportsFormatCombination(
int32_t pos,
const PluginTensorDesc* inOut,
int32_t nbInputs,
int32_t nbOutputs) noexcept override;
void configurePlugin(
const DynamicPluginTensorDesc* in,
int32_t nbInputs,
const DynamicPluginTensorDesc* out,
int32_t nbOutputs) noexcept override;
size_t getWorkspaceSize(
const PluginTensorDesc* inputs,
int32_t nbInputs,
const PluginTensorDesc* outputs,
int32_t nbOutputs) const noexcept override;
int32_t enqueue(
const PluginTensorDesc* inputDesc,
const PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) noexcept override;
// ── IPluginV2 ────────────────────────────────────────────
const char* getPluginType() const noexcept override;
const char* getPluginVersion() const noexcept override;
int32_t getNbOutputs() const noexcept override;
int32_t initialize() noexcept override;
void terminate() noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
void destroy() noexcept override;
void setPluginNamespace(const char* ns) noexcept override;
const char* getPluginNamespace() const noexcept override;
// ── IPluginV2Ext ─────────────────────────────────────────
DataType getOutputDataType(
int32_t index,
const nvinfer1::DataType* inputTypes,
int32_t nbInputs) const noexcept override;
private:
std::string mNamespace;
// 运行时维度(来自 configurePlugin)
// geom_feats: [mN, 4]
// x: [mN, mC]
// final: [mB, mC, mZ, mX, mY]
int32_t mN{0}; // 有效点数(动态)
int32_t mC{0}; // 通道数
int32_t mB{0}; // batch size
int32_t mZ{0}; // nx[2] ≈ 1
int32_t mX{0}; // nx[0] ≈ 200
int32_t mY{0}; // nx[1] ≈ 200
DataType mDataType{DataType::kFLOAT};
};
// ============================================================
class IndexPutPluginCreator : public IPluginCreator {
public:
IndexPutPluginCreator();
~IndexPutPluginCreator() override = default;
const char* getPluginName() const noexcept override;
const char* getPluginVersion() const noexcept override;
const PluginFieldCollection* getFieldNames() noexcept override;
IPluginV2* createPlugin(
const char* name,
const PluginFieldCollection* fc) noexcept override;
IPluginV2* deserializePlugin(
const char* name,
const void* serialData,
size_t serialLength) noexcept override;
void setPluginNamespace(const char* ns) noexcept override;
const char* getPluginNamespace() const noexcept override;
private:
static PluginFieldCollection mFC;
static std::vector<PluginField> mPluginAttributes;
std::string mNamespace;
};
} // namespace plugin
} // namespace nvinfer1
#endif // INDEX_PUT_PLUGIN_H
测试
写好插件后为了判定与python的实现内容是一致的最好的验证方式是写一个验证脚本,我们可以把输入结果pytorch的递推结果中拿出来,然后构建一个只有插件这一层的trt引擎,将同样的输入给到pytorch代码和trt代码,通过对比结果来看插件的工作是否是有效的。
保存的例子
python
# np.save(f"{debug_dir}/x.npy", x.cpu().numpy())
# np.save(f"{debug_dir}/geom_feats.npy", geom_feats_int32.cpu().numpy())
# np.save(f"{debug_dir}/final.npy", final.cpu().numpy())
加载并对比插件
python
# TensorRT IndexPut 插件验证
import torch
import numpy as np
import tensorrt as trt
import os
import sys
import ctypes
from typing import Tuple, Optional
# 插件路径
PLUGIN_PATH = os.path.join(os.path.dirname(__file__), "..", "cpp", "build", "lib", "liblss_trt_plugins.so")
PLUGIN_PATH = os.path.abspath(PLUGIN_PATH)
def load_trt_plugins():
"""加载 TensorRT 插件库(使用 ctypes 触发静态初始化)"""
print(f"Loading plugin: {PLUGIN_PATH}")
if os.path.exists(PLUGIN_PATH):
# 使用 ctypes 加载库,触发静态初始化函数
ctypes.CDLL(PLUGIN_PATH)
trt.init_libnvinfer_plugins(None, "")
print("Plugin loaded successfully!")
else:
raise FileNotFoundError(f"Plugin not found: {PLUGIN_PATH}")
def create_index_put_engine():
"""创建包含 IndexPut 插件的 TensorRT 引擎"""
logger = trt.Logger(trt.Logger.VERBOSE)
builder = trt.Builder(logger)
# 网络配置
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
# pos 0: x (FLOAT) [N, C]
# pos 1: geom_feats (INT32) [N, 4]
# pos 2: final (FLOAT) [C, Z, X, Y]
input_tensor_0 = network.add_input("x", trt.DataType.FLOAT, [-1, 64]) # x: [N, C] FLOAT
input_tensor_1 = network.add_input("geom_feats", trt.DataType.INT32, [-1, 4]) # geom_feats: [N, 4] INT32
input_tensor_2 = network.add_input("final", trt.DataType.FLOAT, [-1, 64, 1, 200, 200]) # final: [C, Z, X, Y]
# 获取插件
plugin_registry = trt.get_plugin_registry()
plugin_creator = plugin_registry.get_plugin_creator("IndexPut", "1", "")
if plugin_creator is None:
print("ERROR: IndexPut plugin not found in registry!")
print("Available plugins:")
for plugin in plugin_registry.plugin_namespace_list:
print(f" - {plugin}")
return None
# 创建空的 PluginFieldCollection
fc = trt.PluginFieldCollection()
# 创建插件
plugin = plugin_creator.create_plugin("IndexPut", fc)
# 添加插件层 - 注意:输入顺序必须与插件期望一致:x, geom_feats, final
inverse_layer = network.add_plugin_v2([input_tensor_0, input_tensor_1, input_tensor_2], plugin)
# 输出
network.mark_output(inverse_layer.get_output(0))
# 构建引擎
config = builder.create_builder_config()
config.max_workspace_size = 1 << 25 # 32MB
# 设置优化配置文件 - 修正 input2 的最小/最优/最大形状
profile = builder.create_optimization_profile()
# input0: x [N, C], C 固定为 64
profile.set_shape("x", [1, 64], [10000, 64], [100000, 64])
# input1: geom_feats [N, 4], N 是动态的
profile.set_shape("geom_feats", [1, 4], [10000, 4], [100000, 4])
# input2: final [C, Z, X, Y], 全部固定
profile.set_shape("final", [1, 64, 1, 200, 200], [1, 64, 1, 200, 200], [1, 64, 1, 200, 200])
config.add_optimization_profile(profile)
# 构建引擎
engine = builder.build_engine(network, config)
if engine is None:
print("ERROR: Engine build failed!")
return None
print("Engine built successfully!")
return engine
def run_trt_inference(engine, x, geom_feats, final) -> np.ndarray:
"""使用 TensorRT 引擎运行推理"""
context = engine.create_execution_context()
# 转换输入 - 顺序:x (FLOAT), geom_feats (INT32), final (FLOAT)
x_input_np = x.float().cpu().numpy() if isinstance(x, torch.Tensor) else x.astype(np.float32)
geom_feats_np = geom_feats.long().cpu().numpy() if isinstance(geom_feats, torch.Tensor) else geom_feats.astype(np.int32)
# 确保 geom_feats 是 int32(插件期望)
if geom_feats_np.dtype == np.int64:
geom_feats_np = geom_feats_np.astype(np.int32)
final_np = final.float().cpu().numpy() if isinstance(final, torch.Tensor) else final.astype(np.float32)
print("x (input0)", x_input_np.shape, x_input_np.dtype)
print("geom_feats (input1)", geom_feats_np.shape, geom_feats_np.dtype)
print("final (input2)", final_np.shape, final_np.dtype)
# 设置输入形状(动态形状需要)
context.set_input_shape("x", x_input_np.shape)
context.set_input_shape("geom_feats", geom_feats_np.shape)
context.set_input_shape("final", final_np.shape)
output_np = np.empty(final.shape, dtype=np.float32)
# 复制输入到 GPU - 顺序:x, geom_feats, final
d_input_0 = torch.cuda.FloatTensor(x_input_np).cuda() # x: FLOAT
d_input_1 = torch.cuda.IntTensor(geom_feats_np).cuda() # geom_feats: INT32
d_input_2 = torch.cuda.FloatTensor(final_np).cuda() # final: FLOAT
d_output = torch.cuda.FloatTensor(output_np).cuda()
# 推理
context.execute_v2([d_input_0.data_ptr(), d_input_1.data_ptr(), d_input_2.data_ptr(), d_output.data_ptr()])
# 同步并返回
torch.cuda.synchronize()
return d_output.cpu().numpy()
class IndexPutFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, geom_feats, final):
final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = x
return final
@staticmethod
def symbolic(g, x, geom_feats, final):
# ONNX符号化实现
# IndexPut 的输出形状与 final 输入相同(in-place 操作)
output = g.op("xyz.onnx.contrib::IndexPut", x, geom_feats, final)
# 明确设置输出类型与 final 相同
output.setType(final.type())
return output
def torch_index_put(x, geom_feats, final) -> torch.Tensor:
"""PyTorch 矩阵求逆(基准)"""
return IndexPutFunction.apply(x, geom_feats, final)
def verify_with_npy(data_path: str) -> Tuple[bool, dict]:
"""
使用 .npy 文件验证插件正确性
Args:
data_path: .npy 文件路径
Returns:
(是否通过, 统计信息)
"""
# 加载数据
x = np.load(data_path + "x.npy")
geom_feats = np.load(data_path + "geom_feats.npy")
final = np.load(data_path + "final.npy")
print(f"形状x : {x.shape}, dtype: {x.dtype}")
print(f"形状geom_feats: {geom_feats.shape}, dtype: {geom_feats.dtype}")
print(f"形状final : {final.shape}, dtype: {final.dtype}")
# 转换为 tensor
x = torch.from_numpy(x).float()
geom_feats = torch.from_numpy(geom_feats).long()
final = torch.from_numpy(final).float()
return verify_with_tensor(x, geom_feats, final)
def verify_with_tensor(x, geom_feats, final) -> Tuple[bool, dict]:
"""
使用 tensor 数据验证插件正确性
"""
# 加载插件
load_trt_plugins()
engine = create_index_put_engine()
if engine is None:
print("ERROR: Failed to create TensorRT engine!")
return False, {"total_elements": 0, "max_error": float('inf'), "mean_error": float('inf'), "passed": False}
# PyTorch 结果
pytorch_results = torch_index_put(x, geom_feats, final)
# TensorRT 结果
trt_results = run_trt_inference(engine, x, geom_feats, final)
trt_results = torch.from_numpy(trt_results)
# 计算元素误差
diff = torch.abs(pytorch_results - trt_results)
max_error = float(torch.max(diff))
mean_error = float(torch.mean(diff))
# 统计信息
stats = {
"total_elements": int(pytorch_results.numel()),
"max_error": max_error,
"mean_error": mean_error,
"passed": max_error < 1e-5
}
return stats["passed"], stats
#python3 ./py/test_index_put_plugin.py ./data/
if __name__ == "__main__":
print("=" * 50)
print("TensorRT IndexPut Plugin Verification")
print("=" * 50)
passed, stats = verify_with_npy(sys.argv[1])
print(f"\n Test Results:")
print(f" Total Elements: {stats['total_elements']}")
print(f" Max Error: {stats['max_error']:.2e}")
print(f" Mean Error: {stats['mean_error']:.2e}")
print(f"\n{'PASSED' if passed else 'FAILED'}")
print("=" * 50)
结果
bash
: CPU +0, GPU +0, now: CPU 0, GPU 0 (MiB)
[03/28/2026-06:37:43] [TRT] [V] CUDA lazy loading is enabled.
x (input0) (8597, 64) float32
geom_feats (input1) (8597, 4) int32
final (input2) (1, 64, 1, 200, 200) float32
/workspace/current_workspace/./py/test_index_put_plugin.py:112: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:78.)
d_input_0 = torch.cuda.FloatTensor(x_input_np).cuda() # x: FLOAT
Test Results:
Total Elements: 2560000
Max Error: 0.00e+00
Mean Error: 0.00e+00
PASSED
==================================================