深入浅出ColBERT模型——实现高效精确的信息检索

为什么看这个论文

最近鼓捣RAG,RAG里面有一个重要部分就是,需要计算prompt和知识库中存储知识的相似度,从而找到最相关知识,因此翻了一下提升检索效果的相关论文。

什么是 ColBERT?

"ColBERT" 指"基于 BERT 的上下文化延迟交互" (Contextualized Late Interaction over BERT),这是一个由斯坦福大学开发的模型。它充分利用了 BERT 的深度语言理解能力,同时引入了一种被称为"延迟交互"的交互机制,通过在检索过程的最后阶段之前分别处理查询和文档,实现了高效而精确的检索。目前,ColBERT 主要有两个版本:

ColBERT: 由 Omar Khattab 和 Matei Zaharia 发表,发表在 SIGIR 2020 上。这篇 ColBERT 论文首次介绍了"延迟交互"的概念。

ColBERTv2: 在初代 ColBERT 的基础上,Omar Khattab 继续深入研究,与 Barlas Oguz、Matei Zaharia 和 Michael S. Bernstein 合作,在 SIGIR 2021 会议上提出了ColBERTv2。这个 ColBERT 的升级版本引入了去噪监督和残差压缩,提高了模型的检索效果和存储效率。

解析 ColBERT 的设计理念

尽管 ColBERTv2 在架构上与原始 ColBERT 极为相似,其创新在训练技术和压缩机制上,我们先深入了解一下原始 ColBERT 的基本概念。

ColBERT 中的延迟交互

在信息检索领域,"交互" (interaction) 指的是通过比较查询(query)和文档(document)的表示(representations)来评估它们之间相关性的过程。

"延迟交互" (late interaction) 是 ColBERT 的核心创新:查询和文档表示之间的交互被推迟到处理的最后阶段,即在两者被独立编码之后才进行。这与"早期交互" (early interaction) 模型形成鲜明对比,后者在模型编码过程中或之前就进行查询和文档嵌入的交互。

交互类型 对应模型
早期交互 BERT, ANCE, DPR, Sentence-BERT, DRMM, KNRM, Conv-KNRM
延迟交互 ColBERT, ColBERTv2

早期交互模型虽然可能更精确,但计算复杂度高,因为它需要考虑所有可能的查询-文档组合。

相比之下,ColBERT 这类延迟交互模型通过预先计算文档表示,并在最后采用更轻量级的交互步骤,大大提高了效率和可扩展性。这种设计使得检索速度更快,计算需求更低,特别适合处理海量文档集合。

(a) 基于表示的相似性 (Representation-based Similarity): 在这种方法中,文档片段通过某些深度神经网络进行离线编码 (即预先处理),而查询片段则以类似的方式进行在线编码 (即实时处理)。然后计算查询与所有文档之间的成对相似性(通常是cos相似度)得分,并返回得分最高的几个文档。

(b) 查询-文档交互 (Query-Document Interaction): 在这种方法中,通过使用 n-gram (即连续的 n 个词组成的序列) 计算查询和文档中所有单词 之间的词语和短语级关系,作为一种特征工程的形式。这些关系被转换为交互矩阵,然后作为输入提供给卷积网络。

(c) 全对全交互 (All-to-all Interaction): 这可以被视为方法 (b) 的泛化,它考虑了查询和文档内部以及彼此之间的所有成对交互。这是通过自注意力机制 (self-attention) 实现的。在实践中,通常使用 BERT (Bidirectional Encoder Representations from Transformers) 来实现,因为它是双向的,因此能够真正模拟所有成对交互,而不仅仅是因果关系。

(d) 后期交互 (Late Interaction): 这是该论文引入的新方法。在这种方法中,可以使用 BERT 离线计算文档嵌入,然后在线计算查询嵌入。之后,他们在查询和文档的嵌入之间应用了一个称为 MaxSim 运算符 (最大相似度运算符,具体解释将在后文给出) 的操作。这种架构的可视化如下所示:

ColBERT 的查询和文档编码器

ColBERT 的编码策略基于 BERT 模型。ColBERT 为查询或文档中的每个标记 (token) 生成密集的向量表示,分别为查询和文档创建一系列考虑上下文的嵌入向量。这种设计为后续的延迟交互阶段提供了细致入微的比较基础。

ColBERT 的查询编码过程

假设有一个查询 Q,其标记(token)为 q1, q2, ..., ql,处理步骤如下:

  • 将 Q 转换为 BERT 使用的 WordPiece 标记(token) (一种子词分词方法)。
  • 在序列开头添加一个特殊的 [Q] 标记(token),紧随 BERT 的 [CLS] 标记(token)之后,用于标识查询的开始。
  • 如果查询长度不足预设的 Nq 个标记(token),用 [mask] 标记(token)填充;若超过则截断。
  • 将处理后的序列输入 BERT,然后通过卷积神经网络 (CNN) 处理,最后进行归一化。

最终输出的查询嵌入向量集合 Eq 可表示为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> E q : = N o r m a l i z e ( B E R T ( [ Q ] , q 0 , q 1 , ... , q l , [ m a s k ] , [ m a s k ] , ... , [ m a s k ] ) ) Eq := Normalize(BERT([Q], q0, q1, ..., ql, [mask], [mask], ..., [mask])) </math>Eq:=Normalize(BERT([Q],q0,q1,...,ql,[mask],[mask],...,[mask]))

ColBERT 的文档编码过程

对于包含标记 d1, d2, ..., dn 的文档 D,处理步骤类似:

  • 在序列开头添加 [D] 标记,标识文档开始。
  • 无需填充,直接输入 BERT 进行处理。

文档嵌入向量集合 Ed 可表示为: Ed := Filter(Normalize(BERT([D], d0, d1, ..., dn)))

Filter用于去除与标点符号对应的嵌入,从而提升分析速度。这里的查询填充策略 (论文中称为"查询增强")确保了所有查询长度一致,有利于批量处理。而 [Q] 和 [D] 标记则帮助模型区分输入类型,提高了处理效率。

使用 ColBERT 查找最相关的前 K 个文档

一旦我们获得了查询和文档的嵌入 (embeddings,即将文本转换为数值向量��表示),找到最相关的前 K 个文档就变得相对简单.

计算过程包括:

  1. 批量点积计算:用于计算词语级别的相似度。每一个词都和整个文档进行计算
  2. 最大池化 (max-pooling):在文档词语上进行操作,找出每个查询词语的最高相似度。
  3. 求和:对查询词语的相似度分数进行累加,得出文档的总体相关性分数。
  4. 排序:根据总分对文档进行排序。

以下是用 PyTorch 实现这些操作的伪代码:

python 复制代码
import torch

def compute_relevance_scores(query_embeddings, document_embeddings, k):
    """
    计算给定查询的前 k 个最相关文档。
    
    参数:
    query_embeddings: 表示查询嵌入的张量 (tensor),形状: [查询词数, 嵌入维度]
    document_embeddings: 表示 k 个文档嵌入的张量,形状: [k, 文档最大长度, 嵌入维度]
    k: 需要重新排序的顶部文档数量
    
    返回: 基于相关性分数排序的文档索引
    """
    
    # 注: 假设 document_embeddings 已经进行了适当的填充并转移到 GPU

    # 1. 计算查询嵌入和文档嵌入的批量点积
    scores = torch.matmul(query_embeddings.unsqueeze(0), document_embeddings.transpose(1, 2))
    
    # 2. 在文档词语维度上应用最大池化,找出每个查询词语的最大相似度
    max_scores_per_query_term = scores.max(dim=2).values
    
    # 3. 对查询词语的分数求和,得到每个文档的总分
    total_scores = max_scores_per_query_term.sum(dim=1)
    
    # 4. 根据总分对文档进行降序排序
    sorted_indices = total_scores.argsort(descending=True)
    
    return sorted_indices

