大模型推理kv_cache缓存

一、目录

  1. kv_cache 用途
  2. 代码比较
  3. gpt2 多头自注意力实现+kv_cache

二、实现

  1. kv_cache 用途

    1. kv_cache 应用于模型推理过程中,训练过程则不需要。
    2. 为了避免生成式模型在推理过程中 每次都需要将先前生成的文本拼接到问题中,将生成的信息保存起来,推理过程进行加载即可。
    3. 将先前的key,value 进行保存,而query 不需要。
      参考:https://zhuanlan.zhihu.com/p/630832593
      原理:query@key^T 目的是计算 当前i 位置时 各value 所占的比率(生成的各个词占的比例),所以计算当前i 位置的信息时,则query 中前i-1 位置的信息无用,因此可以丢掉。
    4. 推理输出的token直接作为下一轮的输入,不再拼接,因为上文信息已经在 kvcache 中。
  2. 代码比较
    推理时未保存past_key_value

    import torch
    from transformers import BertTokenizer, GPT2LMHeadModel,TextGenerationPipeline
    model =GPT2LMHeadModel.from_pretrained("C:/Users/86188/Downloads/gpt2", torchscript=True).eval()

    tokenizer

    tokenizer = BertTokenizer.from_pretrained("C:/Users/86188/Downloads/gpt2")
    in_text = "白日依山尽"
    in_tokens = torch.tensor(tokenizer.encode(in_text))

    inference

    token_eos = torch.tensor([198]) # line break symbol
    out_token = None
    i = 0
    with torch.no_grad():
    while out_token != token_eos:
    logits, _ = model(in_tokens)
    out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
    in_tokens = torch.cat((in_tokens, out_token), 0)
    text = tokenizer.decode(in_tokens)
    print(f'step {i} input: {text}', flush=True)
    i += 1
    out_text = tokenizer.decode(in_tokens)
    print(f' Input: {in_text}')
    print(f'Output: {out_text}')

推理时保留past_key_value

import torch
from transformers import BertTokenizer, GPT2LMHeadModel,TextGenerationPipeline
model =GPT2LMHeadModel.from_pretrained("C:/Users/86188/Downloads/gpt2", torchscript=True).eval()
# tokenizer
tokenizer = BertTokenizer.from_pretrained("C:/Users/86188/Downloads/gpt2")
in_text = "白日依山尽"
in_tokens = torch.tensor(tokenizer.encode(in_text))

# inference
token_eos = torch.tensor([198]) # line break symbol
out_token = None
kvcache = None
out_text = in_text
i = 0
with torch.no_grad():
    while out_token != token_eos:
        logits, kvcache = model(in_tokens, past_key_values=kvcache)  # 增加了一个 past_key_values 的参数
        out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
        in_tokens = out_token  # 输出 token 直接作为下一轮的输入,不再拼接
        text = tokenizer.decode(in_tokens)
        print(f'step {i} input: {text}', flush=True)
        i += 1
        out_text += text
out_text = tokenizer.decode(in_tokens)
print(f' Input: {in_text}')
print(f'Output: {out_text}')

优点:减少计算量,提高推理速度。

底层实现:

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

if layer_past is not None:
    past_key, past_value = layer_past
    key = torch.cat((past_key, key), dim=-2)           #past_key 拼接
    value = torch.cat((past_value, value), dim=-2)     #past_value 拼接

if use_cache is True:
    present = (key, value)
else:
    present = None

