Lss-bev系列-2-部署插件IndexPut

Lss-bev系列-2-部署插件IndexPut

总结

在导出该项目onnx时候,会产生不支持算子的报错。这里主要分析如下这段话进行导出自定义插件。这部分主要做的操作是按索引对final中的数据赋值成x的一行值。

python 复制代码
final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = x

上面的写法等价下面的写法

python 复制代码
for i in range(len(geom_feats)):
    b   = geom_feats[i, 3]   # batch索引
    z   = geom_feats[i, 2]   # Z索引
    xi  = geom_feats[i, 0]   # X索引
    yi  = geom_feats[i, 1]   # Y索引
    final[b, :, z, xi, yi] = x[i]  # 把第i个点的特征写进去

自定义插件-pytorch

注意:在这部分的时候我这里给geom_feats从long类型转换成了int32的类型,因为8.多的版本自定义插件中对long是不支持的,所以转换到32位上进行计算

python 复制代码
class IndexPutFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, geom_feats, final):
        final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = x
        return final
    
    @staticmethod
    def symbolic(g, x, geom_feats, final):
        output = g.op("xyz.onnx.contrib::IndexPut", x, geom_feats, final)
        # 明确设置输出类型与 final 相同
        output.setType(final.type())
        return output

在pytorch中使用此计算规则,替换不可直接支持的计算语句

python 复制代码
# final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = x
final = IndexPutFunction.apply(x, geom_feats, final)

自定义插件-c++

cu文件

cpp 复制代码
#include <cuda_runtime.h>
#include <stdio.h>

// ============================================================
//   x          = inputs[0]  FLOAT  [N, C]
//   geom_feats = inputs[1]  INT32  [N, 4]  → (xi, yi, zi, bi)
//   final      = inputs[2]  FLOAT  [B,C,Z,X,Y]
//   output     = outputs[0] FLOAT  [B,C,Z,X,Y]
// ============================================================

__global__ void indexPutKernel(
    const float*   x,
    const int32_t* geom_feats,
    const float*   final_in,
    float*         output,
    int32_t N, int32_t C,
    int32_t B, int32_t Z,
    int32_t X, int32_t Y)
{
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int totalSize = B * C * Z * X * Y;
    for (int i = idx; i < totalSize; i += blockDim.x * gridDim.x) {
        output[i] = final_in[i];
    }
}

__global__ void indexPutScatterKernel(
    const float*    x,
    const int32_t*  geom_feats,
    float*          output,
    int32_t N, int32_t C,
    int32_t B, int32_t Z,
    int32_t X, int32_t Y)
{
    // 总索引
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= N) return;

    // 读取体素坐标: geom_feats[idx] = (xi, yi, zi, bi)
    int32_t xi = geom_feats[idx * 4 + 0];
    int32_t yi = geom_feats[idx * 4 + 1];
    int32_t zi = geom_feats[idx * 4 + 2];
    int32_t bi = geom_feats[idx * 4 + 3];

    // 边界检查
    if (xi < 0 || xi >= X ||
        yi < 0 || yi >= Y ||
        zi < 0 || zi >= Z ||
        bi < 0 || bi >= B) {
        return;
    }

    // 写入所有通道: output[bi, c, zi, xi, yi] = x[idx, c]
    for (int32_t c = 0; c < C; ++c) {
        int32_t out_idx = bi * (C * Z * X * Y)
                        + c  * (Z * X * Y)
                        + zi * (X * Y)
                        + xi * Y
                        + yi;

        int32_t x_idx = idx * C + c;

        // 直接赋值
        output[out_idx] = x[x_idx];
    }
}

