分层记忆缓冲:AI大模型长文本处理的“记忆宫殿”

在认知科学中,人类记忆并非单一容器,而是由感觉记忆、短时记忆和长时记忆 构成的分层系统。计算机架构师也早已深谙此道------从L1缓存到内存再到磁盘,逐级扩展容量,每一层都平衡着速度与成本。如今,这个思想正在大语言模型领域焕发新生,帮助Transformer突破上下文窗口的限制。这就是本文要探讨的核心:分层记忆缓冲(Hierarchical Memory Buffer)

本文不仅会讲解概念,还会给出可落地的PyTorch代码,覆盖工作记忆、情节记忆、压缩记忆及自主调度等完整实现。


1. 为什么大模型需要一个记忆系统?

标准Transformer的自注意力复杂度随序列长度平方增长,即便有了FlashAttention等优化,处理百万级token的长文档时,仍面临两大顽疾:

  • 遗忘首部信息:过长输入会超出位置编码有效范围,模型"看了后面忘前面"。
  • 推理成本爆炸:KV Cache线性增长,内存和计算不堪重负。

一种自然的思路是:我们不把全部历史压进同一个注意力窗口,而是让模型学会分层存储和召回信息。 这正是分层记忆缓冲的出发点。


2. 分层记忆缓冲的通用蓝图

在神经网络中,分层记忆通常抽象为三层结构:

层级 类比 容量 读写速度 典型实现
工作记忆 L1 缓存 / 短时记忆 几k tokens 极高(直接注意力) 当前窗口的KV Cache
情节记忆 内存 / 长时记忆 几十万tokens 中等(检索/前馈) 外部键值库、kNN索引
语义记忆 磁盘 / 永久知识 近乎无限 较慢(压缩/参数化) 模型参数、向量数据库、摘要树

推理时,模型就像一位带着笔记本的学者:工作记忆是当前段落;情节记忆是手边快速查阅的索引卡片;语义记忆是大脑中长期内化的知识。

下面我们逐层用代码实现。


3. 工作记忆:当前窗口的KV Cache

任何Transformer推理都离不开KV Cache。在分层记忆中,工作记忆就是当前正在处理的片段对应的缓存,通过限制长度来模拟容量上限。

python 复制代码
import torch
import torch.nn as nn

class WorkingMemory(nn.Module):
    def __init__(self, num_layers, num_heads, head_dim, max_len=4096):
        super().__init__()
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.max_len = max_len
        # 每一层维护K和V的缓存,初始为空
        self.k_cache = None
        self.v_cache = None

    def update(self, new_k, new_v, layer_idx):
        """将新片段的KV追加到缓存,并截断至最大长度"""
        if self.k_cache is None:
            self.k_cache = [None] * self.num_layers
            self.v_cache = [None] * self.num_layers

        if self.k_cache[layer_idx] is None:
            self.k_cache[layer_idx] = new_k
            self.v_cache[layer_idx] = new_v
        else:
            self.k_cache[layer_idx] = torch.cat([self.k_cache[layer_idx], new_k], dim=1)
            self.v_cache[layer_idx] = torch.cat([self.v_cache[layer_idx], new_v], dim=1)

        # 截断,保证工作记忆不溢出
        if self.k_cache[layer_idx].size(1) > self.max_len:
            self.k_cache[layer_idx] = self.k_cache[layer_idx][:, -self.max_len:]
            self.v_cache[layer_idx] = self.v_cache[layer_idx][:, -self.max_len:]

在实际注意力计算时,query不仅关注当前片段的KV,还会关注工作记忆中的KV。这正是标准自回归生成流程,此处不再赘述。


4. 情节记忆:kNN增强的外部记忆(Memorizing Transformers)

Google的Memorizing Transformers将过去所有token的Key-Value存入kNN索引,作为情节记忆。我们使用faiss实现一个简化版。

4.1 构建外部记忆库

python 复制代码
import faiss
import numpy as np

