【深度学习】检索增强生成 RAG

RAG (Retrieval-Augmented Generation) 是由 Facebook AI Research (FAIR) 提出的。具体来说,RAG 是在 2020 年的论文 Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks 中首次被提出的。

为什么提出 RAG?

在 NLP 中,有很多知识密集型任务(Knowledge-Intensive Tasks),如开放领域问答、知识生成等,这些任务依赖于模型掌握大量的外部知识。然而,传统的预训练生成模型(如 GPT-3、BERT)存在以下问题:

  1. 知识静态性:模型只能使用训练过程中学到的知识,更新知识需要重新训练。
  2. 参数限制:模型大小有限,不可能记住所有知识。
  3. 生成质量的限制:没有外部支持的生成模型可能会生成不准确或不相关的内容。

RAG 的目标 是解决这些问题,通过引入一个外部知识库来增强生成模型的知识能力,同时使得模型更灵活和可扩展。

RAG 的创新

  1. 检索增强生成:
    • 通过引入检索机制,模型在生成答案时可以动态查询外部知识库(如维基百科)。
  2. 端到端训练:
    • 检索模块和生成模块可以端到端地联合训练,从而优化检索和生成的整体性能。
  3. 结合生成和检索的优点:
    • 保留生成模型的语言生成能力,同时利用检索模块动态补充知识,提升生成的准确性和相关性。

RAG 的核心思想

传统生成模型(如 GPT、BERT)在回答问题时依赖于预训练数据的记忆,而 RAG 提供了一种动态查询外部知识库的能力。具体而言:

  1. 检索(Retrieval):
    • 通过检索模型,从外部知识库(如维基百科、企业文档)中找到与输入问题最相关的文档。
  2. 生成(Generation):
    • 将检索到的文档作为上下文输入生成模型,由生成模型(如 BART、GPT)生成答案。

这种设计使 RAG 能够动态获取外部知识,解决生成模型对训练数据依赖的问题。

RAG 的工作流程

  1. 输入问题

    用户提供一个查询(Query),例如"Who won the Nobel Prize in Physics in 2023?"

  2. 检索阶段(Retrieval)

    • 使用基于嵌入的检索模型(如 SentenceTransformer、BM25、Dense Retriever)从知识库中挑选最相关的文档。
    • 知识库的内容通常被预处理为嵌入向量,存储在向量数据库(如 FAISS)中。
    • 输出是检索到的文档集合(例如 5 个文档)。
  3. 生成阶段(Generation)

    • 将检索到的文档与查询合并,作为上下文输入到生成模型中(如 BART、GPT)。
    • 模型基于上下文生成答案。
  4. 输出答案

    最终生成的答案由生成模型直接输出。

RAG 的优点

  1. 动态知识访问
    • 不依赖模型的固定训练数据,可以随时更新知识库,回答实时问题。
  2. 增强生成能力
    • 将生成任务与知识检索结合,可以显著提高答案的准确性和相关性。
  3. 可扩展性
    • 检索和生成阶段可以分别优化和扩展,比如替换更强的检索模型或生成模型。

RAG 的实现框架

RAG 通常由以下组件实现:

1. 检索模型
  • 用于从知识库中找到相关文档。
  • 常用技术:
    • BM25: 基于关键词的传统检索算法。
    • Dense Retriever: 使用嵌入模型(如 SentenceTransformer)生成文档和查询的向量,通过余弦相似度进行匹配。
    • 向量数据库: FAISS、Weaviate、Milvus。
2. 生成模型
  • 用于基于检索到的文档生成答案。
  • 常用模型:
    • BART: 一种强大的序列到序列生成模型。
    • T5 (Text-to-Text Transfer Transformer): 用于多任务生成。
    • GPT 系列: 强大的生成能力,适合处理长文档。
3. 知识库
  • RAG 的核心存储,包含外部知识源的嵌入或索引。
  • 数据来源:
    • 维基百科。
    • 领域特定文档。
    • 实时爬取的网页。
