Lss-bev系列-2-部署插件IndexPut

Lss-bev系列-2-部署插件IndexPut

总结

在导出该项目onnx时候,会产生不支持算子的报错。这里主要分析如下这段话进行导出自定义插件。这部分主要做的操作是按索引对final中的数据赋值成x的一行值。

python 复制代码
final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = x

上面的写法等价下面的写法

python 复制代码
for i in range(len(geom_feats)):
    b   = geom_feats[i, 3]   # batch索引
    z   = geom_feats[i, 2]   # Z索引
    xi  = geom_feats[i, 0]   # X索引
    yi  = geom_feats[i, 1]   # Y索引
    final[b, :, z, xi, yi] = x[i]  # 把第i个点的特征写进去

自定义插件-pytorch

注意:在这部分的时候我这里给geom_feats从long类型转换成了int32的类型,因为8.多的版本自定义插件中对long是不支持的,所以转换到32位上进行计算

python 复制代码
class IndexPutFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, geom_feats, final):
        final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = x
        return final
    
    @staticmethod
    def symbolic(g, x, geom_feats, final):
        output = g.op("xyz.onnx.contrib::IndexPut", x, geom_feats, final)
        # 明确设置输出类型与 final 相同
        output.setType(final.type())
        return output

在pytorch中使用此计算规则,替换不可直接支持的计算语句

python 复制代码
# final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = x
final = IndexPutFunction.apply(x, geom_feats, final)

自定义插件-c++

cu文件

cpp 复制代码
#include <cuda_runtime.h>
#include <stdio.h>

// ============================================================
//   x          = inputs[0]  FLOAT  [N, C]
//   geom_feats = inputs[1]  INT32  [N, 4]  → (xi, yi, zi, bi)
//   final      = inputs[2]  FLOAT  [B,C,Z,X,Y]
//   output     = outputs[0] FLOAT  [B,C,Z,X,Y]
// ============================================================

__global__ void indexPutKernel(
    const float*   x,
    const int32_t* geom_feats,
    const float*   final_in,
    float*         output,
    int32_t N, int32_t C,
    int32_t B, int32_t Z,
    int32_t X, int32_t Y)
{
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int totalSize = B * C * Z * X * Y;
    for (int i = idx; i < totalSize; i += blockDim.x * gridDim.x) {
        output[i] = final_in[i];
    }
}

__global__ void indexPutScatterKernel(
    const float*    x,
    const int32_t*  geom_feats,
    float*          output,
    int32_t N, int32_t C,
    int32_t B, int32_t Z,
    int32_t X, int32_t Y)
{
    // 总索引
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= N) return;

    // 读取体素坐标: geom_feats[idx] = (xi, yi, zi, bi)
    int32_t xi = geom_feats[idx * 4 + 0];
    int32_t yi = geom_feats[idx * 4 + 1];
    int32_t zi = geom_feats[idx * 4 + 2];
    int32_t bi = geom_feats[idx * 4 + 3];

    // 边界检查
    if (xi < 0 || xi >= X ||
        yi < 0 || yi >= Y ||
        zi < 0 || zi >= Z ||
        bi < 0 || bi >= B) {
        return;
    }

    // 写入所有通道: output[bi, c, zi, xi, yi] = x[idx, c]
    for (int32_t c = 0; c < C; ++c) {
        int32_t out_idx = bi * (C * Z * X * Y)
                        + c  * (Z * X * Y)
                        + zi * (X * Y)
                        + xi * Y
                        + yi;

        int32_t x_idx = idx * C + c;

        // 直接赋值
        output[out_idx] = x[x_idx];
    }
}