// ── 主入口函数 ───────────────────────────────────────────────
extern "C" void indexPutCUDA(
    const float*    x,
    const int32_t*  geom_feats,
    const float*    final_in,
    float*          output,
    int32_t N, int32_t C,
    int32_t B, int32_t Z,
    int32_t X, int32_t Y,
    cudaStream_t    stream)
{
    int totalSize = B * C * Z * X * Y;
    int threadsPerBlock = 256;// 每个block设置256个线程

    // ── Step1: 复制 final_in → output ───────────────────────
    cudaMemcpyAsync(
        output,
        final_in,
        sizeof(float) * totalSize,
        cudaMemcpyDeviceToDevice,
        stream
    );

    // ── Step2: Scatter x → output ────
    if (N <= 0) return;

    int blocksPerGrid = (N + threadsPerBlock - 1) / threadsPerBlock;
    indexPutScatterKernel<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
        x, geom_feats, output,
        N, C, B, Z, X, Y
    );

    // 错误检查
    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        printf("[indexPutCUDA] CUDA Error: %s\n", cudaGetErrorString(err));
    }
}

cpp文件

cpp 复制代码
#include "IndexPutPlugin.h"
#include <cassert>
#include <cstring>
#include <iostream>
#include <vector>

using namespace nvinfer1;
using namespace nvinfer1::plugin;

namespace {
    const char* PLUGIN_VERSION{"1"};
    const char* PLUGIN_NAME{"IndexPut"};
} // namespace

// ── CUDA 核函数声明 ─────────────────────────────────────────
extern "C" void indexPutCUDA(
    const float*    x,
    const int32_t*  geom_feats,
    const float*    final_in,
    float*          output,
    int32_t N, int32_t C,
    int32_t B, int32_t Z,
    int32_t X, int32_t Y,
    cudaStream_t stream);

// ── 序列化工具 ───────────────────────────────────────────────
template<typename T>
static void writeBuf(char*& buf, const T& val) {
    std::memcpy(buf, &val, sizeof(T));
    buf += sizeof(T);
}
template<typename T>
static T readBuf(const char*& buf) {
    T val;
    std::memcpy(&val, buf, sizeof(T));
    buf += sizeof(T);
    return val;
}

// ── 构造函数 ────────────────────────────────────────────────

IndexPutPlugin::IndexPutPlugin()
    : mNamespace("")
    , mDataType(DataType::kFLOAT)
    , mN(0), mC(0), mB(0), mZ(0), mX(0), mY(0)
{}

IndexPutPlugin::IndexPutPlugin(const void* data, size_t length)
{
    const char* d = static_cast<const char*>(data);
    const char* const start = d;

    mDataType = readBuf<DataType>(d);
    mN  = readBuf<int32_t>(d);
    mC  = readBuf<int32_t>(d);
    mB  = readBuf<int32_t>(d);
    mZ  = readBuf<int32_t>(d);
    mX  = readBuf<int32_t>(d);
    mY  = readBuf<int32_t>(d);

    if (static_cast<size_t>(d - start) != length) {
        std::cerr << "[IndexPutPlugin] deserialize size mismatch: "
                  << (d - start) << " vs " << length << "\n";
    }
}

IndexPutPlugin::~IndexPutPlugin() { terminate(); }

IPluginV2DynamicExt* IndexPutPlugin::clone() const noexcept
{
    size_t sz = getSerializationSize();
    std::vector<char> buf(sz);
    serialize(buf.data());

    auto* p = new IndexPutPlugin(buf.data(), sz);
    p->setPluginNamespace(mNamespace.c_str());
    return p;
}

DimsExprs IndexPutPlugin::getOutputDimensions(
    int32_t outputIndex,
    const DimsExprs* inputs,
    int32_t nbInputs,
    IExprBuilder& exprBuilder) noexcept
{
    //   inputs[0] = x          [N, C]
    //   inputs[1] = geom_feats [N, 4]
    //   inputs[2] = final      [B, C, Z, X, Y]  ← 输出形状与此一致

    assert(outputIndex == 0 && nbInputs == 3);
    return inputs[2];
}

