从FasterTransformer源码解读开始了解大模型(2.2)代码解读03-forward函数
写在前面的话
本篇的内容继续解读forward函数,从650行开始进行解读
零、输出Context_embeddings和context_cum_log_probs的参数和逻辑
从653行开始,会从输入的请求tensors中读取一个配置,如果请求中配置了is_return_context_embeddings参数并设置为true时,则会在返回参数中增加一个context_embeddings的tensor,这个tensor中包含的数据是输入经过了ContextDecoder过程的所有的层之后的logits,并对其进行求和。可能有些类似于强化学习(RLHF)之类的场景会用到这里的输出,所以在这里做了一层准备。
跳转到1093行,可以看见,这里调用了invokeSumLengthDimension这个kennels,来将context_decoder_output_buf这块buf中的显存数据拷贝到输出tensor的context_embeddings中,可以简单看一下这个kernel的实现,在src/fastertransformer/kernels/gpt_kernels.cu中
c++
template<typename T>
void invokeSumLengthDimension(float* out_buf,
const T* in_buf,
const size_t batch_size,
const size_t input_length,
const size_t hidden_dim,
cudaStream_t stream)
{
dim3 gridSize(batch_size);
dim3 blockSize(256);
sum_length_dimension<<<gridSize, blockSize, 0, stream>>>(out_buf, in_buf, batch_size, input_length, hidden_dim);
}
template<typename T>
__global__ void sum_length_dimension(
float* out_buf, const T* in_buf, const size_t batch_size, const size_t input_length, const size_t hidden_dim)
{
const int bidx = blockIdx.x;
for (int hidx = threadIdx.x; hidx < hidden_dim; hidx += blockDim.x) {
float accum = 0.0f;
for (int step = 0; step < input_length; step++) {
accum += static_cast<float>(in_buf[(bidx * input_length + step) * hidden_dim + hidx]);
}
out_buf[bidx * hidden_dim + hidx] = accum;
}
}
从kernel中可以看出,这个kernel任务划分为按照batch_size维度进行了grid task分配,并为每个任务划分了256个线程。在kernel内部则是将每个batch内所有输入的logits按照输入length维度进行了累加,并拷贝到输出buf中。
类似的一个设置和处理环节,在783行还处理了is_return_context_cum_log_probs,这里会将contextDecoder完成之后的输出进行cum log计算。其中主要的处理逻辑函数是403行的computeContextCumLogProbs函数,进入这个函数后,在425到502行的处理逻辑是,首先将context_decoder的输出进行layernorm,然后使用cublas矩阵乘,将词表维度的logits计算出来。在此之后,在504行,则会使用invokeLogProbFromLogits这个kernel.
c++
template<typename T>
void invokeLogProbFromLogits(float* cum_log_probs,
const T* logits,
const int* input_ids,
const int* input_lengths,
const size_t max_input_length,
const size_t batch_size,
const size_t vocab_size,
const size_t vocab_size_padded,
void* workspace,
const size_t workspace_size,
cudaStream_t stream,
const bool batch_first)
{
// A batched version of log prob computation.
//
// cum_log_probs: [batch_size]
// logits: [max_input_length, batch_size, vocab_size] or [batch_size, max_input_length, vocab_size]
// input_ids: [max_input_length, batch_size] or [max_input_length, batch_size]
// input_lengths: [batch_size]
// workspace: workspace buffer of size at least sizeof(float) * max_input_length * batch_size.
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
// block_size should be multiple of 32 to use warpReduceMax.
const int block_size = vocab_size < 1024 ? (vocab_size + 31) / 32 * 32 : 1024;
assert(block_size % 32 == 0);
assert(workspace != nullptr && workspace_size >= sizeof(float) * max_input_length * batch_size);
assert(vocab_size <= vocab_size_padded);
float* log_probs = reinterpret_cast<float*>(workspace);
int gx = batch_first ? batch_size : max_input_length - 1;
int gy = batch_first ? max_input_length - 1 : batch_size;
dim3 grid(gx, gy);
log_probs_kernel<T><<<grid, block_size, 0, stream>>>(log_probs,
logits,
input_ids,
input_lengths,
max_input_length,
batch_size,
vocab_size,
vocab_size_padded,
batch_first);
accumulate_log_probs<<<batch_size, block_size, 0, stream>>>(
cum_log_probs, log_probs, input_lengths, max_input_length, batch_size, batch_first);
}
__global__ void accumulate_log_probs(float* cum_log_probs,
const float* log_probs,
const int* lengths,
const size_t max_input_length,
const size_t batch_size,
const bool batch_first)
{
// Accumulate the log probability along with the sequence dimension.
// cum_log_probs[j] = sum_i log(softmax(logits))[ids[i,j]]
//
// cum_log_probs: [batch_size], cumulative log probability
// log_probs: [max_length - 1, batch_size] or [batch_size, max_length - 1],
// log probability of each token
// lengths: [batch_size], sequence lengths
// batch_size: [1], batch_size. in case of beam > 1, batch x beam.
int bidx = blockIdx.x; // batch dim
int tidx = threadIdx.x; // step dim
if (bidx < batch_size) {
int length = lengths[bidx];
// reposition logits to data for the current batch.
log_probs += batch_first ? bidx * (max_input_length - 1) : bidx;
int stride = batch_first ? 1 : batch_size; // stride along with seq dim.
float local_accum = 0.0f;
for (int step = tidx; step < length - 1; step += blockDim.x) {
local_accum += static_cast<float>(log_probs[step * stride]);
}
float accum = blockDim.x <= 32 ? warpReduceSum(local_accum) : blockReduceSum<float>(local_accum);
if (tidx == 0) {
cum_log_probs[bidx] = accum;
}
}
}
在任务划分阶段,grid的x维度是batch维度,而y维度则是输入长度,进入kernel后,则是按照batch维度对任务的log_probs进行了ReduceSum,并写入到cum_log_probs中。
一、跳过prompt和prefix阶段
回到653行,接下来是ft中处理其自身特殊的prompt和prefix的逻辑,但我们的源码解读可以先跳过这一段。之所以要跳过这一段是在当前主要的大模型处理逻辑中,并不会需要在推理引擎这一层过多地关注prompt和prefix的逻辑。在对话的大模型agent中,对于第n次对话的输入/输出,其真正要进行forward的输入,往往是由前n-1轮模型的输出和用户的输入共同组成的。从推理引擎的角度来看,prefill阶段在真正端到端时延的占比并不算高,所以专门设计一个处理prompt的逻辑(还需要常驻的显存空间)多少显得有些得不偿失了
二、正常的逻辑起始和输入展开
让我们直接来到865行,考虑以下一种场景来简化我们的源码解读的结构:beam_width=1,tp=1,pp=1,use_shared_contexts不启用,也就是说,没有beam search,在单机单卡上进行一次简单的,不共享contexts的推理服务的进行。
从866行开始,由于是计算起始,所以先对一些计算中的buff进行清空。在870行,由于不进行beam search,所以也不需要对beam search用到的buf进行清空。
在877到887行,如果词表大小不对齐的话,需要将词表计算的权重进行对齐拷贝。
在889到905行,不使用shared context的话,这一段的逻辑也不需要进行处理。在914行处理prompt的逻辑也可以跳过,直接进入955行的处理逻辑。在955行,由于我们的输入是带batch的,所以需要将batch进行展开(tiled),这里使用的是kernel invokeTileGptPromptInputs
void invokeTileGptPromptInputs(int* tiled_input_ids,
int* tiled_input_lengths,
int* tiled_prompt_lengths,
const int* input_ids,
const int* input_lengths,
const int* prefix_prompt_lengths,
const int batch_size,
const int beam_width,
const int max_input_length,
cudaStream_t stream)
{
dim3 grid(batch_size, beam_width);
dim3 block(min(1024, max_input_length));
if (prefix_prompt_lengths != nullptr) {
tileGptPromptInputs<true><<<grid, block, 0, stream>>>(tiled_input_ids,
tiled_input_lengths,
tiled_prompt_lengths,
input_ids,
input_lengths,
prefix_prompt_lengths,
max_input_length);
}
else {
tileGptPromptInputs<false><<<grid, block, 0, stream>>>(tiled_input_ids,
tiled_input_lengths,
tiled_prompt_lengths,
input_ids,
input_lengths,
prefix_prompt_lengths,
max_input_length);
}
}
template<bool PREFIX_PROMPT>
__global__ void tileGptPromptInputs(int* tiled_input_ids,
int* tiled_input_lengths,
int* tiled_prompt_lengths,
const int* input_ids,
const int* input_lengths,
const int* prefix_prompt_lengths,
const int max_input_length)
{
if (threadIdx.x == 0) {
tiled_input_lengths[blockIdx.x * gridDim.y + blockIdx.y] = input_lengths[blockIdx.x];
if (PREFIX_PROMPT) {
tiled_prompt_lengths[blockIdx.x * gridDim.y + blockIdx.y] = prefix_prompt_lengths[blockIdx.x];
}
}
for (int index = threadIdx.x; index < max_input_length; index += blockDim.x) {
tiled_input_ids[(blockIdx.x * gridDim.y + blockIdx.y) * max_input_length + index] =
input_ids[blockIdx.x * max_input_length + index];
}
}
在任务划分阶段,按照batch维度进行划分,每个任务起了至少1024个线程来进行拷贝。由于没有beam width,那么gridDim.y就会是1。在kernel中,首先将input_length拷贝到tiled_input_lengths,之后再处理
在kernel中,由于没有beam width,那么gridDim.y就会是1。在kernel中,首先将input_length拷贝到tiled_input_lengths,之后再按照batch维度进行处理,将一个batch*max_input_length的数据进行展开,并拷贝到对应的buf,这有利于我们进行接下来的后续embedding和MHA计算
下一回预告
下一回继续讲解forward函数中的处理逻辑,会简单讲解embedding, pre_layernorm,以及进入attention_layer之后的初步讲解