// ── 主入口函数 ───────────────────────────────────────────────
extern "C" void indexPutCUDA(
    const float*    x,
    const int32_t*  geom_feats,
    const float*    final_in,
    float*          output,
    int32_t N, int32_t C,
    int32_t B, int32_t Z,
    int32_t X, int32_t Y,
    cudaStream_t    stream)
{
    int totalSize = B * C * Z * X * Y;
    int threadsPerBlock = 256;// 每个block设置256个线程

    // ── Step1: 复制 final_in → output ───────────────────────
    cudaMemcpyAsync(
        output,
        final_in,
        sizeof(float) * totalSize,
        cudaMemcpyDeviceToDevice,
        stream
    );

    // ── Step2: Scatter x → output ────
    if (N <= 0) return;

    int blocksPerGrid = (N + threadsPerBlock - 1) / threadsPerBlock;
    indexPutScatterKernel<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
        x, geom_feats, output,
        N, C, B, Z, X, Y
    );

    // 错误检查
    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        printf("[indexPutCUDA] CUDA Error: %s\n", cudaGetErrorString(err));
    }
}

cpp文件

cpp 复制代码
#include "IndexPutPlugin.h"
#include <cassert>
#include <cstring>
#include <iostream>
#include <vector>

using namespace nvinfer1;
using namespace nvinfer1::plugin;

namespace {
    const char* PLUGIN_VERSION{"1"};
    const char* PLUGIN_NAME{"IndexPut"};
} // namespace

// ── CUDA 核函数声明 ─────────────────────────────────────────
extern "C" void indexPutCUDA(
    const float*    x,
    const int32_t*  geom_feats,
    const float*    final_in,
    float*          output,
    int32_t N, int32_t C,
    int32_t B, int32_t Z,
    int32_t X, int32_t Y,
    cudaStream_t stream);

// ── 序列化工具 ───────────────────────────────────────────────
template<typename T>
static void writeBuf(char*& buf, const T& val) {
    std::memcpy(buf, &val, sizeof(T));
    buf += sizeof(T);
}
template<typename T>
static T readBuf(const char*& buf) {
    T val;
    std::memcpy(&val, buf, sizeof(T));
    buf += sizeof(T);
    return val;
}

// ── 构造函数 ────────────────────────────────────────────────

IndexPutPlugin::IndexPutPlugin()
    : mNamespace("")
    , mDataType(DataType::kFLOAT)
    , mN(0), mC(0), mB(0), mZ(0), mX(0), mY(0)
{}

IndexPutPlugin::IndexPutPlugin(const void* data, size_t length)
{
    const char* d = static_cast<const char*>(data);
    const char* const start = d;

    mDataType = readBuf<DataType>(d);
    mN  = readBuf<int32_t>(d);
    mC  = readBuf<int32_t>(d);
    mB  = readBuf<int32_t>(d);
    mZ  = readBuf<int32_t>(d);
    mX  = readBuf<int32_t>(d);
    mY  = readBuf<int32_t>(d);

    if (static_cast<size_t>(d - start) != length) {
        std::cerr << "[IndexPutPlugin] deserialize size mismatch: "
                  << (d - start) << " vs " << length << "\n";
    }
}

IndexPutPlugin::~IndexPutPlugin() { terminate(); }

IPluginV2DynamicExt* IndexPutPlugin::clone() const noexcept
{
    size_t sz = getSerializationSize();
    std::vector<char> buf(sz);
    serialize(buf.data());

    auto* p = new IndexPutPlugin(buf.data(), sz);
    p->setPluginNamespace(mNamespace.c_str());
    return p;
}

DimsExprs IndexPutPlugin::getOutputDimensions(
    int32_t outputIndex,
    const DimsExprs* inputs,
    int32_t nbInputs,
    IExprBuilder& exprBuilder) noexcept
{
    //   inputs[0] = x          [N, C]
    //   inputs[1] = geom_feats [N, 4]
    //   inputs[2] = final      [B, C, Z, X, Y]  ← 输出形状与此一致

    assert(outputIndex == 0 && nbInputs == 3);
    return inputs[2];
}