bool IndexPutPlugin::supportsFormatCombination(
    int32_t pos,
    const PluginTensorDesc* inOut,
    int32_t nbInputs,
    int32_t nbOutputs) noexcept
{
    //   pos 0 → x           [N, C]        FLOAT  LINEAR
    //   pos 1 → geom_feats  [N, 4]        INT32  LINEAR
    //   pos 2 → final       [B,C,Z,X,Y]   FLOAT  LINEAR
    //   pos 3 → output      [B,C,Z,X,Y]   FLOAT  LINEAR

    assert(nbInputs == 3 && nbOutputs == 1);
    const auto& desc = inOut[pos];

    // 所有张量必须是 LINEAR 格式
    if (desc.format != TensorFormat::kLINEAR) {
        return false;
    }

    switch (pos) {
        case 0:  // x: FLOAT
            return desc.type == DataType::kFLOAT;

        case 1:  // geom_feats: INT32
            return desc.type == DataType::kINT32;

        case 2:  // final: FLOAT
            return desc.type == DataType::kFLOAT;

        case 3:  // output: FLOAT
            return desc.type == DataType::kFLOAT;

        default:
            return false;
    }
}

void IndexPutPlugin::configurePlugin(
    const DynamicPluginTensorDesc* in,
    int32_t nbInputs,
    const DynamicPluginTensorDesc* out,
    int32_t nbOutputs) noexcept
{
    assert(nbInputs == 3 && nbOutputs == 1);

    //   inputs[0] = x          [N, C]       FLOAT  ← 从这里取 N, C
    //   inputs[1] = geom_feats [N, 4]       INT32
    //   inputs[2] = final      [B,C,Z,X,Y]  FLOAT  ← 从这里取 B,Z,X,Y

    // 从 x (inputs[0]) 提取 [N, C]
    const Dims& xdims = in[0].desc.dims;
    if (xdims.nbDims >= 2) {
        mN = xdims.d[0];   // N: 有效点数(可能是 -1 动态)
        mC = xdims.d[1];   // C: 通道数
    }

    // 从 final (inputs[2]) 提取 [B, C, Z, X, Y]
    const Dims& fd = in[2].desc.dims;
    if (fd.nbDims == 5) {
        mB = fd.d[0];  // B
        mZ = fd.d[2];  // Z ≈ 1
        mX = fd.d[3];  // X ≈ 200
        mY = fd.d[4];  // Y ≈ 200
    }

    mDataType = in[0].desc.type;

    // 确认输出格式
    if (out[0].desc.type != DataType::kFLOAT ||
        out[0].desc.format != TensorFormat::kLINEAR) {
    }
}

size_t IndexPutPlugin::getWorkspaceSize(
    const PluginTensorDesc*, int32_t,
    const PluginTensorDesc*, int32_t) const noexcept
{
    return 0;
}

int32_t IndexPutPlugin::enqueue(
    const PluginTensorDesc* inputDesc,
    const PluginTensorDesc* outputDesc,
    const void* const* inputs,
    void* const* outputs,
    void* /*workspace*/,
    cudaStream_t stream) noexcept
{
    try {
        //   inputs[0] = x           FLOAT  [N, C]
        //   inputs[1] = geom_feats  INT32  [N, 4]
        //   inputs[2] = final       FLOAT  [B,C,Z,X,Y]
        
        const float*   x          = static_cast<const float*>  (inputs[0]);
        const int32_t* geom_feats = static_cast<const int32_t*>(inputs[1]);
        const float*   final_in   = static_cast<const float*>  (inputs[2]);
        float*         output     = static_cast<float*>        (outputs[0]);

        // 从 x (inputDesc[0]) 读取 [N, C]
        const int32_t N = inputDesc[0].dims.d[0];
        const int32_t C = inputDesc[0].dims.d[1];

        // 从 output (outputDesc[0]) 读取 [B, C, Z, X, Y]
        const Dims& od = outputDesc[0].dims;
        assert(od.nbDims == 5);
        const int32_t B = od.d[0];
        const int32_t Z = od.d[2];
        const int32_t X = od.d[3];
        const int32_t Y = od.d[4];

        // 调用 CUDA kernel
        indexPutCUDA(
            x, geom_feats, final_in, output,
            N, C, B, Z, X, Y,
            stream
        );

        return 0;
    } catch (const std::exception& e) {
        std::cerr << "[IndexPutPlugin::enqueue] " << e.what() << "\n";
        return -1;
    }
}

