从FasterTransformer源码解读开始了解大模型(2.1)代码通读03

从FasterTransformer源码解读开始了解大模型(2.2)代码解读03-forward函数

写在前面的话

本篇的内容继续解读forward函数,从650行开始进行解读

零、输出Context_embeddings和context_cum_log_probs的参数和逻辑

从653行开始,会从输入的请求tensors中读取一个配置,如果请求中配置了is_return_context_embeddings参数并设置为true时,则会在返回参数中增加一个context_embeddings的tensor,这个tensor中包含的数据是输入经过了ContextDecoder过程的所有的层之后的logits,并对其进行求和。可能有些类似于强化学习(RLHF)之类的场景会用到这里的输出,所以在这里做了一层准备。

跳转到1093行,可以看见,这里调用了invokeSumLengthDimension这个kennels,来将context_decoder_output_buf这块buf中的显存数据拷贝到输出tensor的context_embeddings中,可以简单看一下这个kernel的实现,在src/fastertransformer/kernels/gpt_kernels.cu中

c++ 复制代码
template<typename T>
void invokeSumLengthDimension(float*       out_buf,
                              const T*     in_buf,
                              const size_t batch_size,
                              const size_t input_length,
                              const size_t hidden_dim,
                              cudaStream_t stream)
{
    dim3 gridSize(batch_size);
    dim3 blockSize(256);

    sum_length_dimension<<<gridSize, blockSize, 0, stream>>>(out_buf, in_buf, batch_size, input_length, hidden_dim);
}


template<typename T>
__global__ void sum_length_dimension(
    float* out_buf, const T* in_buf, const size_t batch_size, const size_t input_length, const size_t hidden_dim)
{
    const int bidx = blockIdx.x;

    for (int hidx = threadIdx.x; hidx < hidden_dim; hidx += blockDim.x) {
        float accum = 0.0f;
        for (int step = 0; step < input_length; step++) {
            accum += static_cast<float>(in_buf[(bidx * input_length + step) * hidden_dim + hidx]);
        }
        out_buf[bidx * hidden_dim + hidx] = accum;
    }
}

从kernel中可以看出,这个kernel任务划分为按照batch_size维度进行了grid task分配,并为每个任务划分了256个线程。在kernel内部则是将每个batch内所有输入的logits按照输入length维度进行了累加,并拷贝到输出buf中。

类似的一个设置和处理环节,在783行还处理了is_return_context_cum_log_probs,这里会将contextDecoder完成之后的输出进行cum log计算。其中主要的处理逻辑函数是403行的computeContextCumLogProbs函数,进入这个函数后,在425到502行的处理逻辑是,首先将context_decoder的输出进行layernorm,然后使用cublas矩阵乘,将词表维度的logits计算出来。在此之后,在504行,则会使用invokeLogProbFromLogits这个kernel.

c++ 复制代码
template<typename T>
void invokeLogProbFromLogits(float*       cum_log_probs,
                             const T*     logits,
                             const int*   input_ids,
                             const int*   input_lengths,
                             const size_t max_input_length,
                             const size_t batch_size,
                             const size_t vocab_size,
                             const size_t vocab_size_padded,
                             void*        workspace,
                             const size_t workspace_size,
                             cudaStream_t stream,
                             const bool   batch_first)
{
    // A batched version of log prob computation.
    //
    // cum_log_probs: [batch_size]
    // logits: [max_input_length, batch_size, vocab_size] or [batch_size, max_input_length, vocab_size]
    // input_ids: [max_input_length, batch_size] or [max_input_length, batch_size]
    // input_lengths: [batch_size]
    // workspace: workspace buffer of size at least sizeof(float) * max_input_length * batch_size.

    FT_LOG_DEBUG(__PRETTY_FUNCTION__);
    // block_size should be multiple of 32 to use warpReduceMax.
    const int block_size = vocab_size < 1024 ? (vocab_size + 31) / 32 * 32 : 1024;
    assert(block_size % 32 == 0);
    assert(workspace != nullptr && workspace_size >= sizeof(float) * max_input_length * batch_size);
    assert(vocab_size <= vocab_size_padded);

    float* log_probs = reinterpret_cast<float*>(workspace);
    int    gx        = batch_first ? batch_size : max_input_length - 1;
    int    gy        = batch_first ? max_input_length - 1 : batch_size;
    dim3   grid(gx, gy);
    log_probs_kernel<T><<<grid, block_size, 0, stream>>>(log_probs,
                                                         logits,
                                                         input_ids,
                                                         input_lengths,
                                                         max_input_length,
                                                         batch_size,
                                                         vocab_size,
                                                         vocab_size_padded,
                                                         batch_first);
    accumulate_log_probs<<<batch_size, block_size, 0, stream>>>(
        cum_log_probs, log_probs, input_lengths, max_input_length, batch_size, batch_first);
}

