RAG论文阅读笔记

RAG(Retrieval-Augmented Generation,检索增强生成)由Facebook在2020年发表的论文《Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks》中提出,应用于知识敏感的NLP任务,如问答。RAG将问题求解划分为检索和生成两阶段,先通过检索,查找与问题相关的文档,再将文档和问题一并输入模型,由模型推理给出最终的答案,从而解决模型无法扩展知识和产生"幻觉"的问题。目前,RAG架构已逐步应用于各个领域,比如在金融领域,对于某公司的财务分析问题,可以从该公司的财报中查找和问题相关的段落,将其和问题一并输入模型,由模型推理给出财务分析结果。目前,很多开源的大模型解决方案也支持RAG架构,例如LangChain中的RetrievalQA。

本文是对论文《Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks》的阅读笔记,并基于LangChain中的RetrievalQA给出一个简单的示例。

介绍

预训练语言模型可以从数据中学习知识,在不访问外部知识库的情况下,直接作为参数化的隐式知识库,但其也存在以下缺点:模型无法扩展或修改知识,比如用某天前的数据预训练的模型无法直接回答该天后发生的事实,并且模型可能产生"幻觉",比如给出和事实不一致的问答。

因此,论文提出了RAG(Retrieval-Augmented Generation,检索增强生成)架构,如图1所示,其包括参数记忆(预训练语言模型作为生成器)与非参数记忆(预训练文档检索器)两部分。

对于问答,例如回答"Define 'middle ear'"这一问题,RAG架构首先将问题输入非参数记忆部分。非参数记忆部分包含两个子部分:查询编码器(Query Encoder) <math xmlns="http://www.w3.org/1998/Math/MathML"> q \mathbf{q} </math>q,其将问题进行向量化,文档索引(Document Index),其预先通过另一个编码器 <math xmlns="http://www.w3.org/1998/Math/MathML"> d \mathbf{d} </math>d将文档进行向量化,并构建文档向量索引。问题输入非参数记忆部分后,先通过查询编码器转化为问题向量,然后从文档向量索引中以最大内积搜索(Maximum Inner Product Search,MIPS)方式查找前K个文档。RAG再将查找出的K个文档和问题合并输入非参数记忆部分。非参数记忆部分即预训练语言模型,其推理生成相应的问答,例如"The middle ear includes the tympanic cavity and the three ossicles."

方法

使用公式描述RAG架构。令问题的词元序列为 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x,文档检索器为 <math xmlns="http://www.w3.org/1998/Math/MathML"> p η ( z ∣ x ) p_\eta(z|x) </math>pη(z∣x),其根据 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x返回最相关的K个文档 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z,生成器 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( y i ∣ x , z , y 1 : i − 1 ) p_\theta(y_i|x,z,y_{1:i-1}) </math>pθ(yi∣x,z,y1:i−1),其采用自回归的方式,根据 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x和 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z,以及回答的词元序列的前 <math xmlns="http://www.w3.org/1998/Math/MathML"> i − 1 i-1 </math>i−1个词元,预测第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i个词元,直至生成完整的回答。

模型

论文采用端到端的方式对检索器和生成器进行联合训练,并且采用两种模型。

第一种模型是RAG-Sequence模型,可用以下公式表示:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> p RAG-Sequence ( y ∣ x ) ≈ ∑ z ∈ top- k ( p ( ⋅ ∣ x ) ) p η ( z ∣ x ) p θ ( y ∣ x , z ) = ∑ z ∈ top- k ( p ( ⋅ ∣ x ) ) p η ( z ∣ x ) ∏ i N p θ ( y i ∣ x , z , y i : i − 1 ) p_{\text{RAG-Sequence}}(y|x)\approx\sum_{z\in\text{top-}k(p(\cdot|x))}{p_\eta(z|x)p_\theta(y|x,z)}=\sum_{z\in\text{top-}k(p(\cdot|x))}{p_\eta(z|x)\prod_i^N{p_\theta(y_i|x,z,y_{i:i-1})}} </math>pRAG-Sequence(y∣x)≈z∈top-k(p(⋅∣x))∑pη(z∣x)pθ(y∣x,z)=z∈top-k(p(⋅∣x))∑pη(z∣x)i∏Npθ(yi∣x,z,yi:i−1)