// ── IPluginV2 基础 ───────────────────────────────────────────

const char* IndexPutPlugin::getPluginType()    const noexcept { return PLUGIN_NAME; }
const char* IndexPutPlugin::getPluginVersion() const noexcept { return PLUGIN_VERSION; }
int32_t     IndexPutPlugin::getNbOutputs()     const noexcept { return 1; }
int32_t     IndexPutPlugin::initialize()              noexcept { return 0; }
void        IndexPutPlugin::terminate()               noexcept {}
void        IndexPutPlugin::destroy()                 noexcept { delete this; }

void IndexPutPlugin::setPluginNamespace(const char* ns) noexcept {
    mNamespace = (ns != nullptr) ? ns : "";
}
const char* IndexPutPlugin::getPluginNamespace() const noexcept {
    return mNamespace.c_str();
}

// ── 序列化 ───────────────────────────────────────────────────

size_t IndexPutPlugin::getSerializationSize() const noexcept {
    // DataType(4) + N(4) + C(4) + B(4) + Z(4) + X(4) + Y(4) = 28 bytes
    return sizeof(DataType) + 6 * sizeof(int32_t);
}

void IndexPutPlugin::serialize(void* buffer) const noexcept {
    char* d = static_cast<char*>(buffer);
    const char* const start = d;

    writeBuf(d, mDataType);
    writeBuf(d, mN);
    writeBuf(d, mC);
    writeBuf(d, mB);
    writeBuf(d, mZ);
    writeBuf(d, mX);
    writeBuf(d, mY);

    assert(d == start + getSerializationSize());
}

// ── IPluginV2Ext ─────────────────────────────────────────────

DataType IndexPutPlugin::getOutputDataType(
    int32_t index,
    const nvinfer1::DataType* /*inputTypes*/,
    int32_t /*nbInputs*/) const noexcept
{
    assert(index == 0);
    // 输出永远是 FLOAT,不受 INT32 输入影响
    return DataType::kFLOAT;
}

// ============================================================
// IndexPutPluginCreator
// ============================================================

PluginFieldCollection         IndexPutPluginCreator::mFC{};
std::vector<PluginField>      IndexPutPluginCreator::mPluginAttributes;

IndexPutPluginCreator::IndexPutPluginCreator() {
    mPluginAttributes.clear();
    mFC.nbFields = 0;
    mFC.fields   = nullptr;
}

const char* IndexPutPluginCreator::getPluginName()    const noexcept { return PLUGIN_NAME; }
const char* IndexPutPluginCreator::getPluginVersion() const noexcept { return PLUGIN_VERSION; }
const PluginFieldCollection* IndexPutPluginCreator::getFieldNames() noexcept { return &mFC; }

IPluginV2* IndexPutPluginCreator::createPlugin(
    const char* /*name*/,
    const PluginFieldCollection* /*fc*/) noexcept
{
    auto* p = new IndexPutPlugin();
    p->setPluginNamespace(mNamespace.c_str());
    return p;
}

IPluginV2* IndexPutPluginCreator::deserializePlugin(
    const char* /*name*/,
    const void* serialData,
    size_t serialLength) noexcept
{
    auto* p = new IndexPutPlugin(serialData, serialLength);
    p->setPluginNamespace(mNamespace.c_str());
    return p;
}

void IndexPutPluginCreator::setPluginNamespace(const char* ns) noexcept {
    mNamespace = (ns != nullptr) ? ns : "";
}
const char* IndexPutPluginCreator::getPluginNamespace() const noexcept {
    return mNamespace.c_str();
}

REGISTER_TENSORRT_PLUGIN(IndexPutPluginCreator);

h文件

cpp 复制代码
#ifndef INDEX_PUT_PLUGIN_H
#define INDEX_PUT_PLUGIN_H

#include <NvInfer.h>
#include <vector>
#include <string>
#include <cuda_runtime.h>

