基于langchain的长文本多迭代总结

之前我们讲到langchain的rag问答,有兴趣的同学可以做下回顾

langchain基于混元大模型的实时内容的RAG问答

今天我们来了解下如何基于前文的方案实现长文本总结

为什么需要文本总结

通常会议内容是冗长的,如果能够提取关键信息的话,能够帮我们节省大量的时间

模型不能总结吗,为什么单独提出来长文本这个概念

大部分模型都会限制输入长度,如果会议长度超出了模型的限制则无法进行总结

方案

langchain提供了多种方案供我们选择,python.langchain.com/v0.1/docs/u...

  1. stuff:全文本总结,将整个文本全部投入模型;这样仍然可能会超出模型
  2. MapReduce:将文本拆成n个小段,每个小段分别总结,然后再将最终的内容一起总结;这样虽然能解决问题,但是可能会破坏文本的上下文导致最终的结果不理想
  3. refine:和MapReduce相似的是将文本拆成n个小段,但是会以循环的方式先总结第一段,然后将第一段的总结结果和第二段再总结以此类推,此方法能够更好的保留原文的语义

难点

  1. 代码实现
  2. 流式返回
  3. 如何确定是最后一轮的返回(在流式响应的情况下,每轮都会返回总结结果,那么入会确定是最后一轮并返回个前端)

实现

由于langchain的部分实现比较紧凑,导致做二次开发不是很方便,所以可能有部分修改源码的地方

1.创建文本加载工具,用于加载文本 AttachCode

Python 复制代码
from typing import Dict, Optional

from langchain.chains.combine_documents.base import AnalyzeDocumentChain
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.callbacks import CallbackManagerForChainRun

from modules.BasinessException import BusinessException
from modules.resultCodeEnum import ResultCodeEnum
from service.SubtitleService import SubtitleService
from utils import constants
from utils.logger import logger

class download_summarize_chain(AnalyzeDocumentChain):
    def _call(
            self,
            inputs: Dict[str, str],
            run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> dict[str, str]:

        docs = self.get_docs(inputs, run_manager)

        # Other keys are assumed to be needed for LLM prediction
        other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key}
        other_keys[self.combine_docs_chain.input_key] = docs

        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        return self.combine_docs_chain(
            other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()
        )

    def get_docs(self, inputs, run_manager):

        file_download_url = str(inputs[constants.TRANSCRIPTION_FILE_URL])
        if file_download_url is not None and file_download_url.startswith("http"):
            # 通过下载地址下载文件
            loader = WebBaseLoader(file_download_url, None, False)
            """Split document into chunks and pass to CombineDocumentsChain."""
            document = loader.load()[0].page_content
            if len(document) <= 0:
                logger.error(f"file not exists:{file_download_url}")
                raise BusinessException.new_instance_with_rce(400, ResultCodeEnum.EMPTY_CONTENT)
        else:
            # 通过企业id和会议id获取字幕
            enterprise_id: str = run_manager.metadata.get(constants.ENTERPRISE_ID)
            meeting_id: str = run_manager.metadata.get(constants.MEETING_ID)
            logger.info(f"process task with llm:{enterprise_id}-{meeting_id}")
            document = SubtitleService().fetch_subtitles(enterprise_id=enterprise_id, meeting_id=meeting_id)

        docs = self.text_splitter.create_documents([document])
        logger.info("number of splitting doc parts:{}", len(docs))
        return docs

2.创建重写chain,实现迭代次数的记录 AttachCode

Python 复制代码
"""Load summarizing chains."""

from typing import Any, Mapping, Optional, Protocol

from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
from langchain_core.callbacks import Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate

from adapters.langchain.chains.refine import RefineDocumentsChain


class LoadingCallable(Protocol):
    """Interface for loading the combine documents chain."""

    def __call__(
            self, llm: BaseLanguageModel, **kwargs: Any
    ) -> BaseCombineDocumentsChain:
        """Callable to load the combine documents chain."""