这个过程在模型的训练和推理阶段都会用到。ColBERT 模型的训练采用了成对排序损失 (pairwise ranking loss) 。训练数据由三元组 (q, d+, d−) 组成,其中:

  • q 表示查询
  • d+ 是与查询相关的正面文档
  • d− 是与查询不相关的负面文档

模型的目标是学习出这样的表示:查询 q 与相关文档 d+ 之间的相似度分数应该高于 q 与不相关文档 d− 之间的分数。 这个训练目标可以数学化表示为最小化以下损失函数:

Loss = max(0, 1 − S(q, d+) + S(q, d−))

其中 S(q, d) 表示 ColBERT 计算的查询 q 和文档 d 之间的相似度分数。这个分数是通过聚合查询和文档中最佳匹配词语嵌入的最大相似度得到的,遵循了模型架构中描述的"后期交互"模式。

这种方法确保了模型能够有效区分给定查询的相关和不相关文档,通过鼓励相关文档对和不相关文档对之间的相似度分数有更大的差距。这样,ColBERT 就能在实际应用中更准确地检索出与用户查询最相关的文档。

ColBERT 的索引策略与效率提升

与传统的将整个文档压缩成单一向量的方法不同,ColBERT 为文档 (和查询) 中的每个词元 (token) 都创建独立的嵌入 (embedding) 向量。这种方法虽然增加了存储需求,也带来了显著的性能提升。 为了有效管理大量的嵌入向量,ColBERT 采用了以下策略:

  • 使用向量数据库 (如 FAISS,Facebook AI Similarity Search) 进行高效索引和检索。
  • 离线处理文档嵌入:预先计算并存储文档的嵌入向量,充分利用批处理和 GPU 加速。
  • 灵活的存储方案:可以选择使用 32 位或 16 位数值存储每个维度,在精度和存储空间之间取得平衡。

ColBERTv2 进一步改进了存储效率,引入了"残差压缩"技术。这种方法通过只存储嵌入向量与预定义参考点之间的差异,将模型的空间占用减少了 6 到 10 倍,同时保持了检索质量。

相关推荐
凡人的AI工具箱4 分钟前
每天40分玩转Django:Django类视图
数据库·人工智能·后端·python·django·sqlite
千天夜10 分钟前
深度学习中的残差网络、加权残差连接(WRC)与跨阶段部分连接(CSP)详解
网络·人工智能·深度学习·神经网络·yolo·机器学习
凡人的AI工具箱14 分钟前
每天40分玩转Django:实操图片分享社区
数据库·人工智能·后端·python·django
小军军军军军军18 分钟前
MLU运行Stable Diffusion WebUI Forge【flux】
人工智能·python·语言模型·stable diffusion
诚威_lol_中大努力中41 分钟前
关于VQ-GAN利用滑动窗口生成 高清图像
人工智能·神经网络·生成对抗网络
中关村科金1 小时前
中关村科金智能客服机器人如何解决客户个性化需求与标准化服务之间的矛盾?
人工智能·机器人·在线客服·智能客服机器人·中关村科金
逸_1 小时前
Product Hunt 今日热榜 | 2024-12-25
人工智能
Luke Ewin1 小时前
基于3D-Speaker进行区分说话人项目搭建过程报错记录 | 通话录音说话人区分以及语音识别 | 声纹识别以及语音识别 | pyannote-audio
人工智能·语音识别·声纹识别·通话录音区分说话人
DashVector1 小时前
如何通过HTTP API检索Doc
数据库·人工智能·http·阿里云·数据库开发·向量检索
说私域1 小时前
无人零售及开源 AI 智能名片 S2B2C 商城小程序的深度剖析
人工智能·小程序·零售