namespace nvinfer1 {
namespace plugin {

// ============================================================
//   inputs[0]: x          [N, C]        kFLOAT  有效特征点
//   inputs[1]: geom_feats [N, 4]        kINT32  体素坐标(xi,yi,zi,bi)
//   inputs[2]: final      [B,C,Z,X,Y]   kFLOAT  零初始化体素
//
//   output[0]:            [B,C,Z,X,Y]   kFLOAT  填充后体素
// ============================================================

class IndexPutPlugin : public IPluginV2DynamicExt {
public:
    IndexPutPlugin();
    explicit IndexPutPlugin(const void* data, size_t length);
    ~IndexPutPlugin() override;

    // ── IPluginV2DynamicExt ──────────────────────────────────
    IPluginV2DynamicExt* clone() const noexcept override;

    DimsExprs getOutputDimensions(
        int32_t outputIndex,
        const DimsExprs* inputs,
        int32_t nbInputs,
        IExprBuilder& exprBuilder) noexcept override;

    bool supportsFormatCombination(
        int32_t pos,
        const PluginTensorDesc* inOut,
        int32_t nbInputs,
        int32_t nbOutputs) noexcept override;

    void configurePlugin(
        const DynamicPluginTensorDesc* in,
        int32_t nbInputs,
        const DynamicPluginTensorDesc* out,
        int32_t nbOutputs) noexcept override;

    size_t getWorkspaceSize(
        const PluginTensorDesc* inputs,
        int32_t nbInputs,
        const PluginTensorDesc* outputs,
        int32_t nbOutputs) const noexcept override;

    int32_t enqueue(
        const PluginTensorDesc* inputDesc,
        const PluginTensorDesc* outputDesc,
        const void* const* inputs,
        void* const* outputs,
        void* workspace,
        cudaStream_t stream) noexcept override;

    // ── IPluginV2 ────────────────────────────────────────────
    const char*  getPluginType()    const noexcept override;
    const char*  getPluginVersion() const noexcept override;
    int32_t      getNbOutputs()     const noexcept override;
    int32_t      initialize()              noexcept override;
    void         terminate()               noexcept override;
    size_t       getSerializationSize() const noexcept override;
    void         serialize(void* buffer)  const noexcept override;
    void         destroy()                noexcept override;
    void         setPluginNamespace(const char* ns) noexcept override;
    const char*  getPluginNamespace()   const noexcept override;

    // ── IPluginV2Ext ─────────────────────────────────────────
    DataType getOutputDataType(
        int32_t index,
        const nvinfer1::DataType* inputTypes,
        int32_t nbInputs) const noexcept override;

private:
    std::string mNamespace;

    // 运行时维度(来自 configurePlugin)
    // geom_feats: [mN, 4]
    // x:          [mN, mC]
    // final:      [mB, mC, mZ, mX, mY]
    int32_t mN{0};   // 有效点数(动态)
    int32_t mC{0};   // 通道数
    int32_t mB{0};   // batch size
    int32_t mZ{0};   // nx[2] ≈ 1
    int32_t mX{0};   // nx[0] ≈ 200
    int32_t mY{0};   // nx[1] ≈ 200

    DataType mDataType{DataType::kFLOAT};
};

// ============================================================
class IndexPutPluginCreator : public IPluginCreator {
public:
    IndexPutPluginCreator();
    ~IndexPutPluginCreator() override = default;

    const char* getPluginName()    const noexcept override;
    const char* getPluginVersion() const noexcept override;
    const PluginFieldCollection* getFieldNames() noexcept override;

    IPluginV2* createPlugin(
        const char* name,
        const PluginFieldCollection* fc) noexcept override;

    IPluginV2* deserializePlugin(
        const char* name,
        const void* serialData,
        size_t serialLength) noexcept override;

    void        setPluginNamespace(const char* ns) noexcept override;
    const char* getPluginNamespace() const noexcept override;

private:
    static PluginFieldCollection    mFC;
    static std::vector<PluginField> mPluginAttributes;
    std::string mNamespace;
};

} // namespace plugin
} // namespace nvinfer1