__global__ void accumulate_log_probs(float*       cum_log_probs,
                                     const float* log_probs,
                                     const int*   lengths,
                                     const size_t max_input_length,
                                     const size_t batch_size,
                                     const bool   batch_first)
{
    // Accumulate the log probability along with the sequence dimension.
    //   cum_log_probs[j] = sum_i log(softmax(logits))[ids[i,j]]
    //
    // cum_log_probs: [batch_size], cumulative log probability
    // log_probs: [max_length - 1, batch_size] or [batch_size, max_length - 1],
    //   log probability of each token
    // lengths: [batch_size], sequence lengths
    // batch_size: [1], batch_size. in case of beam > 1, batch x beam.

    int bidx = blockIdx.x;   // batch dim
    int tidx = threadIdx.x;  // step dim

    if (bidx < batch_size) {
        int length = lengths[bidx];
        // reposition logits to data for the current batch.
        log_probs += batch_first ? bidx * (max_input_length - 1) : bidx;
        int   stride      = batch_first ? 1 : batch_size;  // stride along with seq dim.
        float local_accum = 0.0f;
        for (int step = tidx; step < length - 1; step += blockDim.x) {
            local_accum += static_cast<float>(log_probs[step * stride]);
        }
        float accum = blockDim.x <= 32 ? warpReduceSum(local_accum) : blockReduceSum<float>(local_accum);
        if (tidx == 0) {
            cum_log_probs[bidx] = accum;
        }
    }
}

在任务划分阶段,grid的x维度是batch维度,而y维度则是输入长度,进入kernel后,则是按照batch维度对任务的log_probs进行了ReduceSum,并写入到cum_log_probs中。

一、跳过prompt和prefix阶段

回到653行,接下来是ft中处理其自身特殊的prompt和prefix的逻辑,但我们的源码解读可以先跳过这一段。之所以要跳过这一段是在当前主要的大模型处理逻辑中,并不会需要在推理引擎这一层过多地关注prompt和prefix的逻辑。在对话的大模型agent中,对于第n次对话的输入/输出,其真正要进行forward的输入,往往是由前n-1轮模型的输出和用户的输入共同组成的。从推理引擎的角度来看,prefill阶段在真正端到端时延的占比并不算高,所以专门设计一个处理prompt的逻辑(还需要常驻的显存空间)多少显得有些得不偿失了

二、正常的逻辑起始和输入展开

让我们直接来到865行,考虑以下一种场景来简化我们的源码解读的结构:beam_width=1,tp=1,pp=1,use_shared_contexts不启用,也就是说,没有beam search,在单机单卡上进行一次简单的,不共享contexts的推理服务的进行。

从866行开始,由于是计算起始,所以先对一些计算中的buff进行清空。在870行,由于不进行beam search,所以也不需要对beam search用到的buf进行清空。

在877到887行,如果词表大小不对齐的话,需要将词表计算的权重进行对齐拷贝。

在889到905行,不使用shared context的话,这一段的逻辑也不需要进行处理。在914行处理prompt的逻辑也可以跳过,直接进入955行的处理逻辑。在955行,由于我们的输入是带batch的,所以需要将batch进行展开(tiled),这里使用的是kernel invokeTileGptPromptInputs