第二种模型是RAG-Token模型,可用以下公式表示:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> p RAG-Token ( y ∣ x ) ≈ ∏ i N ∑ z ∈ top- k ( p ( ⋅ ∣ x ) ) p η ( z ∣ x ) p θ ( y i ∣ x , z , y 1 : i − 1 ) p_{\text{RAG-Token}}(y|x)\approx\prod_i^N{\sum_{z\in\text{top-}k(p(\cdot|x))}{p_\eta(z|x)p_\theta(y_i|x,z,y_{1:i-1})}} </math>pRAG-Token(y∣x)≈i∏Nz∈top-k(p(⋅∣x))∑pη(z∣x)pθ(yi∣x,z,y1:i−1)

从公式对这两种模型进行理解,RAG-Sequence模型先计算每个文档条件下回答词元的概率分布,再连乘得到每个文档条件下回答的概率分布,最后再求和得到所有最相关文档条件下回答的概率分布,而RAG-Token模型先计算每个文档条件下回答词元的概率分布,再求和得到所有最相关文档条件下回答词元的概率分布,最后再连乘得到所有最相关文档条件下回答的概率分布。

检索器

检索器采用双编码器架构,可用以下公式表示:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> d ( z ) = BERT d ( z ) q ( x ) = BERT q ( x ) p η ( z ∣ x ) ∝ exp ⁡ ( d ( z ) ⊤ q ( x ) ) \begin{align} &\mathbf{d}(z)=\text{BERT}_d(z)\\ &\mathbf{q}(x)=\text{BERT}q(x)\\ &p\eta(z|x)\varpropto\exp(\mathbf{d}(z)^{\top}\mathbf{q}(x)) \end{align} </math>d(z)=BERTd(z)q(x)=BERTq(x)pη(z∣x)∝exp(d(z)⊤q(x))

其中,查询编码器 <math xmlns="http://www.w3.org/1998/Math/MathML"> q \mathbf{q} </math>q和文档编码器 <math xmlns="http://www.w3.org/1998/Math/MathML"> d \mathbf{d} </math>d均采用BERT。通过上述两个编码器分别得到问题和文档的向量。论文使用FASIS构建文档向量索引,并使用了HNSW算法对向量检索进行加速。向量检索的过程就是对问题和文档的向量求取内积作为相关度量,返回和问题向量内积最大、和问题最相关的前K个文档。

论文指出,检索器中的文档采用维基百科截至2018年12月的全量数据,将每篇维基百科文档切分为互不重叠的包含100个单词的块,每个块作为一个文档,共有2100万个文档。

生成器

生成器使用了Meta发布的大语言模型BART-Large,其采用编码器+解码器架构,共有4亿个参数。论文直接将问题和检索器返回的相关文档拼接在一起作为生成器的输入。

训练

论文的训练数据为问题、回答对 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x i , y i ) (x_i,y_i) </math>(xi,yi)集合,损失函数采用负对象似然函数:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∑ j − log ⁡ p ( y j ∣ x j ) \sum_j{-\log{p(y_j|x_j)}} </math>j∑−logp(yj∣xj)

使用Adam优化器进行梯度下降,对检索器和生成器进行端到端的联合训练。论文指出,训练时如果调整文档编码器所使用BERT模型的参数,则需要定期重新计算所有文档的向量并构建索引,成本较高。因此,训练时,文档编码器所使用BERT模型的参数固定、不更新,相应的文档向量和索引也不更新,而只对查询编码器所使用BERT模型和生成器所使用BART模型的参数进行更新。

示例