#endif // INDEX_PUT_PLUGIN_H

测试

写好插件后为了判定与python的实现内容是一致的最好的验证方式是写一个验证脚本,我们可以把输入结果pytorch的递推结果中拿出来,然后构建一个只有插件这一层的trt引擎,将同样的输入给到pytorch代码和trt代码,通过对比结果来看插件的工作是否是有效的。

保存的例子

python 复制代码
# np.save(f"{debug_dir}/x.npy", x.cpu().numpy())
# np.save(f"{debug_dir}/geom_feats.npy", geom_feats_int32.cpu().numpy())
# np.save(f"{debug_dir}/final.npy", final.cpu().numpy())

加载并对比插件

python 复制代码
# TensorRT IndexPut 插件验证
import torch
import numpy as np
import tensorrt as trt
import os
import sys
import ctypes
from typing import Tuple, Optional

# 插件路径
PLUGIN_PATH = os.path.join(os.path.dirname(__file__), "..", "cpp", "build", "lib", "liblss_trt_plugins.so")
PLUGIN_PATH = os.path.abspath(PLUGIN_PATH)


def load_trt_plugins():
    """加载 TensorRT 插件库(使用 ctypes 触发静态初始化)"""
    print(f"Loading plugin: {PLUGIN_PATH}")
    if os.path.exists(PLUGIN_PATH):
        # 使用 ctypes 加载库,触发静态初始化函数
        ctypes.CDLL(PLUGIN_PATH)
        trt.init_libnvinfer_plugins(None, "")
        print("Plugin loaded successfully!")
    else:
        raise FileNotFoundError(f"Plugin not found: {PLUGIN_PATH}")


def create_index_put_engine():
    """创建包含 IndexPut 插件的 TensorRT 引擎"""
    logger = trt.Logger(trt.Logger.VERBOSE)
    builder = trt.Builder(logger)
    
    # 网络配置
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    # pos 0: x (FLOAT) [N, C]
    # pos 1: geom_feats (INT32) [N, 4]
    # pos 2: final (FLOAT) [C, Z, X, Y]
    input_tensor_0 = network.add_input("x", trt.DataType.FLOAT, [-1, 64])           # x: [N, C] FLOAT
    input_tensor_1 = network.add_input("geom_feats", trt.DataType.INT32, [-1, 4])   # geom_feats: [N, 4] INT32
    input_tensor_2 = network.add_input("final", trt.DataType.FLOAT, [-1, 64, 1, 200, 200])  # final: [C, Z, X, Y]

    # 获取插件
    plugin_registry = trt.get_plugin_registry()
    plugin_creator = plugin_registry.get_plugin_creator("IndexPut", "1", "")
    if plugin_creator is None:
        print("ERROR: IndexPut plugin not found in registry!")
        print("Available plugins:")
        for plugin in plugin_registry.plugin_namespace_list:
            print(f"  - {plugin}")
        return None
    
    # 创建空的 PluginFieldCollection
    fc = trt.PluginFieldCollection()
    
    # 创建插件
    plugin = plugin_creator.create_plugin("IndexPut", fc)
    
    # 添加插件层 - 注意:输入顺序必须与插件期望一致:x, geom_feats, final
    inverse_layer = network.add_plugin_v2([input_tensor_0, input_tensor_1, input_tensor_2], plugin)
    
    # 输出
    network.mark_output(inverse_layer.get_output(0))
    
    # 构建引擎
    config = builder.create_builder_config()
    config.max_workspace_size = 1 << 25  # 32MB
    
    # 设置优化配置文件 - 修正 input2 的最小/最优/最大形状
    profile = builder.create_optimization_profile()
    # input0: x [N, C], C 固定为 64
    profile.set_shape("x", [1, 64], [10000, 64], [100000, 64])
    # input1: geom_feats [N, 4], N 是动态的
    profile.set_shape("geom_feats", [1, 4], [10000, 4], [100000, 4])
    # input2: final [C, Z, X, Y], 全部固定
    profile.set_shape("final", [1, 64, 1, 200, 200], [1, 64, 1, 200, 200], [1, 64, 1, 200, 200])

    config.add_optimization_profile(profile)
    
    # 构建引擎
    engine = builder.build_engine(network, config)
    if engine is None:
        print("ERROR: Engine build failed!")
        return None
    
    print("Engine built successfully!")
    return engine