def _load_stuff_chain(
        llm: BaseLanguageModel,
        prompt: BasePromptTemplate = stuff_prompt.PROMPT,
        document_variable_name: str = "text",
        verbose: Optional[bool] = None,
        **kwargs: Any,
) -> StuffDocumentsChain:
    llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)  # type: ignore[arg-type]
    # TODO: document prompt
    return StuffDocumentsChain(
        llm_chain=llm_chain,
        document_variable_name=document_variable_name,
        verbose=verbose,  # type: ignore[arg-type]
        **kwargs,
    )


def _load_map_reduce_chain(
        llm: BaseLanguageModel,
        map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
        combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
        combine_document_variable_name: str = "text",
        map_reduce_document_variable_name: str = "text",
        collapse_prompt: Optional[BasePromptTemplate] = None,
        reduce_llm: Optional[BaseLanguageModel] = None,
        collapse_llm: Optional[BaseLanguageModel] = None,
        verbose: Optional[bool] = None,
        token_max: int = 3000,
        callbacks: Callbacks = None,
        *,
        collapse_max_retries: Optional[int] = None,
        **kwargs: Any,
) -> MapReduceDocumentsChain:
    map_chain = LLMChain(
        llm=llm,
        prompt=map_prompt,
        verbose=verbose,  # type: ignore[arg-type]
        callbacks=callbacks,  # type: ignore[arg-type]
    )
    _reduce_llm = reduce_llm or llm
    reduce_chain = LLMChain(
        llm=_reduce_llm,
        prompt=combine_prompt,
        verbose=verbose,  # type: ignore[arg-type]
        callbacks=callbacks,  # type: ignore[arg-type]
    )
    # TODO: document prompt
    combine_documents_chain = StuffDocumentsChain(
        llm_chain=reduce_chain,
        document_variable_name=combine_document_variable_name,
        verbose=verbose,  # type: ignore[arg-type]
        callbacks=callbacks,
    )
    if collapse_prompt is None:
        collapse_chain = None
        if collapse_llm is not None:
            raise ValueError(
                "collapse_llm provided, but collapse_prompt was not: please "
                "provide one or stop providing collapse_llm."
            )
    else:
        _collapse_llm = collapse_llm or llm
        collapse_chain = StuffDocumentsChain(
            llm_chain=LLMChain(
                llm=_collapse_llm,
                prompt=collapse_prompt,
                verbose=verbose,  # type: ignore[arg-type]
                callbacks=callbacks,
            ),
            document_variable_name=combine_document_variable_name,
        )
    reduce_documents_chain = ReduceDocumentsChain(
        combine_documents_chain=combine_documents_chain,
        collapse_documents_chain=collapse_chain,
        token_max=token_max,
        verbose=verbose,  # type: ignore[arg-type]
        callbacks=callbacks,
        collapse_max_retries=collapse_max_retries,
    )
    return MapReduceDocumentsChain(
        llm_chain=map_chain,
        reduce_documents_chain=reduce_documents_chain,
        document_variable_name=map_reduce_document_variable_name,
        verbose=verbose,  # type: ignore[arg-type]
        callbacks=callbacks,
        **kwargs,
    )


def _load_refine_chain(
        llm: BaseLanguageModel,
        question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
        refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
        document_variable_name: str = "text",
        initial_response_name: str = "existing_answer",
        refine_llm: Optional[BaseLanguageModel] = None,
        verbose: Optional[bool] = None,
        **kwargs: Any,
) -> RefineDocumentsChain:
    initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)  # type: ignore[arg-type]
    _refine_llm = refine_llm or llm
    refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)  # type: ignore[arg-type]
    return RefineDocumentsChain(
        initial_llm_chain=initial_chain,
        refine_llm_chain=refine_chain,
        document_variable_name=document_variable_name,
        initial_response_name=initial_response_name,
        verbose=verbose,  # type: ignore[arg-type]
        **kwargs,
    )