bool IndexPutPlugin::supportsFormatCombination(
    int32_t pos,
    const PluginTensorDesc* inOut,
    int32_t nbInputs,
    int32_t nbOutputs) noexcept
{
    //   pos 0 → x           [N, C]        FLOAT  LINEAR
    //   pos 1 → geom_feats  [N, 4]        INT32  LINEAR
    //   pos 2 → final       [B,C,Z,X,Y]   FLOAT  LINEAR
    //   pos 3 → output      [B,C,Z,X,Y]   FLOAT  LINEAR

    assert(nbInputs == 3 && nbOutputs == 1);
    const auto& desc = inOut[pos];

    // 所有张量必须是 LINEAR 格式
    if (desc.format != TensorFormat::kLINEAR) {
        return false;
    }

    switch (pos) {
        case 0:  // x: FLOAT
            return desc.type == DataType::kFLOAT;

        case 1:  // geom_feats: INT32
            return desc.type == DataType::kINT32;

        case 2:  // final: FLOAT
            return desc.type == DataType::kFLOAT;

        case 3:  // output: FLOAT
            return desc.type == DataType::kFLOAT;

        default:
            return false;
    }
}

void IndexPutPlugin::configurePlugin(
    const DynamicPluginTensorDesc* in,
    int32_t nbInputs,
    const DynamicPluginTensorDesc* out,
    int32_t nbOutputs) noexcept
{
    assert(nbInputs == 3 && nbOutputs == 1);

    //   inputs[0] = x          [N, C]       FLOAT  ← 从这里取 N, C
    //   inputs[1] = geom_feats [N, 4]       INT32
    //   inputs[2] = final      [B,C,Z,X,Y]  FLOAT  ← 从这里取 B,Z,X,Y

    // 从 x (inputs[0]) 提取 [N, C]
    const Dims& xdims = in[0].desc.dims;
    if (xdims.nbDims >= 2) {
        mN = xdims.d[0];   // N: 有效点数(可能是 -1 动态)
        mC = xdims.d[1];   // C: 通道数
    }

    // 从 final (inputs[2]) 提取 [B, C, Z, X, Y]
    const Dims& fd = in[2].desc.dims;
    if (fd.nbDims == 5) {
        mB = fd.d[0];  // B
        mZ = fd.d[2];  // Z ≈ 1
        mX = fd.d[3];  // X ≈ 200
        mY = fd.d[4];  // Y ≈ 200
    }

    mDataType = in[0].desc.type;

    // 确认输出格式
    if (out[0].desc.type != DataType::kFLOAT ||
        out[0].desc.format != TensorFormat::kLINEAR) {
    }
}

size_t IndexPutPlugin::getWorkspaceSize(
    const PluginTensorDesc*, int32_t,
    const PluginTensorDesc*, int32_t) const noexcept
{
    return 0;
}

int32_t IndexPutPlugin::enqueue(
    const PluginTensorDesc* inputDesc,
    const PluginTensorDesc* outputDesc,
    const void* const* inputs,
    void* const* outputs,
    void* /*workspace*/,
    cudaStream_t stream) noexcept
{
    try {
        //   inputs[0] = x           FLOAT  [N, C]
        //   inputs[1] = geom_feats  INT32  [N, 4]
        //   inputs[2] = final       FLOAT  [B,C,Z,X,Y]
        
        const float*   x          = static_cast<const float*>  (inputs[0]);
        const int32_t* geom_feats = static_cast<const int32_t*>(inputs[1]);
        const float*   final_in   = static_cast<const float*>  (inputs[2]);
        float*         output     = static_cast<float*>        (outputs[0]);

        // 从 x (inputDesc[0]) 读取 [N, C]
        const int32_t N = inputDesc[0].dims.d[0];
        const int32_t C = inputDesc[0].dims.d[1];

        // 从 output (outputDesc[0]) 读取 [B, C, Z, X, Y]
        const Dims& od = outputDesc[0].dims;
        assert(od.nbDims == 5);
        const int32_t B = od.d[0];
        const int32_t Z = od.d[2];
        const int32_t X = od.d[3];
        const int32_t Y = od.d[4];

        // 调用 CUDA kernel
        indexPutCUDA(
            x, geom_feats, final_in, output,
            N, C, B, Z, X, Y,
            stream
        );

        return 0;
    } catch (const std::exception& e) {
        std::cerr << "[IndexPutPlugin::enqueue] " << e.what() << "\n";
        return -1;
    }
}

