大模型推理kv_cache缓存

一、目录

  1. kv_cache 用途
  2. 代码比较
  3. gpt2 多头自注意力实现+kv_cache

二、实现

  1. kv_cache 用途

    1. kv_cache 应用于模型推理过程中,训练过程则不需要。
    2. 为了避免生成式模型在推理过程中 每次都需要将先前生成的文本拼接到问题中,将生成的信息保存起来,推理过程进行加载即可。
    3. 将先前的key,value 进行保存,而query 不需要。
      参考:https://zhuanlan.zhihu.com/p/630832593
      原理:query@key^T 目的是计算 当前i 位置时 各value 所占的比率(生成的各个词占的比例),所以计算当前i 位置的信息时,则query 中前i-1 位置的信息无用,因此可以丢掉。
    4. 推理输出的token直接作为下一轮的输入,不再拼接,因为上文信息已经在 kvcache 中。
  2. 代码比较
    推理时未保存past_key_value

    import torch
    from transformers import BertTokenizer, GPT2LMHeadModel,TextGenerationPipeline
    model =GPT2LMHeadModel.from_pretrained("C:/Users/86188/Downloads/gpt2", torchscript=True).eval()

    tokenizer

    tokenizer = BertTokenizer.from_pretrained("C:/Users/86188/Downloads/gpt2")
    in_text = "白日依山尽"
    in_tokens = torch.tensor(tokenizer.encode(in_text))

    inference

    token_eos = torch.tensor([198]) # line break symbol
    out_token = None
    i = 0
    with torch.no_grad():
    while out_token != token_eos:
    logits, _ = model(in_tokens)
    out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
    in_tokens = torch.cat((in_tokens, out_token), 0)
    text = tokenizer.decode(in_tokens)
    print(f'step {i} input: {text}', flush=True)
    i += 1
    out_text = tokenizer.decode(in_tokens)
    print(f' Input: {in_text}')
    print(f'Output: {out_text}')

推理时保留past_key_value

import torch
from transformers import BertTokenizer, GPT2LMHeadModel,TextGenerationPipeline
model =GPT2LMHeadModel.from_pretrained("C:/Users/86188/Downloads/gpt2", torchscript=True).eval()
# tokenizer
tokenizer = BertTokenizer.from_pretrained("C:/Users/86188/Downloads/gpt2")
in_text = "白日依山尽"
in_tokens = torch.tensor(tokenizer.encode(in_text))

# inference
token_eos = torch.tensor([198]) # line break symbol
out_token = None
kvcache = None
out_text = in_text
i = 0
with torch.no_grad():
    while out_token != token_eos:
        logits, kvcache = model(in_tokens, past_key_values=kvcache)  # 增加了一个 past_key_values 的参数
        out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
        in_tokens = out_token  # 输出 token 直接作为下一轮的输入,不再拼接
        text = tokenizer.decode(in_tokens)
        print(f'step {i} input: {text}', flush=True)
        i += 1
        out_text += text
out_text = tokenizer.decode(in_tokens)
print(f' Input: {in_text}')
print(f'Output: {out_text}')

优点:减少计算量,提高推理速度。

底层实现:

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

if layer_past is not None:
    past_key, past_value = layer_past
    key = torch.cat((past_key, key), dim=-2)           #past_key 拼接
    value = torch.cat((past_value, value), dim=-2)     #past_value 拼接

if use_cache is True:
    present = (key, value)
else:
    present = None

