基于langchain的长文本多迭代总结

之前我们讲到langchain的rag问答,有兴趣的同学可以做下回顾

langchain基于混元大模型的实时内容的RAG问答

今天我们来了解下如何基于前文的方案实现长文本总结

为什么需要文本总结

通常会议内容是冗长的,如果能够提取关键信息的话,能够帮我们节省大量的时间

模型不能总结吗,为什么单独提出来长文本这个概念

大部分模型都会限制输入长度,如果会议长度超出了模型的限制则无法进行总结

方案

langchain提供了多种方案供我们选择,python.langchain.com/v0.1/docs/u...

  1. stuff:全文本总结,将整个文本全部投入模型;这样仍然可能会超出模型
  2. MapReduce:将文本拆成n个小段,每个小段分别总结,然后再将最终的内容一起总结;这样虽然能解决问题,但是可能会破坏文本的上下文导致最终的结果不理想
  3. refine:和MapReduce相似的是将文本拆成n个小段,但是会以循环的方式先总结第一段,然后将第一段的总结结果和第二段再总结以此类推,此方法能够更好的保留原文的语义

难点

  1. 代码实现
  2. 流式返回
  3. 如何确定是最后一轮的返回(在流式响应的情况下,每轮都会返回总结结果,那么入会确定是最后一轮并返回个前端)

实现

由于langchain的部分实现比较紧凑,导致做二次开发不是很方便,所以可能有部分修改源码的地方

1.创建文本加载工具,用于加载文本 AttachCode

Python 复制代码
from typing import Dict, Optional

from langchain.chains.combine_documents.base import AnalyzeDocumentChain
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.callbacks import CallbackManagerForChainRun

from modules.BasinessException import BusinessException
from modules.resultCodeEnum import ResultCodeEnum
from service.SubtitleService import SubtitleService
from utils import constants
from utils.logger import logger

class download_summarize_chain(AnalyzeDocumentChain):
    def _call(
            self,
            inputs: Dict[str, str],
            run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> dict[str, str]:

        docs = self.get_docs(inputs, run_manager)

        # Other keys are assumed to be needed for LLM prediction
        other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key}
        other_keys[self.combine_docs_chain.input_key] = docs

        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        return self.combine_docs_chain(
            other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()
        )

    def get_docs(self, inputs, run_manager):

        file_download_url = str(inputs[constants.TRANSCRIPTION_FILE_URL])
        if file_download_url is not None and file_download_url.startswith("http"):
            # 通过下载地址下载文件
            loader = WebBaseLoader(file_download_url, None, False)
            """Split document into chunks and pass to CombineDocumentsChain."""
            document = loader.load()[0].page_content
            if len(document) <= 0:
                logger.error(f"file not exists:{file_download_url}")
                raise BusinessException.new_instance_with_rce(400, ResultCodeEnum.EMPTY_CONTENT)
        else:
            # 通过企业id和会议id获取字幕
            enterprise_id: str = run_manager.metadata.get(constants.ENTERPRISE_ID)
            meeting_id: str = run_manager.metadata.get(constants.MEETING_ID)
            logger.info(f"process task with llm:{enterprise_id}-{meeting_id}")
            document = SubtitleService().fetch_subtitles(enterprise_id=enterprise_id, meeting_id=meeting_id)

        docs = self.text_splitter.create_documents([document])
        logger.info("number of splitting doc parts:{}", len(docs))
        return docs

2.创建重写chain,实现迭代次数的记录 AttachCode

Python 复制代码
"""Load summarizing chains."""

from typing import Any, Mapping, Optional, Protocol

from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
from langchain_core.callbacks import Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate

from adapters.langchain.chains.refine import RefineDocumentsChain


class LoadingCallable(Protocol):
    """Interface for loading the combine documents chain."""

    def __call__(
            self, llm: BaseLanguageModel, **kwargs: Any
    ) -> BaseCombineDocumentsChain:
        """Callable to load the combine documents chain."""