if self.reorder_and_upcast_attn:
    attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:             #query [batch,n,1,d_dim]     key,value=[batch,n,seq_len,d_dim]
    attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  1. gpt2 自注意力实现

    import torch
    import torch.nn as nn
    from typing import Optional,List,Tuple,Set,Union
    from torch.cuda.amp import autocast

    class Conv1D(nn.Module):
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).

     Basically works like a linear layer but the weights are transposed.
    
     Args:
         nf (`int`): The number of output features.
         nx (`int`): The number of input features.
     """
    
     def __init__(self, nf, nx):
         super().__init__()
         self.nf = nf
         self.weight = nn.Parameter(torch.empty(nx, nf))
         self.bias = nn.Parameter(torch.zeros(nf))
         nn.init.normal_(self.weight, std=0.02)
    
     def forward(self, x):
         size_out = x.size()[:-1] + (self.nf,)
         x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
         x = x.view(size_out)
         return x
    

    def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
    """
    Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
    are transposed.

     Used to remove heads.
    
     Args:
         layer ([`~pytorch_utils.Conv1D`]): The layer to prune.
         index (`torch.LongTensor`): The indices to keep in the layer.
         dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices.
    
     Returns:
         [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
     """
     index = index.to(layer.weight.device)
     W = layer.weight.index_select(dim, index).clone().detach()
     if dim == 0:
         b = layer.bias.clone().detach()
     else:
         b = layer.bias[index].clone().detach()
     new_size = list(layer.weight.size())
     new_size[dim] = len(index)
     new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
     new_layer.weight.requires_grad = False
     new_layer.weight.copy_(W.contiguous())
     new_layer.weight.requires_grad = True
     new_layer.bias.requires_grad = False
     new_layer.bias.copy_(b.contiguous())
     new_layer.bias.requires_grad = True
     return new_layer
    

    def find_pruneable_heads_and_indices(
    heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
    ) -> Tuple[Set[int], torch.LongTensor]:
    """
    Finds the heads and their indices taking already_pruned_heads into account.

     Args:
         heads (`List[int]`): List of the indices of heads to prune.
         n_heads (`int`): The number of heads in the model.
         head_size (`int`): The size of each head.
         already_pruned_heads (`Set[int]`): A set of already pruned heads.
    
     Returns:
         `Tuple[Set[int], torch.LongTensor]`: A tuple with the indices of heads to prune taking `already_pruned_heads`
         into account and the indices of rows/columns to keep in the layer weight.
     """
     mask = torch.ones(n_heads, head_size)
     heads = set(heads) - already_pruned_heads  # Convert to set and remove already pruned heads
     for head in heads:
         # Compute how many pruned heads are before the head and move the index accordingly
         head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
         mask[head] = 0
     mask = mask.view(-1).contiguous().eq(1)
     index: torch.LongTensor = torch.arange(len(mask))[mask].long()
     return heads, index
    

    class GPT2Attention(nn.Module):
    def init(self, config, is_cross_attention=False, layer_idx=None):
    super().init()

         max_positions = config.max_position_embeddings
         self.register_buffer(
             "bias",
             torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
                 1, 1, max_positions, max_positions
             ),
             persistent=False,
         )
         self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
    
         self.embed_dim = config.hidden_size
         self.num_heads = config.num_attention_heads
         self.head_dim = self.embed_dim // self.num_heads
         self.split_size = self.embed_dim
         if self.head_dim * self.num_heads != self.embed_dim:
             raise ValueError(
                 f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                 f" {self.num_heads})."
             )
    
         self.scale_attn_weights = config.scale_attn_weights
         self.is_cross_attention = is_cross_attention
    
         # Layer-wise attention scaling, reordering, and upcasting
         self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
         self.layer_idx = layer_idx
         self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
    
         if self.is_cross_attention:
             self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
             self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
         else:
             self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
         self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
    
         self.attn_dropout = nn.Dropout(config.attn_pdrop)
         self.resid_dropout = nn.Dropout(config.resid_pdrop)
    
         self.pruned_heads = set()
    
     def _attn(self, query, key, value, attention_mask=None, head_mask=None):
         attn_weights = torch.matmul(query, key.transpose(-1, -2))
    
         if self.scale_attn_weights:
             attn_weights = attn_weights / torch.full(
                 [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
             )
    
         # Layer-wise attention scaling
         if self.scale_attn_by_inverse_layer_idx:
             attn_weights = attn_weights / float(self.layer_idx + 1)
    
         if not self.is_cross_attention:
             # if only "normal" attention layer implements causal mask
             query_length, key_length = query.size(-2), key.size(-2)
             causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
             mask_value = torch.finfo(attn_weights.dtype).min
             # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
             # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
             mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
             attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
    
         if attention_mask is not None:
             # Apply the attention mask
             attn_weights = attn_weights + attention_mask
    
         attn_weights = nn.functional.softmax(attn_weights, dim=-1)
    
         # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
         attn_weights = attn_weights.type(value.dtype)
         attn_weights = self.attn_dropout(attn_weights)
    
         # Mask heads if we want to
         if head_mask is not None:
             attn_weights = attn_weights * head_mask
    
         attn_output = torch.matmul(attn_weights, value)
    
         return attn_output, attn_weights
    
     def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
         # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
         bsz, num_heads, q_seq_len, dk = query.size()
         _, _, k_seq_len, _ = key.size()
    
         # Preallocate attn_weights for `baddbmm`
         attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
    
         # Compute Scale Factor
         scale_factor = 1.0
         if self.scale_attn_weights:
             scale_factor /= float(value.size(-1)) ** 0.5
    
         if self.scale_attn_by_inverse_layer_idx:
             scale_factor /= float(self.layer_idx + 1)
    
         # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
         with autocast(enabled=False):
             q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
             attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
             attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
    
         if not self.is_cross_attention:
             # if only "normal" attention layer implements causal mask
             query_length, key_length = query.size(-2), key.size(-2)
             causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
             mask_value = torch.finfo(attn_weights.dtype).min
             # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
             # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
             mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
             attn_weights = torch.where(causal_mask, attn_weights, mask_value)
    
         if attention_mask is not None:
             # Apply the attention mask
             attn_weights = attn_weights + attention_mask
    
         attn_weights = nn.functional.softmax(attn_weights, dim=-1)
    
         # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
         if attn_weights.dtype != torch.float32:
             raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
         attn_weights = attn_weights.type(value.dtype)
         attn_weights = self.attn_dropout(attn_weights)
    
         # Mask heads if we want to
         if head_mask is not None:
             attn_weights = attn_weights * head_mask
    
         attn_output = torch.matmul(attn_weights, value)
    
         return attn_output, attn_weights
    
     def _split_heads(self, tensor, num_heads, attn_head_size):
         """
         Splits hidden_size dim into attn_head_size and num_heads
         """
         new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
         tensor = tensor.view(new_shape)
         return tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
    
     def _merge_heads(self, tensor, num_heads, attn_head_size):
         """
         Merges attn_head_size dim and num_attn_heads dim into hidden_size
         """
         tensor = tensor.permute(0, 2, 1, 3).contiguous()
         new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
         return tensor.view(new_shape)
    
     def forward(
         self,
         hidden_states: Optional[Tuple[torch.FloatTensor]],
         layer_past: Optional[Tuple[torch.Tensor]] = None,
         attention_mask: Optional[torch.FloatTensor] = None,
         head_mask: Optional[torch.FloatTensor] = None,
         encoder_hidden_states: Optional[torch.Tensor] = None,
         encoder_attention_mask: Optional[torch.FloatTensor] = None,
         use_cache: Optional[bool] = False,
         output_attentions: Optional[bool] = False,
     ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
         if encoder_hidden_states is not None:
             if not hasattr(self, "q_attn"):
                 raise ValueError(
                     "If class is used as cross attention, the weights `q_attn` have to be defined. "
                     "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
                 )
    
             query = self.q_attn(hidden_states)
             key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
             attention_mask = encoder_attention_mask
         else:#[batch,seq_len,hidden]
             query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
         
         #[batch,h,seq_len,d_hiden]
         query = self._split_heads(query, self.num_heads, self.head_dim)
         key = self._split_heads(key, self.num_heads, self.head_dim)
         value = self._split_heads(value, self.num_heads, self.head_dim)
    
         if layer_past is not None:
             past_key, past_value = layer_past
             key = torch.cat((past_key, key), dim=-2)
             value = torch.cat((past_value, value), dim=-2)
    
         if use_cache is True:
             present = (key, value)
         else:
             present = None
    
         if self.reorder_and_upcast_attn:
             attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
         else:
             attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
    
         attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
         attn_output = self.c_proj(attn_output)
         attn_output = self.resid_dropout(attn_output)
    
         outputs = (attn_output, present)
         if output_attentions:
             outputs += (attn_weights,)
    
         return outputs  # a, present, (attentions)
    

    if name == 'main':
    class Config:
    max_position_embeddings=1024
    hidden_size=768
    num_attention_heads=12
    scale_attn_weights=True
    is_cross_attention=False
    layer_idx=1
    scale_attn_by_inverse_layer_idx=False
    reorder_and_upcast_attn=False
    attn_pdrop=0.1
    resid_pdrop=0.1

     config=Config()
     hidden_states=torch.randn(size=(1,1,768))
     key_past=torch.randn(size=(1,12,7,64))
     value_past = torch.randn(size=(1, 12, 7, 64))
     layer_past=[key_past,value_past]
    
     attention=GPT2Attention(config)
     attention(hidden_states,layer_past)
    
相关推荐
梦云澜2 小时前
论文阅读(十二):全基因组关联研究中生物通路的图形建模
论文阅读·人工智能·深度学习
远洋录3 小时前
构建一个数据分析Agent:提升分析效率的实践
人工智能·ai·ai agent
IT古董4 小时前
【深度学习】常见模型-Transformer模型
人工智能·深度学习·transformer
沐雪架构师5 小时前
AI大模型开发原理篇-2:语言模型雏形之词袋模型
人工智能·语言模型·自然语言处理
python算法(魔法师版)5 小时前
深度学习深度解析:从基础到前沿
人工智能·深度学习
kakaZhui6 小时前
【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
人工智能·深度学习·chatgpt·aigc·llama
struggle20257 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习
佛州小李哥7 小时前
通过亚马逊云科技Bedrock打造自定义AI智能体Agent(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
Linux运维老纪7 小时前
DNS缓存详解(DNS Cache Detailed Explanation)
计算机网络·缓存·云原生·容器·kubernetes·云计算·运维开发
云空8 小时前
《DeepSeek 网页/API 性能异常(DeepSeek Web/API Degraded Performance):网络安全日志》
运维·人工智能·web安全·网络安全·开源·网络攻击模型·安全威胁分析