大模型推理kv_cache缓存

一、目录

  1. kv_cache 用途
  2. 代码比较
  3. gpt2 多头自注意力实现+kv_cache

二、实现

  1. kv_cache 用途

    1. kv_cache 应用于模型推理过程中,训练过程则不需要。
    2. 为了避免生成式模型在推理过程中 每次都需要将先前生成的文本拼接到问题中,将生成的信息保存起来,推理过程进行加载即可。
    3. 将先前的key,value 进行保存,而query 不需要。
      参考:https://zhuanlan.zhihu.com/p/630832593
      原理:query@key^T 目的是计算 当前i 位置时 各value 所占的比率(生成的各个词占的比例),所以计算当前i 位置的信息时,则query 中前i-1 位置的信息无用,因此可以丢掉。
    4. 推理输出的token直接作为下一轮的输入,不再拼接,因为上文信息已经在 kvcache 中。
  2. 代码比较
    推理时未保存past_key_value

    import torch
    from transformers import BertTokenizer, GPT2LMHeadModel,TextGenerationPipeline
    model =GPT2LMHeadModel.from_pretrained("C:/Users/86188/Downloads/gpt2", torchscript=True).eval()

    tokenizer

    tokenizer = BertTokenizer.from_pretrained("C:/Users/86188/Downloads/gpt2")
    in_text = "白日依山尽"
    in_tokens = torch.tensor(tokenizer.encode(in_text))

    inference

    token_eos = torch.tensor([198]) # line break symbol
    out_token = None
    i = 0
    with torch.no_grad():
    while out_token != token_eos:
    logits, _ = model(in_tokens)
    out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
    in_tokens = torch.cat((in_tokens, out_token), 0)
    text = tokenizer.decode(in_tokens)
    print(f'step {i} input: {text}', flush=True)
    i += 1
    out_text = tokenizer.decode(in_tokens)
    print(f' Input: {in_text}')
    print(f'Output: {out_text}')

推理时保留past_key_value

import torch
from transformers import BertTokenizer, GPT2LMHeadModel,TextGenerationPipeline
model =GPT2LMHeadModel.from_pretrained("C:/Users/86188/Downloads/gpt2", torchscript=True).eval()
# tokenizer
tokenizer = BertTokenizer.from_pretrained("C:/Users/86188/Downloads/gpt2")
in_text = "白日依山尽"
in_tokens = torch.tensor(tokenizer.encode(in_text))

# inference
token_eos = torch.tensor([198]) # line break symbol
out_token = None
kvcache = None
out_text = in_text
i = 0
with torch.no_grad():
    while out_token != token_eos:
        logits, kvcache = model(in_tokens, past_key_values=kvcache)  # 增加了一个 past_key_values 的参数
        out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
        in_tokens = out_token  # 输出 token 直接作为下一轮的输入,不再拼接
        text = tokenizer.decode(in_tokens)
        print(f'step {i} input: {text}', flush=True)
        i += 1
        out_text += text
out_text = tokenizer.decode(in_tokens)
print(f' Input: {in_text}')
print(f'Output: {out_text}')

优点:减少计算量,提高推理速度。

底层实现:

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

if layer_past is not None:
    past_key, past_value = layer_past
    key = torch.cat((past_key, key), dim=-2)           #past_key 拼接
    value = torch.cat((past_value, value), dim=-2)     #past_value 拼接

if use_cache is True:
    present = (key, value)
else:
    present = None