def load_summarize_chain(
        llm: BaseLanguageModel,
        chain_type: str = "stuff",
        verbose: Optional[bool] = None,
        **kwargs: Any,
) -> BaseCombineDocumentsChain:
    """Load summarizing chain.

    Args:
        llm: Language Model to use in the chain.
        chain_type: Type of document combining chain to use. Should be one of "stuff",
            "map_reduce", and "refine".
        verbose: Whether chains should be run in verbose mode or not. Note that this
            applies to all chains that make up the final chain.

    Returns:
        A chain to use for summarizing.
    """
    loader_mapping: Mapping[str, LoadingCallable] = {
        "stuff": _load_stuff_chain,
        "map_reduce": _load_map_reduce_chain,
        "refine": _load_refine_chain,
    }
    if chain_type not in loader_mapping:
        raise ValueError(
            f"Got unsupported chain type: {chain_type}. "
            f"Should be one of {loader_mapping.keys()}"
        )
    return loader_mapping[chain_type](llm, verbose=verbose, **kwargs)

3.refine chain AttachCode

Python 复制代码
"""Combine documents by doing a first pass and then refining on more documents."""

from __future__ import annotations

from typing import Any, Dict, List, Tuple

from langchain.chains.combine_documents.base import (
    BaseCombineDocumentsChain,
)
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import Callbacks, dispatch_custom_event
from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.prompts.prompt import PromptTemplate
from pydantic import ConfigDict, Field, model_validator

from utils.logger import logger


def _get_default_document_prompt() -> PromptTemplate:
    return PromptTemplate(input_variables=["page_content"], template="{page_content}")