def _load_stuff_chain(
        llm: BaseLanguageModel,
        prompt: BasePromptTemplate = stuff_prompt.PROMPT,
        document_variable_name: str = "text",
        verbose: Optional[bool] = None,
        **kwargs: Any,
) -> StuffDocumentsChain:
    llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)  # type: ignore[arg-type]
    # TODO: document prompt
    return StuffDocumentsChain(
        llm_chain=llm_chain,
        document_variable_name=document_variable_name,
        verbose=verbose,  # type: ignore[arg-type]
        **kwargs,
    )


def _load_map_reduce_chain(
        llm: BaseLanguageModel,
        map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
        combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
        combine_document_variable_name: str = "text",
        map_reduce_document_variable_name: str = "text",
        collapse_prompt: Optional[BasePromptTemplate] = None,
        reduce_llm: Optional[BaseLanguageModel] = None,
        collapse_llm: Optional[BaseLanguageModel] = None,
        verbose: Optional[bool] = None,
        token_max: int = 3000,
        callbacks: Callbacks = None,
        *,
        collapse_max_retries: Optional[int] = None,
        **kwargs: Any,
) -> MapReduceDocumentsChain:
    map_chain = LLMChain(
        llm=llm,
        prompt=map_prompt,
        verbose=verbose,  # type: ignore[arg-type]
        callbacks=callbacks,  # type: ignore[arg-type]
    )
    _reduce_llm = reduce_llm or llm
    reduce_chain = LLMChain(
        llm=_reduce_llm,
        prompt=combine_prompt,
        verbose=verbose,  # type: ignore[arg-type]
        callbacks=callbacks,  # type: ignore[arg-type]
    )
    # TODO: document prompt
    combine_documents_chain = StuffDocumentsChain(
        llm_chain=reduce_chain,
        document_variable_name=combine_document_variable_name,
        verbose=verbose,  # type: ignore[arg-type]
        callbacks=callbacks,
    )
    if collapse_prompt is None:
        collapse_chain = None
        if collapse_llm is not None:
            raise ValueError(
                "collapse_llm provided, but collapse_prompt was not: please "
                "provide one or stop providing collapse_llm."
            )
    else:
        _collapse_llm = collapse_llm or llm
        collapse_chain = StuffDocumentsChain(
            llm_chain=LLMChain(
                llm=_collapse_llm,
                prompt=collapse_prompt,
                verbose=verbose,  # type: ignore[arg-type]
                callbacks=callbacks,
            ),
            document_variable_name=combine_document_variable_name,
        )
    reduce_documents_chain = ReduceDocumentsChain(
        combine_documents_chain=combine_documents_chain,
        collapse_documents_chain=collapse_chain,
        token_max=token_max,
        verbose=verbose,  # type: ignore[arg-type]
        callbacks=callbacks,
        collapse_max_retries=collapse_max_retries,
    )
    return MapReduceDocumentsChain(
        llm_chain=map_chain,
        reduce_documents_chain=reduce_documents_chain,
        document_variable_name=map_reduce_document_variable_name,
        verbose=verbose,  # type: ignore[arg-type]
        callbacks=callbacks,
        **kwargs,
    )


def _load_refine_chain(
        llm: BaseLanguageModel,
        question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
        refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
        document_variable_name: str = "text",
        initial_response_name: str = "existing_answer",
        refine_llm: Optional[BaseLanguageModel] = None,
        verbose: Optional[bool] = None,
        **kwargs: Any,
) -> RefineDocumentsChain:
    initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)  # type: ignore[arg-type]
    _refine_llm = refine_llm or llm
    refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)  # type: ignore[arg-type]
    return RefineDocumentsChain(
        initial_llm_chain=initial_chain,
        refine_llm_chain=refine_chain,
        document_variable_name=document_variable_name,
        initial_response_name=initial_response_name,
        verbose=verbose,  # type: ignore[arg-type]
        **kwargs,
    )