// ── IPluginV2 基础 ───────────────────────────────────────────

const char* IndexPutPlugin::getPluginType()    const noexcept { return PLUGIN_NAME; }
const char* IndexPutPlugin::getPluginVersion() const noexcept { return PLUGIN_VERSION; }
int32_t     IndexPutPlugin::getNbOutputs()     const noexcept { return 1; }
int32_t     IndexPutPlugin::initialize()              noexcept { return 0; }
void        IndexPutPlugin::terminate()               noexcept {}
void        IndexPutPlugin::destroy()                 noexcept { delete this; }

void IndexPutPlugin::setPluginNamespace(const char* ns) noexcept {
    mNamespace = (ns != nullptr) ? ns : "";
}
const char* IndexPutPlugin::getPluginNamespace() const noexcept {
    return mNamespace.c_str();
}

// ── 序列化 ───────────────────────────────────────────────────

size_t IndexPutPlugin::getSerializationSize() const noexcept {
    // DataType(4) + N(4) + C(4) + B(4) + Z(4) + X(4) + Y(4) = 28 bytes
    return sizeof(DataType) + 6 * sizeof(int32_t);
}

void IndexPutPlugin::serialize(void* buffer) const noexcept {
    char* d = static_cast<char*>(buffer);
    const char* const start = d;

    writeBuf(d, mDataType);
    writeBuf(d, mN);
    writeBuf(d, mC);
    writeBuf(d, mB);
    writeBuf(d, mZ);
    writeBuf(d, mX);
    writeBuf(d, mY);

    assert(d == start + getSerializationSize());
}

// ── IPluginV2Ext ─────────────────────────────────────────────

DataType IndexPutPlugin::getOutputDataType(
    int32_t index,
    const nvinfer1::DataType* /*inputTypes*/,
    int32_t /*nbInputs*/) const noexcept
{
    assert(index == 0);
    // 输出永远是 FLOAT,不受 INT32 输入影响
    return DataType::kFLOAT;
}

// ============================================================
// IndexPutPluginCreator
// ============================================================

PluginFieldCollection         IndexPutPluginCreator::mFC{};
std::vector<PluginField>      IndexPutPluginCreator::mPluginAttributes;

IndexPutPluginCreator::IndexPutPluginCreator() {
    mPluginAttributes.clear();
    mFC.nbFields = 0;
    mFC.fields   = nullptr;
}

const char* IndexPutPluginCreator::getPluginName()    const noexcept { return PLUGIN_NAME; }
const char* IndexPutPluginCreator::getPluginVersion() const noexcept { return PLUGIN_VERSION; }
const PluginFieldCollection* IndexPutPluginCreator::getFieldNames() noexcept { return &mFC; }

IPluginV2* IndexPutPluginCreator::createPlugin(
    const char* /*name*/,
    const PluginFieldCollection* /*fc*/) noexcept
{
    auto* p = new IndexPutPlugin();
    p->setPluginNamespace(mNamespace.c_str());
    return p;
}

IPluginV2* IndexPutPluginCreator::deserializePlugin(
    const char* /*name*/,
    const void* serialData,
    size_t serialLength) noexcept
{
    auto* p = new IndexPutPlugin(serialData, serialLength);
    p->setPluginNamespace(mNamespace.c_str());
    return p;
}

void IndexPutPluginCreator::setPluginNamespace(const char* ns) noexcept {
    mNamespace = (ns != nullptr) ? ns : "";
}
const char* IndexPutPluginCreator::getPluginNamespace() const noexcept {
    return mNamespace.c_str();
}

REGISTER_TENSORRT_PLUGIN(IndexPutPluginCreator);

