RAG论文阅读笔记

RAG(Retrieval-Augmented Generation,检索增强生成)由Facebook在2020年发表的论文《Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks》中提出,应用于知识敏感的NLP任务,如问答。RAG将问题求解划分为检索和生成两阶段,先通过检索,查找与问题相关的文档,再将文档和问题一并输入模型,由模型推理给出最终的答案,从而解决模型无法扩展知识和产生"幻觉"的问题。目前,RAG架构已逐步应用于各个领域,比如在金融领域,对于某公司的财务分析问题,可以从该公司的财报中查找和问题相关的段落,将其和问题一并输入模型,由模型推理给出财务分析结果。目前,很多开源的大模型解决方案也支持RAG架构,例如LangChain中的RetrievalQA。

本文是对论文《Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks》的阅读笔记,并基于LangChain中的RetrievalQA给出一个简单的示例。

介绍

预训练语言模型可以从数据中学习知识,在不访问外部知识库的情况下,直接作为参数化的隐式知识库,但其也存在以下缺点:模型无法扩展或修改知识,比如用某天前的数据预训练的模型无法直接回答该天后发生的事实,并且模型可能产生"幻觉",比如给出和事实不一致的问答。

因此,论文提出了RAG(Retrieval-Augmented Generation,检索增强生成)架构,如图1所示,其包括参数记忆(预训练语言模型作为生成器)与非参数记忆(预训练文档检索器)两部分。

对于问答,例如回答"Define 'middle ear'"这一问题,RAG架构首先将问题输入非参数记忆部分。非参数记忆部分包含两个子部分:查询编码器(Query Encoder) <math xmlns="http://www.w3.org/1998/Math/MathML"> q \mathbf{q} </math>q,其将问题进行向量化,文档索引(Document Index),其预先通过另一个编码器 <math xmlns="http://www.w3.org/1998/Math/MathML"> d \mathbf{d} </math>d将文档进行向量化,并构建文档向量索引。问题输入非参数记忆部分后,先通过查询编码器转化为问题向量,然后从文档向量索引中以最大内积搜索(Maximum Inner Product Search,MIPS)方式查找前K个文档。RAG再将查找出的K个文档和问题合并输入非参数记忆部分。非参数记忆部分即预训练语言模型,其推理生成相应的问答,例如"The middle ear includes the tympanic cavity and the three ossicles."

方法

使用公式描述RAG架构。令问题的词元序列为 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x,文档检索器为 <math xmlns="http://www.w3.org/1998/Math/MathML"> p η ( z ∣ x ) p_\eta(z|x) </math>pη(z∣x),其根据 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x返回最相关的K个文档 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z,生成器 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( y i ∣ x , z , y 1 : i − 1 ) p_\theta(y_i|x,z,y_{1:i-1}) </math>pθ(yi∣x,z,y1:i−1),其采用自回归的方式,根据 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x和 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z,以及回答的词元序列的前 <math xmlns="http://www.w3.org/1998/Math/MathML"> i − 1 i-1 </math>i−1个词元,预测第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i个词元,直至生成完整的回答。

模型

论文采用端到端的方式对检索器和生成器进行联合训练,并且采用两种模型。

第一种模型是RAG-Sequence模型,可用以下公式表示:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> p RAG-Sequence ( y ∣ x ) ≈ ∑ z ∈ top- k ( p ( ⋅ ∣ x ) ) p η ( z ∣ x ) p θ ( y ∣ x , z ) = ∑ z ∈ top- k ( p ( ⋅ ∣ x ) ) p η ( z ∣ x ) ∏ i N p θ ( y i ∣ x , z , y i : i − 1 ) p_{\text{RAG-Sequence}}(y|x)\approx\sum_{z\in\text{top-}k(p(\cdot|x))}{p_\eta(z|x)p_\theta(y|x,z)}=\sum_{z\in\text{top-}k(p(\cdot|x))}{p_\eta(z|x)\prod_i^N{p_\theta(y_i|x,z,y_{i:i-1})}} </math>pRAG-Sequence(y∣x)≈z∈top-k(p(⋅∣x))∑pη(z∣x)pθ(y∣x,z)=z∈top-k(p(⋅∣x))∑pη(z∣x)i∏Npθ(yi∣x,z,yi:i−1)

