NVIDIA大模型推理框架:TensorRT-LLM软件流程(四)探究TensorRT LLM自定义算子调用流程

TensorRT LLM plugin背景

  • TensorRT LLM Runtime代表着我们之前三个博客介绍的内容,Runtime下面就对接了TensorRT Engine,但是原生的TensorRT都是适配传统模型的算子
  • TensorRT为了更好的适配TensorRT LLM处理大模型,特别设计Tensorrt_llm_module模块,其中特别封装了 LLM 推理所需的核心计算逻辑(比如注意力机制、前馈网络(FFN)、层归一化等),是 TensorRT-LLM 框架中 "将 LLM 架构转化为可执行计算图" 的关键载体。
  • Tensorrt_llm_module中发挥解析LLM能力的是通过TensorRT Plugin的形式存在的,往下是高效发挥GPU并行计算能力TensorRT LLM Kernel ,也就是CUDA kernel。
  • 这篇博客就是想探究整个链路软件架构和代码具体实现

TensorRT Plugin

Plugin register

  • TensorRT Plugin作为中间模块发挥着承上启下重要的作用,在TensorRT LLM代码仓库中有Plugin文件夹存放plugin module
  • 里面有很多plugin模块,我选取在调试过程中使用的Plugin GPTAttentionPlugin(我选取的模型文件是Qwen)作为example给大家分析
  • 先说下TensorRT Plugin的注册流程
cpp 复制代码
// TensorRT-LLM-main\cpp\tensorrt_llm\plugins\api\tllmPlugin.cpp
void initOnLoad()
{
    auto constexpr kLoadPlugins = "TRT_LLM_LOAD_PLUGINS";
    auto const loadPlugins = std::getenv(kLoadPlugins);
    if (loadPlugins && loadPlugins[0] == '1')
    {
        initTrtLlmPlugins(gLogger);
    }
}

bool initTrtLlmPlugins(void* logger, char const* libNamespace)
{
    if (pluginsInitialized)
    {
        return true;
    }

    if (logger)
    {
        gLogger = static_cast<nvinfer1::ILogger*>(logger);
    }
    setLoggerFinder(&gGlobalLoggerFinder);

    auto registry = getPluginRegistry();

    {
        std::int32_t nbCreators;
        auto creators = getPluginCreators(nbCreators);

        for (std::int32_t i = 0; i < nbCreators; ++i)
        {
            auto const creator = creators[i];
            creator->setPluginNamespace(libNamespace);
            registry->registerCreator(*creator, libNamespace);
            if (gLogger)
            {
                auto const msg = tc::fmtstr("Registered plugin creator %s version %s in namespace %s",
                    creator->getPluginName(), creator->getPluginVersion(), libNamespace);
                gLogger->log(nvinfer1::ILogger::Severity::kVERBOSE, msg.c_str());
            }
        }
    }

    {
        std::int32_t nbCreators;
        auto creators = getCreators(nbCreators);

        for (std::int32_t i = 0; i < nbCreators; ++i)
        {
            auto const creator = creators[i];
            registry->registerCreator(*creator, libNamespace);
        }
    }

    pluginsInitialized = true;
    return true;
}
  • 调用initOnLoad之后开始加载 Plugin,getPluginCreators函数会获取所有创建出来的PluginCreator,registerCreator把creator注册到 TensorRT 注册表,plugin creator注册之后才能使用。
  • 为什么creator很重要呢?因为具体的Plugin use Creator create instance。
  • Plugin instance是TensorRT自动化调用
cpp 复制代码
// cpp\tensorrt_llm\plugins\gptAttentionPlugin\gptAttentionPlugin.cpp
//删除了很多参数
IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
{
    PluginFieldParser p{fc->nbFields, fc->fields};
    try
    {
        auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("layer_idx").value(),
            p.getScalar<int32_t>("num_heads").value(), p.getScalar<int32_t>("vision_start").value(),
            p.getScalar<int32_t>("vision_length").value(), p.getScalar<int32_t>("num_kv_heads").value(),
            static_cast<bool>(p.getScalar<int8_t>("remove_input_padding").value()),
            static_cast<AttentionMaskType>(p.getScalar<int32_t>("mask_type").value()),
            BlockSparseParams{p.getScalar<int32_t>("block_sparse_block_size").value(),
                static_cast<bool>(p.getScalar<int8_t>("block_sparse_homo_head_pattern").value()),
                p.getScalar<int32_t>("block_sparse_num_local_blocks").value(),
                p.getScalar<int32_t>("block_sparse_vertical_stride").value()},
        obj->setPluginNamespace(mNamespace.c_str());
        return obj;
    }
    catch (std::exception const& e)
    {
        caughtError(e);
    }
    return nullptr;
}
  • getCreators会获取EaglePrepareDrafterInputsPluginCreator、DoraPluginCreator两个特殊的PluginCreator
  • getPluginCreators函数是 TensorRT LLM 插件库的 "启动开关"------ 只有调用它,所有 LLM 专用优化插件(注意力、量化、Mamba、分布式等)才会被 TensorRT 识别
  • 所有plugin creator加载完之后会通过logger打印出来