class RefineDocumentsChain(BaseCombineDocumentsChain):
    """Combine documents by doing a first pass and then refining on more documents.

    This algorithm first calls `initial_llm_chain` on the first document, passing
    that first document in with the variable name `document_variable_name`, and
    produces a new variable with the variable name `initial_response_name`.

    Then, it loops over every remaining document. This is called the "refine" step.
    It calls `refine_llm_chain`,
    passing in that document with the variable name `document_variable_name`
    as well as the previous response with the variable name `initial_response_name`.

    Example:
        .. code-block:: python

            from langchain.chains import RefineDocumentsChain, LLMChain
            from langchain_core.prompts import PromptTemplate
            from langchain_community.llms import OpenAI

            # This controls how each document will be formatted. Specifically,
            # it will be passed to `format_document` - see that function for more
            # details.
            document_prompt = PromptTemplate(
                input_variables=["page_content"],
                 template="{page_content}"
            )
            document_variable_name = "context"
            llm = OpenAI()
            # The prompt here should take as an input variable the
            # `document_variable_name`
            prompt = PromptTemplate.from_template(
                "Summarize this content: {context}"
            )
            initial_llm_chain = LLMChain(llm=llm, prompt=prompt)
            initial_response_name = "prev_response"
            # The prompt here should take as an input variable the
            # `document_variable_name` as well as `initial_response_name`
            prompt_refine = PromptTemplate.from_template(
                "Here's your first summary: {prev_response}. "
                "Now add to it based on the following context: {context}"
            )
            refine_llm_chain = LLMChain(llm=llm, prompt=prompt_refine)
            chain = RefineDocumentsChain(
                initial_llm_chain=initial_llm_chain,
                refine_llm_chain=refine_llm_chain,
                document_prompt=document_prompt,
                document_variable_name=document_variable_name,
                initial_response_name=initial_response_name,
            )
    """

    initial_llm_chain: LLMChain
    """LLM chain to use on initial document."""
    refine_llm_chain: LLMChain
    """LLM chain to use when refining."""
    document_variable_name: str
    """The variable name in the initial_llm_chain to put the documents in.
    If only one variable in the initial_llm_chain, this need not be provided."""
    initial_response_name: str
    """The variable name to format the initial response in when refining."""
    document_prompt: BasePromptTemplate = Field(
        default_factory=_get_default_document_prompt
    )
    """Prompt to use to format each document, gets passed to `format_document`."""
    return_intermediate_steps: bool = False
    """Return the results of the refine steps in the output."""

    @property
    def output_keys(self) -> List[str]:
        """Expect input key.

        :meta private:
        """
        _output_keys = super().output_keys
        if self.return_intermediate_steps:
            _output_keys = _output_keys + ["intermediate_steps"]
        return _output_keys

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        extra="forbid",
    )

    @model_validator(mode="before")
    @classmethod
    def get_return_intermediate_steps(cls, values: Dict) -> Any:
        """For backwards compatibility."""
        if "return_refine_steps" in values:
            values["return_intermediate_steps"] = values["return_refine_steps"]
            del values["return_refine_steps"]
        return values

    @model_validator(mode="before")
    @classmethod
    def get_default_document_variable_name(cls, values: Dict) -> Any:
        """Get default document variable name, if not provided."""
        if "initial_llm_chain" not in values:
            raise ValueError("initial_llm_chain must be provided")

        llm_chain_variables = values["initial_llm_chain"].prompt.input_variables
        if "document_variable_name" not in values:
            if len(llm_chain_variables) == 1:
                values["document_variable_name"] = llm_chain_variables[0]
            else:
                raise ValueError(
                    "document_variable_name must be provided if there are "
                    "multiple llm_chain input_variables"
                )
        else:
            if values["document_variable_name"] not in llm_chain_variables:
                raise ValueError(
                    f"document_variable_name {values['document_variable_name']} was "
                    f"not found in llm_chain input_variables: {llm_chain_variables}"
                )
        return values

    def combine_docs(
            self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
    ) -> Tuple[str, dict]:
        """Combine by mapping first chain over all, then stuffing into final chain.

        Args:
            docs: List of documents to combine
            callbacks: Callbacks to be passed through
            **kwargs: additional parameters to be passed to LLM calls (like other
                input variables besides the documents)

        Returns:
            The first element returned is the single string output. The second
            element returned is a dictionary of other keys to return.
        """
        inputs = self._construct_initial_inputs(docs, **kwargs)
        dispatch_custom_event("last_doc_mark", {"chunk": False})
        doc_length = len(docs)
        if doc_length == 1:
            dispatch_custom_event("last_doc_mark", {"chunk": True})
        logger.info(f"refine_docs index:1/{doc_length} of {kwargs}")
        res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
        refine_steps = [res]
        for index, doc in enumerate(docs[1:], start=1):
            logger.info(f"refine_docs index:{index+1}/{doc_length} of {kwargs}")
            if index == doc_length - 1:
                dispatch_custom_event("last_doc_mark", {"chunk": True})
            base_inputs = self._construct_refine_inputs(doc, res)
            inputs = {**base_inputs, **kwargs}
            res = self.refine_llm_chain.predict(callbacks=callbacks, **inputs)
            refine_steps.append(res)
        logger.info(f"refine_docs finished of {kwargs}, result:{res}")
        return self._construct_result(refine_steps, res)

    async def acombine_docs(
            self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
    ) -> Tuple[str, dict]:
        """Async combine by mapping a first chain over all, then stuffing
         into a final chain.

        Args:
            docs: List of documents to combine
            callbacks: Callbacks to be passed through
            **kwargs: additional parameters to be passed to LLM calls (like other
                input variables besides the documents)

        Returns:
            The first element returned is the single string output. The second
            element returned is a dictionary of other keys to return.
        """
        inputs = self._construct_initial_inputs(docs, **kwargs)
        res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs)
        refine_steps = [res]
        for doc in docs[1:]:
            base_inputs = self._construct_refine_inputs(doc, res)
            inputs = {**base_inputs, **kwargs}
            res = await self.refine_llm_chain.apredict(callbacks=callbacks, **inputs)
            refine_steps.append(res)
        return self._construct_result(refine_steps, res)

    def _construct_result(self, refine_steps: List[str], res: str) -> Tuple[str, dict]:
        if self.return_intermediate_steps:
            extra_return_dict = {"intermediate_steps": refine_steps}
        else:
            extra_return_dict = {}
        return res, extra_return_dict

    def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
        return {
            self.document_variable_name: format_document(doc, self.document_prompt),
            self.initial_response_name: res,
        }

    def _construct_initial_inputs(
            self, docs: List[Document], **kwargs: Any
    ) -> Dict[str, Any]:
        base_info = {"page_content": docs[0].page_content}
        base_info.update(docs[0].metadata)
        document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
        base_inputs: dict = {
            self.document_variable_name: self.document_prompt.format(**document_info)
        }
        inputs = {**base_inputs, **kwargs}
        return inputs

    @property
    def _chain_type(self) -> str:
        return "refine_documents_chain"