第二种模型是RAG-Token模型,可用以下公式表示:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> p RAG-Token ( y ∣ x ) ≈ ∏ i N ∑ z ∈ top- k ( p ( ⋅ ∣ x ) ) p η ( z ∣ x ) p θ ( y i ∣ x , z , y 1 : i − 1 ) p_{\text{RAG-Token}}(y|x)\approx\prod_i^N{\sum_{z\in\text{top-}k(p(\cdot|x))}{p_\eta(z|x)p_\theta(y_i|x,z,y_{1:i-1})}} </math>pRAG-Token(y∣x)≈i∏Nz∈top-k(p(⋅∣x))∑pη(z∣x)pθ(yi∣x,z,y1:i−1)

从公式对这两种模型进行理解,RAG-Sequence模型先计算每个文档条件下回答词元的概率分布,再连乘得到每个文档条件下回答的概率分布,最后再求和得到所有最相关文档条件下回答的概率分布,而RAG-Token模型先计算每个文档条件下回答词元的概率分布,再求和得到所有最相关文档条件下回答词元的概率分布,最后再连乘得到所有最相关文档条件下回答的概率分布。

检索器

检索器采用双编码器架构,可用以下公式表示:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> d ( z ) = BERT d ( z ) q ( x ) = BERT q ( x ) p η ( z ∣ x ) ∝ exp ⁡ ( d ( z ) ⊤ q ( x ) ) \begin{align} &\mathbf{d}(z)=\text{BERT}_d(z)\\ &\mathbf{q}(x)=\text{BERT}q(x)\\ &p\eta(z|x)\varpropto\exp(\mathbf{d}(z)^{\top}\mathbf{q}(x)) \end{align} </math>d(z)=BERTd(z)q(x)=BERTq(x)pη(z∣x)∝exp(d(z)⊤q(x))

其中,查询编码器 <math xmlns="http://www.w3.org/1998/Math/MathML"> q \mathbf{q} </math>q和文档编码器 <math xmlns="http://www.w3.org/1998/Math/MathML"> d \mathbf{d} </math>d均采用BERT。通过上述两个编码器分别得到问题和文档的向量。论文使用FASIS构建文档向量索引,并使用了HNSW算法对向量检索进行加速。向量检索的过程就是对问题和文档的向量求取内积作为相关度量,返回和问题向量内积最大、和问题最相关的前K个文档。

论文指出,检索器中的文档采用维基百科截至2018年12月的全量数据,将每篇维基百科文档切分为互不重叠的包含100个单词的块,每个块作为一个文档,共有2100万个文档。

生成器

生成器使用了Meta发布的大语言模型BART-Large,其采用编码器+解码器架构,共有4亿个参数。论文直接将问题和检索器返回的相关文档拼接在一起作为生成器的输入。

训练

论文的训练数据为问题、回答对 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x i , y i ) (x_i,y_i) </math>(xi,yi)集合,损失函数采用负对象似然函数:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∑ j − log ⁡ p ( y j ∣ x j ) \sum_j{-\log{p(y_j|x_j)}} </math>j∑−logp(yj∣xj)

使用Adam优化器进行梯度下降,对检索器和生成器进行端到端的联合训练。论文指出,训练时如果调整文档编码器所使用BERT模型的参数,则需要定期重新计算所有文档的向量并构建索引,成本较高。因此,训练时,文档编码器所使用BERT模型的参数固定、不更新,相应的文档向量和索引也不更新,而只对查询编码器所使用BERT模型和生成器所使用BART模型的参数进行更新。

示例