def load_summarize_chain(
        llm: BaseLanguageModel,
        chain_type: str = "stuff",
        verbose: Optional[bool] = None,
        **kwargs: Any,
) -> BaseCombineDocumentsChain:
    """Load summarizing chain.

    Args:
        llm: Language Model to use in the chain.
        chain_type: Type of document combining chain to use. Should be one of "stuff",
            "map_reduce", and "refine".
        verbose: Whether chains should be run in verbose mode or not. Note that this
            applies to all chains that make up the final chain.

    Returns:
        A chain to use for summarizing.
    """
    loader_mapping: Mapping[str, LoadingCallable] = {
        "stuff": _load_stuff_chain,
        "map_reduce": _load_map_reduce_chain,
        "refine": _load_refine_chain,
    }
    if chain_type not in loader_mapping:
        raise ValueError(
            f"Got unsupported chain type: {chain_type}. "
            f"Should be one of {loader_mapping.keys()}"
        )
    return loader_mapping[chain_type](llm, verbose=verbose, **kwargs)

3.refine chain AttachCode

Python 复制代码
"""Combine documents by doing a first pass and then refining on more documents."""

from __future__ import annotations

from typing import Any, Dict, List, Tuple

from langchain.chains.combine_documents.base import (
    BaseCombineDocumentsChain,
)
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import Callbacks, dispatch_custom_event
from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.prompts.prompt import PromptTemplate
from pydantic import ConfigDict, Field, model_validator

from utils.logger import logger


def _get_default_document_prompt() -> PromptTemplate:
    return PromptTemplate(input_variables=["page_content"], template="{page_content}")