class EpisodicMemory:
    def __init__(self, key_dim, capacity=100000):
        self.key_dim = key_dim
        self.capacity = capacity
        self.keys = []          # 存储所有过去的Key
        self.values = []        # 存储所有过去的Value
        self.index = faiss.IndexFlatIP(key_dim)  # 内积相似度,与注意力点积对齐

    def add(self, keys, values):
        """keys: [seq_len, key_dim], values: [seq_len, value_dim]"""
        self.keys.extend(keys.detach().cpu().numpy())
        self.values.extend(values.detach().cpu().numpy())
        # 保持容量限制
        if len(self.keys) > self.capacity:
            self.keys = self.keys[-self.capacity:]
            self.values = self.values[-self.capacity:]
        # 重建索引(实际可使用增量索引,此处简化)
        self.index = faiss.IndexFlatIP(self.key_dim)
        if len(self.keys) > 0:
            self.index.add(np.array(self.keys).astype(np.float32))

    def search(self, query, top_k=32):
        """query: [batch*heads, q_len, key_dim]"""
        orig_shape = query.shape
        query_np = query.reshape(-1, self.key_dim).detach().cpu().numpy().astype(np.float32)
        scores, indices = self.index.search(query_np, top_k)
        # 根据索引取出对应的value
        retrieved_vals = []
        for idx_row in indices:
            row_vals = [self.values[i] for i in idx_row]
            retrieved_vals.append(torch.tensor(np.array(row_vals)))
        retrieved_vals = torch.stack(retrieved_vals).view(*orig_shape[:-1], top_k, -1)
        return retrieved_vals, torch.tensor(scores).view(*orig_shape[:-1], top_k)

4.2 将外部记忆融入注意力

修改注意力计算,将检索到的记忆值通过softmax融合,并使用可学习的门控与本地注意力结合。

python 复制代码
def attention_with_memory(query, key, value, episodic_memory, top_k=32):
    # 1. 正常局部注意力
    attn_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
    attn_probs = torch.softmax(attn_scores, dim=-1)
    local_output = torch.matmul(attn_probs, value)

    # 2. 从情节记忆检索
    mem_values, mem_scores = episodic_memory.search(query, top_k)
    mem_scores = mem_scores / math.sqrt(query.size(-1))
    mem_probs = torch.softmax(mem_scores, dim=-1)
    mem_output = torch.matmul(mem_probs.unsqueeze(-2), mem_values).squeeze(-2)

    # 3. 可学习门控融合(此处简化为固定值,实际可训练)
    gate = torch.sigmoid(torch.tensor(0.5))
    output = gate * local_output + (1 - gate) * mem_output
    return output

推理一个片段后,将该片段的K、V存入情节记忆:

python 复制代码
episodic_memory.add(layer_k[0], layer_v[0])  # batch中第一个样本

5. 递归压缩记忆:用摘要向量传递(AutoCompressor / Infini-Transformer)

另一种路径是将长序列压缩为固定数量的"记忆token"。这些token作为下一片段的前缀,扮演情节记忆。

5.1 记忆压缩模块

python 复制代码
class CompressiveMemory(nn.Module):
    def __init__(self, dim, num_memory_tokens=16):
        super().__init__()
        # 可学习的记忆查询向量,负责从片段中提取信息
        self.memory_queries = nn.Parameter(torch.randn(num_memory_tokens, dim))
        self.cross_attn = nn.MultiheadAttention(dim, num_heads=8, batch_first=True)

    def forward(self, segment_hidden):
        """
        segment_hidden: [batch, seg_len, dim]
        返回压缩后的记忆: [batch, num_memory_tokens, dim]
        """
        queries = self.memory_queries.unsqueeze(0).expand(segment_hidden.size(0), -1, -1)
        compressed, _ = self.cross_attn(queries, segment_hidden, segment_hidden)
        return compressed

5.2 片段间记忆传递

处理长文档时,将前一片段的压缩记忆拼接到当前段embedding之前,实现记忆的递归传递。

python 复制代码
class HierarchicalTransformer(nn.Module):
    def __init__(self, base_transformer, num_memory_tokens=16):
        super().__init__()
        self.transformer = base_transformer
        self.memory_compressor = CompressiveMemory(base_transformer.d_model, num_memory_tokens)
        self.memory = None  # 上一片段的压缩记忆

    def forward(self, input_ids, segment_length=2048):
        segments = input_ids.split(segment_length, dim=1)
        outputs = []
        for seg in segments:
            if self.memory is not None:
                seg_emb = self.transformer.embedding(seg)
                seg_emb = torch.cat([self.memory, seg_emb], dim=1)  # 记忆作为前缀
            else:
                seg_emb = self.transformer.embedding(seg)

            hidden = self.transformer(seg_emb)  # 简化,实际需处理mask
            outputs.append(hidden)
            # 压缩当前段最后一部分作为新记忆
            self.memory = self.memory_compressor(hidden[:, -segment_length:])

        return torch.cat(outputs, dim=1)

