05-Hugging Face Transformers 缓存系统深度分析

Hugging Face Transformers 缓存系统深度分析

相关文章:
Hugging Face Transformers 源码全景解读
01-Hugging Face Transformers 核心基础设施深度分析
02-Hugging Face Transformers 配置系统深度分析
03-Hugging Face Transformers 模型系统深度分析
04-Hugging Face Transformers 注意力与掩码系统深度分析

分析版本:Transformers v5.8.0.dev0

分析范围:cache_utils.py --- KV Cache 完整架构,含层抽象、注册表、动态/静态/量化/线性注意力缓存


目录

  1. 模块总览
  2. [CacheLayerMixin --- 缓存层抽象基类](#CacheLayerMixin — 缓存层抽象基类)
  3. [LAYER_TYPE_CACHE_MAPPING --- 层类型注册表](#LAYER_TYPE_CACHE_MAPPING — 层类型注册表)
  4. 动态缓存层
  5. 静态缓存层(可编译)
  6. 量化缓存层
  7. 线性注意力缓存层
  8. [Cache --- 缓存容器基类](#Cache — 缓存容器基类)
  9. [DynamicCache --- 动态缓存](#DynamicCache — 动态缓存)
  10. [StaticCache --- 静态缓存](#StaticCache — 静态缓存)
  11. [QuantizedCache --- 量化缓存](#QuantizedCache — 量化缓存)
  12. [EncoderDecoderCache --- 编码器-解码器缓存](#EncoderDecoderCache — 编码器-解码器缓存)
  13. 设计原理与架构总结
  14. 与其他模块的关系

缓存系统架构总览

注册表
缓存容器体系
缓存层体系
CacheLayerMixin ABC
DynamicLayer
DynamicSlidingWindowLayer
QuantizedLayer
QuantoQuantizedLayer
HQQQuantizedLayer
StaticLayer
StaticSlidingWindowLayer
LinearAttentionCacheLayerMixin ABC
LinearAttentionLayer
Cache 基类
DynamicCache
StaticCache
QuantizedCache
EncoderDecoderCache
LAYER_TYPE_CACHE_MAPPING


1. 模块总览

文件路径 : src/transformers/cache_utils.py(约 1623 行)

模块职责

缓存系统是 Transformers 推理引擎的核心基础设施,负责在自回归生成过程中存储和复用 Key/Value 状态,避免重复计算。该模块实现了:

  1. 分层缓存架构 --- 每个模型层拥有独立的缓存层对象,支持异构层类型(全注意力、滑动窗口、线性注意力等)
  2. 动态与静态双模式 --- 动态缓存按需增长,静态缓存预分配固定张量以支持 torch.compile
  3. 量化 KV 缓存 --- 支持 Quanto 和 HQQ 两种后端,将 KV 状态压缩至 2/4 bit
  4. 线性注意力缓存 --- 为 Mamba/SSM 等非 Transformer 架构提供 conv/recurrent 状态缓存
  5. 编码器-解码器缓存 --- 组合自注意力与交叉注意力缓存,支持 seq2seq 模型
  6. GPU 卸载与预取 --- 层级 CPU 卸载机制,配合异步流实现计算-传输重叠

类继承体系

复制代码
CacheLayerMixin (ABC)                    LinearAttentionCacheLayerMixin (ABC)
├── DynamicLayer                         ├── LinearAttentionLayer
│   ├── DynamicSlidingWindowLayer        └── (via multiple inheritance)
│   └── QuantizedLayer                       LinearAttentionAndFullAttentionLayer
│       ├── QuantoQuantizedLayer              = LinearAttentionLayer + DynamicLayer
│       └── HQQQuantizedLayer
└── StaticLayer
    └── StaticSlidingWindowLayer

Cache (容器基类)
├── DynamicCache
├── StaticCache
├── QuantizedCache
└── EncoderDecoderCache (组合模式,非继承)

SlidingWindowCache = StaticCache  (v5 弃用别名)

2. CacheLayerMixin --- 缓存层抽象基类

源码位置: [cache_utils.py:37-107](file:///workspace/src/transformers/cache_utils.py#L37-L107)

设计意图

CacheLayerMixin 是所有注意力缓存层的抽象基类,定义了单层缓存的最小接口契约 。它采用 Mixin 模式而非纯继承,使得缓存层可以与其他基类(如 DynamicLayerLinearAttentionLayer)灵活组合。

核心属性

python 复制代码
class CacheLayerMixin(ABC):
    is_compileable = False          # 是否支持 torch.compile(静态缓存为 True)
    layer_type: str | None = None   # 自动注册到 LAYER_TYPE_CACHE_MAPPING 的键名

    def __init__(self):
        self.keys: torch.Tensor | None = None       # Key 状态张量
        self.values: torch.Tensor | None = None     # Value 状态张量
        self.is_initialized = False                  # 懒初始化标记

自动注册机制 --- __init_subclass__

python 复制代码
def __init_subclass__(cls, **kwargs):
    super().__init_subclass__(**kwargs)
    layer_type = cls.__dict__.get("layer_type", None)  # 只取类自身定义的,不取继承的
    if layer_type is not None:
        LAYER_TYPE_CACHE_MAPPING[layer_type] = cls     # 导入时自动注册

关键设计 :使用 cls.__dict__.get() 而非 getattr(),确保只有子类显式定义layer_type 才会注册,避免父类的 layer_type 被子类意外覆盖。

抽象方法

方法 返回类型 说明
lazy_initialization(key, value) None 首次收到真实张量时初始化缓存
update(key, value, *args, **kwargs) tuple[Tensor, Tensor] 更新缓存并返回完整 KV 状态
get_mask_sizes(query_length) tuple[int, int] 返回 (kv_length, kv_offset) 用于生成注意力掩码
get_seq_length() int 返回已缓存的序列长度
get_max_cache_shape() int 返回最大缓存容量(动态缓存返回 -1)

通用方法

python 复制代码
def offload(self):
    """将缓存数据卸载到 CPU,节省 GPU 显存"""
    if self.is_initialized:
        self.keys = self.keys.to("cpu", non_blocking=True)
        self.values = self.values.to("cpu", non_blocking=True)

def prefetch(self):
    """从 CPU 预取回 GPU,配合异步流实现计算-传输重叠"""
    if self.is_initialized and self.keys.device != self.device:
        self.keys = self.keys.to(self.device, non_blocking=True)
        self.values = self.values.to(self.device, non_blocking=True)

def reset(self):
    """重置缓存值但保留对象(避免重新分配内存)"""
    if self.is_initialized:
        self.keys.zero_()
        self.values.zero_()
    if hasattr(self, "cumulative_length"):
        if isinstance(self.cumulative_length, int):
            self.cumulative_length = 0
        else:
            self.cumulative_length.zero_()  # 张量形式用于 torch.compile

def reorder_cache(self, beam_idx):
    """Beam Search 时按 beam 索引重排缓存"""
    if self.get_seq_length() > 0:
        self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
        self.values = self.values.index_select(0, beam_idx.to(self.values.device))

3. LAYER_TYPE_CACHE_MAPPING --- 层类型注册表

源码位置: [cache_utils.py:26-34](file:///workspace/src/transformers/cache_utils.py#L26-L34) 及 [871-887](file:///workspace/src/transformers/cache_utils.py#L871-L887)

设计意图

注册表模式将 config.layer_types[i] 字符串映射到对应的缓存层类,使 DynamicCacheStaticCache 能根据模型配置自动分发 正确的缓存层类型,无需为每种模型创建专属 Cache 子类。

注册表内容

python 复制代码
LAYER_TYPE_CACHE_MAPPING: dict[str, type] = {}

# 自动注册:子类定义 layer_type = "xxx" 时通过 __init_subclass__ 自动加入
# 手动注册:标准层类型在此处批量注册
LAYER_TYPE_CACHE_MAPPING.update({
    "full_attention":      DynamicLayer,                    # 全注意力(默认)
    "sliding_attention":   DynamicSlidingWindowLayer,       # 滑动窗口注意力
    "chunked_attention":   DynamicSlidingWindowLayer,       # 分块注意力(缓存行为同滑动窗口)
    "mamba":               LinearAttentionLayer,            # Mamba SSM 层
    "conv":                LinearAttentionLayer,            # 卷积层
    "linear_attention":    LinearAttentionLayer,            # 纯线性注意力
    "moe":                 LinearAttentionLayer,            # MoE 占位(无 KV 缓存需求)
    "hybrid":              LinearAttentionAndFullAttentionLayer,  # 混合层(如 Zamba)
})

双重注册机制

  1. 自动注册 :子类定义 layer_type = "xxx" 时,CacheLayerMixin.__init_subclass__ 自动将其加入注册表
  2. 手动注册 :对于多个 layer_type 共享同一缓存类的情况(如 "sliding_attention""chunked_attention" 都映射到 DynamicSlidingWindowLayer),通过 update() 手动添加

使用流程

复制代码
模型配置 config.layer_types = ["full_attention", "sliding_attention", "mamba", ...]
                                    ↓
DynamicCache.__init__ 遍历 layer_types
                                    ↓
LAYER_TYPE_CACHE_MAPPING["full_attention"] → DynamicLayer(config)
LAYER_TYPE_CACHE_MAPPING["sliding_attention"] → DynamicSlidingWindowLayer(config)
LAYER_TYPE_CACHE_MAPPING["mamba"] → LinearAttentionLayer(config)

4. 动态缓存层

4.1 DynamicLayer --- 全注意力动态层

源码位置: [cache_utils.py:109-188](file:///workspace/src/transformers/cache_utils.py#L109-L188)

动态层的核心特征:缓存随生成步骤增长 ,每次 update 通过 torch.cat 追加新 token 的 KV 状态。

python 复制代码
class DynamicLayer(CacheLayerMixin):
    is_sliding = False

    def lazy_initialization(self, key_states, value_states):
        self.dtype, self.device = key_states.dtype, key_states.device
        self.keys = torch.tensor([], dtype=self.dtype, device=self.device)   # 空张量占位
        self.values = torch.tensor([], dtype=self.dtype, device=self.device)
        self.is_initialized = True

    def update(self, key_states, value_states, *args, **kwargs):
        if not self.is_initialized:
            self.lazy_initialization(key_states, value_states)
        # 核心:拼接追加新 KV 状态,形状 [batch, heads, seq_len, head_dim]
        self.keys = torch.cat([self.keys, key_states], dim=-2)
        self.values = torch.cat([self.values, value_states], dim=-2)
        return self.keys, self.values

关键设计点

  • dim=-2 是序列长度维度,torch.cat 在该维度上追加
  • 返回的是完整 KV 状态(包含历史 + 新增),供注意力计算使用
  • get_max_cache_shape() 返回 -1,表示无上限
  • 提供 crop()batch_repeat_interleave()batch_select_indices() 等辅助方法

4.2 DynamicSlidingWindowLayer --- 滑动窗口动态层

源码位置: [cache_utils.py:190-275](file:///workspace/src/transformers/cache_utils.py#L190-L275)

滑动窗口层在动态增长的基础上,只保留最近 sliding_window - 1 个 token 的缓存,大幅减少显存占用。

python 复制代码
class DynamicSlidingWindowLayer(DynamicLayer):
    is_sliding = True

    def __init__(self, config=None, sliding_window=None):
        super().__init__()
        if sliding_window is None:
            if config is None:
                raise ValueError("Either `config` or `sliding_window` must be provided.")
            sliding_window = getattr(config, "sliding_window", None) or \
                             getattr(config, "attention_chunk_size", None)
        self.sliding_window = sliding_window
        self.cumulative_length = 0  # 跟踪已处理的总 token 数(含已丢弃的)

    def update(self, key_states, value_states, *args, **kwargs):
        if not self.is_initialized:
            self.lazy_initialization(key_states, value_states)

        self.cumulative_length += key_states.shape[-2]

        # 1. 先拼接得到完整状态
        full_key_states = torch.cat([self.keys, key_states], dim=-2)
        full_value_states = torch.cat([self.values, value_states], dim=-2)
        # 2. 只缓存最后 sliding_window - 1 个 token
        self.keys = full_key_states[:, :, -self.sliding_window + 1 :, :]
        self.values = full_value_states[:, :, -self.sliding_window + 1 :, :]
        # 3. 返回完整状态(注意力计算需要完整上下文)
        return full_key_states, full_value_states

核心设计

  • cumulative_length 跟踪已处理 token 数,即使部分已被丢弃
  • 缓存只保留 sliding_window - 1 个 token(减 1 是因为当前 query 本身也算一个窗口位置)
  • update 返回完整 KV 状态而非截断后的,因为注意力计算需要看到窗口内的所有 token
  • get_mask_sizes() 根据 cumulative_length 是否已超过 sliding_window 计算掩码偏移
python 复制代码
def get_mask_sizes(self, query_length):
    is_full = self.cumulative_length >= self.sliding_window
    kv_offset = max(self.cumulative_length - self.sliding_window + 1, 0)
    if is_full:
        kv_length = self.sliding_window - 1 + query_length
    else:
        kv_length = self.cumulative_length + query_length
    return kv_length, kv_offset

5. 静态缓存层(可编译)

5.1 StaticLayer --- 全注意力静态层

源码位置: [cache_utils.py:277-385](file:///workspace/src/transformers/cache_utils.py#L277-L385)

静态层为 torch.compile 设计,预分配固定大小的张量,通过原地更新(in-place mutation)避免动态内存分配,确保编译图结构稳定。

python 复制代码
class StaticLayer(CacheLayerMixin):
    is_compileable = True
    is_sliding = False

    def __init__(self, max_cache_len: int):
        super().__init__()
        self.max_cache_len = max_cache_len
        # 关键:cumulative_length 必须是张量,避免 recompile
        self.cumulative_length = torch.tensor([0], dtype=int)

    def lazy_initialization(self, key_states, value_states):
        self.dtype, self.device = key_states.dtype, key_states.device
        self.max_batch_size, self.num_heads = key_states.shape[:2]
        self.v_head_dim = value_states.shape[-1]
        self.k_head_dim = key_states.shape[-1]

        # 预分配完整大小的零张量
        self.keys = torch.zeros(
            (self.max_batch_size, self.num_heads, self.max_cache_len, self.k_head_dim),
            dtype=self.dtype, device=self.device,
        )
        self.values = torch.zeros(
            (self.max_batch_size, self.num_heads, self.max_cache_len, self.v_head_dim),
            dtype=self.dtype, device=self.device,
        )
        self.cumulative_length = self.cumulative_length.to(self.device)

        # 标记静态地址,防止 cudagraph 因原地修改跳过或图断裂
        if not is_torchdynamo_compiling():
            torch._dynamo.mark_static_address(self.keys)
            torch._dynamo.mark_static_address(self.values)
            torch._dynamo.mark_static_address(self.cumulative_length)

        self.is_initialized = True

mark_static_address 的关键作用

torch.compile 在编译时追踪张量的内存地址。如果张量被原地修改(如 index_copy_),Dynamo 默认会认为地址变化了,导致图断裂或 cudagraph 失效。mark_static_address 告诉 Dynamo:"这个张量的地址不会变",从而允许原地修改操作被正确编译。

python 复制代码
def update(self, key_states, value_states, *args, **kwargs):
    if not self.is_initialized:
        self.lazy_initialization(key_states, value_states)

    kv_length = key_states.shape[-2]
    cache_position = torch.arange(kv_length, device=self.device) + self.cumulative_length
    # 原地更新累计长度(保持静态地址)
    self.cumulative_length.add_(kv_length)

    # 使用 index_copy_ 原地写入新 KV 状态
    try:
        self.keys.index_copy_(2, cache_position, key_states)
        self.values.index_copy_(2, cache_position, value_states)
    except NotImplementedError:
        # MPS 等设备可能不支持 index_copy_,回退到切片赋值
        self.keys[:, :, cache_position] = key_states
        self.values[:, :, cache_position] = value_states

    return self.keys, self.values

5.2 StaticSlidingWindowLayer --- 滑动窗口静态层

源码位置: [cache_utils.py:387-512](file:///workspace/src/transformers/cache_utils.py#L387-L512)

这是最复杂的缓存层,需要在静态张量上实现滑动窗口的"滚动"效果,同时保持编译友好。

python 复制代码
class StaticSlidingWindowLayer(StaticLayer):
    is_sliding = True

    def __init__(self, max_cache_len, sliding_window):
        effective_max_cache_len = min(sliding_window, max_cache_len)
        super().__init__(max_cache_len=effective_max_cache_len)
        # 用 Python int 跟踪长度,避免数据依赖控制流导致 recompile
        self.cumulative_length_int = 0

update 方法的四路分支

python 复制代码
def update(self, key_states, value_states, *args, **kwargs):
    if not self.is_initialized:
        self.lazy_initialization(key_states, value_states)

    kv_length = key_states.shape[-2]
    current_length = self.cumulative_length_int
    is_full = current_length >= self.max_cache_len
    self.cumulative_length_int += kv_length

    if is_full:
        if key_states.shape[-2] == 1:
            # 分支1: 已满 + 单 token 解码 --- 使用 roll + 赋值
            new_keys = self.keys.roll(-1, dims=-2)
            new_values = self.values.roll(-1, dims=-2)
            index = torch.tensor([-1], dtype=int, device=self.device)
            new_keys[:, :, index] = key_states
            new_values[:, :, index] = value_states
            self.keys.copy_(new_keys)       # copy_ 保持静态地址
            self.values.copy_(new_values)
            return self.keys, self.values
        else:
            # 分支2: 已满 + 多 token(如 prefill 缓存、chat 续接)
            full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
            full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2)
    elif current_length + kv_length > self.max_cache_len:
        # 分支3: 即将满 --- 需要截断
        if current_length == 0:
            full_key_states = key_states          # 快速路径:缓存为空
            full_value_states = value_states
        else:
            full_key_states = torch.cat((self.keys[:, :, :current_length, :], key_states), dim=-2)
            full_value_states = torch.cat((self.values[:, :, :current_length, :], value_states), dim=-2)
    else:
        # 分支4: 未满 --- 标准 index_copy_ 原地写入
        cache_position = torch.arange(kv_length, device=self.device) + self.cumulative_length
        self.keys.index_copy_(2, cache_position, key_states)
        self.values.index_copy_(2, cache_position, value_states)
        self.cumulative_length.add_(kv_length)    # 同步更新张量版本
        return self.keys, self.values

    # 分支2/3 的公共后处理:截断到 max_cache_len 并 copy_ 回静态张量
    self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :])
    self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :])
    return full_key_states, full_value_states

设计要点

  • roll 技巧 :单 token 解码时,用 roll(-1) 将所有元素左移一位,再覆盖最后一个位置,避免 cat 产生新张量
  • copy_ 而非赋值self.keys.copy_(new_keys) 保持静态地址不变,赋值 self.keys = new_keys 会改变地址导致 cudagraph 失效
  • 双长度追踪cumulative_length_int(Python int)用于控制流判断,cumulative_length(Tensor)用于 index_copy_ 的位置计算,两者需要同步

6. 量化缓存层

6.1 QuantizedLayer --- 量化基类

源码位置: [cache_utils.py:514-588](file:///workspace/src/transformers/cache_utils.py#L514-L588)

基于 KIVI 论文的思路,量化缓存维护双存储 :原始精度缓存 + 量化缓存。当原始精度缓存超过 residual_length 时,将其量化并合并到量化缓存中。

python 复制代码
class QuantizedLayer(DynamicLayer):
    def __init__(self, nbits=4, axis_key=0, axis_value=0,
                 q_group_size=64, residual_length=128):
        super().__init__()
        self.nbits = nbits
        self.axis_key = axis_key
        self.axis_value = axis_value
        self.q_group_size = q_group_size
        self.residual_length = residual_length
        self.cumulative_length = 0

    def update(self, key_states, value_states, *args, **kwargs):
        self.cumulative_length += key_states.shape[-2]

        if not self.is_initialized:
            self.lazy_initialization(key_states, value_states)
            # 首次直接量化
            self._quantized_keys = self._quantize(key_states.contiguous(), axis=self.axis_key)
            self._quantized_values = self._quantize(value_states.contiguous(), axis=self.axis_value)
            return key_states, value_states

        # 反量化已存储的量化缓存
        dequant_keys = self._dequantize(self._quantized_keys)
        dequant_values = self._dequantize(self._quantized_values)
        # 拼接:量化部分 + 原始精度部分 + 新增部分
        keys_to_return = torch.cat([dequant_keys, self.keys, key_states], dim=-2)
        values_to_return = torch.cat([dequant_values, self.values, value_states], dim=-2)

        # 当原始精度部分超过 residual_length 时,全部量化
        if self.keys.dim() == 4 and self.keys.shape[-2] + 1 >= self.residual_length:
            self._quantized_keys = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
            self._quantized_values = self._quantize(values_to_return.contiguous(), axis=self.axis_value)
            self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
            self.values = torch.tensor([], dtype=key_states.dtype, key_states.device)
        else:
            self.keys = torch.cat([self.keys, key_states], dim=-2)
            self.values = torch.cat([self.values, value_states], dim=-2)

        return keys_to_return, values_to_return

    @abstractmethod
    def _quantize(self, tensor, axis): ...

    @abstractmethod
    def _dequantize(self, q_tensor): ...

量化-反量化流程

复制代码
新增 KV → 拼接到原始精度缓存 → 超过 residual_length?
                                    ├── 是:量化全部 → 清空原始精度缓存
                                    └── 否:保留在原始精度缓存
读取时:反量化部分 + 原始精度部分 + 新增部分 → 返回完整 KV

6.2 QuantoQuantizedLayer --- Quanto 后端

源码位置: [cache_utils.py:590-643](file:///workspace/src/transformers/cache_utils.py#L590-L643)

python 复制代码
class QuantoQuantizedLayer(QuantizedLayer):
    def __init__(self, nbits=4, axis_key=0, axis_value=0, q_group_size=64, residual_length=128):
        super().__init__(nbits, axis_key, axis_value, q_group_size, residual_length)
        # 需要 optimum-quanto >= 0.2.5
        from optimum.quanto import MaxOptimizer, qint2, qint4
        self.qtype = qint4 if self.nbits == 4 else qint2  # 仅支持 2/4 bit
        self.optimizer = MaxOptimizer()  # 逐通道量化的唯一优化器

    def _quantize(self, tensor, axis):
        from optimum.quanto import quantize_weight
        scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
        qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
        return qtensor

    def _dequantize(self, qtensor):
        return qtensor.dequantize()

6.3 HQQQuantizedLayer --- HQQ 后端

源码位置: [cache_utils.py:645-699](file:///workspace/src/transformers/cache_utils.py#L645-L699)

python 复制代码
class HQQQuantizedLayer(QuantizedLayer):
    def __init__(self, nbits=4, axis_key=0, axis_value=0, q_group_size=64, residual_length=128):
        super().__init__(nbits, axis_key, axis_value, q_group_size, residual_length)
        # 支持 1/2/3/4/8 bit
        self.quantizer = HQQQuantizer

    def _quantize(self, tensor, axis):
        qtensor, meta = self.quantizer.quantize(
            tensor, axis=axis, device=self.keys.device,
            compute_dtype=self.keys.dtype, nbits=self.nbits, group_size=self.q_group_size,
        )
        meta["compute_dtype"] = self.keys.dtype
        self.quantizer.cuda(qtensor, meta=meta, device=self.keys.device)
        return qtensor, meta  # 返回元组(量化张量 + 元数据)

    def _dequantize(self, qtensor):
        quant_tensor, meta = qtensor  # 解包元组
        return self.quantizer.dequantize(quant_tensor, meta)

7. 线性注意力缓存层

7.1 LinearAttentionCacheLayerMixin --- 线性注意力层抽象基类

源码位置: [cache_utils.py:702-761](file:///workspace/src/transformers/cache_utils.py#L702-L761)

线性注意力层(Mamba/SSM 等)不使用传统的 KV 缓存,而是维护 conv_states(卷积状态)和 recurrent_states(循环状态)。

python 复制代码
class LinearAttentionCacheLayerMixin(ABC):
    is_compileable = True  # 线性注意力的状态形状天然是静态的

    def __init__(self):
        self.conv_states: torch.Tensor | None = None
        self.recurrent_states: torch.Tensor | None = None
        self.is_conv_states_initialized = False
        self.is_recurrent_states_initialized = False
        self.has_previous_state = False  # 标记是否有前序状态(影响 conv 更新逻辑)

    @abstractmethod
    def lazy_initialization(self, conv_states=None, recurrent_states=None): ...

    @abstractmethod
    def update_conv_state(self, conv_states): ...

    @abstractmethod
    def update_recurrent_state(self, recurrent_states): ...

7.2 LinearAttentionLayer --- 线性注意力层实现

源码位置: [cache_utils.py:764-838](file:///workspace/src/transformers/cache_utils.py#L764-L838)

python 复制代码
class LinearAttentionLayer(LinearAttentionCacheLayerMixin):
    def update_conv_state(self, conv_states, **kwargs):
        if not self.is_conv_states_initialized:
            self.lazy_initialization(conv_states=conv_states)

        if not self.has_previous_state:
            # 首次:直接复制(copy_ 保持静态地址)
            self.conv_states.copy_(conv_states)
            self.has_previous_state = True
        else:
            num_new_tokens = conv_states.shape[-1]
            if num_new_tokens >= self.conv_kernel_size:
                # 新 token 数 >= 卷积核大小:直接替换
                self.conv_states.copy_(conv_states[..., -self.conv_kernel_size:])
            else:
                # 滚动:左移腾出空间,末尾填入新状态
                new_conv_states = self.conv_states.roll(shifts=-num_new_tokens, dims=-1)
                new_conv_states[:, :, -num_new_tokens:] = conv_states
                self.conv_states.copy_(new_conv_states)

        return self.conv_states

    def update_recurrent_state(self, recurrent_states, **kwargs):
        if not self.is_recurrent_states_initialized:
            self.lazy_initialization(recurrent_states=recurrent_states)
        # 循环状态直接覆盖
        self.recurrent_states.copy_(recurrent_states)
        return self.recurrent_states

7.3 LinearAttentionAndFullAttentionLayer --- 混合层

源码位置: [cache_utils.py:840-865](file:///workspace/src/transformers/cache_utils.py#L840-L865)

通过多重继承同时拥有线性注意力缓存和全注意力缓存的能力,用于 Zamba 等混合架构模型。

python 复制代码
class LinearAttentionAndFullAttentionLayer(LinearAttentionLayer, DynamicLayer):
    is_compileable = False  # DynamicLayer 部分使其不可编译

    def lazy_initialization(self, *args, **kwargs):
        # 根据参数形式分发到对应的父类初始化
        if len(args) == 2 and len(kwargs) == 0:
            DynamicLayer.lazy_initialization(self, *args)       # KV 缓存初始化
        if len(args) == 0 and len(kwargs) == 1:
            LinearAttentionLayer.lazy_initialization(self, **kwargs)  # conv/recurrent 初始化

    def reset(self):
        LinearAttentionLayer.reset(self)
        DynamicLayer.reset(self)

    def reorder_cache(self, beam_idx):
        LinearAttentionLayer.reorder_cache(self, beam_idx)
        DynamicLayer.reorder_cache(self, beam_idx)

8. Cache --- 缓存容器基类

源码位置: [cache_utils.py:890-1227](file:///workspace/src/transformers/cache_utils.py#L890-L1227)

设计意图

Cache 是所有缓存容器的基类,本质上是一个 CacheLayerMixin 对象的列表,提供统一的层间操作接口。它采用组合模式:容器持有层对象,操作委托给各层。

构造方式

python 复制代码
class Cache:
    def __init__(self, layers=None, layer_class_to_replicate=None,
                 offloading=False, offload_only_non_sliding=True):
        # 两种构造方式二选一:
        # 1. 传入预创建的 layers 列表
        # 2. 传入 layer_class_to_replicate,按需懒创建
        if layers is not None and layer_class_to_replicate is not None:
            raise ValueError("...")
        if layers is None and layer_class_to_replicate is None:
            raise ValueError("...")
        self.layers = layers if layers is not None else []
        self.layer_class_to_replicate = layer_class_to_replicate
        self.offloading = offloading
        if self.offloading:
            self.only_non_sliding = offload_only_non_sliding
            self.prefetch_stream = torch.Stream() if _is_torch_greater_or_equal_than_2_7 \
                                   else torch.cuda.Stream()

核心 update 方法

python 复制代码
def update(self, key_states, value_states, layer_idx, *args, **kwargs):
    # 懒创建:如果 layer_idx 超出当前 layers 列表长度,自动追加
    if self.layer_class_to_replicate is not None:
        while len(self.layers) <= layer_idx:
            self.layers.append(self.layer_class_to_replicate())

    # GPU 卸载:等待预取流完成,同时预取下一层
    if self.offloading:
        torch.cuda.default_stream(key_states.device).wait_stream(self.prefetch_stream)
        self.prefetch(layer_idx + 1, self.only_non_sliding)

    # 委托给具体层
    keys, values = self.layers[layer_idx].update(key_states, value_states, *args, **kwargs)

    # GPU 卸载:当前层计算完毕,卸载到 CPU
    if self.offloading:
        self.offload(layer_idx, self.only_non_sliding)

    return keys, values

GPU 卸载的流水线设计

复制代码
时间线:
  Layer 0 计算 ──→ 卸载 Layer 0 到 CPU ──→
                    预取 Layer 1 到 GPU ──→ Layer 1 计算 ──→ 卸载 Layer 1 ──→ ...
  • prefetch 使用非默认流,不阻塞主计算流
  • offload 在默认流上执行,确保层的 update 计算已完成
  • 滑动窗口层通常较小,默认不卸载(offload_only_non_sliding=True

线性注意力专用方法

python 复制代码
def update_conv_state(self, conv_states, layer_idx, **kwargs):
    if not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin):
        raise ValueError("Cannot call `update_conv_state` on a non-LinearAttention layer!")
    return self.layers[layer_idx].update_conv_state(conv_states, **kwargs)

def update_recurrent_state(self, recurrent_states, layer_idx, **kwargs):
    if not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin):
        raise ValueError("Cannot call `update_conv_state` on a non-LinearAttention layer!")
    return self.layers[layer_idx].update_recurrent_state(recurrent_states, **kwargs)

提前初始化

python 复制代码
def early_initialization(self, batch_size, num_heads, head_dim, dtype, device):
    """用于 torch.export,在首次 forward 之前初始化所有层"""
    if isinstance(num_heads, int):
        num_heads = [num_heads] * len(self)
    if isinstance(head_dim, int):
        head_dim = [head_dim] * len(self)

    for layer, layer_num_heads, layer_head_dim in zip(self.layers, num_heads, head_dim):
        # 使用 size 0 的伪张量触发初始化(不分配实际数据)
        fake_kv_tensor = torch.zeros(
            (batch_size, layer_num_heads, 0, layer_head_dim), dtype=dtype, device=device
        )
        layer.lazy_initialization(fake_kv_tensor, fake_kv_tensor)

关键属性

python 复制代码
@property
def is_compileable(self) -> bool:
    if len(self.layers) == 0:
        return False
    return all(layer.is_compileable for layer in self.layers)

@property
def is_sliding(self) -> list[bool]:
    return [getattr(layer, "is_sliding", False) for layer in self.layers]

@property
def max_batch_size(self) -> int:
    values = [layer.max_batch_size for layer in self.layers]
    if len(set(values)) > 1:
        raise ValueError(f"Max batch size is not consistent across layers: {values}")
    return values[0]

@property
def max_cache_len(self) -> int:
    values = [layer.max_cache_len for layer in self.layers]
    return max(values)

9. DynamicCache --- 动态缓存

源码位置: [cache_utils.py:1229-1334](file:///workspace/src/transformers/cache_utils.py#L1229-L1334)

设计意图

DynamicCache最常用的缓存类,适用于绝大多数生成式模型。它根据模型配置自动选择每层的缓存类型。

构造逻辑

python 复制代码
class DynamicCache(Cache):
    def __init__(self, ddp_cache_data=None, config=None,
                 offloading=False, offload_only_non_sliding=False):
        layers = []

        # 路径1: 根据 config 自动构建异构层
        if config is not None:
            decoder_config = config.get_text_config(decoder=True)
            sliding_window = getattr(decoder_config, "sliding_window", None) or \
                             getattr(decoder_config, "attention_chunk_size", None)
            layer_types = getattr(decoder_config, "layer_types", None)
            if layer_types is None:
                # 无显式 layer_types,根据 sliding_window 推断
                layer_types = []
                for _ in range(decoder_config.num_hidden_layers):
                    if sliding_window is not None:
                        layer_types.append("sliding_attention")
                    else:
                        layer_types.append("full_attention")
            # 跳过共享 KV 层(如 Gemma3n)
            if hasattr(decoder_config, "num_kv_shared_layers"):
                layer_types = layer_types[: -decoder_config.num_kv_shared_layers]

            for layer_type in layer_types:
                cache_cls = LAYER_TYPE_CACHE_MAPPING.get(layer_type, DynamicLayer)
                layers.append(cache_cls(decoder_config))

        # 路径2: 从 DDP 数据填充
        if ddp_cache_data is not None:
            for layer_idx, kv_and_optional_sliding in enumerate(ddp_cache_data):
                if config is None:
                    sliding_window_tensor = kv_and_optional_sliding[2] \
                        if len(kv_and_optional_sliding) == 3 else None
                    if sliding_window_tensor is not None:
                        layers.append(DynamicSlidingWindowLayer(
                            sliding_window=sliding_window_tensor[0].item()))
                    else:
                        layers.append(DynamicLayer())
                _, _ = layers[layer_idx].update(
                    kv_and_optional_sliding[0], kv_and_optional_sliding[1])

        # 路径3: 无 config 无数据 --- 懒创建 DynamicLayer
        if len(layers) == 0:
            super().__init__(layer_class_to_replicate=DynamicLayer, ...)
        else:
            super().__init__(layers=layers, ...)

迭代器

python 复制代码
def __iter__(self):
    for layer in self.layers:
        yield layer.keys, layer.values, getattr(layer, "_sliding_window_tensor", None)

每个元素是 (keys, values, sliding_window_tensor) 三元组,sliding_window_tensor 可能为 None


10. StaticCache --- 静态缓存

源码位置: [cache_utils.py:1336-1421](file:///workspace/src/transformers/cache_utils.py#L1336-L1421)

设计意图

StaticCache 专为 torch.compiletorch.export 设计,所有层使用静态预分配张量。

构造逻辑

python 复制代码
class StaticCache(Cache):
    def __init__(self, config, max_cache_len, offloading=False,
                 offload_only_non_sliding=True, **kwargs):
        config = config.get_text_config(decoder=True)
        layer_types = getattr(config, "layer_types", None)
        if layer_types is None:
            if getattr(config, "sliding_window", None) is not None:
                layer_types = ["sliding_attention" for _ in range(config.num_hidden_layers)]
            elif getattr(config, "attention_chunk_size", None) is not None:
                layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)]
            else:
                layer_types = ["full_attention" for _ in range(config.num_hidden_layers)]
        if hasattr(config, "num_kv_shared_layers"):
            layer_types = layer_types[: -config.num_kv_shared_layers]

        # 识别滑动窗口类层类型
        sliding_layer_types = {
            name for name, cls in LAYER_TYPE_CACHE_MAPPING.items()
            if isinstance(cls, type) and issubclass(cls, DynamicSlidingWindowLayer)
            and name != "chunked_attention"
        }

        layers = []
        for layer_type in layer_types:
            if layer_type == "chunked_attention":
                layer = StaticSlidingWindowLayer(
                    max_cache_len=max_cache_len,
                    sliding_window=config.attention_chunk_size)
            elif layer_type in sliding_layer_types:
                layer = StaticSlidingWindowLayer(
                    max_cache_len=max_cache_len,
                    sliding_window=config.sliding_window)
            elif layer_type in ("mamba", "conv", "linear_attention", "moe"):
                layer = LinearAttentionLayer()
            else:
                layer = StaticLayer(max_cache_len=max_cache_len)
            layers.append(layer)

        super().__init__(layers=layers, ...)

与 DynamicCache 的关键区别

  • DynamicCache 使用 DynamicLayer/DynamicSlidingWindowLayertorch.cat 增长)
  • StaticCache 使用 StaticLayer/StaticSlidingWindowLayer(预分配 + index_copy_ 原地更新)
  • StaticCache 必须 传入 configmax_cache_len,不支持懒创建

11. QuantizedCache --- 量化缓存

源码位置: [cache_utils.py:1423-1476](file:///workspace/src/transformers/cache_utils.py#L1423-L1476)

python 复制代码
class QuantizedCache(Cache):
    def __init__(self, backend, config, nbits=4, axis_key=0,
                 axis_value=0, q_group_size=64, residual_length=128):
        if backend == "quanto":
            layer_class = QuantoQuantizedLayer
        elif backend == "hqq":
            layer_class = HQQQuantizedLayer
        else:
            raise ValueError(f"Unknown quantization backend `{backend}`")

        config = config.get_text_config(decoder=True)
        layers = [
            layer_class(nbits, axis_key, axis_value, q_group_size, residual_length)
            for _ in range(config.num_hidden_layers)
        ]
        super().__init__(layers=layers)

注意 :量化缓存目前不支持异构层类型,所有层使用相同的量化策略。这与 DynamicCache/StaticCache 的异构层支持形成对比。


12. EncoderDecoderCache --- 编码器-解码器缓存

源码位置: [cache_utils.py:1479-1623](file:///workspace/src/transformers/cache_utils.py#L1479-L1623)

设计意图

编码器-解码器模型(如 Whisper、T5)需要两套独立的缓存 :自注意力缓存和交叉注意力缓存。EncoderDecoderCache 使用组合模式将两个 Cache 对象封装在一起。

构造方式

python 复制代码
class EncoderDecoderCache(Cache):
    def __init__(self, *caches):
        if len(caches) == 1:
            # DDP 兼容:从合并数据拆分出自注意力和交叉注意力
            self_attention_cache_data, cross_attention_cache_data = [], []
            for combined_cache_data in caches[0]:
                if len(combined_cache_data) == 6:  # (self_k, self_v, self_sw, cross_k, cross_v, cross_sw)
                    self_attention_cache_data.append(combined_cache_data[:3])
                    cross_attention_cache_data.append(combined_cache_data[3:])
                elif len(combined_cache_data) == 4:  # 旧格式无 sliding_window
                    self_attention_cache_data.append(combined_cache_data[:2])
                    cross_attention_cache_data.append(combined_cache_data[2:])
            self.self_attention_cache = DynamicCache(self_attention_cache_data)
            self.cross_attention_cache = DynamicCache(cross_attention_cache_data)
        elif len(caches) == 2:
            # 标准用法:传入两个 Cache 对象
            if not isinstance(caches[0], Cache) or not isinstance(caches[1], Cache):
                raise TypeError("...")
            self.self_attention_cache = caches[0]
            self.cross_attention_cache = caches[1]

交叉注意力更新追踪

python 复制代码
self.is_updated = {}
for layer_idx in range(len(self.cross_attention_cache)):
    self.is_updated[layer_idx] = bool(
        self.cross_attention_cache.get_seq_length(layer_idx) > 0
    )

is_updated 字典追踪每层的交叉注意力缓存是否已被更新。这在编码器-解码器模型中很重要,因为交叉注意力缓存通常在编码器前向传播后一次性填充,后续解码步骤中不再变化。

方法委托

python 复制代码
def get_seq_length(self, layer_idx=0):
    return self.self_attention_cache.get_seq_length(layer_idx)

def reset(self):
    self.self_attention_cache.reset()
    self.cross_attention_cache.reset()
    for layer_idx in self.is_updated:
        self.is_updated[layer_idx] = False

def reorder_cache(self, beam_idx):
    self.self_attention_cache.reorder_cache(beam_idx)
    self.cross_attention_cache.reorder_cache(beam_idx)

迭代器

python 复制代码
def __iter__(self):
    """返回 (self_k, self_v, self_sw, cross_k, cross_v, cross_sw) 六元组"""
    for self_attention_layer, cross_attention_layer in \
            zip(self.self_attention_cache, self.cross_attention_cache):
        yield self_attention_layer + cross_attention_layer

13. 设计原理与架构总结

13.1 两级架构:层(Layer)+ 容器(Cache)

复制代码
┌─────────────────────────────────────────────────┐
│                    Cache 容器                     │
│  ┌─────────┐ ┌─────────┐       ┌─────────┐     │
│  │ Layer 0 │ │ Layer 1 │  ...  │ Layer N │     │
│  │ (全注意力)│ │(滑动窗口)│       │ (Mamba) │     │
│  └─────────┘ └─────────┘       └─────────┘     │
│                                                  │
│  统一接口: update / reset / reorder / crop       │
│  容器级功能: offloading / prefetch / iteration   │
└─────────────────────────────────────────────────┘
  • :负责单层的 KV 状态存储与更新逻辑
  • 容器:管理层列表,提供跨层操作和统一接口

13.2 懒初始化模式

所有缓存层都采用懒初始化:在首次收到真实张量时才分配内存。好处:

  1. 延迟设备绑定:在 GPU 张量到来之前不需要知道设备类型
  2. 延迟形状推断 :从真实输入推断 num_headshead_dimbatch_size
  3. 支持 torch.export :通过 early_initialization() 可提前初始化

13.3 动态 vs 静态的核心权衡

特性 DynamicLayer StaticLayer
内存分配 按需增长(torch.cat 预分配固定大小
torch.compile ❌ 不支持 ✅ 支持
torch.export ❌ 不支持 ✅ 支持(需 early_initialization
cudagraphs ❌ 不支持 ✅ 支持(mark_static_address
内存效率 只用所需 预分配可能浪费
适用场景 默认生成 高性能推理、编译部署

13.4 注册表模式实现开放-封闭原则

LAYER_TYPE_CACHE_MAPPING 使得新增模型层类型时:

  • 无需修改 DynamicCacheStaticCache 的代码
  • 只需 创建新的 CacheLayerMixin 子类并设置 layer_type
  • 自动注册机制确保导入即可用

13.5 GPU 卸载流水线

复制代码
                    ┌──────────────┐
                    │  默认流       │
                    │  Layer N 计算 │
                    └──────┬───────┘
                           │ 计算完成
                    ┌──────▼───────┐
                    │  卸载到 CPU   │ ← 阻塞确保计算完成
                    └──────────────┘

                    ┌──────────────┐
                    │  预取流       │
                    │  Layer N+1   │ ← 异步,不阻塞
                    │  CPU → GPU   │
                    └──────────────┘

14. 与其他模块的关系

14.1 与模型前向传播的交互

模型注意力层在 forward 中调用缓存:

python 复制代码
# 典型调用模式(在模型的 Attention 类中)
key_states, value_states = past_key_values.update(
    key_states, value_states, self.layer_idx
)

对于线性注意力层:

python 复制代码
conv_states = past_key_values.update_conv_state(conv_states, self.layer_idx)
recurrent_states = past_key_values.update_recurrent_state(recurrent_states, self.layer_idx)

14.2 与配置系统的交互

  • DynamicCache(config=model.config) --- 从配置推断层类型
  • StaticCache(config=model.config, max_cache_len=...) --- 从配置 + 最大长度构建
  • 关键配置字段:layer_typessliding_windowattention_chunk_sizenum_hidden_layersnum_kv_shared_layers

14.3 与注意力掩码系统的交互

get_mask_sizes() 方法为掩码生成提供 (kv_length, kv_offset) 信息:

python 复制代码
# 在模型的前向传播中
kv_length, kv_offset = past_key_values.get_mask_sizes(query_length, layer_idx)
# 据此生成因果掩码 + 滑动窗口掩码

14.4 与生成(Generation)系统的交互

  • reorder_cache(beam_idx) --- Beam Search 时重排缓存
  • crop(max_length) --- 辅助解码/对比搜索时裁剪缓存
  • reset() --- 多轮对话时重置缓存

14.5 与 torch.compile / torch.export 的交互

  • StaticCache + mark_static_address → 支持 cudagraphs
  • early_initialization() → 支持 torch.export
  • is_compileable 属性 → 生成系统判断是否可编译

14.6 导入关系图

复制代码
cache_utils.py
├── 导入 configuration_utils.PreTrainedConfig  (配置系统)
├── 导入 utils.is_torch_greater_or_equal       (版本检测)
├── 导入 utils.is_torchdynamo_compiling         (编译状态检测)
├── 导入 utils.is_hqq_available / is_optimum_quanto_available  (量化后端检测)
│
├── 被 __init__.py 导出: Cache, DynamicCache, StaticCache, EncoderDecoderCache, ...
├── 被 modeling_layers.py 导入: Cache
├── 被 modeling_outputs.py 导入: Cache, EncoderDecoderCache
├── 被 masking_utils.py 导入: Cache
├── 被 _typing.py 导入: Cache
└── 被各模型的 modeling_*.py 使用: past_key_values.update(key, value, layer_idx)
相关推荐
cg.family14 小时前
Spring生态启动过程
spring
__zRainy__14 小时前
Redis系列:核心数据类型与基础 API 解读
数据库·redis·缓存
大数据三康15 小时前
Java静态常量与静态导入:计算圆面积
java·开发语言
凤山老林15 小时前
68-Java ConcurrentHashMap
java·开发语言
憧憬成为java架构高手的小白15 小时前
苍穹外卖--day10(订单状态定时处理、来单提醒和客户催单)
java·spring boot
ch.ju16 小时前
Java Programming Chapter 4——Construction method
java·开发语言
小龙报16 小时前
【优选算法】双指针专项:1.移动零 2. 复写零 3.快乐数
java·c语言·数据结构·c++·python·算法·面试
AI行业学习16 小时前
CC-Switch Windows + macOS 下载安装配置全流程
java·开发语言·人工智能·python
Niliuershangba16 小时前
ChestnutCMS 栗子内容管理系统:从入门到模板开发实战
java·git·开源·gitlab·github·开源软件·gitcode