Plugin Create

  • TensorRT Plugin create 相当于是自定义Plugin,NVIDIA 给出一套非常标准的TensorRT Plugin添加流程,最核心是实现「Plugin 计算类」和「Plugin 工厂类(Creator)」,重写几个重要的接口。
cpp 复制代码
// cpp\tensorrt_llm\plugins\gptAttentionPlugin\gptAttentionPlugin.h
class GPTAttentionPlugin : public GPTAttentionPluginCommon{
    int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
        void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
    // IPluginV2 Methods
    char const* getPluginType() const noexcept override;
    char const* getPluginVersion() const noexcept override;
    int getNbOutputs() const noexcept override;
    GPTAttentionPlugin* clone() const noexcept override;
}
class GPTAttentionPluginCreator : public GPTAttentionPluginCreatorCommon{
    char const* getPluginName() const noexcept override;
    char const* getPluginVersion() const noexcept override;
    nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
    nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
}

//cpp\tensorrt_llm\plugins\gptAttentionCommon\gptAttentionCommon.h
class GPTAttentionPluginCommon : public BasePlugin{}
class GPTAttentionPluginCreatorCommon : public BaseCreator
{
public:
    GPTAttentionPluginCreatorCommon();
    nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
    template <typename T>
    T* deserializePluginImpl(char const* name, void const* serialData, size_t serialLength) noexcept;
protected:
    std::vector<nvinfer1::PluginField> mPluginAttributes;
    nvinfer1::PluginFieldCollection mFC{};
};

//cpp\tensorrt_llm\plugins\common\plugin.h
class BasePlugin : public nvinfer1::IPluginV2DynamicExt
{
public:
    void setPluginNamespace(char const* libNamespace) noexcept override
    {
        mNamespace = libNamespace;
    }

    [[nodiscard]] char const* getPluginNamespace() const noexcept override
    {
        return mNamespace.c_str();
    }

protected:
    std::string mNamespace{api::kDefaultNamespace};
};

class BaseCreator : public nvinfer1::IPluginCreator
{
public:
    void setPluginNamespace(char const* libNamespace) noexcept override
    {
        mNamespace = libNamespace;
    }

    [[nodiscard]] char const* getPluginNamespace() const noexcept override
    {
        return mNamespace.c_str();
    }

protected:
    std::string mNamespace{api::kDefaultNamespace};
};
  • 为了大家更加清楚GPTAttentionPlugin 继承关系这里直接给出类图
  • 现在NVIDIA已经升级到第三代plugin:IPluginV3、IPluginCreatorV3One,但是GPTAttention功能比较老还是用的第二代plugin
  • GPTAttentionPluginCreator在通过initTrtLlmPlugins registerCreator之后,在推理过程中会自动调用GPTAttentionPluginCreator::createPlugin new GPTAttentionPlugin,
  • 然后使用enqueue作为接口被调用,调用点是第三篇博客提到的:enqueueV3 interface

GPTAttentionPlugin调用到CUDA

  • GPTAttentionPlugin 是专为 GPT 类大模型(Qwen、LLaMA、GPT-3 等自回归模型) 设计的注意力层优化插件,核心作用是替代原生 TensorRT 算子,通过硬件加速、计算融合、内存优化等手段,实现低延迟、高吞吐、省显存的注意力计算。
  • Qwen 模型的注意力层与 GPT 类模型高度兼容,所以复用成熟的 GPTAttentionPlugin 来优化 QWEN 的注意力计算
  • 流程图:
  • enqueueV3被调用之后会调用GPTAttentionPlugin::enqueue interface,会根据量化的nvinfer1::DataType类型选择路径
  • enqueueImpl准备一下Requests Context。enqueueSome根据各种优化技术设置参数,都设置好之后
cpp 复制代码
// GPTAttentionPlugin::enqueueSome
if (is_context) // context stage
{enqueueContext<T, KVCacheBuffer>(enqueue_params, stream);}
else // generation stage
{enqueueGeneration<T, KVCacheBuffer>(enqueue_params, stream);}
  • enqueueGeneration是在AttentionOp 负责具体的注意力层向前计算逻辑,operation确实写得很复杂。。。
  • 当使用了useKVCache的时候,会调用invokeShiftKCache对key缓存进行处理。实际的处理是在GPU 上运行shiftKCache CUDA函数执行
cpp 复制代码
// cpp\tensorrt_llm\kernels\unfusedAttentionKernels.cu