void invokeTileGptPromptInputs(int*         tiled_input_ids,
                               int*         tiled_input_lengths,
                               int*         tiled_prompt_lengths,
                               const int*   input_ids,
                               const int*   input_lengths,
                               const int*   prefix_prompt_lengths,
                               const int    batch_size,
                               const int    beam_width,
                               const int    max_input_length,
                               cudaStream_t stream)
{
    dim3 grid(batch_size, beam_width);
    dim3 block(min(1024, max_input_length));
    if (prefix_prompt_lengths != nullptr) {
        tileGptPromptInputs<true><<<grid, block, 0, stream>>>(tiled_input_ids,
                                                              tiled_input_lengths,
                                                              tiled_prompt_lengths,
                                                              input_ids,
                                                              input_lengths,
                                                              prefix_prompt_lengths,
                                                              max_input_length);
    }
    else {
        tileGptPromptInputs<false><<<grid, block, 0, stream>>>(tiled_input_ids,
                                                               tiled_input_lengths,
                                                               tiled_prompt_lengths,
                                                               input_ids,
                                                               input_lengths,
                                                               prefix_prompt_lengths,
                                                               max_input_length);
    }
}

template<bool PREFIX_PROMPT>
__global__ void tileGptPromptInputs(int*       tiled_input_ids,
                                    int*       tiled_input_lengths,
                                    int*       tiled_prompt_lengths,
                                    const int* input_ids,
                                    const int* input_lengths,
                                    const int* prefix_prompt_lengths,
                                    const int  max_input_length)
{
    if (threadIdx.x == 0) {
        tiled_input_lengths[blockIdx.x * gridDim.y + blockIdx.y] = input_lengths[blockIdx.x];
        if (PREFIX_PROMPT) {
            tiled_prompt_lengths[blockIdx.x * gridDim.y + blockIdx.y] = prefix_prompt_lengths[blockIdx.x];
        }
    }
    for (int index = threadIdx.x; index < max_input_length; index += blockDim.x) {
        tiled_input_ids[(blockIdx.x * gridDim.y + blockIdx.y) * max_input_length + index] =
            input_ids[blockIdx.x * max_input_length + index];
    }
}

在任务划分阶段,按照batch维度进行划分,每个任务起了至少1024个线程来进行拷贝。由于没有beam width,那么gridDim.y就会是1。在kernel中,首先将input_length拷贝到tiled_input_lengths,之后再处理

在kernel中,由于没有beam width,那么gridDim.y就会是1。在kernel中,首先将input_length拷贝到tiled_input_lengths,之后再按照batch维度进行处理,将一个batch*max_input_length的数据进行展开,并拷贝到对应的buf,这有利于我们进行接下来的后续embedding和MHA计算

下一回预告

下一回继续讲解forward函数中的处理逻辑,会简单讲解embedding, pre_layernorm,以及进入attention_layer之后的初步讲解

相关推荐
2401_8904167115 分钟前
Recaptcha2 图像怎么识别
人工智能·python·django
机器之心39 分钟前
贾佳亚团队联合Adobe提出GenProp,物体追踪移除特效样样在行
人工智能
一叶_障目1 小时前
机器学习之决策树(DecisionTree——C4.5)
人工智能·决策树·机器学习
思码逸研发效能1 小时前
在 DevOps 实践中,如何构建自动化的持续集成和持续交付(CI/CD)管道,以提高开发和测试效率?
运维·人工智能·ci/cd·自动化·研发效能·devops·效能度量
AI量化投资实验室2 小时前
deap系统重构,再新增一个新的因子,年化39.1%,卡玛提升至2.76(附python代码)
大数据·人工智能·重构
张登杰踩2 小时前
如何快速下载Huggingface上的超大模型,不用梯子,以Deepseek-R1为例子
人工智能
AIGC大时代2 小时前
分享14分数据分析相关ChatGPT提示词
人工智能·chatgpt·数据分析
TMT星球3 小时前
生数科技携手央视新闻《文博日历》,推动AI视频技术的创新应用
大数据·人工智能·科技
AI视觉网奇3 小时前
图生3d算法学习笔记
人工智能
小锋学长生活大爆炸3 小时前
【DGL系列】dgl中为graph指定CSR/COO/CSC矩阵格式
人工智能·pytorch·深度学习·图神经网络·gnn·dgl