if self.reorder_and_upcast_attn:
    attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:             #query [batch,n,1,d_dim]     key,value=[batch,n,seq_len,d_dim]
    attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  1. gpt2 自注意力实现

    import torch
    import torch.nn as nn
    from typing import Optional,List,Tuple,Set,Union
    from torch.cuda.amp import autocast

    class Conv1D(nn.Module):
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).

     Basically works like a linear layer but the weights are transposed.
    
     Args:
         nf (`int`): The number of output features.
         nx (`int`): The number of input features.
     """
    
     def __init__(self, nf, nx):
         super().__init__()
         self.nf = nf
         self.weight = nn.Parameter(torch.empty(nx, nf))
         self.bias = nn.Parameter(torch.zeros(nf))
         nn.init.normal_(self.weight, std=0.02)
    
     def forward(self, x):
         size_out = x.size()[:-1] + (self.nf,)
         x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
         x = x.view(size_out)
         return x
    

    def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
    """
    Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
    are transposed.

     Used to remove heads.
    
     Args:
         layer ([`~pytorch_utils.Conv1D`]): The layer to prune.
         index (`torch.LongTensor`): The indices to keep in the layer.
         dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices.
    
     Returns:
         [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
     """
     index = index.to(layer.weight.device)
     W = layer.weight.index_select(dim, index).clone().detach()
     if dim == 0:
         b = layer.bias.clone().detach()
     else:
         b = layer.bias[index].clone().detach()
     new_size = list(layer.weight.size())
     new_size[dim] = len(index)
     new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
     new_layer.weight.requires_grad = False
     new_layer.weight.copy_(W.contiguous())
     new_layer.weight.requires_grad = True
     new_layer.bias.requires_grad = False
     new_layer.bias.copy_(b.contiguous())
     new_layer.bias.requires_grad = True
     return new_layer
    

    def find_pruneable_heads_and_indices(
    heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
    ) -> Tuple[Set[int], torch.LongTensor]:
    """
    Finds the heads and their indices taking already_pruned_heads into account.

     Args:
         heads (`List[int]`): List of the indices of heads to prune.
         n_heads (`int`): The number of heads in the model.
         head_size (`int`): The size of each head.
         already_pruned_heads (`Set[int]`): A set of already pruned heads.
    
     Returns:
         `Tuple[Set[int], torch.LongTensor]`: A tuple with the indices of heads to prune taking `already_pruned_heads`
         into account and the indices of rows/columns to keep in the layer weight.
     """
     mask = torch.ones(n_heads, head_size)
     heads = set(heads) - already_pruned_heads  # Convert to set and remove already pruned heads
     for head in heads:
         # Compute how many pruned heads are before the head and move the index accordingly
         head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
         mask[head] = 0
     mask = mask.view(-1).contiguous().eq(1)
     index: torch.LongTensor = torch.arange(len(mask))[mask].long()
     return heads, index
    

    class GPT2Attention(nn.Module):
    def init(self, config, is_cross_attention=False, layer_idx=None):
    super().init()

         max_positions = config.max_position_embeddings
         self.register_buffer(
             "bias",
             torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
                 1, 1, max_positions, max_positions
             ),
             persistent=False,
         )
         self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
    
         self.embed_dim = config.hidden_size
         self.num_heads = config.num_attention_heads
         self.head_dim = self.embed_dim // self.num_heads
         self.split_size = self.embed_dim
         if self.head_dim * self.num_heads != self.embed_dim:
             raise ValueError(
                 f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                 f" {self.num_heads})."
             )
    
         self.scale_attn_weights = config.scale_attn_weights
         self.is_cross_attention = is_cross_attention
    
         # Layer-wise attention scaling, reordering, and upcasting
         self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
         self.layer_idx = layer_idx
         self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
    
         if self.is_cross_attention:
             self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
             self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
         else:
             self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
         self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
    
         self.attn_dropout = nn.Dropout(config.attn_pdrop)
         self.resid_dropout = nn.Dropout(config.resid_pdrop)
    
         self.pruned_heads = set()
    
     def _attn(self, query, key, value, attention_mask=None, head_mask=None):
         attn_weights = torch.matmul(query, key.transpose(-1, -2))
    
         if self.scale_attn_weights:
             attn_weights = attn_weights / torch.full(
                 [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
             )
    
         # Layer-wise attention scaling
         if self.scale_attn_by_inverse_layer_idx:
             attn_weights = attn_weights / float(self.layer_idx + 1)
    
         if not self.is_cross_attention:
             # if only "normal" attention layer implements causal mask
             query_length, key_length = query.size(-2), key.size(-2)
             causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
             mask_value = torch.finfo(attn_weights.dtype).min
             # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
             # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
             mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
             attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
    
         if attention_mask is not None:
             # Apply the attention mask
             attn_weights = attn_weights + attention_mask
    
         attn_weights = nn.functional.softmax(attn_weights, dim=-1)
    
         # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
         attn_weights = attn_weights.type(value.dtype)
         attn_weights = self.attn_dropout(attn_weights)
    
         # Mask heads if we want to
         if head_mask is not None:
             attn_weights = attn_weights * head_mask
    
         attn_output = torch.matmul(attn_weights, value)
    
         return attn_output, attn_weights
    
     def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
         # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
         bsz, num_heads, q_seq_len, dk = query.size()
         _, _, k_seq_len, _ = key.size()
    
         # Preallocate attn_weights for `baddbmm`
         attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
    
         # Compute Scale Factor
         scale_factor = 1.0
         if self.scale_attn_weights:
             scale_factor /= float(value.size(-1)) ** 0.5
    
         if self.scale_attn_by_inverse_layer_idx:
             scale_factor /= float(self.layer_idx + 1)
    
         # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
         with autocast(enabled=False):
             q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
             attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
             attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
    
         if not self.is_cross_attention:
             # if only "normal" attention layer implements causal mask
             query_length, key_length = query.size(-2), key.size(-2)
             causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
             mask_value = torch.finfo(attn_weights.dtype).min
             # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
             # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
             mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
             attn_weights = torch.where(causal_mask, attn_weights, mask_value)
    
         if attention_mask is not None:
             # Apply the attention mask
             attn_weights = attn_weights + attention_mask
    
         attn_weights = nn.functional.softmax(attn_weights, dim=-1)
    
         # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
         if attn_weights.dtype != torch.float32:
             raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
         attn_weights = attn_weights.type(value.dtype)
         attn_weights = self.attn_dropout(attn_weights)
    
         # Mask heads if we want to
         if head_mask is not None:
             attn_weights = attn_weights * head_mask
    
         attn_output = torch.matmul(attn_weights, value)
    
         return attn_output, attn_weights
    
     def _split_heads(self, tensor, num_heads, attn_head_size):
         """
         Splits hidden_size dim into attn_head_size and num_heads
         """
         new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
         tensor = tensor.view(new_shape)
         return tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
    
     def _merge_heads(self, tensor, num_heads, attn_head_size):
         """
         Merges attn_head_size dim and num_attn_heads dim into hidden_size
         """
         tensor = tensor.permute(0, 2, 1, 3).contiguous()
         new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
         return tensor.view(new_shape)
    
     def forward(
         self,
         hidden_states: Optional[Tuple[torch.FloatTensor]],
         layer_past: Optional[Tuple[torch.Tensor]] = None,
         attention_mask: Optional[torch.FloatTensor] = None,
         head_mask: Optional[torch.FloatTensor] = None,
         encoder_hidden_states: Optional[torch.Tensor] = None,
         encoder_attention_mask: Optional[torch.FloatTensor] = None,
         use_cache: Optional[bool] = False,
         output_attentions: Optional[bool] = False,
     ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
         if encoder_hidden_states is not None:
             if not hasattr(self, "q_attn"):
                 raise ValueError(
                     "If class is used as cross attention, the weights `q_attn` have to be defined. "
                     "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
                 )
    
             query = self.q_attn(hidden_states)
             key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
             attention_mask = encoder_attention_mask
         else:#[batch,seq_len,hidden]
             query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
         
         #[batch,h,seq_len,d_hiden]
         query = self._split_heads(query, self.num_heads, self.head_dim)
         key = self._split_heads(key, self.num_heads, self.head_dim)
         value = self._split_heads(value, self.num_heads, self.head_dim)
    
         if layer_past is not None:
             past_key, past_value = layer_past
             key = torch.cat((past_key, key), dim=-2)
             value = torch.cat((past_value, value), dim=-2)
    
         if use_cache is True:
             present = (key, value)
         else:
             present = None
    
         if self.reorder_and_upcast_attn:
             attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
         else:
             attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
    
         attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
         attn_output = self.c_proj(attn_output)
         attn_output = self.resid_dropout(attn_output)
    
         outputs = (attn_output, present)
         if output_attentions:
             outputs += (attn_weights,)
    
         return outputs  # a, present, (attentions)
    

    if name == 'main':
    class Config:
    max_position_embeddings=1024
    hidden_size=768
    num_attention_heads=12
    scale_attn_weights=True
    is_cross_attention=False
    layer_idx=1
    scale_attn_by_inverse_layer_idx=False
    reorder_and_upcast_attn=False
    attn_pdrop=0.1
    resid_pdrop=0.1

     config=Config()
     hidden_states=torch.randn(size=(1,1,768))
     key_past=torch.randn(size=(1,12,7,64))
     value_past = torch.randn(size=(1, 12, 7, 64))
     layer_past=[key_past,value_past]
    
     attention=GPT2Attention(config)
     attention(hidden_states,layer_past)
    
相关推荐
极客代码5 分钟前
【Python TensorFlow】入门到精通
开发语言·人工智能·python·深度学习·tensorflow
义小深8 分钟前
TensorFlow|咖啡豆识别
人工智能·python·tensorflow
Tianyanxiao1 小时前
如何利用探商宝精准营销,抓住行业机遇——以AI技术与大数据推动企业信息精准筛选
大数据·人工智能·科技·数据分析·深度优先·零售
撞南墙者1 小时前
OpenCV自学系列(1)——简介和GUI特征操作
人工智能·opencv·计算机视觉
OCR_wintone4211 小时前
易泊车牌识别相机,助力智慧工地建设
人工智能·数码相机·ocr
王哈哈^_^1 小时前
【数据集】【YOLO】【VOC】目标检测数据集,查找数据集,yolo目标检测算法详细实战训练步骤!
人工智能·深度学习·算法·yolo·目标检测·计算机视觉·pyqt
一者仁心1 小时前
【AI技术】PaddleSpeech
人工智能
写代码的小阿帆1 小时前
pytorch实现深度神经网络DNN与卷积神经网络CNN
pytorch·cnn·dnn
是瑶瑶子啦2 小时前
【深度学习】论文笔记:空间变换网络(Spatial Transformer Networks)
论文阅读·人工智能·深度学习·视觉检测·空间变换
EasyCVR2 小时前
萤石设备视频接入平台EasyCVR多品牌摄像机视频平台海康ehome平台(ISUP)接入EasyCVR不在线如何排查?
运维·服务器·网络·人工智能·ffmpeg·音视频