h文件

cpp 复制代码
#ifndef INDEX_PUT_PLUGIN_H
#define INDEX_PUT_PLUGIN_H

#include <NvInfer.h>
#include <vector>
#include <string>
#include <cuda_runtime.h>

namespace nvinfer1 {
namespace plugin {

// ============================================================
//   inputs[0]: x          [N, C]        kFLOAT  有效特征点
//   inputs[1]: geom_feats [N, 4]        kINT32  体素坐标(xi,yi,zi,bi)
//   inputs[2]: final      [B,C,Z,X,Y]   kFLOAT  零初始化体素
//
//   output[0]:            [B,C,Z,X,Y]   kFLOAT  填充后体素
// ============================================================

class IndexPutPlugin : public IPluginV2DynamicExt {
public:
    IndexPutPlugin();
    explicit IndexPutPlugin(const void* data, size_t length);
    ~IndexPutPlugin() override;

    // ── IPluginV2DynamicExt ──────────────────────────────────
    IPluginV2DynamicExt* clone() const noexcept override;

    DimsExprs getOutputDimensions(
        int32_t outputIndex,
        const DimsExprs* inputs,
        int32_t nbInputs,
        IExprBuilder& exprBuilder) noexcept override;

    bool supportsFormatCombination(
        int32_t pos,
        const PluginTensorDesc* inOut,
        int32_t nbInputs,
        int32_t nbOutputs) noexcept override;

    void configurePlugin(
        const DynamicPluginTensorDesc* in,
        int32_t nbInputs,
        const DynamicPluginTensorDesc* out,
        int32_t nbOutputs) noexcept override;

    size_t getWorkspaceSize(
        const PluginTensorDesc* inputs,
        int32_t nbInputs,
        const PluginTensorDesc* outputs,
        int32_t nbOutputs) const noexcept override;

    int32_t enqueue(
        const PluginTensorDesc* inputDesc,
        const PluginTensorDesc* outputDesc,
        const void* const* inputs,
        void* const* outputs,
        void* workspace,
        cudaStream_t stream) noexcept override;

    // ── IPluginV2 ────────────────────────────────────────────
    const char*  getPluginType()    const noexcept override;
    const char*  getPluginVersion() const noexcept override;
    int32_t      getNbOutputs()     const noexcept override;
    int32_t      initialize()              noexcept override;
    void         terminate()               noexcept override;
    size_t       getSerializationSize() const noexcept override;
    void         serialize(void* buffer)  const noexcept override;
    void         destroy()                noexcept override;
    void         setPluginNamespace(const char* ns) noexcept override;
    const char*  getPluginNamespace()   const noexcept override;

    // ── IPluginV2Ext ─────────────────────────────────────────
    DataType getOutputDataType(
        int32_t index,
        const nvinfer1::DataType* inputTypes,
        int32_t nbInputs) const noexcept override;

private:
    std::string mNamespace;

    // 运行时维度(来自 configurePlugin)
    // geom_feats: [mN, 4]
    // x:          [mN, mC]
    // final:      [mB, mC, mZ, mX, mY]
    int32_t mN{0};   // 有效点数(动态)
    int32_t mC{0};   // 通道数
    int32_t mB{0};   // batch size
    int32_t mZ{0};   // nx[2] ≈ 1
    int32_t mX{0};   // nx[0] ≈ 200
    int32_t mY{0};   // nx[1] ≈ 200

    DataType mDataType{DataType::kFLOAT};
};

// ============================================================
class IndexPutPluginCreator : public IPluginCreator {
public:
    IndexPutPluginCreator();
    ~IndexPutPluginCreator() override = default;

    const char* getPluginName()    const noexcept override;
    const char* getPluginVersion() const noexcept override;
    const PluginFieldCollection* getFieldNames() noexcept override;

    IPluginV2* createPlugin(
        const char* name,
        const PluginFieldCollection* fc) noexcept override;