def run_trt_inference(engine, x, geom_feats, final) -> np.ndarray:
    """使用 TensorRT 引擎运行推理"""
    context = engine.create_execution_context()
    
    # 转换输入 - 顺序:x (FLOAT), geom_feats (INT32), final (FLOAT)
    x_input_np = x.float().cpu().numpy() if isinstance(x, torch.Tensor) else x.astype(np.float32)
    geom_feats_np = geom_feats.long().cpu().numpy() if isinstance(geom_feats, torch.Tensor) else geom_feats.astype(np.int32)
    # 确保 geom_feats 是 int32(插件期望)
    if geom_feats_np.dtype == np.int64:
        geom_feats_np = geom_feats_np.astype(np.int32)
    final_np = final.float().cpu().numpy() if isinstance(final, torch.Tensor) else final.astype(np.float32)
    
    print("x (input0)", x_input_np.shape, x_input_np.dtype)
    print("geom_feats (input1)", geom_feats_np.shape, geom_feats_np.dtype)
    print("final (input2)", final_np.shape, final_np.dtype)
    
    # 设置输入形状(动态形状需要)
    context.set_input_shape("x", x_input_np.shape)
    context.set_input_shape("geom_feats", geom_feats_np.shape)
    context.set_input_shape("final", final_np.shape)
    
    output_np = np.empty(final.shape, dtype=np.float32)
    
    # 复制输入到 GPU - 顺序:x, geom_feats, final
    d_input_0 = torch.cuda.FloatTensor(x_input_np).cuda()      # x: FLOAT
    d_input_1 = torch.cuda.IntTensor(geom_feats_np).cuda()      # geom_feats: INT32
    d_input_2 = torch.cuda.FloatTensor(final_np).cuda()         # final: FLOAT

    d_output = torch.cuda.FloatTensor(output_np).cuda()
    
    # 推理
    context.execute_v2([d_input_0.data_ptr(), d_input_1.data_ptr(), d_input_2.data_ptr(), d_output.data_ptr()])
    
    # 同步并返回
    torch.cuda.synchronize()
    return d_output.cpu().numpy()


class IndexPutFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, geom_feats, final):
        final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = x
        return final
    
    @staticmethod
    def symbolic(g, x, geom_feats, final):
        # ONNX符号化实现
        # IndexPut 的输出形状与 final 输入相同(in-place 操作)
        output = g.op("xyz.onnx.contrib::IndexPut", x, geom_feats, final)
        # 明确设置输出类型与 final 相同
        output.setType(final.type())
        return output
def torch_index_put(x, geom_feats, final) -> torch.Tensor:
    """PyTorch 矩阵求逆(基准)"""
    return  IndexPutFunction.apply(x, geom_feats, final)


def verify_with_npy(data_path: str) -> Tuple[bool, dict]:
    """
    使用 .npy 文件验证插件正确性
    
    Args:
        data_path: .npy 文件路径
        
    Returns:
        (是否通过, 统计信息)
    """
    
    # 加载数据
    x = np.load(data_path + "x.npy")
    geom_feats = np.load(data_path + "geom_feats.npy")
    final = np.load(data_path + "final.npy")
    print(f"形状x         : {x.shape}, dtype: {x.dtype}")
    print(f"形状geom_feats: {geom_feats.shape}, dtype: {geom_feats.dtype}")
    print(f"形状final     : {final.shape}, dtype: {final.dtype}")

    # 转换为 tensor
    x = torch.from_numpy(x).float()
    geom_feats = torch.from_numpy(geom_feats).long()
    final = torch.from_numpy(final).float()
    
    return verify_with_tensor(x, geom_feats, final)