class RefineDocumentsChain(BaseCombineDocumentsChain):
    """Combine documents by doing a first pass and then refining on more documents.

    This algorithm first calls `initial_llm_chain` on the first document, passing
    that first document in with the variable name `document_variable_name`, and
    produces a new variable with the variable name `initial_response_name`.

    Then, it loops over every remaining document. This is called the "refine" step.
    It calls `refine_llm_chain`,
    passing in that document with the variable name `document_variable_name`
    as well as the previous response with the variable name `initial_response_name`.

    Example:
        .. code-block:: python

            from langchain.chains import RefineDocumentsChain, LLMChain
            from langchain_core.prompts import PromptTemplate
            from langchain_community.llms import OpenAI

            # This controls how each document will be formatted. Specifically,
            # it will be passed to `format_document` - see that function for more
            # details.
            document_prompt = PromptTemplate(
                input_variables=["page_content"],
                 template="{page_content}"
            )
            document_variable_name = "context"
            llm = OpenAI()
            # The prompt here should take as an input variable the
            # `document_variable_name`
            prompt = PromptTemplate.from_template(
                "Summarize this content: {context}"
            )
            initial_llm_chain = LLMChain(llm=llm, prompt=prompt)
            initial_response_name = "prev_response"
            # The prompt here should take as an input variable the
            # `document_variable_name` as well as `initial_response_name`
            prompt_refine = PromptTemplate.from_template(
                "Here's your first summary: {prev_response}. "
                "Now add to it based on the following context: {context}"
            )
            refine_llm_chain = LLMChain(llm=llm, prompt=prompt_refine)
            chain = RefineDocumentsChain(
                initial_llm_chain=initial_llm_chain,
                refine_llm_chain=refine_llm_chain,
                document_prompt=document_prompt,
                document_variable_name=document_variable_name,
                initial_response_name=initial_response_name,
            )
    """

    initial_llm_chain: LLMChain
    """LLM chain to use on initial document."""
    refine_llm_chain: LLMChain
    """LLM chain to use when refining."""
    document_variable_name: str
    """The variable name in the initial_llm_chain to put the documents in.
    If only one variable in the initial_llm_chain, this need not be provided."""
    initial_response_name: str
    """The variable name to format the initial response in when refining."""
    document_prompt: BasePromptTemplate = Field(
        default_factory=_get_default_document_prompt
    )
    """Prompt to use to format each document, gets passed to `format_document`."""
    return_intermediate_steps: bool = False
    """Return the results of the refine steps in the output."""

    @property
    def output_keys(self) -> List[str]:
        """Expect input key.

        :meta private:
        """
        _output_keys = super().output_keys
        if self.return_intermediate_steps:
            _output_keys = _output_keys + ["intermediate_steps"]
        return _output_keys

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        extra="forbid",
    )

    @model_validator(mode="before")
    @classmethod
    def get_return_intermediate_steps(cls, values: Dict) -> Any:
        """For backwards compatibility."""
        if "return_refine_steps" in values:
            values["return_intermediate_steps"] = values["return_refine_steps"]
            del values["return_refine_steps"]
        return values

    @model_validator(mode="before")
    @classmethod
    def get_default_document_variable_name(cls, values: Dict) -> Any:
        """Get default document variable name, if not provided."""
        if "initial_llm_chain" not in values:
            raise ValueError("initial_llm_chain must be provided")

        llm_chain_variables = values["initial_llm_chain"].prompt.input_variables
        if "document_variable_name" not in values:
            if len(llm_chain_variables) == 1:
                values["document_variable_name"] = llm_chain_variables[0]
            else:
                raise ValueError(
                    "document_variable_name must be provided if there are "
                    "multiple llm_chain input_variables"
                )
        else:
            if values["document_variable_name"] not in llm_chain_variables:
                raise ValueError(
                    f"document_variable_name {values['document_variable_name']} was "
                    f"not found in llm_chain input_variables: {llm_chain_variables}"
                )
        return values

    def combine_docs(
            self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
    ) -> Tuple[str, dict]:
        """Combine by mapping first chain over all, then stuffing into final chain.

        Args:
            docs: List of documents to combine
            callbacks: Callbacks to be passed through
            **kwargs: additional parameters to be passed to LLM calls (like other
                input variables besides the documents)

        Returns:
            The first element returned is the single string output. The second
            element returned is a dictionary of other keys to return.
        """
        inputs = self._construct_initial_inputs(docs, **kwargs)
        dispatch_custom_event("last_doc_mark", {"chunk": False})
        doc_length = len(docs)
        if doc_length == 1:
            dispatch_custom_event("last_doc_mark", {"chunk": True})
        logger.info(f"refine_docs index:1/{doc_length} of {kwargs}")
        res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
        refine_steps = [res]
        for index, doc in enumerate(docs[1:], start=1):
            logger.info(f"refine_docs index:{index+1}/{doc_length} of {kwargs}")
            if index == doc_length - 1:
                dispatch_custom_event("last_doc_mark", {"chunk": True})
            base_inputs = self._construct_refine_inputs(doc, res)
            inputs = {**base_inputs, **kwargs}
            res = self.refine_llm_chain.predict(callbacks=callbacks, **inputs)
            refine_steps.append(res)
        logger.info(f"refine_docs finished of {kwargs}, result:{res}")
        return self._construct_result(refine_steps, res)

    async def acombine_docs(
            self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
    ) -> Tuple[str, dict]:
        """Async combine by mapping a first chain over all, then stuffing
         into a final chain.

        Args:
            docs: List of documents to combine
            callbacks: Callbacks to be passed through
            **kwargs: additional parameters to be passed to LLM calls (like other
                input variables besides the documents)

        Returns:
            The first element returned is the single string output. The second
            element returned is a dictionary of other keys to return.
        """
        inputs = self._construct_initial_inputs(docs, **kwargs)
        res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs)
        refine_steps = [res]
        for doc in docs[1:]:
            base_inputs = self._construct_refine_inputs(doc, res)
            inputs = {**base_inputs, **kwargs}
            res = await self.refine_llm_chain.apredict(callbacks=callbacks, **inputs)
            refine_steps.append(res)
        return self._construct_result(refine_steps, res)

    def _construct_result(self, refine_steps: List[str], res: str) -> Tuple[str, dict]:
        if self.return_intermediate_steps:
            extra_return_dict = {"intermediate_steps": refine_steps}
        else:
            extra_return_dict = {}
        return res, extra_return_dict

    def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
        return {
            self.document_variable_name: format_document(doc, self.document_prompt),
            self.initial_response_name: res,
        }

    def _construct_initial_inputs(
            self, docs: List[Document], **kwargs: Any
    ) -> Dict[str, Any]:
        base_info = {"page_content": docs[0].page_content}
        base_info.update(docs[0].metadata)
        document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
        base_inputs: dict = {
            self.document_variable_name: self.document_prompt.format(**document_info)
        }
        inputs = {**base_inputs, **kwargs}
        return inputs

    @property
    def _chain_type(self) -> str:
        return "refine_documents_chain"