示例采用LangChain中的RetrievalQA,以下内容节选自笔者的笔记《Mac本地部署大模型体验AIGC能力》 RetrievalQA实现原理如图2所示,先构建本地知识库,包含三步:

  • 加载文档,LangChain提供多种BaseLoader实现进行文档加载;
  • 切分文本段,LangChain同时提供多种TextSplitter实现进行文本段切分;
  • 向量化文本段,使用向量化模型将文本段转化为向量,LangChain也支持多种方式的向量化模型,比如,OpenAIEmbeddings通过调用OpenAI的相关服务进行向量化,HuggingFaceEmbeddings可以远程或本地加载HuggingFace上的模型进行向量化;
  • 对文本段向量构建向量索引,LangChain也支持多种向量索引引擎,包括FaissChromaMilvus等。

再基于本地知识库进行模型推理,包含五步:

  • 输入问题;
  • 向量化问题,和文本段向量化一致,将问题转化为向量;
  • 搜索相关文本段,从向量索引中搜索和问题相关的文本段;
  • 拼接提示,根据提示模板将问题和相关文本段转化为提示;
  • 模型推理,输出答案。

代码文件retrieval_qa_demo.py如下所示:

python 复制代码
from langchain.chains import RetrievalQA
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import MarkdownTextSplitter
from langchain.vectorstores import Chroma
from chatglm2_llm import ChatGLM2

if __name__ == "__main__":
    # 加载文档
    loader = UnstructuredMarkdownLoader("/Users/xxx/workspace/docs/creative.md")
    documents = loader.load()
    # 切分文本
    text_splitter = MarkdownTextSplitter(chunk_size=1000, chunk_overlap=0)
    texts = text_splitter.split_documents(documents)
    # 初始化向量化模型
    embeddings = HuggingFaceEmbeddings(model_name="/Users/xxx/workspace/models/text2vec-large-chinese",)
    # 构建向量索引
    db = Chroma.from_documents(texts, embeddings)
    # 定义模型
    llm = ChatGLM2()
    # 加载模型
    llm.load_model("/Users/xxx/workspace/models/chatglm2-6b-int4")
    # 执行链路
    qa = RetrievalQA.from_chain_type(llm, chain_type="stuff", retriever=db.as_retriever(), verbose=True)
    print(qa.run("怎么创建程序化创意"))

其中,对于知识库文档,笔者使用《超级汇川程序化创意产品手册》这一文档,将其以Markdown格式下载至本地,使用UnstructuredMarkdownLoader进行加载,并使用MarkdownTextSplitter进行切分得到文本段。对于向量化模型,笔者使用HuggingFace上的GanymedeNil/text2vec-large-chinese,并下载至本地:

cd ~/workspace/models/ git lfs install #若ChatGLM-6B部分已执行,则无需再执行 git clone huggingface.co/GanymedeNil...

对于向量索引引擎,笔者使用Chroma;对于大语言模型,笔者使用之前已定义的ChatGLM2。对于问题和从向量索引返回的相关文本段,RetrievalQA按下述提示模板拼接提示:

Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

{context}

Question: {question} Helpful Answer:

retrieval_qa_demo.py运行结果如图3所示。

相关推荐
ChoSeitaku2 小时前
链表循环及差集相关算法题|判断循环双链表是否对称|两循环单链表合并成循环链表|使双向循环链表有序|单循环链表改双向循环链表|两链表的差集(C)
c语言·算法·链表
Fuxiao___2 小时前
不使用递归的决策树生成算法
算法
我爱工作&工作love我2 小时前
1435:【例题3】曲线 一本通 代替三分
c++·算法
白-胖-子3 小时前
【蓝桥等考C++真题】蓝桥杯等级考试C++组第13级L13真题原题(含答案)-统计数字
开发语言·c++·算法·蓝桥杯·等考·13级
workflower3 小时前
数据结构练习题和答案
数据结构·算法·链表·线性回归
好睡凯3 小时前
c++写一个死锁并且自己解锁
开发语言·c++·算法
Sunyanhui13 小时前
力扣 二叉树的直径-543
算法·leetcode·职场和发展
一个不喜欢and不会代码的码农3 小时前
力扣105:从先序和中序序列构造二叉树
数据结构·算法·leetcode
前端郭德纲3 小时前
浏览器是加载ES6模块的?
javascript·算法
Just Jump3 小时前
大语言模型LLM综述
llm·大语言模型