4. 上下文生成
  • 将检索结果与原始查询拼接,构建生成模型的输入。

RAG 的主要应用

  • 开放领域问答(Open-Domain QA)
  • 知识生成(Knowledge Generation)
  • 零样本学习(Zero-shot Learning)

RAG 的代码实现简要示例

以下是一个简单的 RAG 示例,结合检索和生成模型完成问答任务:

python 复制代码
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
import numpy as np
import torch
from sklearn.metrics.pairwise import cosine_similarity

# 加载生成模型(如 BART)
generator_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large")
generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

# 加载检索模型(如 SentenceTransformer)
retriever_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# 知识库(文档集合)
knowledge_base = [
    "Albert Einstein was awarded the Nobel Prize in 1921.",
    "The Nobel Prize in Physics in 2023 was awarded for quantum technologies.",
    "Quantum computing uses the principles of quantum mechanics.",
]

# 创建知识库嵌入
knowledge_embeddings = retriever_model.encode(knowledge_base)

# 用户问题
query = "Who won the Nobel Prize in Physics in 2023?"
query_embedding = retriever_model.encode([query])

# 检索阶段:计算余弦相似度
similarities = cosine_similarity(query_embedding, knowledge_embeddings)
top_k_idx = np.argsort(similarities[0])[-3:][::-1]  # 选取前 3 个文档
retrieved_docs = [knowledge_base[i] for i in top_k_idx]

# 生成阶段:将检索结果与查询拼接
context = " ".join(retrieved_docs)
input_text = f"Question: {query} Context: {context}"
inputs = generator_tokenizer.encode(input_text, return_tensors="pt")

# 生成答案
outputs = generator_model.generate(inputs, max_length=50, num_beams=5)
answer = generator_tokenizer.decode(outputs[0], skip_special_tokens=True)

print("Generated Answer:", answer)

示例运行结果

假设知识库中包含:

  • "Albert Einstein was awarded the Nobel Prize in 1921."
  • "The Nobel Prize in Physics in 2023 was awarded for quantum technologies."

输入问题:

text 复制代码
Who won the Nobel Prize in Physics in 2023?

可能的生成答案:

text 复制代码
The Nobel Prize in Physics in 2023 was awarded for quantum technologies.

总结

RAG 是一个强大的框架,将检索(准确找到相关信息)和生成(生成连贯的答案)相结合,解决了生成式模型知识过时和回答准确性的问题,是现代问答和知识生成任务的重要方向。

相关推荐
蚝油菜花3 分钟前
DeepSite:基于DeepSeek的开源AI前端开发神器,一键生成游戏/网页代码
人工智能·开源
蚝油菜花3 分钟前
PaperBench:OpenAI开源AI智能体评测基准,8316节点精准考核复现能力
人工智能·开源
蚝油菜花7 分钟前
DreamActor-M1:字节跳动推出AI动画黑科技,静态照片秒变生动视频
人工智能·开源
MPCTHU7 分钟前
预测分析(三):基于机器学习的分类预测
人工智能·机器学习·分类
jndingxin14 分钟前
OpenCV 图形API(11)对图像进行掩码操作的函数mask()
人工智能·opencv·计算机视觉
Scc_hy23 分钟前
强化学习_Paper_1988_Learning to predict by the methods of temporal differences
人工智能·深度学习·算法
袁煦丞27 分钟前
【亲测】1.5万搞定DeepSeek满血版!本地部署避坑指南+内网穿透黑科技揭秘
人工智能·程序员·远程工作
大模型真好玩28 分钟前
理论+代码一文带你深入浅出MCP:人工智能大模型与外部世界交互的革命性突破
人工智能·python·mcp
遇码42 分钟前
大语言模型开发框架——LangChain
人工智能·语言模型·langchain·llm·大模型开发·智能体