Hugging Face Transformers 缓存系统深度分析
相关文章:
Hugging Face Transformers 源码全景解读
01-Hugging Face Transformers 核心基础设施深度分析
02-Hugging Face Transformers 配置系统深度分析
03-Hugging Face Transformers 模型系统深度分析
04-Hugging Face Transformers 注意力与掩码系统深度分析
分析版本:Transformers v5.8.0.dev0
分析范围:cache_utils.py --- KV Cache 完整架构,含层抽象、注册表、动态/静态/量化/线性注意力缓存
目录
- 模块总览
- [CacheLayerMixin --- 缓存层抽象基类](#CacheLayerMixin — 缓存层抽象基类)
- [LAYER_TYPE_CACHE_MAPPING --- 层类型注册表](#LAYER_TYPE_CACHE_MAPPING — 层类型注册表)
- 动态缓存层
- 静态缓存层(可编译)
- 量化缓存层
- 线性注意力缓存层
- [Cache --- 缓存容器基类](#Cache — 缓存容器基类)
- [DynamicCache --- 动态缓存](#DynamicCache — 动态缓存)
- [StaticCache --- 静态缓存](#StaticCache — 静态缓存)
- [QuantizedCache --- 量化缓存](#QuantizedCache — 量化缓存)
- [EncoderDecoderCache --- 编码器-解码器缓存](#EncoderDecoderCache — 编码器-解码器缓存)
- 设计原理与架构总结
- 与其他模块的关系
缓存系统架构总览
注册表
缓存容器体系
缓存层体系
CacheLayerMixin ABC
DynamicLayer
DynamicSlidingWindowLayer
QuantizedLayer
QuantoQuantizedLayer
HQQQuantizedLayer
StaticLayer
StaticSlidingWindowLayer
LinearAttentionCacheLayerMixin ABC
LinearAttentionLayer
Cache 基类
DynamicCache
StaticCache
QuantizedCache
EncoderDecoderCache
LAYER_TYPE_CACHE_MAPPING
1. 模块总览
文件路径 : src/transformers/cache_utils.py(约 1623 行)
模块职责
缓存系统是 Transformers 推理引擎的核心基础设施,负责在自回归生成过程中存储和复用 Key/Value 状态,避免重复计算。该模块实现了:
- 分层缓存架构 --- 每个模型层拥有独立的缓存层对象,支持异构层类型(全注意力、滑动窗口、线性注意力等)
- 动态与静态双模式 --- 动态缓存按需增长,静态缓存预分配固定张量以支持
torch.compile - 量化 KV 缓存 --- 支持 Quanto 和 HQQ 两种后端,将 KV 状态压缩至 2/4 bit
- 线性注意力缓存 --- 为 Mamba/SSM 等非 Transformer 架构提供 conv/recurrent 状态缓存
- 编码器-解码器缓存 --- 组合自注意力与交叉注意力缓存,支持 seq2seq 模型
- GPU 卸载与预取 --- 层级 CPU 卸载机制,配合异步流实现计算-传输重叠
类继承体系
CacheLayerMixin (ABC) LinearAttentionCacheLayerMixin (ABC)
├── DynamicLayer ├── LinearAttentionLayer
│ ├── DynamicSlidingWindowLayer └── (via multiple inheritance)
│ └── QuantizedLayer LinearAttentionAndFullAttentionLayer
│ ├── QuantoQuantizedLayer = LinearAttentionLayer + DynamicLayer
│ └── HQQQuantizedLayer
└── StaticLayer
└── StaticSlidingWindowLayer
Cache (容器基类)
├── DynamicCache
├── StaticCache
├── QuantizedCache
└── EncoderDecoderCache (组合模式,非继承)
SlidingWindowCache = StaticCache (v5 弃用别名)
2. CacheLayerMixin --- 缓存层抽象基类
源码位置: [cache_utils.py:37-107](file:///workspace/src/transformers/cache_utils.py#L37-L107)
设计意图
CacheLayerMixin 是所有注意力缓存层的抽象基类,定义了单层缓存的最小接口契约 。它采用 Mixin 模式而非纯继承,使得缓存层可以与其他基类(如 DynamicLayer、LinearAttentionLayer)灵活组合。
核心属性
python
class CacheLayerMixin(ABC):
is_compileable = False # 是否支持 torch.compile(静态缓存为 True)
layer_type: str | None = None # 自动注册到 LAYER_TYPE_CACHE_MAPPING 的键名
def __init__(self):
self.keys: torch.Tensor | None = None # Key 状态张量
self.values: torch.Tensor | None = None # Value 状态张量
self.is_initialized = False # 懒初始化标记
自动注册机制 --- __init_subclass__
python
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
layer_type = cls.__dict__.get("layer_type", None) # 只取类自身定义的,不取继承的
if layer_type is not None:
LAYER_TYPE_CACHE_MAPPING[layer_type] = cls # 导入时自动注册
关键设计 :使用 cls.__dict__.get() 而非 getattr(),确保只有子类显式定义 了 layer_type 才会注册,避免父类的 layer_type 被子类意外覆盖。
抽象方法
| 方法 | 返回类型 | 说明 |
|---|---|---|
lazy_initialization(key, value) |
None |
首次收到真实张量时初始化缓存 |
update(key, value, *args, **kwargs) |
tuple[Tensor, Tensor] |
更新缓存并返回完整 KV 状态 |
get_mask_sizes(query_length) |
tuple[int, int] |
返回 (kv_length, kv_offset) 用于生成注意力掩码 |
get_seq_length() |
int |
返回已缓存的序列长度 |
get_max_cache_shape() |
int |
返回最大缓存容量(动态缓存返回 -1) |
通用方法
python
def offload(self):
"""将缓存数据卸载到 CPU,节省 GPU 显存"""
if self.is_initialized:
self.keys = self.keys.to("cpu", non_blocking=True)
self.values = self.values.to("cpu", non_blocking=True)
def prefetch(self):
"""从 CPU 预取回 GPU,配合异步流实现计算-传输重叠"""
if self.is_initialized and self.keys.device != self.device:
self.keys = self.keys.to(self.device, non_blocking=True)
self.values = self.values.to(self.device, non_blocking=True)
def reset(self):
"""重置缓存值但保留对象(避免重新分配内存)"""
if self.is_initialized:
self.keys.zero_()
self.values.zero_()
if hasattr(self, "cumulative_length"):
if isinstance(self.cumulative_length, int):
self.cumulative_length = 0
else:
self.cumulative_length.zero_() # 张量形式用于 torch.compile
def reorder_cache(self, beam_idx):
"""Beam Search 时按 beam 索引重排缓存"""
if self.get_seq_length() > 0:
self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
self.values = self.values.index_select(0, beam_idx.to(self.values.device))
3. LAYER_TYPE_CACHE_MAPPING --- 层类型注册表
源码位置: [cache_utils.py:26-34](file:///workspace/src/transformers/cache_utils.py#L26-L34) 及 [871-887](file:///workspace/src/transformers/cache_utils.py#L871-L887)
设计意图
注册表模式将 config.layer_types[i] 字符串映射到对应的缓存层类,使 DynamicCache 和 StaticCache 能根据模型配置自动分发 正确的缓存层类型,无需为每种模型创建专属 Cache 子类。
注册表内容
python
LAYER_TYPE_CACHE_MAPPING: dict[str, type] = {}
# 自动注册:子类定义 layer_type = "xxx" 时通过 __init_subclass__ 自动加入
# 手动注册:标准层类型在此处批量注册
LAYER_TYPE_CACHE_MAPPING.update({
"full_attention": DynamicLayer, # 全注意力(默认)
"sliding_attention": DynamicSlidingWindowLayer, # 滑动窗口注意力
"chunked_attention": DynamicSlidingWindowLayer, # 分块注意力(缓存行为同滑动窗口)
"mamba": LinearAttentionLayer, # Mamba SSM 层
"conv": LinearAttentionLayer, # 卷积层
"linear_attention": LinearAttentionLayer, # 纯线性注意力
"moe": LinearAttentionLayer, # MoE 占位(无 KV 缓存需求)
"hybrid": LinearAttentionAndFullAttentionLayer, # 混合层(如 Zamba)
})
双重注册机制
- 自动注册 :子类定义
layer_type = "xxx"时,CacheLayerMixin.__init_subclass__自动将其加入注册表 - 手动注册 :对于多个
layer_type共享同一缓存类的情况(如"sliding_attention"和"chunked_attention"都映射到DynamicSlidingWindowLayer),通过update()手动添加
使用流程
模型配置 config.layer_types = ["full_attention", "sliding_attention", "mamba", ...]
↓
DynamicCache.__init__ 遍历 layer_types
↓
LAYER_TYPE_CACHE_MAPPING["full_attention"] → DynamicLayer(config)
LAYER_TYPE_CACHE_MAPPING["sliding_attention"] → DynamicSlidingWindowLayer(config)
LAYER_TYPE_CACHE_MAPPING["mamba"] → LinearAttentionLayer(config)
4. 动态缓存层
4.1 DynamicLayer --- 全注意力动态层
源码位置: [cache_utils.py:109-188](file:///workspace/src/transformers/cache_utils.py#L109-L188)
动态层的核心特征:缓存随生成步骤增长 ,每次 update 通过 torch.cat 追加新 token 的 KV 状态。
python
class DynamicLayer(CacheLayerMixin):
is_sliding = False
def lazy_initialization(self, key_states, value_states):
self.dtype, self.device = key_states.dtype, key_states.device
self.keys = torch.tensor([], dtype=self.dtype, device=self.device) # 空张量占位
self.values = torch.tensor([], dtype=self.dtype, device=self.device)
self.is_initialized = True
def update(self, key_states, value_states, *args, **kwargs):
if not self.is_initialized:
self.lazy_initialization(key_states, value_states)
# 核心:拼接追加新 KV 状态,形状 [batch, heads, seq_len, head_dim]
self.keys = torch.cat([self.keys, key_states], dim=-2)
self.values = torch.cat([self.values, value_states], dim=-2)
return self.keys, self.values
关键设计点:
dim=-2是序列长度维度,torch.cat在该维度上追加- 返回的是完整 KV 状态(包含历史 + 新增),供注意力计算使用
get_max_cache_shape()返回-1,表示无上限- 提供
crop()、batch_repeat_interleave()、batch_select_indices()等辅助方法
4.2 DynamicSlidingWindowLayer --- 滑动窗口动态层
源码位置: [cache_utils.py:190-275](file:///workspace/src/transformers/cache_utils.py#L190-L275)
滑动窗口层在动态增长的基础上,只保留最近 sliding_window - 1 个 token 的缓存,大幅减少显存占用。
python
class DynamicSlidingWindowLayer(DynamicLayer):
is_sliding = True
def __init__(self, config=None, sliding_window=None):
super().__init__()
if sliding_window is None:
if config is None:
raise ValueError("Either `config` or `sliding_window` must be provided.")
sliding_window = getattr(config, "sliding_window", None) or \
getattr(config, "attention_chunk_size", None)
self.sliding_window = sliding_window
self.cumulative_length = 0 # 跟踪已处理的总 token 数(含已丢弃的)
def update(self, key_states, value_states, *args, **kwargs):
if not self.is_initialized:
self.lazy_initialization(key_states, value_states)
self.cumulative_length += key_states.shape[-2]
# 1. 先拼接得到完整状态
full_key_states = torch.cat([self.keys, key_states], dim=-2)
full_value_states = torch.cat([self.values, value_states], dim=-2)
# 2. 只缓存最后 sliding_window - 1 个 token
self.keys = full_key_states[:, :, -self.sliding_window + 1 :, :]
self.values = full_value_states[:, :, -self.sliding_window + 1 :, :]
# 3. 返回完整状态(注意力计算需要完整上下文)
return full_key_states, full_value_states
核心设计:
cumulative_length跟踪总已处理 token 数,即使部分已被丢弃- 缓存只保留
sliding_window - 1个 token(减 1 是因为当前 query 本身也算一个窗口位置) update返回完整 KV 状态而非截断后的,因为注意力计算需要看到窗口内的所有 tokenget_mask_sizes()根据cumulative_length是否已超过sliding_window计算掩码偏移
python
def get_mask_sizes(self, query_length):
is_full = self.cumulative_length >= self.sliding_window
kv_offset = max(self.cumulative_length - self.sliding_window + 1, 0)
if is_full:
kv_length = self.sliding_window - 1 + query_length
else:
kv_length = self.cumulative_length + query_length
return kv_length, kv_offset
5. 静态缓存层(可编译)
5.1 StaticLayer --- 全注意力静态层
源码位置: [cache_utils.py:277-385](file:///workspace/src/transformers/cache_utils.py#L277-L385)
静态层为 torch.compile 设计,预分配固定大小的张量,通过原地更新(in-place mutation)避免动态内存分配,确保编译图结构稳定。
python
class StaticLayer(CacheLayerMixin):
is_compileable = True
is_sliding = False
def __init__(self, max_cache_len: int):
super().__init__()
self.max_cache_len = max_cache_len
# 关键:cumulative_length 必须是张量,避免 recompile
self.cumulative_length = torch.tensor([0], dtype=int)
def lazy_initialization(self, key_states, value_states):
self.dtype, self.device = key_states.dtype, key_states.device
self.max_batch_size, self.num_heads = key_states.shape[:2]
self.v_head_dim = value_states.shape[-1]
self.k_head_dim = key_states.shape[-1]
# 预分配完整大小的零张量
self.keys = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.k_head_dim),
dtype=self.dtype, device=self.device,
)
self.values = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.v_head_dim),
dtype=self.dtype, device=self.device,
)
self.cumulative_length = self.cumulative_length.to(self.device)
# 标记静态地址,防止 cudagraph 因原地修改跳过或图断裂
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(self.keys)
torch._dynamo.mark_static_address(self.values)
torch._dynamo.mark_static_address(self.cumulative_length)
self.is_initialized = True
mark_static_address 的关键作用:
torch.compile 在编译时追踪张量的内存地址。如果张量被原地修改(如 index_copy_),Dynamo 默认会认为地址变化了,导致图断裂或 cudagraph 失效。mark_static_address 告诉 Dynamo:"这个张量的地址不会变",从而允许原地修改操作被正确编译。
python
def update(self, key_states, value_states, *args, **kwargs):
if not self.is_initialized:
self.lazy_initialization(key_states, value_states)
kv_length = key_states.shape[-2]
cache_position = torch.arange(kv_length, device=self.device) + self.cumulative_length
# 原地更新累计长度(保持静态地址)
self.cumulative_length.add_(kv_length)
# 使用 index_copy_ 原地写入新 KV 状态
try:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# MPS 等设备可能不支持 index_copy_,回退到切片赋值
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
return self.keys, self.values
5.2 StaticSlidingWindowLayer --- 滑动窗口静态层
源码位置: [cache_utils.py:387-512](file:///workspace/src/transformers/cache_utils.py#L387-L512)
这是最复杂的缓存层,需要在静态张量上实现滑动窗口的"滚动"效果,同时保持编译友好。
python
class StaticSlidingWindowLayer(StaticLayer):
is_sliding = True
def __init__(self, max_cache_len, sliding_window):
effective_max_cache_len = min(sliding_window, max_cache_len)
super().__init__(max_cache_len=effective_max_cache_len)
# 用 Python int 跟踪长度,避免数据依赖控制流导致 recompile
self.cumulative_length_int = 0
update 方法的四路分支:
python
def update(self, key_states, value_states, *args, **kwargs):
if not self.is_initialized:
self.lazy_initialization(key_states, value_states)
kv_length = key_states.shape[-2]
current_length = self.cumulative_length_int
is_full = current_length >= self.max_cache_len
self.cumulative_length_int += kv_length
if is_full:
if key_states.shape[-2] == 1:
# 分支1: 已满 + 单 token 解码 --- 使用 roll + 赋值
new_keys = self.keys.roll(-1, dims=-2)
new_values = self.values.roll(-1, dims=-2)
index = torch.tensor([-1], dtype=int, device=self.device)
new_keys[:, :, index] = key_states
new_values[:, :, index] = value_states
self.keys.copy_(new_keys) # copy_ 保持静态地址
self.values.copy_(new_values)
return self.keys, self.values
else:
# 分支2: 已满 + 多 token(如 prefill 缓存、chat 续接)
full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2)
elif current_length + kv_length > self.max_cache_len:
# 分支3: 即将满 --- 需要截断
if current_length == 0:
full_key_states = key_states # 快速路径:缓存为空
full_value_states = value_states
else:
full_key_states = torch.cat((self.keys[:, :, :current_length, :], key_states), dim=-2)
full_value_states = torch.cat((self.values[:, :, :current_length, :], value_states), dim=-2)
else:
# 分支4: 未满 --- 标准 index_copy_ 原地写入
cache_position = torch.arange(kv_length, device=self.device) + self.cumulative_length
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
self.cumulative_length.add_(kv_length) # 同步更新张量版本
return self.keys, self.values
# 分支2/3 的公共后处理:截断到 max_cache_len 并 copy_ 回静态张量
self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :])
self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :])
return full_key_states, full_value_states
设计要点:
roll技巧 :单 token 解码时,用roll(-1)将所有元素左移一位,再覆盖最后一个位置,避免cat产生新张量copy_而非赋值 :self.keys.copy_(new_keys)保持静态地址不变,赋值self.keys = new_keys会改变地址导致 cudagraph 失效- 双长度追踪 :
cumulative_length_int(Python int)用于控制流判断,cumulative_length(Tensor)用于index_copy_的位置计算,两者需要同步
6. 量化缓存层
6.1 QuantizedLayer --- 量化基类
源码位置: [cache_utils.py:514-588](file:///workspace/src/transformers/cache_utils.py#L514-L588)
基于 KIVI 论文的思路,量化缓存维护双存储 :原始精度缓存 + 量化缓存。当原始精度缓存超过 residual_length 时,将其量化并合并到量化缓存中。
python
class QuantizedLayer(DynamicLayer):
def __init__(self, nbits=4, axis_key=0, axis_value=0,
q_group_size=64, residual_length=128):
super().__init__()
self.nbits = nbits
self.axis_key = axis_key
self.axis_value = axis_value
self.q_group_size = q_group_size
self.residual_length = residual_length
self.cumulative_length = 0
def update(self, key_states, value_states, *args, **kwargs):
self.cumulative_length += key_states.shape[-2]
if not self.is_initialized:
self.lazy_initialization(key_states, value_states)
# 首次直接量化
self._quantized_keys = self._quantize(key_states.contiguous(), axis=self.axis_key)
self._quantized_values = self._quantize(value_states.contiguous(), axis=self.axis_value)
return key_states, value_states
# 反量化已存储的量化缓存
dequant_keys = self._dequantize(self._quantized_keys)
dequant_values = self._dequantize(self._quantized_values)
# 拼接:量化部分 + 原始精度部分 + 新增部分
keys_to_return = torch.cat([dequant_keys, self.keys, key_states], dim=-2)
values_to_return = torch.cat([dequant_values, self.values, value_states], dim=-2)
# 当原始精度部分超过 residual_length 时,全部量化
if self.keys.dim() == 4 and self.keys.shape[-2] + 1 >= self.residual_length:
self._quantized_keys = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
self._quantized_values = self._quantize(values_to_return.contiguous(), axis=self.axis_value)
self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
self.values = torch.tensor([], dtype=key_states.dtype, key_states.device)
else:
self.keys = torch.cat([self.keys, key_states], dim=-2)
self.values = torch.cat([self.values, value_states], dim=-2)
return keys_to_return, values_to_return
@abstractmethod
def _quantize(self, tensor, axis): ...
@abstractmethod
def _dequantize(self, q_tensor): ...
量化-反量化流程:
新增 KV → 拼接到原始精度缓存 → 超过 residual_length?
├── 是:量化全部 → 清空原始精度缓存
└── 否:保留在原始精度缓存
读取时:反量化部分 + 原始精度部分 + 新增部分 → 返回完整 KV
6.2 QuantoQuantizedLayer --- Quanto 后端
源码位置: [cache_utils.py:590-643](file:///workspace/src/transformers/cache_utils.py#L590-L643)
python
class QuantoQuantizedLayer(QuantizedLayer):
def __init__(self, nbits=4, axis_key=0, axis_value=0, q_group_size=64, residual_length=128):
super().__init__(nbits, axis_key, axis_value, q_group_size, residual_length)
# 需要 optimum-quanto >= 0.2.5
from optimum.quanto import MaxOptimizer, qint2, qint4
self.qtype = qint4 if self.nbits == 4 else qint2 # 仅支持 2/4 bit
self.optimizer = MaxOptimizer() # 逐通道量化的唯一优化器
def _quantize(self, tensor, axis):
from optimum.quanto import quantize_weight
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
return qtensor
def _dequantize(self, qtensor):
return qtensor.dequantize()
6.3 HQQQuantizedLayer --- HQQ 后端
源码位置: [cache_utils.py:645-699](file:///workspace/src/transformers/cache_utils.py#L645-L699)
python
class HQQQuantizedLayer(QuantizedLayer):
def __init__(self, nbits=4, axis_key=0, axis_value=0, q_group_size=64, residual_length=128):
super().__init__(nbits, axis_key, axis_value, q_group_size, residual_length)
# 支持 1/2/3/4/8 bit
self.quantizer = HQQQuantizer
def _quantize(self, tensor, axis):
qtensor, meta = self.quantizer.quantize(
tensor, axis=axis, device=self.keys.device,
compute_dtype=self.keys.dtype, nbits=self.nbits, group_size=self.q_group_size,
)
meta["compute_dtype"] = self.keys.dtype
self.quantizer.cuda(qtensor, meta=meta, device=self.keys.device)
return qtensor, meta # 返回元组(量化张量 + 元数据)
def _dequantize(self, qtensor):
quant_tensor, meta = qtensor # 解包元组
return self.quantizer.dequantize(quant_tensor, meta)
7. 线性注意力缓存层
7.1 LinearAttentionCacheLayerMixin --- 线性注意力层抽象基类
源码位置: [cache_utils.py:702-761](file:///workspace/src/transformers/cache_utils.py#L702-L761)
线性注意力层(Mamba/SSM 等)不使用传统的 KV 缓存,而是维护 conv_states(卷积状态)和 recurrent_states(循环状态)。
python
class LinearAttentionCacheLayerMixin(ABC):
is_compileable = True # 线性注意力的状态形状天然是静态的
def __init__(self):
self.conv_states: torch.Tensor | None = None
self.recurrent_states: torch.Tensor | None = None
self.is_conv_states_initialized = False
self.is_recurrent_states_initialized = False
self.has_previous_state = False # 标记是否有前序状态(影响 conv 更新逻辑)
@abstractmethod
def lazy_initialization(self, conv_states=None, recurrent_states=None): ...
@abstractmethod
def update_conv_state(self, conv_states): ...
@abstractmethod
def update_recurrent_state(self, recurrent_states): ...
7.2 LinearAttentionLayer --- 线性注意力层实现
源码位置: [cache_utils.py:764-838](file:///workspace/src/transformers/cache_utils.py#L764-L838)
python
class LinearAttentionLayer(LinearAttentionCacheLayerMixin):
def update_conv_state(self, conv_states, **kwargs):
if not self.is_conv_states_initialized:
self.lazy_initialization(conv_states=conv_states)
if not self.has_previous_state:
# 首次:直接复制(copy_ 保持静态地址)
self.conv_states.copy_(conv_states)
self.has_previous_state = True
else:
num_new_tokens = conv_states.shape[-1]
if num_new_tokens >= self.conv_kernel_size:
# 新 token 数 >= 卷积核大小:直接替换
self.conv_states.copy_(conv_states[..., -self.conv_kernel_size:])
else:
# 滚动:左移腾出空间,末尾填入新状态
new_conv_states = self.conv_states.roll(shifts=-num_new_tokens, dims=-1)
new_conv_states[:, :, -num_new_tokens:] = conv_states
self.conv_states.copy_(new_conv_states)
return self.conv_states
def update_recurrent_state(self, recurrent_states, **kwargs):
if not self.is_recurrent_states_initialized:
self.lazy_initialization(recurrent_states=recurrent_states)
# 循环状态直接覆盖
self.recurrent_states.copy_(recurrent_states)
return self.recurrent_states
7.3 LinearAttentionAndFullAttentionLayer --- 混合层
源码位置: [cache_utils.py:840-865](file:///workspace/src/transformers/cache_utils.py#L840-L865)
通过多重继承同时拥有线性注意力缓存和全注意力缓存的能力,用于 Zamba 等混合架构模型。
python
class LinearAttentionAndFullAttentionLayer(LinearAttentionLayer, DynamicLayer):
is_compileable = False # DynamicLayer 部分使其不可编译
def lazy_initialization(self, *args, **kwargs):
# 根据参数形式分发到对应的父类初始化
if len(args) == 2 and len(kwargs) == 0:
DynamicLayer.lazy_initialization(self, *args) # KV 缓存初始化
if len(args) == 0 and len(kwargs) == 1:
LinearAttentionLayer.lazy_initialization(self, **kwargs) # conv/recurrent 初始化
def reset(self):
LinearAttentionLayer.reset(self)
DynamicLayer.reset(self)
def reorder_cache(self, beam_idx):
LinearAttentionLayer.reorder_cache(self, beam_idx)
DynamicLayer.reorder_cache(self, beam_idx)
8. Cache --- 缓存容器基类
源码位置: [cache_utils.py:890-1227](file:///workspace/src/transformers/cache_utils.py#L890-L1227)
设计意图
Cache 是所有缓存容器的基类,本质上是一个 CacheLayerMixin 对象的列表,提供统一的层间操作接口。它采用组合模式:容器持有层对象,操作委托给各层。
构造方式
python
class Cache:
def __init__(self, layers=None, layer_class_to_replicate=None,
offloading=False, offload_only_non_sliding=True):
# 两种构造方式二选一:
# 1. 传入预创建的 layers 列表
# 2. 传入 layer_class_to_replicate,按需懒创建
if layers is not None and layer_class_to_replicate is not None:
raise ValueError("...")
if layers is None and layer_class_to_replicate is None:
raise ValueError("...")
self.layers = layers if layers is not None else []
self.layer_class_to_replicate = layer_class_to_replicate
self.offloading = offloading
if self.offloading:
self.only_non_sliding = offload_only_non_sliding
self.prefetch_stream = torch.Stream() if _is_torch_greater_or_equal_than_2_7 \
else torch.cuda.Stream()
核心 update 方法
python
def update(self, key_states, value_states, layer_idx, *args, **kwargs):
# 懒创建:如果 layer_idx 超出当前 layers 列表长度,自动追加
if self.layer_class_to_replicate is not None:
while len(self.layers) <= layer_idx:
self.layers.append(self.layer_class_to_replicate())
# GPU 卸载:等待预取流完成,同时预取下一层
if self.offloading:
torch.cuda.default_stream(key_states.device).wait_stream(self.prefetch_stream)
self.prefetch(layer_idx + 1, self.only_non_sliding)
# 委托给具体层
keys, values = self.layers[layer_idx].update(key_states, value_states, *args, **kwargs)
# GPU 卸载:当前层计算完毕,卸载到 CPU
if self.offloading:
self.offload(layer_idx, self.only_non_sliding)
return keys, values
GPU 卸载的流水线设计:
时间线:
Layer 0 计算 ──→ 卸载 Layer 0 到 CPU ──→
预取 Layer 1 到 GPU ──→ Layer 1 计算 ──→ 卸载 Layer 1 ──→ ...
prefetch使用非默认流,不阻塞主计算流offload在默认流上执行,确保层的update计算已完成- 滑动窗口层通常较小,默认不卸载(
offload_only_non_sliding=True)
线性注意力专用方法
python
def update_conv_state(self, conv_states, layer_idx, **kwargs):
if not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin):
raise ValueError("Cannot call `update_conv_state` on a non-LinearAttention layer!")
return self.layers[layer_idx].update_conv_state(conv_states, **kwargs)
def update_recurrent_state(self, recurrent_states, layer_idx, **kwargs):
if not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin):
raise ValueError("Cannot call `update_conv_state` on a non-LinearAttention layer!")
return self.layers[layer_idx].update_recurrent_state(recurrent_states, **kwargs)
提前初始化
python
def early_initialization(self, batch_size, num_heads, head_dim, dtype, device):
"""用于 torch.export,在首次 forward 之前初始化所有层"""
if isinstance(num_heads, int):
num_heads = [num_heads] * len(self)
if isinstance(head_dim, int):
head_dim = [head_dim] * len(self)
for layer, layer_num_heads, layer_head_dim in zip(self.layers, num_heads, head_dim):
# 使用 size 0 的伪张量触发初始化(不分配实际数据)
fake_kv_tensor = torch.zeros(
(batch_size, layer_num_heads, 0, layer_head_dim), dtype=dtype, device=device
)
layer.lazy_initialization(fake_kv_tensor, fake_kv_tensor)
关键属性
python
@property
def is_compileable(self) -> bool:
if len(self.layers) == 0:
return False
return all(layer.is_compileable for layer in self.layers)
@property
def is_sliding(self) -> list[bool]:
return [getattr(layer, "is_sliding", False) for layer in self.layers]
@property
def max_batch_size(self) -> int:
values = [layer.max_batch_size for layer in self.layers]
if len(set(values)) > 1:
raise ValueError(f"Max batch size is not consistent across layers: {values}")
return values[0]
@property
def max_cache_len(self) -> int:
values = [layer.max_cache_len for layer in self.layers]
return max(values)
9. DynamicCache --- 动态缓存
源码位置: [cache_utils.py:1229-1334](file:///workspace/src/transformers/cache_utils.py#L1229-L1334)
设计意图
DynamicCache 是最常用的缓存类,适用于绝大多数生成式模型。它根据模型配置自动选择每层的缓存类型。
构造逻辑
python
class DynamicCache(Cache):
def __init__(self, ddp_cache_data=None, config=None,
offloading=False, offload_only_non_sliding=False):
layers = []
# 路径1: 根据 config 自动构建异构层
if config is not None:
decoder_config = config.get_text_config(decoder=True)
sliding_window = getattr(decoder_config, "sliding_window", None) or \
getattr(decoder_config, "attention_chunk_size", None)
layer_types = getattr(decoder_config, "layer_types", None)
if layer_types is None:
# 无显式 layer_types,根据 sliding_window 推断
layer_types = []
for _ in range(decoder_config.num_hidden_layers):
if sliding_window is not None:
layer_types.append("sliding_attention")
else:
layer_types.append("full_attention")
# 跳过共享 KV 层(如 Gemma3n)
if hasattr(decoder_config, "num_kv_shared_layers"):
layer_types = layer_types[: -decoder_config.num_kv_shared_layers]
for layer_type in layer_types:
cache_cls = LAYER_TYPE_CACHE_MAPPING.get(layer_type, DynamicLayer)
layers.append(cache_cls(decoder_config))
# 路径2: 从 DDP 数据填充
if ddp_cache_data is not None:
for layer_idx, kv_and_optional_sliding in enumerate(ddp_cache_data):
if config is None:
sliding_window_tensor = kv_and_optional_sliding[2] \
if len(kv_and_optional_sliding) == 3 else None
if sliding_window_tensor is not None:
layers.append(DynamicSlidingWindowLayer(
sliding_window=sliding_window_tensor[0].item()))
else:
layers.append(DynamicLayer())
_, _ = layers[layer_idx].update(
kv_and_optional_sliding[0], kv_and_optional_sliding[1])
# 路径3: 无 config 无数据 --- 懒创建 DynamicLayer
if len(layers) == 0:
super().__init__(layer_class_to_replicate=DynamicLayer, ...)
else:
super().__init__(layers=layers, ...)
迭代器
python
def __iter__(self):
for layer in self.layers:
yield layer.keys, layer.values, getattr(layer, "_sliding_window_tensor", None)
每个元素是 (keys, values, sliding_window_tensor) 三元组,sliding_window_tensor 可能为 None。
10. StaticCache --- 静态缓存
源码位置: [cache_utils.py:1336-1421](file:///workspace/src/transformers/cache_utils.py#L1336-L1421)
设计意图
StaticCache 专为 torch.compile 和 torch.export 设计,所有层使用静态预分配张量。
构造逻辑
python
class StaticCache(Cache):
def __init__(self, config, max_cache_len, offloading=False,
offload_only_non_sliding=True, **kwargs):
config = config.get_text_config(decoder=True)
layer_types = getattr(config, "layer_types", None)
if layer_types is None:
if getattr(config, "sliding_window", None) is not None:
layer_types = ["sliding_attention" for _ in range(config.num_hidden_layers)]
elif getattr(config, "attention_chunk_size", None) is not None:
layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)]
else:
layer_types = ["full_attention" for _ in range(config.num_hidden_layers)]
if hasattr(config, "num_kv_shared_layers"):
layer_types = layer_types[: -config.num_kv_shared_layers]
# 识别滑动窗口类层类型
sliding_layer_types = {
name for name, cls in LAYER_TYPE_CACHE_MAPPING.items()
if isinstance(cls, type) and issubclass(cls, DynamicSlidingWindowLayer)
and name != "chunked_attention"
}
layers = []
for layer_type in layer_types:
if layer_type == "chunked_attention":
layer = StaticSlidingWindowLayer(
max_cache_len=max_cache_len,
sliding_window=config.attention_chunk_size)
elif layer_type in sliding_layer_types:
layer = StaticSlidingWindowLayer(
max_cache_len=max_cache_len,
sliding_window=config.sliding_window)
elif layer_type in ("mamba", "conv", "linear_attention", "moe"):
layer = LinearAttentionLayer()
else:
layer = StaticLayer(max_cache_len=max_cache_len)
layers.append(layer)
super().__init__(layers=layers, ...)
与 DynamicCache 的关键区别:
- DynamicCache 使用
DynamicLayer/DynamicSlidingWindowLayer(torch.cat增长) - StaticCache 使用
StaticLayer/StaticSlidingWindowLayer(预分配 +index_copy_原地更新) - StaticCache 必须 传入
config和max_cache_len,不支持懒创建
11. QuantizedCache --- 量化缓存
源码位置: [cache_utils.py:1423-1476](file:///workspace/src/transformers/cache_utils.py#L1423-L1476)
python
class QuantizedCache(Cache):
def __init__(self, backend, config, nbits=4, axis_key=0,
axis_value=0, q_group_size=64, residual_length=128):
if backend == "quanto":
layer_class = QuantoQuantizedLayer
elif backend == "hqq":
layer_class = HQQQuantizedLayer
else:
raise ValueError(f"Unknown quantization backend `{backend}`")
config = config.get_text_config(decoder=True)
layers = [
layer_class(nbits, axis_key, axis_value, q_group_size, residual_length)
for _ in range(config.num_hidden_layers)
]
super().__init__(layers=layers)
注意 :量化缓存目前不支持异构层类型,所有层使用相同的量化策略。这与 DynamicCache/StaticCache 的异构层支持形成对比。
12. EncoderDecoderCache --- 编码器-解码器缓存
源码位置: [cache_utils.py:1479-1623](file:///workspace/src/transformers/cache_utils.py#L1479-L1623)
设计意图
编码器-解码器模型(如 Whisper、T5)需要两套独立的缓存 :自注意力缓存和交叉注意力缓存。EncoderDecoderCache 使用组合模式将两个 Cache 对象封装在一起。
构造方式
python
class EncoderDecoderCache(Cache):
def __init__(self, *caches):
if len(caches) == 1:
# DDP 兼容:从合并数据拆分出自注意力和交叉注意力
self_attention_cache_data, cross_attention_cache_data = [], []
for combined_cache_data in caches[0]:
if len(combined_cache_data) == 6: # (self_k, self_v, self_sw, cross_k, cross_v, cross_sw)
self_attention_cache_data.append(combined_cache_data[:3])
cross_attention_cache_data.append(combined_cache_data[3:])
elif len(combined_cache_data) == 4: # 旧格式无 sliding_window
self_attention_cache_data.append(combined_cache_data[:2])
cross_attention_cache_data.append(combined_cache_data[2:])
self.self_attention_cache = DynamicCache(self_attention_cache_data)
self.cross_attention_cache = DynamicCache(cross_attention_cache_data)
elif len(caches) == 2:
# 标准用法:传入两个 Cache 对象
if not isinstance(caches[0], Cache) or not isinstance(caches[1], Cache):
raise TypeError("...")
self.self_attention_cache = caches[0]
self.cross_attention_cache = caches[1]
交叉注意力更新追踪
python
self.is_updated = {}
for layer_idx in range(len(self.cross_attention_cache)):
self.is_updated[layer_idx] = bool(
self.cross_attention_cache.get_seq_length(layer_idx) > 0
)
is_updated 字典追踪每层的交叉注意力缓存是否已被更新。这在编码器-解码器模型中很重要,因为交叉注意力缓存通常在编码器前向传播后一次性填充,后续解码步骤中不再变化。
方法委托
python
def get_seq_length(self, layer_idx=0):
return self.self_attention_cache.get_seq_length(layer_idx)
def reset(self):
self.self_attention_cache.reset()
self.cross_attention_cache.reset()
for layer_idx in self.is_updated:
self.is_updated[layer_idx] = False
def reorder_cache(self, beam_idx):
self.self_attention_cache.reorder_cache(beam_idx)
self.cross_attention_cache.reorder_cache(beam_idx)
迭代器
python
def __iter__(self):
"""返回 (self_k, self_v, self_sw, cross_k, cross_v, cross_sw) 六元组"""
for self_attention_layer, cross_attention_layer in \
zip(self.self_attention_cache, self.cross_attention_cache):
yield self_attention_layer + cross_attention_layer
13. 设计原理与架构总结
13.1 两级架构:层(Layer)+ 容器(Cache)
┌─────────────────────────────────────────────────┐
│ Cache 容器 │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ Layer 0 │ │ Layer 1 │ ... │ Layer N │ │
│ │ (全注意力)│ │(滑动窗口)│ │ (Mamba) │ │
│ └─────────┘ └─────────┘ └─────────┘ │
│ │
│ 统一接口: update / reset / reorder / crop │
│ 容器级功能: offloading / prefetch / iteration │
└─────────────────────────────────────────────────┘
- 层:负责单层的 KV 状态存储与更新逻辑
- 容器:管理层列表,提供跨层操作和统一接口
13.2 懒初始化模式
所有缓存层都采用懒初始化:在首次收到真实张量时才分配内存。好处:
- 延迟设备绑定:在 GPU 张量到来之前不需要知道设备类型
- 延迟形状推断 :从真实输入推断
num_heads、head_dim、batch_size - 支持
torch.export:通过early_initialization()可提前初始化
13.3 动态 vs 静态的核心权衡
| 特性 | DynamicLayer | StaticLayer |
|---|---|---|
| 内存分配 | 按需增长(torch.cat) |
预分配固定大小 |
torch.compile |
❌ 不支持 | ✅ 支持 |
torch.export |
❌ 不支持 | ✅ 支持(需 early_initialization) |
| cudagraphs | ❌ 不支持 | ✅ 支持(mark_static_address) |
| 内存效率 | 只用所需 | 预分配可能浪费 |
| 适用场景 | 默认生成 | 高性能推理、编译部署 |
13.4 注册表模式实现开放-封闭原则
LAYER_TYPE_CACHE_MAPPING 使得新增模型层类型时:
- 无需修改
DynamicCache或StaticCache的代码 - 只需 创建新的
CacheLayerMixin子类并设置layer_type - 自动注册机制确保导入即可用
13.5 GPU 卸载流水线
┌──────────────┐
│ 默认流 │
│ Layer N 计算 │
└──────┬───────┘
│ 计算完成
┌──────▼───────┐
│ 卸载到 CPU │ ← 阻塞确保计算完成
└──────────────┘
┌──────────────┐
│ 预取流 │
│ Layer N+1 │ ← 异步,不阻塞
│ CPU → GPU │
└──────────────┘
14. 与其他模块的关系
14.1 与模型前向传播的交互
模型注意力层在 forward 中调用缓存:
python
# 典型调用模式(在模型的 Attention 类中)
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx
)
对于线性注意力层:
python
conv_states = past_key_values.update_conv_state(conv_states, self.layer_idx)
recurrent_states = past_key_values.update_recurrent_state(recurrent_states, self.layer_idx)
14.2 与配置系统的交互
DynamicCache(config=model.config)--- 从配置推断层类型StaticCache(config=model.config, max_cache_len=...)--- 从配置 + 最大长度构建- 关键配置字段:
layer_types、sliding_window、attention_chunk_size、num_hidden_layers、num_kv_shared_layers
14.3 与注意力掩码系统的交互
get_mask_sizes() 方法为掩码生成提供 (kv_length, kv_offset) 信息:
python
# 在模型的前向传播中
kv_length, kv_offset = past_key_values.get_mask_sizes(query_length, layer_idx)
# 据此生成因果掩码 + 滑动窗口掩码
14.4 与生成(Generation)系统的交互
reorder_cache(beam_idx)--- Beam Search 时重排缓存crop(max_length)--- 辅助解码/对比搜索时裁剪缓存reset()--- 多轮对话时重置缓存
14.5 与 torch.compile / torch.export 的交互
StaticCache+mark_static_address→ 支持 cudagraphsearly_initialization()→ 支持torch.exportis_compileable属性 → 生成系统判断是否可编译
14.6 导入关系图
cache_utils.py
├── 导入 configuration_utils.PreTrainedConfig (配置系统)
├── 导入 utils.is_torch_greater_or_equal (版本检测)
├── 导入 utils.is_torchdynamo_compiling (编译状态检测)
├── 导入 utils.is_hqq_available / is_optimum_quanto_available (量化后端检测)
│
├── 被 __init__.py 导出: Cache, DynamicCache, StaticCache, EncoderDecoderCache, ...
├── 被 modeling_layers.py 导入: Cache
├── 被 modeling_outputs.py 导入: Cache, EncoderDecoderCache
├── 被 masking_utils.py 导入: Cache
├── 被 _typing.py 导入: Cache
└── 被各模型的 modeling_*.py 使用: past_key_values.update(key, value, layer_idx)