    IPluginV2* deserializePlugin(
        const char* name,
        const void* serialData,
        size_t serialLength) noexcept override;

    void        setPluginNamespace(const char* ns) noexcept override;
    const char* getPluginNamespace() const noexcept override;

private:
    static PluginFieldCollection    mFC;
    static std::vector<PluginField> mPluginAttributes;
    std::string mNamespace;
};

} // namespace plugin
} // namespace nvinfer1

#endif // INDEX_PUT_PLUGIN_H

测试

写好插件后为了判定与python的实现内容是一致的最好的验证方式是写一个验证脚本,我们可以把输入结果pytorch的递推结果中拿出来,然后构建一个只有插件这一层的trt引擎,将同样的输入给到pytorch代码和trt代码,通过对比结果来看插件的工作是否是有效的。

保存的例子

python 复制代码
# np.save(f"{debug_dir}/x.npy", x.cpu().numpy())
# np.save(f"{debug_dir}/geom_feats.npy", geom_feats_int32.cpu().numpy())
# np.save(f"{debug_dir}/final.npy", final.cpu().numpy())

加载并对比插件

python 复制代码
# TensorRT IndexPut 插件验证
import torch
import numpy as np
import tensorrt as trt
import os
import sys
import ctypes
from typing import Tuple, Optional

# 插件路径
PLUGIN_PATH = os.path.join(os.path.dirname(__file__), "..", "cpp", "build", "lib", "liblss_trt_plugins.so")
PLUGIN_PATH = os.path.abspath(PLUGIN_PATH)


def load_trt_plugins():
    """加载 TensorRT 插件库(使用 ctypes 触发静态初始化)"""
    print(f"Loading plugin: {PLUGIN_PATH}")
    if os.path.exists(PLUGIN_PATH):
        # 使用 ctypes 加载库,触发静态初始化函数
        ctypes.CDLL(PLUGIN_PATH)
        trt.init_libnvinfer_plugins(None, "")
        print("Plugin loaded successfully!")
    else:
        raise FileNotFoundError(f"Plugin not found: {PLUGIN_PATH}")


def create_index_put_engine():
    """创建包含 IndexPut 插件的 TensorRT 引擎"""
    logger = trt.Logger(trt.Logger.VERBOSE)
    builder = trt.Builder(logger)
    
    # 网络配置
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    # pos 0: x (FLOAT) [N, C]
    # pos 1: geom_feats (INT32) [N, 4]
    # pos 2: final (FLOAT) [C, Z, X, Y]
    input_tensor_0 = network.add_input("x", trt.DataType.FLOAT, [-1, 64])           # x: [N, C] FLOAT
    input_tensor_1 = network.add_input("geom_feats", trt.DataType.INT32, [-1, 4])   # geom_feats: [N, 4] INT32
    input_tensor_2 = network.add_input("final", trt.DataType.FLOAT, [-1, 64, 1, 200, 200])  # final: [C, Z, X, Y]

    # 获取插件
    plugin_registry = trt.get_plugin_registry()
    plugin_creator = plugin_registry.get_plugin_creator("IndexPut", "1", "")
    if plugin_creator is None:
        print("ERROR: IndexPut plugin not found in registry!")
        print("Available plugins:")
        for plugin in plugin_registry.plugin_namespace_list:
            print(f"  - {plugin}")
        return None
    
    # 创建空的 PluginFieldCollection
    fc = trt.PluginFieldCollection()
    
    # 创建插件
    plugin = plugin_creator.create_plugin("IndexPut", fc)
    
    # 添加插件层 - 注意:输入顺序必须与插件期望一致:x, geom_feats, final
    inverse_layer = network.add_plugin_v2([input_tensor_0, input_tensor_1, input_tensor_2], plugin)
    
    # 输出
    network.mark_output(inverse_layer.get_output(0))
    
    # 构建引擎
    config = builder.create_builder_config()
    config.max_workspace_size = 1 << 25  # 32MB
    
    # 设置优化配置文件 - 修正 input2 的最小/最优/最大形状
    profile = builder.create_optimization_profile()
    # input0: x [N, C], C 固定为 64
    profile.set_shape("x", [1, 64], [10000, 64], [100000, 64])
    # input1: geom_feats [N, 4], N 是动态的
    profile.set_shape("geom_feats", [1, 4], [10000, 4], [100000, 4])
    # input2: final [C, Z, X, Y], 全部固定
    profile.set_shape("final", [1, 64, 1, 200, 200], [1, 64, 1, 200, 200], [1, 64, 1, 200, 200])

    config.add_optimization_profile(profile)
    
    # 构建引擎
    engine = builder.build_engine(network, config)
    if engine is None:
        print("ERROR: Engine build failed!")
        return None
    
    print("Engine built successfully!")
    return engine


