解析 Transformers 的 KV 缓存机制
本文由丘山子翻译,原文链接:medium.com/@joaolages/... ,原文作者:João Lages。
如何通过缓存键和值状态来提升 Transformers 的速度
缓存生成型 Transformers 的键(K)和值(V)状态这个机制早已存在一段时间了,但也许你需要确切了解它到底是什么,以及它所带来的极大推理速度提升。
如下图所示,键(K)和值(V)状态主要用于计算缩放点积注意力(scaled dot-product attention)。
缩放点积注意力及其在 Transformers 架构中的应用。(图片来源:lilianweng.github.io/posts/2018-...)
KV缓存(KV caching)发生在多个词元(tokens)的生成步骤中,仅在解码器(decoder)中发生(例如,在仅有解码器的模型(如GPT)中,或者在编码器-解码器模型(ncoder-decoder models)(如T5)中的解码器部分)。像BERT这样的模型不是生成模型,因此没有KV缓存。
解码器按照自回归(auto-regressive)的方式工作,就像这个GPT-2文本生成示例所示。
在解码器的自回归生成过程中,给定输入后,模型会预测下一个词元(token),然后在下一步中综合输入进行下一步预测。(图片来源:jalammar.github.io/illustrated...)。
这种自回归行为会重复一些操作,我们可以通过放大解码器中计算的经过掩码的缩放点积注意力机制来更好地理解这一点。(译者注:在解码器中,为了避免模型在生成每个令牌时看到未来的信息,需要对注意力机制进行掩码操作。这意味着在计算注意力权重时,将未来的位置掩盖起来,使得模型只能关注当前及之前的位置。)
解码器中缩放点积注意力机制的逐步可视化展示。emb_size表示嵌入大小。图片由原文作者创建。
由于是因果解码器(即,一个词元(token)的注意力只取决于其前面的词元),因此在每个生成步骤中,我们都要重新计算先前词元的注意力,而实际上我们只想计算新词元的注意力。
这就是KV(键值)缓存发挥作用的地方。通过缓存先前的键矩阵和值矩阵,我们可以只专注于计算新词元的注意力。
在每个生成步骤中,我们将先前的键和值存储在KV缓存中。当需要计算新词元的注意力时,我们只需要使用新词元的查询矩阵与KV缓存中的键矩阵进行计算,而无需重新计算先前词元的注意力。这样可以大大减少计算量,提高生成效率。
通过使用KV缓存,我们可以在解码器的自回归生成过程中更高效地计算注意力,只关注于新词元的注意力计算,而不必重复计算先前词元的注意力。
有 KV 缓存和无 KV 缓存的缩放点积注意力机制比较。 emb_size 表示嵌入大小。图片由原文作者创建。
为什么这一优化非常重要?如上图所示,使用 KV 缓存获得的矩阵更小,从而加快了矩阵乘法的速度。唯一的缺点是需要更多的 GPU VRAM(或 CPU RAM,如果不使用 GPU的话)来缓存键和值状态。
让我们使用 Transformers🤗来比较有 KV 缓存和没有 KV 缓存的 GPT-2 的生成速度。
python
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
for use_cache in (True, False):
times = []
for _ in range(10): # measuring 10 generations
start = time.time()
model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1000)
times.append(time.time() - start)
print(f"{'with' if use_cache else 'without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")
在使用 Tesla T4 GPU 的 Google Colab notebook上,生成 1000 个新词元的平均时间和标准差如下所示:
- 使用 KV 缓存:11.885 ± 0.272 秒
- 不使用 KV 缓存:56.197 ± 1.855 秒
推理速度的差异非常大,而 GPU VRAM 的使用几乎可以忽略不计,如此报告所述。因此请确保在您的 Transformer 模型中使用 KV 缓存!
不过值得庆幸的是,这是 transformers 🤗 的默认装备。