4.调用chain AttachCode

Python 复制代码
def process(tool: BaseTool, prompt_type: QuestionTypeEnum, input_dict: dict,
            run_manager: Optional[CallbackManagerForToolRun] = None):
    # 获取模型实例
    model_type = ModelTypeEnum.from_string(run_manager.metadata.get(constants.MODEL_TYPE))
    model_instance = ModelAdapter.get_model_instance(model_type)

    # 提示词模板集合
    prompt_template = PromptSynchronizer.get_prompt_template(model_type=model_type, questionType=prompt_type)
    prompt_map = json.loads(prompt_template)
    refine_prompt = PromptTemplate.from_template(prompt_map["refine_template"], template_format="f-string")
    question_prompt = PromptTemplate.from_template(prompt_map["prompt"])

    logger.info("invoke tool input_dicts:{}",input_dict)

    combine_docs_chain=load_summarize_chain(llm=model_instance,
                                            chain_type="refine",
                                            question_prompt=question_prompt,
                                            refine_prompt=refine_prompt,
                                            return_intermediate_steps=True,
                                            input_key="text",
                                            output_key="existing_answer",
                                            verbose=True)

    res = (tool.pre_handler
           | download_summarize_chain(combine_docs_chain=combine_docs_chain,
                                  text_splitter=model_instance.get_text_splitter(),
                                  verbose=True,
                                  input_key="input_document")
           | tool.post_handler).invoke(input={"input_document": "", **input_dict}, config=RunnableConfig())
    return res

5.过滤最后一次迭代 AttachCode

Python 复制代码
async def get_stream_content_async(llm_adapter,
                                   question: str,
                                   runnable_config: RunnableConfig) -> AsyncGenerator[AIMessageChunk, None]:
    last_doc_mark = None
    async for event in llm_adapter.get_chain().astream_events({"input": question},
                                                              config=runnable_config,
                                                              version="v2"):
        #多次迭代的标记
        if event.get("event") == "on_custom_event" and event.get("name") == "last_doc_mark":
            last_doc_mark = event["data"]["chunk"]
        #最后一次迭代的结果
        elif event.get("event") == "on_chat_model_stream" and (last_doc_mark is None or last_doc_mark):
            chunk: AIMessageChunk = event["data"]["chunk"]
            yield chunk
        #缓存返回结果
        elif event.get("event") == "on_custom_event" and event.get("name") == "cache_cached_return":
            chunk: AIMessageChunk = event["data"]["chunk"]
            yield chunk

此笔记由idea插件辅助生成

idea插件推荐 AnNote - IntelliJ IDEs Plugin | Marketplace 75 折折扣: MGRYF-TJW4N-WZMSJ-MZDLD-LVGJH BTKQ8-XZLPH-L3QH3-MPKBH-BP9RR

本文由博客群发一文多发等运营工具平台 OpenWrite 发布

相关推荐
兰亭序咖啡12 分钟前
学透Spring Boot — 018. 优雅支持多种响应格式
java·spring boot·后端
小雨凉如水16 分钟前
docker 常用命令
java·docker·eureka
高山流水&上善36 分钟前
医药档案区块链系统
java·springboot
南汐以墨1 小时前
探秘JVM内部
java·jvm
Craaaayon1 小时前
Java八股文-List集合
java·开发语言·数据结构·list
信徒_1 小时前
Spring 怎么解决循环依赖问题?
java·后端·spring
2301_794461572 小时前
多线程编程中的锁策略
java·开发语言
老华带你飞2 小时前
木里风景文化|基于Java+vue的木里风景文化管理平台的设计与实现(源码+数据库+文档)
java·数据库·vue.js·毕业设计·论文·风景·木里风景文化管理平台
SofterICer2 小时前
Eclipse Leshan 常见问题解答 (FAQ) 笔记
java·笔记·eclipse
liang89992 小时前
Shiro学习(四):Shiro对Session的处理和缓存
java·学习·缓存