def run_trt_inference(engine, x, geom_feats, final) -> np.ndarray:
    """使用 TensorRT 引擎运行推理"""
    context = engine.create_execution_context()
    
    # 转换输入 - 顺序:x (FLOAT), geom_feats (INT32), final (FLOAT)
    x_input_np = x.float().cpu().numpy() if isinstance(x, torch.Tensor) else x.astype(np.float32)
    geom_feats_np = geom_feats.long().cpu().numpy() if isinstance(geom_feats, torch.Tensor) else geom_feats.astype(np.int32)
    # 确保 geom_feats 是 int32(插件期望)
    if geom_feats_np.dtype == np.int64:
        geom_feats_np = geom_feats_np.astype(np.int32)
    final_np = final.float().cpu().numpy() if isinstance(final, torch.Tensor) else final.astype(np.float32)
    
    print("x (input0)", x_input_np.shape, x_input_np.dtype)
    print("geom_feats (input1)", geom_feats_np.shape, geom_feats_np.dtype)
    print("final (input2)", final_np.shape, final_np.dtype)
    
    # 设置输入形状(动态形状需要)
    context.set_input_shape("x", x_input_np.shape)
    context.set_input_shape("geom_feats", geom_feats_np.shape)
    context.set_input_shape("final", final_np.shape)
    
    output_np = np.empty(final.shape, dtype=np.float32)
    
    # 复制输入到 GPU - 顺序:x, geom_feats, final
    d_input_0 = torch.cuda.FloatTensor(x_input_np).cuda()      # x: FLOAT
    d_input_1 = torch.cuda.IntTensor(geom_feats_np).cuda()      # geom_feats: INT32
    d_input_2 = torch.cuda.FloatTensor(final_np).cuda()         # final: FLOAT

    d_output = torch.cuda.FloatTensor(output_np).cuda()
    
    # 推理
    context.execute_v2([d_input_0.data_ptr(), d_input_1.data_ptr(), d_input_2.data_ptr(), d_output.data_ptr()])
    
    # 同步并返回
    torch.cuda.synchronize()
    return d_output.cpu().numpy()


class IndexPutFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, geom_feats, final):
        final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = x
        return final
    
    @staticmethod
    def symbolic(g, x, geom_feats, final):
        # ONNX符号化实现
        # IndexPut 的输出形状与 final 输入相同(in-place 操作)
        output = g.op("xyz.onnx.contrib::IndexPut", x, geom_feats, final)
        # 明确设置输出类型与 final 相同
        output.setType(final.type())
        return output
def torch_index_put(x, geom_feats, final) -> torch.Tensor:
    """PyTorch 矩阵求逆(基准)"""
    return  IndexPutFunction.apply(x, geom_feats, final)