template <typename T, typename KVCacheBuffer>
void invokeShiftKCache(KVCacheBuffer const& kvCacheBuffer, KVLinearBuffer const& shiftKCacheBuffer,
    const KvCacheDataType cache_type, int const sizePerHead, int const timestep, int const batch_beam,
    int const kv_head_num, int const beam_width, int const maxKCacheLen, int const sinkTokenLen,
    float const* kScaleQuantOrig, int const* sequence_lengths, int const* input_lengths, int const rotary_embedding_dim,
    float rotary_embedding_base, RotaryScalingType const rotary_scale_type, float rotary_embedding_scale,
    int const rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type, cudaStream_t stream)
{
    // Block handles K tile.
    int const token_num_in_k = (timestep <= maxKCacheLen) ? timestep : maxKCacheLen;
    int const vec_size = 16u / sizeof(T);
    dim3 block((sizePerHead / vec_size + 31) / 32 * 32);
    dim3 grid(token_num_in_k, kv_head_num, batch_beam);
    size_t smem_size = (position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
                || position_embedding_type == PositionEmbeddingType::kLONG_ROPE
                || position_embedding_type == PositionEmbeddingType::kROPE_M
            ? 2 * rotary_embedding_dim * sizeof(T)
            : 0);

    if (cache_type == KvCacheDataType::INT8)
    {
        shiftKCache<T, int8_t, KVCacheBuffer><<<grid, block, smem_size, stream>>>(kvCacheBuffer, shiftKCacheBuffer,
            sizePerHead, timestep, beam_width, maxKCacheLen, sinkTokenLen, kScaleQuantOrig, sequence_lengths,
            input_lengths, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale,
            rotary_embedding_max_positions, position_embedding_type);
    }
#ifdef ENABLE_FP8
    else if (cache_type == KvCacheDataType::FP8)
    {
        shiftKCache<T, __nv_fp8_e4m3, KVCacheBuffer><<<grid, block, smem_size, stream>>>(kvCacheBuffer,
            shiftKCacheBuffer, sizePerHead, timestep, beam_width, maxKCacheLen, sinkTokenLen, kScaleQuantOrig,
            sequence_lengths, input_lengths, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
            rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type);
    }
#endif // ENABLE_FP8
    else
    {
        shiftKCache<T, T, KVCacheBuffer><<<grid, block, smem_size, stream>>>(kvCacheBuffer, shiftKCacheBuffer,
            sizePerHead, timestep, beam_width, maxKCacheLen, sinkTokenLen, kScaleQuantOrig, sequence_lengths,
            input_lengths, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale,
            rotary_embedding_max_positions, position_embedding_type);
    }
}

//global CUDA 函数实际执行在GPU
template <typename T, typename T_cache, typename KVCacheBuffer>
__global__ void shiftKCache(KVCacheBuffer kvCacheBuffer, KVLinearBuffer shiftKCacheBuffer, int const sizePerHead,
    int const timestep, int const beam_width, int const maxKCacheLen, int const sinkTokenLen,
    float const* kScaleQuantOrig, int const* sequence_lengths, int const* input_lengths, int const rotary_embedding_dim,
    float rotary_embedding_base, RotaryScalingType const rotary_scale_type, float rotary_embedding_scale,
    int const rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type)

TensorRT Plugin 创建

  • 从上面的代码可以看到createPlugin new GPTAttentionPlugin,那createPlugin是在什么时候调用的呢?
  • 是在trtllm-build的阶段,这个阶段会逐 Layer 构建网络时实时解析和适配判断,最终会覆盖 Qwen 模型的每一层,决定是否用自定义plugin 替代原生 TensorRT 层
  • 当build的过程使用GPTAttentionPlugin替换原生TensorRT之后,在实际运行过程中,如果creator注册之后就可以使用register Plugin去实现推理,提升大模型运行效率

Summary

  • 通过以上的流程分析,使用自定义TensorRT Plugin嵌入到TensorRT LLM框架中, 在build转换模型阶段create Plugin、推理框架运行加载模型阶段register creator,实际模型接收Request消息阶段使用Plugin 实际代码处理。
  • 这样高效合理的框架保证大模型可以使用自定义的Plugin、自定义的operation、自定义的CUDA算子,对每个大模型针对性优化的高效运行
相关推荐
love530love13 小时前
突破 ComfyUI 环境枷锁:RTX 3090 强行开启 comfy-kitchen 官方全后端加速库实战
人工智能·windows·python·cuda·comfyui·triton·comfy-kitchen
心 爱心 爱1 天前
pip 隔离环境内 安装 cuda 113 不覆盖原有的全局 cuda 115
pip·cuda·隔离环境
小烤箱1 天前
CUDA 编程完全理解系列(第二篇):从 Block 生命周期理解调度
自动驾驶·cuda·并行计算·感知算法
KIDGINBROOK1 天前
Blackwell架构学习
gpu·cuda·blackwell
REDcker1 天前
Nvidia英伟达显卡型号发布史与架构演进详解
架构·gpu·显卡·nvidia·cuda·英伟达·演进
小烤箱3 天前
CUDA 编程完全理解系列(第一篇):GPU 的设计哲学与硬件架构基础
自动驾驶·硬件架构·cuda·并行计算·感知算法
英雄各有见3 天前
Chapter 5.1.1: 编写你的第一个GPU kernel——Cuda Basics
c++·gpu·cuda·hpc
scott1985124 天前
NVIDIA GPU内部结构:高性能矩阵乘法内核剖析
线性代数·矩阵·gpu·nvidia·cuda
小烤箱5 天前
Autoware Universe 感知模块详解 | 第十二节 CUDA 编程基础——CUDA执行模型
自动驾驶·cuda·感知