这种设计使得记忆规模恒定,不会随时间增长。


6. 操作系统式记忆:LLM自主管理读写(MemGPT)

MemGPT让LLM通过函数调用显式管理外部记忆。我们可借助OpenAI Function Calling的风格实现。

6.1 定义记忆工具

python 复制代码
import json

class MemoryStore:
    def __init__(self):
        self.storage = {}
        self.conversation_history = []

    def read(self, key):
        return self.storage.get(key, "Memory not found.")

    def write(self, key, content):
        self.storage[key] = content
        return f"Stored '{key}'."

    def search(self, query):
        results = {k: v for k, v in self.storage.items() if query in v}
        return json.dumps(results)

# 工具定义(符合OpenAI function calling格式)
tools = [
    {
        "name": "read_memory",
        "description": "Read content from external memory by key.",
        "parameters": {
            "type": "object",
            "properties": {"key": {"type": "string"}},
            "required": ["key"]
        }
    },
    {
        "name": "write_memory",
        "description": "Write a key-content pair to external memory.",
        "parameters": {
            "type": "object",
            "properties": {
                "key": {"type": "string"},
                "content": {"type": "string"}
            },
            "required": ["key", "content"]
        }
    },
    {
        "name": "search_memory",
        "description": "Search memory for a query string.",
        "parameters": {
            "type": "object",
            "properties": {"query": {"type": "string"}},
            "required": ["query"]
        }
    }
]

6.2 自主记忆调度

与LLM交互时,让模型决定何时读写记忆。

python 复制代码
def llm_with_memory(user_message, model):
    messages = [
        {"role": "system", "content": "You have an external memory. Use read/write/search_memory to manage it."},
        {"role": "user", "content": user_message}
    ]
    response = model.chat(messages, tools=tools)

    if response.tool_calls:
        for tool_call in response.tool_calls:
            func_name = tool_call.function.name
            args = json.loads(tool_call.function.arguments)
            if func_name == "read_memory":
                result = memory_store.read(args["key"])
            elif func_name == "write_memory":
                result = memory_store.write(args["key"], args["content"])
            elif func_name == "search_memory":
                result = memory_store.search(args["query"])
            messages.append({"role": "tool", "content": result, "name": func_name})
        final_response = model.chat(messages)
        return final_response.content
    else:
        return response.content

模型可以自行将不重要的内容换出,需要时再检索,实现动态上下文扩展。


7. 训练分层记忆:让梯度流过记忆边界

要让模型学会何时写入、如何压缩,记忆操作必须可微或采用强化学习。

7.1 可微的近似检索

在训练时,用全部过去key的softmax近似替代kNN硬检索,使梯度能够回传。

python 复制代码
def differentiable_memory_retrieval(query, all_past_keys, all_past_values, top_k=32):
    scores = torch.matmul(query, all_past_keys.transpose(-2, -1)) / math.sqrt(query.size(-1))
    topk_scores, topk_indices = torch.topk(scores, top_k, dim=-1)
    topk_probs = torch.softmax(topk_scores, dim=-1)
    retrieved_values = torch.gather(all_past_values, 1,
                                    topk_indices.unsqueeze(-1).expand(-1, -1, -1, all_past_values.size(-1)))
    return torch.matmul(topk_probs.unsqueeze(-2), retrieved_values).squeeze(-2)

7.2 压缩记忆的自监督损失

对于压缩记忆,可以要求模型从压缩向量重建原始片段,作为辅助损失。

python 复制代码
def compression_loss(compressed_memory, original_segment, decoder):
    reconstructed = decoder(compressed_memory)
    loss = nn.CrossEntropyLoss()(reconstructed.view(-1, vocab_size), original_segment.view(-1))
    return loss

联合主任务损失一起优化,迫使压缩记忆保留足够细节。


8. 最小可行示例:串起整个系统

下面代码演示了一个极简的分层记忆LLM,结合了工作记忆(KV Cache)和情节记忆(外部存储)。

