基于Langchain的txt文本向量库搭建与检索

这里的源码主要来自于Langchain-ChatGLM中的向量库部分,做了一些代码上的修改和封装,以适用于基于问题包含数据库表描述的txt文件(文件名为库表名,文件内容为库表中的字段及描述)对数据库表进行快速检索。

中文分词类

splitter.py

py 复制代码
from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List


class ChineseTextSplitter(CharacterTextSplitter):
    def __init__(self, pdf: bool = False, sentence_size: int = 100, **kwargs):
        super().__init__(**kwargs)
        self.pdf = pdf
        self.sentence_size = sentence_size

    def split_text1(self, text: str) -> List[str]:
        if self.pdf:
            text = re.sub(r"\n{3,}", "\n", text)
            text = re.sub('\s', ' ', text)
            text = text.replace("\n\n", "")
        sent_sep_pattern = re.compile('([﹒﹔﹖﹗。!?]["'"」』]{0,2}|(?=["'"「『]{1,2}|$))')  # del :;
        sent_list = []
        for ele in sent_sep_pattern.split(text):
            if sent_sep_pattern.match(ele) and sent_list:
                sent_list[-1] += ele
            elif ele:
                sent_list.append(ele)
        return sent_list

    def split_text(self, text: str) -> List[str]:   ##此处需要进一步优化逻辑
        if self.pdf:
            text = re.sub(r"\n{3,}", r"\n", text)
            text = re.sub('\s', " ", text)
            text = re.sub("\n\n", "", text)

        text = re.sub(r'([;;!?。!?\?])([^"'])', r"\1\n\2", text)  # 单字符断句符
        text = re.sub(r'(\.{6})([^"'"」』])', r"\1\n\2", text)  # 英文省略号
        text = re.sub(r'(\...{2})([^"'"」』])', r"\1\n\2", text)  # 中文省略号
        text = re.sub(r'([;;!?。!?\?]["'"」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text)
        # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
        text = text.rstrip()  # 段尾如果有多余的\n就去掉它
        # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
        ls = [i for i in text.split("\n") if i]
        for ele in ls:
            if len(ele) > self.sentence_size:
                ele1 = re.sub(r'([,,]["'"」』]{0,2})([^,,])', r'\1\n\2', ele)
                ele1_ls = ele1.split("\n")
                for ele_ele1 in ele1_ls:
                    if len(ele_ele1) > self.sentence_size:
                        ele_ele2 = re.sub(r'([\n]{1,}| {2,}["'"」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
                        ele2_ls = ele_ele2.split("\n")
                        for ele_ele2 in ele2_ls:
                            if len(ele_ele2) > self.sentence_size:
                                ele_ele3 = re.sub('( ["'"」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
                                ele2_id = ele2_ls.index(ele_ele2)
                                ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[
                                                                                                       ele2_id + 1:]
                        ele_id = ele1_ls.index(ele_ele1)
                        ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]

                id = ls.index(ele)
                ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:]
        return ls

faiss向量库类

myfaiss.py

py 复制代码
from langchain.vectorstores import FAISS
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.faiss import dependable_faiss_import
from typing import Any, Callable, List, Dict
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
import numpy as np
import copy
import os


class MyFAISS(FAISS, VectorStore):
    def __init__(
            self,
            embedding_function: Callable,
            index: Any,
            docstore: Docstore,
            index_to_docstore_id: Dict[int, str],
            normalize_L2: bool = False,
    ):
        super().__init__(embedding_function=embedding_function,
                         index=index,
                         docstore=docstore,
                         index_to_docstore_id=index_to_docstore_id,
                         normalize_L2=normalize_L2)

    def seperate_list(self, ls: List[int]) -> List[List[int]]:
        lists = []
        ls1 = [ls[0]]
        source1 = self.index_to_docstore_source(ls[0])
        for i in range(1, len(ls)):
            if ls[i - 1] + 1 == ls[i] and self.index_to_docstore_source(ls[i]) == source1:
                ls1.append(ls[i])
            else:
                lists.append(ls1)
                ls1 = [ls[i]]
                source1 = self.index_to_docstore_source(ls[i])
        lists.append(ls1)
        return lists

    def similarity_search_with_score_by_vector(
            self, embedding: List[float], k: int = 4
    ) -> List[Document]:
        faiss = dependable_faiss_import()
        # (1,1024)
        vector = np.array([embedding], dtype=np.float32)
        # 默认FALSE
        if self._normalize_L2:
            faiss.normalize_L2(vector)
        # shape均为(1, k)
        scores, indices = self.index.search(vector, k)
        docs = []
        id_set = set()
        # 存储关键句
        keysentences = []
        # 遍历找到的k个最近相关文档的索引
        # top-k是第一次的筛选条件,score是第二次的筛选条件
        for j, i in enumerate(indices[0]):
            if i in self.index_to_docstore_id:
                _id = self.index_to_docstore_id[i]
            # 执行接下来的操作
            else:
                continue
            # index→id→content
            doc = self.docstore.search(_id)
            doc.metadata["score"] = int(scores[0][j])
            docs.append(doc)
            # 其实存的都是index
            id_set.add(i)
        docs.sort(key=lambda doc: doc.metadata['score'])
        return docs

嵌入检索类

embedder.py

py 复制代码
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.document_loaders import TextLoader
from embeddings.splitter import ChineseTextSplitter
from embeddings.myfaiss import MyFAISS
import os
import torch
from config import *

def torch_gc():
    if torch.cuda.is_available():
        # with torch.cuda.device(DEVICE):
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    elif torch.backends.mps.is_available():
        try:
            from torch.mps import empty_cache
            empty_cache()
        except Exception as e:
            print(e)
            print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")

class Embedder:
    def __init__(self, config):
        self.model = HuggingFaceEmbeddings(model_name=
            "/home/df1500/NLP/LLM/pretrained_model/WordEmbeddings/"+config.emb_model,
            model_kwargs={'device': 'cuda'})
        self.config = config
        self.create_vector_score()
        self.vector_store = MyFAISS.load_local(self.config.db_vs_path, self.model)

    def load_file(self, filepath):
        # 对文件分词
        if filepath.lower().endswith(".txt"):
            loader = TextLoader(filepath, autodetect_encoding=True)
            textsplitter = ChineseTextSplitter(pdf=False, sentence_size=self.config.sentence_size)
            docs = loader.load_and_split(textsplitter)
        else:
            raise Exception("{}文件不是txt格式".format(filepath))
        return docs
    
    def txt2vector_store(self, filepaths):
        # 批量建立知识库
        docs = []
        for filepath in filepaths:
            try:
                docs += self.load_file(filepath)
            except Exception as e:
                raise Exception("{}文件加载失败".format(filepath))
        print("文件加载完毕,正在生成向量库")
        vector_store = MyFAISS.from_documents(docs, self.model)
        torch_gc()
        vector_store.save_local(self.config.db_vs_path)

    def create_vector_score(self):
        if "index.faiss" not in os.listdir(self.config.db_vs_path):
            filepaths = os.listdir(self.config.db_doc_path)
            filepaths = [os.path.join(self.config.db_doc_path, filepath) for filepath in filepaths]
            self.txt2vector_store(filepaths)
        print("向量库已建立成功")

    def get_topk_db(self, query):
        related_dbs_with_score = self.vector_store.similarity_search_with_score(query, k=self.config.sim_k)
        topk_db = [{'匹配句': db_data.page_content, 
                    '数据库': os.path.basename(db_data.metadata['source'])[:-4], 
                    '得分': db_data.metadata['score']} 
                   for db_data in related_dbs_with_score]
        return topk_db

测试代码

Config是用来传参的类,这里略去定义

py 复制代码
if __name__ == '__main__':
    Conf = Config()
    configs = Conf.get_config()
    embedder = Embedder(configs)
    query = "公司哪个月的出勤率是最高的?"
    topk_db = embedder.get_topk_db(query)
    print(topk_db)
相关推荐
lusasky8 小时前
AgentScope、LangChain、AutoGen 全方位对比 + 混用可行性指南
microsoft·langchain
前端阿森纳20 小时前
从产品经理视角拆解 LangChain 的抽象设计
langchain·llm·aigc
大模型真好玩1 天前
LangGraph1.0速通指南(一)—— LangGraph1.0 核心概念、点、边
人工智能·langchain·agent
阿里云云原生1 天前
AgentRun Sandbox SDK 正式开源!集成 LangChain 等主流框架,一键开启智能体沙箱新体验
阿里云·langchain·开源·serverless·agentarun
、、、、南山小雨、、、、1 天前
最简单的LangChain和RAG
langchain
路边草随风1 天前
langchain agent动态变更系统prompt
人工智能·python·langchain·prompt
Jack___Xue1 天前
LangChain实战快速入门笔记(六)--LangChain使用之Agent
笔记·langchain·unix
大模型教程2 天前
使用Langchain4j和Ollama3搭建RAG系统
langchain·llm·ollama
Elwin Wong2 天前
本地运行LangChain Agent用于开发调试
人工智能·langchain·大模型·llm·agent·codingagent
FreeCode2 天前
智能体设计模式解析:ReAct模式
设计模式·langchain·agent