def verify_with_tensor(x, geom_feats, final) -> Tuple[bool, dict]:
    """
    使用 tensor 数据验证插件正确性
    """
    # 加载插件
    load_trt_plugins()
    engine = create_index_put_engine()
    
    if engine is None:
        print("ERROR: Failed to create TensorRT engine!")
        return False, {"total_elements": 0, "max_error": float('inf'), "mean_error": float('inf'), "passed": False}
        
    # PyTorch 结果
    pytorch_results = torch_index_put(x, geom_feats, final)
    
    # TensorRT 结果
    trt_results = run_trt_inference(engine, x, geom_feats, final)
    trt_results = torch.from_numpy(trt_results)
    
    # 计算元素误差
    diff = torch.abs(pytorch_results - trt_results)
    max_error = float(torch.max(diff))
    mean_error = float(torch.mean(diff))
    
    # 统计信息
    stats = {
        "total_elements": int(pytorch_results.numel()),
        "max_error": max_error,
        "mean_error": mean_error,
        "passed": max_error < 1e-5
    }
    
    return stats["passed"], stats


#python3 ./py/test_index_put_plugin.py ./data/

if __name__ == "__main__":
    print("=" * 50)
    print("TensorRT IndexPut Plugin Verification")
    print("=" * 50)
    
    passed, stats = verify_with_npy(sys.argv[1])
    
    print(f"\n Test Results:")
    print(f"   Total Elements: {stats['total_elements']}")
    print(f"   Max Error: {stats['max_error']:.2e}")
    print(f"   Mean Error: {stats['mean_error']:.2e}")
    print(f"\n{'PASSED' if passed else 'FAILED'}")
    print("=" * 50)

结果

bash 复制代码
: CPU +0, GPU +0, now: CPU 0, GPU 0 (MiB)
[03/28/2026-06:37:43] [TRT] [V] CUDA lazy loading is enabled.
x (input0) (8597, 64) float32
geom_feats (input1) (8597, 4) int32
final (input2) (1, 64, 1, 200, 200) float32
/workspace/current_workspace/./py/test_index_put_plugin.py:112: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:78.)
  d_input_0 = torch.cuda.FloatTensor(x_input_np).cuda()      # x: FLOAT

 Test Results:
   Total Elements: 2560000
   Max Error: 0.00e+00
   Mean Error: 0.00e+00

PASSED
==================================================
相关推荐
小七-七牛开发者6 小时前
周一上线 | SpaceX 收购 Cursor、支付宝进入 AI 时代、DeepSeek 完成 500 亿元融资
ai·agent·token·glm·智谱·claudecode·ai coding·周一上线
doiito1 天前
【Agent Harness】为什么我把 JSON‑LD “编译成 DAG” 后,整个 Agent 平台立刻聪明了
ai·rust·架构设计·系统设计·ai agent
xiezhr1 天前
折腾半小时,终于让AI 能直接帮我写飞书文档了
ai·飞书·ai agent·飞书cli·飞书文档
岳小哥AI1 天前
Claude Fable和Claude Mythos 5同时发布:注意力机制下愈加强大的AI大模型
ai·ai基础
Artech1 天前
[MAF预定义的AIContextProvider-04]Mem0Provider——长期记忆基于的云端解决方案
ai·agent·maf·aicontextprovider·chathistorymemoryprovider·mem0provider
哥不是小萝莉2 天前
一文读懂 OpenAI Codex 源码的原理、架构与未来
ai
AlfredZhao2 天前
AI 编程工作总结:从体验问题到模块能力建设
ai·codex
cup113 天前
[技术复盘] Windows Python 打包实战:Nuitka 环境踩坑总结与 CI 自动化构建全指南
python·ai·环境变量·ci·nuitka·skill
IT王师傅3 天前
从 豆包 到 Codex CLI:一名普通开发者的 AI 工具进化路线
ai·codex cli·openclaw
岳小哥AI3 天前
Siri要接入AI了,苹果手机上一句话让GPT写文案、DeepSeek写代码的时刻来了
ai·ai基础