if self.reorder_and_upcast_attn:
    attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:             #query [batch,n,1,d_dim]     key,value=[batch,n,seq_len,d_dim]
    attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  1. gpt2 自注意力实现

    import torch
    import torch.nn as nn
    from typing import Optional,List,Tuple,Set,Union
    from torch.cuda.amp import autocast

    class Conv1D(nn.Module):
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).

     Basically works like a linear layer but the weights are transposed.
    
     Args:
         nf (`int`): The number of output features.
         nx (`int`): The number of input features.
     """
    
     def __init__(self, nf, nx):
         super().__init__()
         self.nf = nf
         self.weight = nn.Parameter(torch.empty(nx, nf))
         self.bias = nn.Parameter(torch.zeros(nf))
         nn.init.normal_(self.weight, std=0.02)
    
     def forward(self, x):
         size_out = x.size()[:-1] + (self.nf,)
         x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
         x = x.view(size_out)
         return x
    

    def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
    """
    Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
    are transposed.

     Used to remove heads.
    
     Args:
         layer ([`~pytorch_utils.Conv1D`]): The layer to prune.
         index (`torch.LongTensor`): The indices to keep in the layer.
         dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices.
    
     Returns:
         [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
     """
     index = index.to(layer.weight.device)
     W = layer.weight.index_select(dim, index).clone().detach()
     if dim == 0:
         b = layer.bias.clone().detach()
     else:
         b = layer.bias[index].clone().detach()
     new_size = list(layer.weight.size())
     new_size[dim] = len(index)
     new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
     new_layer.weight.requires_grad = False
     new_layer.weight.copy_(W.contiguous())
     new_layer.weight.requires_grad = True
     new_layer.bias.requires_grad = False
     new_layer.bias.copy_(b.contiguous())
     new_layer.bias.requires_grad = True
     return new_layer
    

    def find_pruneable_heads_and_indices(
    heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
    ) -> Tuple[Set[int], torch.LongTensor]:
    """
    Finds the heads and their indices taking already_pruned_heads into account.

     Args:
         heads (`List[int]`): List of the indices of heads to prune.
         n_heads (`int`): The number of heads in the model.
         head_size (`int`): The size of each head.
         already_pruned_heads (`Set[int]`): A set of already pruned heads.
    
     Returns:
         `Tuple[Set[int], torch.LongTensor]`: A tuple with the indices of heads to prune taking `already_pruned_heads`
         into account and the indices of rows/columns to keep in the layer weight.
     """
     mask = torch.ones(n_heads, head_size)
     heads = set(heads) - already_pruned_heads  # Convert to set and remove already pruned heads
     for head in heads:
         # Compute how many pruned heads are before the head and move the index accordingly
         head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
         mask[head] = 0
     mask = mask.view(-1).contiguous().eq(1)
     index: torch.LongTensor = torch.arange(len(mask))[mask].long()
     return heads, index
    

    class GPT2Attention(nn.Module):
    def init(self, config, is_cross_attention=False, layer_idx=None):
    super().init()

         max_positions = config.max_position_embeddings
         self.register_buffer(
             "bias",
             torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
                 1, 1, max_positions, max_positions
             ),
             persistent=False,
         )
         self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
    
         self.embed_dim = config.hidden_size
         self.num_heads = config.num_attention_heads
         self.head_dim = self.embed_dim // self.num_heads
         self.split_size = self.embed_dim
         if self.head_dim * self.num_heads != self.embed_dim:
             raise ValueError(
                 f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                 f" {self.num_heads})."
             )
    
         self.scale_attn_weights = config.scale_attn_weights
         self.is_cross_attention = is_cross_attention
    
         # Layer-wise attention scaling, reordering, and upcasting
         self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
         self.layer_idx = layer_idx
         self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
    
         if self.is_cross_attention:
             self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
             self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
         else:
             self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
         self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
    
         self.attn_dropout = nn.Dropout(config.attn_pdrop)
         self.resid_dropout = nn.Dropout(config.resid_pdrop)
    
         self.pruned_heads = set()
    
     def _attn(self, query, key, value, attention_mask=None, head_mask=None):
         attn_weights = torch.matmul(query, key.transpose(-1, -2))
    
         if self.scale_attn_weights:
             attn_weights = attn_weights / torch.full(
                 [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
             )
    
         # Layer-wise attention scaling
         if self.scale_attn_by_inverse_layer_idx:
             attn_weights = attn_weights / float(self.layer_idx + 1)
    
         if not self.is_cross_attention:
             # if only "normal" attention layer implements causal mask
             query_length, key_length = query.size(-2), key.size(-2)
             causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
             mask_value = torch.finfo(attn_weights.dtype).min
             # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
             # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
             mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
             attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
    
         if attention_mask is not None:
             # Apply the attention mask
             attn_weights = attn_weights + attention_mask
    
         attn_weights = nn.functional.softmax(attn_weights, dim=-1)
    
         # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
         attn_weights = attn_weights.type(value.dtype)
         attn_weights = self.attn_dropout(attn_weights)
    
         # Mask heads if we want to
         if head_mask is not None:
             attn_weights = attn_weights * head_mask
    
         attn_output = torch.matmul(attn_weights, value)
    
         return attn_output, attn_weights
    
     def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
         # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
         bsz, num_heads, q_seq_len, dk = query.size()
         _, _, k_seq_len, _ = key.size()
    
         # Preallocate attn_weights for `baddbmm`
         attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
    
         # Compute Scale Factor
         scale_factor = 1.0
         if self.scale_attn_weights:
             scale_factor /= float(value.size(-1)) ** 0.5
    
         if self.scale_attn_by_inverse_layer_idx:
             scale_factor /= float(self.layer_idx + 1)
    
         # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
         with autocast(enabled=False):
             q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
             attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
             attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
    
         if not self.is_cross_attention:
             # if only "normal" attention layer implements causal mask
             query_length, key_length = query.size(-2), key.size(-2)
             causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
             mask_value = torch.finfo(attn_weights.dtype).min
             # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
             # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
             mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
             attn_weights = torch.where(causal_mask, attn_weights, mask_value)
    
         if attention_mask is not None:
             # Apply the attention mask
             attn_weights = attn_weights + attention_mask
    
         attn_weights = nn.functional.softmax(attn_weights, dim=-1)
    
         # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
         if attn_weights.dtype != torch.float32:
             raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
         attn_weights = attn_weights.type(value.dtype)
         attn_weights = self.attn_dropout(attn_weights)
    
         # Mask heads if we want to
         if head_mask is not None:
             attn_weights = attn_weights * head_mask
    
         attn_output = torch.matmul(attn_weights, value)
    
         return attn_output, attn_weights
    
     def _split_heads(self, tensor, num_heads, attn_head_size):
         """
         Splits hidden_size dim into attn_head_size and num_heads
         """
         new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
         tensor = tensor.view(new_shape)
         return tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
    
     def _merge_heads(self, tensor, num_heads, attn_head_size):
         """
         Merges attn_head_size dim and num_attn_heads dim into hidden_size
         """
         tensor = tensor.permute(0, 2, 1, 3).contiguous()
         new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
         return tensor.view(new_shape)
    
     def forward(
         self,
         hidden_states: Optional[Tuple[torch.FloatTensor]],
         layer_past: Optional[Tuple[torch.Tensor]] = None,
         attention_mask: Optional[torch.FloatTensor] = None,
         head_mask: Optional[torch.FloatTensor] = None,
         encoder_hidden_states: Optional[torch.Tensor] = None,
         encoder_attention_mask: Optional[torch.FloatTensor] = None,
         use_cache: Optional[bool] = False,
         output_attentions: Optional[bool] = False,
     ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
         if encoder_hidden_states is not None:
             if not hasattr(self, "q_attn"):
                 raise ValueError(
                     "If class is used as cross attention, the weights `q_attn` have to be defined. "
                     "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
                 )
    
             query = self.q_attn(hidden_states)
             key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
             attention_mask = encoder_attention_mask
         else:#[batch,seq_len,hidden]
             query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
         
         #[batch,h,seq_len,d_hiden]
         query = self._split_heads(query, self.num_heads, self.head_dim)
         key = self._split_heads(key, self.num_heads, self.head_dim)
         value = self._split_heads(value, self.num_heads, self.head_dim)
    
         if layer_past is not None:
             past_key, past_value = layer_past
             key = torch.cat((past_key, key), dim=-2)
             value = torch.cat((past_value, value), dim=-2)
    
         if use_cache is True:
             present = (key, value)
         else:
             present = None
    
         if self.reorder_and_upcast_attn:
             attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
         else:
             attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
    
         attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
         attn_output = self.c_proj(attn_output)
         attn_output = self.resid_dropout(attn_output)
    
         outputs = (attn_output, present)
         if output_attentions:
             outputs += (attn_weights,)
    
         return outputs  # a, present, (attentions)
    

    if name == 'main':
    class Config:
    max_position_embeddings=1024
    hidden_size=768
    num_attention_heads=12
    scale_attn_weights=True
    is_cross_attention=False
    layer_idx=1
    scale_attn_by_inverse_layer_idx=False
    reorder_and_upcast_attn=False
    attn_pdrop=0.1
    resid_pdrop=0.1

     config=Config()
     hidden_states=torch.randn(size=(1,1,768))
     key_past=torch.randn(size=(1,12,7,64))
     value_past = torch.randn(size=(1, 12, 7, 64))
     layer_past=[key_past,value_past]
    
     attention=GPT2Attention(config)
     attention(hidden_states,layer_past)
    
相关推荐
机器懒得学习9 分钟前
基于YOLOv5的智能水域监测系统:从目标检测到自动报告生成
人工智能·yolo·目标检测
QQ同步助手24 分钟前
如何正确使用人工智能:开启智慧学习与创新之旅
人工智能·学习·百度
AIGC大时代26 分钟前
如何使用ChatGPT辅助文献综述,以及如何进行优化?一篇说清楚
人工智能·深度学习·chatgpt·prompt·aigc
流浪的小新31 分钟前
【AI】人工智能、LLM学习资源汇总
人工智能·学习
martian6652 小时前
【人工智能数学基础篇】——深入详解多变量微积分:在机器学习模型中优化损失函数时应用
人工智能·机器学习·微积分·数学基础
人机与认知实验室2 小时前
人、机、环境中各有其神经网络系统
人工智能·深度学习·神经网络·机器学习
黑色叉腰丶大魔王3 小时前
基于 MATLAB 的图像增强技术分享
图像处理·人工智能·计算机视觉
迅易科技5 小时前
借助腾讯云质检平台的新范式,做工业制造企业质检的“AI慧眼”
人工智能·视觉检测·制造
古希腊掌管学习的神6 小时前
[机器学习]XGBoost(3)——确定树的结构
人工智能·机器学习
ZHOU_WUYI7 小时前
4.metagpt中的软件公司智能体 (ProjectManager 角色)
人工智能·metagpt