4.调用chain AttachCode

Python 复制代码
def process(tool: BaseTool, prompt_type: QuestionTypeEnum, input_dict: dict,
            run_manager: Optional[CallbackManagerForToolRun] = None):
    # 获取模型实例
    model_type = ModelTypeEnum.from_string(run_manager.metadata.get(constants.MODEL_TYPE))
    model_instance = ModelAdapter.get_model_instance(model_type)

    # 提示词模板集合
    prompt_template = PromptSynchronizer.get_prompt_template(model_type=model_type, questionType=prompt_type)
    prompt_map = json.loads(prompt_template)
    refine_prompt = PromptTemplate.from_template(prompt_map["refine_template"], template_format="f-string")
    question_prompt = PromptTemplate.from_template(prompt_map["prompt"])

    logger.info("invoke tool input_dicts:{}",input_dict)

    combine_docs_chain=load_summarize_chain(llm=model_instance,
                                            chain_type="refine",
                                            question_prompt=question_prompt,
                                            refine_prompt=refine_prompt,
                                            return_intermediate_steps=True,
                                            input_key="text",
                                            output_key="existing_answer",
                                            verbose=True)

    res = (tool.pre_handler
           | download_summarize_chain(combine_docs_chain=combine_docs_chain,
                                  text_splitter=model_instance.get_text_splitter(),
                                  verbose=True,
                                  input_key="input_document")
           | tool.post_handler).invoke(input={"input_document": "", **input_dict}, config=RunnableConfig())
    return res

5.过滤最后一次迭代 AttachCode

Python 复制代码
async def get_stream_content_async(llm_adapter,
                                   question: str,
                                   runnable_config: RunnableConfig) -> AsyncGenerator[AIMessageChunk, None]:
    last_doc_mark = None
    async for event in llm_adapter.get_chain().astream_events({"input": question},
                                                              config=runnable_config,
                                                              version="v2"):
        #多次迭代的标记
        if event.get("event") == "on_custom_event" and event.get("name") == "last_doc_mark":
            last_doc_mark = event["data"]["chunk"]
        #最后一次迭代的结果
        elif event.get("event") == "on_chat_model_stream" and (last_doc_mark is None or last_doc_mark):
            chunk: AIMessageChunk = event["data"]["chunk"]
            yield chunk
        #缓存返回结果
        elif event.get("event") == "on_custom_event" and event.get("name") == "cache_cached_return":
            chunk: AIMessageChunk = event["data"]["chunk"]
            yield chunk

此笔记由idea插件辅助生成

idea插件推荐 AnNote - IntelliJ IDEs Plugin | Marketplace 75 折折扣: MGRYF-TJW4N-WZMSJ-MZDLD-LVGJH BTKQ8-XZLPH-L3QH3-MPKBH-BP9RR

本文由博客群发一文多发等运营工具平台 OpenWrite 发布

相关推荐
chuanauc2 分钟前
Kubernets K8s 学习
java·学习·kubernetes
一头生产的驴19 分钟前
java整合itext pdf实现自定义PDF文件格式导出
java·spring boot·pdf·itextpdf
YuTaoShao25 分钟前
【LeetCode 热题 100】73. 矩阵置零——(解法二)空间复杂度 O(1)
java·算法·leetcode·矩阵
zzywxc78729 分钟前
AI 正在深度重构软件开发的底层逻辑和全生命周期,从技术演进、流程重构和未来趋势三个维度进行系统性分析
java·大数据·开发语言·人工智能·spring
YuTaoShao3 小时前
【LeetCode 热题 100】56. 合并区间——排序+遍历
java·算法·leetcode·职场和发展
程序员张33 小时前
SpringBoot计时一次请求耗时
java·spring boot·后端
llwszx6 小时前
深入理解Java锁原理(一):偏向锁的设计原理与性能优化
java·spring··偏向锁
云泽野6 小时前
【Java|集合类】list遍历的6种方式
java·python·list
二进制person7 小时前
Java SE--方法的使用
java·开发语言·算法
小阳拱白菜8 小时前
java异常学习
java