python 复制代码
class MiniHierarchicalLLM:
    def __init__(self, transformer, episodic_memory_capacity=10000):
        self.model = transformer
        self.working_memory = WorkingMemory(
            num_layers=transformer.num_layers,
            num_heads=transformer.num_heads,
            head_dim=transformer.head_dim,
            max_len=4096
        )
        self.episodic = EpisodicMemory(key_dim=transformer.d_model, capacity=episodic_memory_capacity)
        self.max_seg_len = 2048

    def generate(self, input_ids, max_new_tokens=100):
        # 分段处理输入,更新记忆
        segments = input_ids.split(self.max_seg_len, dim=1)
        for seg in segments:
            hidden = self.model(seg, use_cache=True,
                                past_key_values=self.working_memory.k_cache)
            # 更新工作记忆
            self.working_memory.k_cache = hidden.past_key_values
            # 将当前段的K,V存入情节记忆(取最后一层)
            last_k, last_v = hidden.past_key_values[-1]
            self.episodic.add(last_k.squeeze(0), last_v.squeeze(0))

        generated = []
        current = input_ids[:, -1:]  # 从最后一个token开始自回归生成
        for _ in range(max_new_tokens):
            output = self.model(
                current, use_cache=True,
                past_key_values=self.working_memory.k_cache,
                episodic_memory=self.episodic  # 需要自行修改forward支持
            )
            next_token = output.logits[:, -1:].argmax(dim=-1)
            generated.append(next_token)
            current = next_token
            # 工作记忆的缓存在模型内部自动更新
        return torch.cat(generated, dim=1)

你可以从最简单的外部向量存储+检索开始,逐步加入压缩、自主调度和可微训练,让你的模型拥有真正的长时记忆。


9. 挑战与未来

分层记忆缓冲已在代码库理解、终生对话代理等任务上展现潜力,但仍面临挑战:

  • 记忆冗余与遗忘:如何优雅地淘汰旧信息?能否模拟记忆的"再巩固"过程?
  • 跨层级重组:能否增加离线阶段,自动将情节记忆提炼进语义记忆(模型参数)?
  • 隐私与安全:外部记忆可能包含敏感信息,选择性遗忘机制至关重要。
  • 多模态统一记忆:能否将文本、图像、音频映射到同一套键值空间?

10. 结语

分层记忆缓冲并非要让大模型变成笨重的数据库系统,而是赋予它一种组织自身经验的能力。正如记忆术中的"记忆宫殿"------将信息放置在熟悉的空间结构中,逐层导引,随时提取。

本文给出的代码片段为你提供了构建记忆系统的基石。无论你是想为聊天机器人增加长期记忆,还是让代码助手理解整个仓库,都可以从这里开始。随着我们向通用人工智能迈进,记忆的架构可能比模型本身更能定义其思考的深度与连贯性。


延伸阅读:

  • Memorizing Transformers (Wu et al., 2022)
  • MemGPT: Towards LLMs as Operating Systems (Packer et al., 2023)
  • Infini-Transformer: Infinite Context with Compressive Memory
  • AutoCompressor: Long Context Compression via Summary Tokens
相关推荐
沉默王二1 小时前
老板:“请说出一个录用你的理由。”我脱口而出:“每个月 AI 支出都超过我的生活费了!”老板愣了一下,随即哈哈大笑:“好吧,你被录用了。”
人工智能·ai编程·claude
小林ixn1 小时前
一文搞懂AI Agent核心概念:从LLM、Tools到记忆体,手把手带你实现一个能查股价的智能体
agent·ai编程
乘风gg3 小时前
OpenClaw 爆火,但”飞书"赢麻了!!!
前端·ai编程·claude
袋鱼不重17 小时前
我的神奇同事,AI 用多了居然写了个 Open In Codex
前端·后端·ai编程
量子位18 小时前
刚刚,Fable-5之下,智谱开源的GLM-5.2拿下AI编程第一!
ai编程
量子位18 小时前
SpaceX一分现金没花收购Cursor,马斯克吞下AI编程工具第一名
ai编程
程序员黑豆18 小时前
JDK 下载安装与配置详细教程
java·前端·ai编程
孟健18 小时前
我装了 Hermes Desktop,但最后还是回到 Telegram
ai编程