示例采用LangChain中的RetrievalQA,以下内容节选自笔者的笔记《Mac本地部署大模型体验AIGC能力》 RetrievalQA实现原理如图2所示,先构建本地知识库,包含三步:

  • 加载文档,LangChain提供多种BaseLoader实现进行文档加载;
  • 切分文本段,LangChain同时提供多种TextSplitter实现进行文本段切分;
  • 向量化文本段,使用向量化模型将文本段转化为向量,LangChain也支持多种方式的向量化模型,比如,OpenAIEmbeddings通过调用OpenAI的相关服务进行向量化,HuggingFaceEmbeddings可以远程或本地加载HuggingFace上的模型进行向量化;
  • 对文本段向量构建向量索引,LangChain也支持多种向量索引引擎,包括FaissChromaMilvus等。

再基于本地知识库进行模型推理,包含五步:

  • 输入问题;
  • 向量化问题,和文本段向量化一致,将问题转化为向量;
  • 搜索相关文本段,从向量索引中搜索和问题相关的文本段;
  • 拼接提示,根据提示模板将问题和相关文本段转化为提示;
  • 模型推理,输出答案。

代码文件retrieval_qa_demo.py如下所示:

python 复制代码
from langchain.chains import RetrievalQA
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import MarkdownTextSplitter
from langchain.vectorstores import Chroma
from chatglm2_llm import ChatGLM2

if __name__ == "__main__":
    # 加载文档
    loader = UnstructuredMarkdownLoader("/Users/xxx/workspace/docs/creative.md")
    documents = loader.load()
    # 切分文本
    text_splitter = MarkdownTextSplitter(chunk_size=1000, chunk_overlap=0)
    texts = text_splitter.split_documents(documents)
    # 初始化向量化模型
    embeddings = HuggingFaceEmbeddings(model_name="/Users/xxx/workspace/models/text2vec-large-chinese",)
    # 构建向量索引
    db = Chroma.from_documents(texts, embeddings)
    # 定义模型
    llm = ChatGLM2()
    # 加载模型
    llm.load_model("/Users/xxx/workspace/models/chatglm2-6b-int4")
    # 执行链路
    qa = RetrievalQA.from_chain_type(llm, chain_type="stuff", retriever=db.as_retriever(), verbose=True)
    print(qa.run("怎么创建程序化创意"))

其中,对于知识库文档,笔者使用《超级汇川程序化创意产品手册》这一文档,将其以Markdown格式下载至本地,使用UnstructuredMarkdownLoader进行加载,并使用MarkdownTextSplitter进行切分得到文本段。对于向量化模型,笔者使用HuggingFace上的GanymedeNil/text2vec-large-chinese,并下载至本地:

cd ~/workspace/models/ git lfs install #若ChatGLM-6B部分已执行,则无需再执行 git clone huggingface.co/GanymedeNil...

对于向量索引引擎,笔者使用Chroma;对于大语言模型,笔者使用之前已定义的ChatGLM2。对于问题和从向量索引返回的相关文本段,RetrievalQA按下述提示模板拼接提示:

Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

{context}

Question: {question} Helpful Answer:

retrieval_qa_demo.py运行结果如图3所示。

相关推荐
little redcap1 小时前
第十九次CCF计算机软件能力认证-乔乔和牛牛逛超市
数据结构·c++·算法
muyierfly2 小时前
34.贪心算法1
算法·贪心算法
luthane4 小时前
python 实现average mean平均数算法
开发语言·python·算法
静心问道5 小时前
WGAN算法
深度学习·算法·机器学习
杰九5 小时前
【算法题】46. 全排列-力扣(LeetCode)
算法·leetcode·深度优先·剪枝
manba_5 小时前
leetcode-560. 和为 K 的子数组
数据结构·算法·leetcode
liuyang-neu5 小时前
力扣 11.盛最多水的容器
算法·leetcode·职场和发展
忍界英雄5 小时前
LeetCode:2398. 预算内的最多机器人数目 双指针+单调队列,时间复杂度O(n)
算法·leetcode·机器人
Kenneth風车5 小时前
【机器学习(五)】分类和回归任务-AdaBoost算法-Sentosa_DSML社区版
人工智能·算法·低代码·机器学习·数据分析
AI小白龙*5 小时前
大模型团队招人(校招):阿里巴巴智能信息,2025届春招来了!
人工智能·langchain·大模型·llm·transformer