def verify_with_npy(data_path: str) -> Tuple[bool, dict]:
    """
    使用 .npy 文件验证插件正确性
    
    Args:
        data_path: .npy 文件路径
        
    Returns:
        (是否通过, 统计信息)
    """
    
    # 加载数据
    x = np.load(data_path + "x.npy")
    geom_feats = np.load(data_path + "geom_feats.npy")
    final = np.load(data_path + "final.npy")
    print(f"形状x         : {x.shape}, dtype: {x.dtype}")
    print(f"形状geom_feats: {geom_feats.shape}, dtype: {geom_feats.dtype}")
    print(f"形状final     : {final.shape}, dtype: {final.dtype}")

    # 转换为 tensor
    x = torch.from_numpy(x).float()
    geom_feats = torch.from_numpy(geom_feats).long()
    final = torch.from_numpy(final).float()
    
    return verify_with_tensor(x, geom_feats, final)


def verify_with_tensor(x, geom_feats, final) -> Tuple[bool, dict]:
    """
    使用 tensor 数据验证插件正确性
    """
    # 加载插件
    load_trt_plugins()
    engine = create_index_put_engine()
    
    if engine is None:
        print("ERROR: Failed to create TensorRT engine!")
        return False, {"total_elements": 0, "max_error": float('inf'), "mean_error": float('inf'), "passed": False}
        
    # PyTorch 结果
    pytorch_results = torch_index_put(x, geom_feats, final)
    
    # TensorRT 结果
    trt_results = run_trt_inference(engine, x, geom_feats, final)
    trt_results = torch.from_numpy(trt_results)
    
    # 计算元素误差
    diff = torch.abs(pytorch_results - trt_results)
    max_error = float(torch.max(diff))
    mean_error = float(torch.mean(diff))
    
    # 统计信息
    stats = {
        "total_elements": int(pytorch_results.numel()),
        "max_error": max_error,
        "mean_error": mean_error,
        "passed": max_error < 1e-5
    }
    
    return stats["passed"], stats


#python3 ./py/test_index_put_plugin.py ./data/

if __name__ == "__main__":
    print("=" * 50)
    print("TensorRT IndexPut Plugin Verification")
    print("=" * 50)
    
    passed, stats = verify_with_npy(sys.argv[1])
    
    print(f"\n Test Results:")
    print(f"   Total Elements: {stats['total_elements']}")
    print(f"   Max Error: {stats['max_error']:.2e}")
    print(f"   Mean Error: {stats['mean_error']:.2e}")
    print(f"\n{'PASSED' if passed else 'FAILED'}")
    print("=" * 50)

结果

bash 复制代码
: CPU +0, GPU +0, now: CPU 0, GPU 0 (MiB)
[03/28/2026-06:37:43] [TRT] [V] CUDA lazy loading is enabled.
x (input0) (8597, 64) float32
geom_feats (input1) (8597, 4) int32
final (input2) (1, 64, 1, 200, 200) float32
/workspace/current_workspace/./py/test_index_put_plugin.py:112: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:78.)
  d_input_0 = torch.cuda.FloatTensor(x_input_np).cuda()      # x: FLOAT

 Test Results:
   Total Elements: 2560000
   Max Error: 0.00e+00
   Mean Error: 0.00e+00

PASSED
==================================================
相关推荐
程序阿北3 小时前
飞书 CLI 昨天开源,我用 Claude Code 打通了公众号写作全流程
经验分享·ai·飞书
ai超级个体3 小时前
别再吹牛了,100% Vibe Coding 存在无法自洽的逻辑漏洞!
前端·ai·ai编程·vibe coding
tzy2333 小时前
Skill 为什么“淘汰”了 MCP?
ai·agent·function call·skill·mcp
tkevinjd3 小时前
hello-agents-chapter1-初识智能体
人工智能·ai·agent
亓才孓4 小时前
【提示词五要素】
python·ai·prompt
bingyu98754 小时前
OpenClaw 在 WSL 中的完整安装与配置指南
ai·openclaw
VIP_CQCRE8 小时前
Flux 图像生成 API 集成指南
ai
会飞的大可8 小时前
企业级文档自动化处理实战:合同/财报/标书智能解析系统搭建指南
ai
JavaGuide9 小时前
万字拆解 LLM 运行机制:Token、上下文与采样参数
ai·llm·prompt·ai编程·token