之前我们讲到langchain的rag问答,有兴趣的同学可以做下回顾
今天我们来了解下如何基于前文的方案实现长文本总结
为什么需要文本总结
通常会议内容是冗长的,如果能够提取关键信息的话,能够帮我们节省大量的时间
模型不能总结吗,为什么单独提出来长文本这个概念
大部分模型都会限制输入长度,如果会议长度超出了模型的限制则无法进行总结
方案
langchain提供了多种方案供我们选择,python.langchain.com/v0.1/docs/u...
- stuff:全文本总结,将整个文本全部投入模型;这样仍然可能会超出模型
- MapReduce:将文本拆成n个小段,每个小段分别总结,然后再将最终的内容一起总结;这样虽然能解决问题,但是可能会破坏文本的上下文导致最终的结果不理想
- refine:和MapReduce相似的是将文本拆成n个小段,但是会以循环的方式先总结第一段,然后将第一段的总结结果和第二段再总结以此类推,此方法能够更好的保留原文的语义
难点
- 代码实现
- 流式返回
- 如何确定是最后一轮的返回(在流式响应的情况下,每轮都会返回总结结果,那么入会确定是最后一轮并返回个前端)
实现
由于langchain的部分实现比较紧凑,导致做二次开发不是很方便,所以可能有部分修改源码的地方
1.创建文本加载工具,用于加载文本 AttachCode
Python
from typing import Dict, Optional
from langchain.chains.combine_documents.base import AnalyzeDocumentChain
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.callbacks import CallbackManagerForChainRun
from modules.BasinessException import BusinessException
from modules.resultCodeEnum import ResultCodeEnum
from service.SubtitleService import SubtitleService
from utils import constants
from utils.logger import logger
class download_summarize_chain(AnalyzeDocumentChain):
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> dict[str, str]:
docs = self.get_docs(inputs, run_manager)
# Other keys are assumed to be needed for LLM prediction
other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key}
other_keys[self.combine_docs_chain.input_key] = docs
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
return self.combine_docs_chain(
other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()
)
def get_docs(self, inputs, run_manager):
file_download_url = str(inputs[constants.TRANSCRIPTION_FILE_URL])
if file_download_url is not None and file_download_url.startswith("http"):
# 通过下载地址下载文件
loader = WebBaseLoader(file_download_url, None, False)
"""Split document into chunks and pass to CombineDocumentsChain."""
document = loader.load()[0].page_content
if len(document) <= 0:
logger.error(f"file not exists:{file_download_url}")
raise BusinessException.new_instance_with_rce(400, ResultCodeEnum.EMPTY_CONTENT)
else:
# 通过企业id和会议id获取字幕
enterprise_id: str = run_manager.metadata.get(constants.ENTERPRISE_ID)
meeting_id: str = run_manager.metadata.get(constants.MEETING_ID)
logger.info(f"process task with llm:{enterprise_id}-{meeting_id}")
document = SubtitleService().fetch_subtitles(enterprise_id=enterprise_id, meeting_id=meeting_id)
docs = self.text_splitter.create_documents([document])
logger.info("number of splitting doc parts:{}", len(docs))
return docs
2.创建重写chain,实现迭代次数的记录 AttachCode
Python
"""Load summarizing chains."""
from typing import Any, Mapping, Optional, Protocol
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
from langchain_core.callbacks import Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from adapters.langchain.chains.refine import RefineDocumentsChain
class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain."""
def __call__(
self, llm: BaseLanguageModel, **kwargs: Any
) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain."""
def _load_stuff_chain(
llm: BaseLanguageModel,
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "text",
verbose: Optional[bool] = None,
**kwargs: Any,
) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # type: ignore[arg-type]
# TODO: document prompt
return StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
verbose=verbose, # type: ignore[arg-type]
**kwargs,
)
def _load_map_reduce_chain(
llm: BaseLanguageModel,
map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
combine_document_variable_name: str = "text",
map_reduce_document_variable_name: str = "text",
collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLanguageModel] = None,
collapse_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None,
token_max: int = 3000,
callbacks: Callbacks = None,
*,
collapse_max_retries: Optional[int] = None,
**kwargs: Any,
) -> MapReduceDocumentsChain:
map_chain = LLMChain(
llm=llm,
prompt=map_prompt,
verbose=verbose, # type: ignore[arg-type]
callbacks=callbacks, # type: ignore[arg-type]
)
_reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(
llm=_reduce_llm,
prompt=combine_prompt,
verbose=verbose, # type: ignore[arg-type]
callbacks=callbacks, # type: ignore[arg-type]
)
# TODO: document prompt
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name,
verbose=verbose, # type: ignore[arg-type]
callbacks=callbacks,
)
if collapse_prompt is None:
collapse_chain = None
if collapse_llm is not None:
raise ValueError(
"collapse_llm provided, but collapse_prompt was not: please "
"provide one or stop providing collapse_llm."
)
else:
_collapse_llm = collapse_llm or llm
collapse_chain = StuffDocumentsChain(
llm_chain=LLMChain(
llm=_collapse_llm,
prompt=collapse_prompt,
verbose=verbose, # type: ignore[arg-type]
callbacks=callbacks,
),
document_variable_name=combine_document_variable_name,
)
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain,
token_max=token_max,
verbose=verbose, # type: ignore[arg-type]
callbacks=callbacks,
collapse_max_retries=collapse_max_retries,
)
return MapReduceDocumentsChain(
llm_chain=map_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name=map_reduce_document_variable_name,
verbose=verbose, # type: ignore[arg-type]
callbacks=callbacks,
**kwargs,
)
def _load_refine_chain(
llm: BaseLanguageModel,
question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
document_variable_name: str = "text",
initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None,
**kwargs: Any,
) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type]
_refine_llm = refine_llm or llm
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) # type: ignore[arg-type]
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name=document_variable_name,
initial_response_name=initial_response_name,
verbose=verbose, # type: ignore[arg-type]
**kwargs,
)
def load_summarize_chain(
llm: BaseLanguageModel,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
**kwargs: Any,
) -> BaseCombineDocumentsChain:
"""Load summarizing chain.
Args:
llm: Language Model to use in the chain.
chain_type: Type of document combining chain to use. Should be one of "stuff",
"map_reduce", and "refine".
verbose: Whether chains should be run in verbose mode or not. Note that this
applies to all chains that make up the final chain.
Returns:
A chain to use for summarizing.
"""
loader_mapping: Mapping[str, LoadingCallable] = {
"stuff": _load_stuff_chain,
"map_reduce": _load_map_reduce_chain,
"refine": _load_refine_chain,
}
if chain_type not in loader_mapping:
raise ValueError(
f"Got unsupported chain type: {chain_type}. "
f"Should be one of {loader_mapping.keys()}"
)
return loader_mapping[chain_type](llm, verbose=verbose, **kwargs)
3.refine chain AttachCode
Python
"""Combine documents by doing a first pass and then refining on more documents."""
from __future__ import annotations
from typing import Any, Dict, List, Tuple
from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain,
)
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import Callbacks, dispatch_custom_event
from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.prompts.prompt import PromptTemplate
from pydantic import ConfigDict, Field, model_validator
from utils.logger import logger
def _get_default_document_prompt() -> PromptTemplate:
return PromptTemplate(input_variables=["page_content"], template="{page_content}")
class RefineDocumentsChain(BaseCombineDocumentsChain):
"""Combine documents by doing a first pass and then refining on more documents.
This algorithm first calls `initial_llm_chain` on the first document, passing
that first document in with the variable name `document_variable_name`, and
produces a new variable with the variable name `initial_response_name`.
Then, it loops over every remaining document. This is called the "refine" step.
It calls `refine_llm_chain`,
passing in that document with the variable name `document_variable_name`
as well as the previous response with the variable name `initial_response_name`.
Example:
.. code-block:: python
from langchain.chains import RefineDocumentsChain, LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import OpenAI
# This controls how each document will be formatted. Specifically,
# it will be passed to `format_document` - see that function for more
# details.
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template(
"Summarize this content: {context}"
)
initial_llm_chain = LLMChain(llm=llm, prompt=prompt)
initial_response_name = "prev_response"
# The prompt here should take as an input variable the
# `document_variable_name` as well as `initial_response_name`
prompt_refine = PromptTemplate.from_template(
"Here's your first summary: {prev_response}. "
"Now add to it based on the following context: {context}"
)
refine_llm_chain = LLMChain(llm=llm, prompt=prompt_refine)
chain = RefineDocumentsChain(
initial_llm_chain=initial_llm_chain,
refine_llm_chain=refine_llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name,
initial_response_name=initial_response_name,
)
"""
initial_llm_chain: LLMChain
"""LLM chain to use on initial document."""
refine_llm_chain: LLMChain
"""LLM chain to use when refining."""
document_variable_name: str
"""The variable name in the initial_llm_chain to put the documents in.
If only one variable in the initial_llm_chain, this need not be provided."""
initial_response_name: str
"""The variable name to format the initial response in when refining."""
document_prompt: BasePromptTemplate = Field(
default_factory=_get_default_document_prompt
)
"""Prompt to use to format each document, gets passed to `format_document`."""
return_intermediate_steps: bool = False
"""Return the results of the refine steps in the output."""
@property
def output_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
_output_keys = super().output_keys
if self.return_intermediate_steps:
_output_keys = _output_keys + ["intermediate_steps"]
return _output_keys
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def get_return_intermediate_steps(cls, values: Dict) -> Any:
"""For backwards compatibility."""
if "return_refine_steps" in values:
values["return_intermediate_steps"] = values["return_refine_steps"]
del values["return_refine_steps"]
return values
@model_validator(mode="before")
@classmethod
def get_default_document_variable_name(cls, values: Dict) -> Any:
"""Get default document variable name, if not provided."""
if "initial_llm_chain" not in values:
raise ValueError("initial_llm_chain must be provided")
llm_chain_variables = values["initial_llm_chain"].prompt.input_variables
if "document_variable_name" not in values:
if len(llm_chain_variables) == 1:
values["document_variable_name"] = llm_chain_variables[0]
else:
raise ValueError(
"document_variable_name must be provided if there are "
"multiple llm_chain input_variables"
)
else:
if values["document_variable_name"] not in llm_chain_variables:
raise ValueError(
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
return values
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain.
Args:
docs: List of documents to combine
callbacks: Callbacks to be passed through
**kwargs: additional parameters to be passed to LLM calls (like other
input variables besides the documents)
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._construct_initial_inputs(docs, **kwargs)
dispatch_custom_event("last_doc_mark", {"chunk": False})
doc_length = len(docs)
if doc_length == 1:
dispatch_custom_event("last_doc_mark", {"chunk": True})
logger.info(f"refine_docs index:1/{doc_length} of {kwargs}")
res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
refine_steps = [res]
for index, doc in enumerate(docs[1:], start=1):
logger.info(f"refine_docs index:{index+1}/{doc_length} of {kwargs}")
if index == doc_length - 1:
dispatch_custom_event("last_doc_mark", {"chunk": True})
base_inputs = self._construct_refine_inputs(doc, res)
inputs = {**base_inputs, **kwargs}
res = self.refine_llm_chain.predict(callbacks=callbacks, **inputs)
refine_steps.append(res)
logger.info(f"refine_docs finished of {kwargs}, result:{res}")
return self._construct_result(refine_steps, res)
async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Async combine by mapping a first chain over all, then stuffing
into a final chain.
Args:
docs: List of documents to combine
callbacks: Callbacks to be passed through
**kwargs: additional parameters to be passed to LLM calls (like other
input variables besides the documents)
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._construct_initial_inputs(docs, **kwargs)
res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs)
refine_steps = [res]
for doc in docs[1:]:
base_inputs = self._construct_refine_inputs(doc, res)
inputs = {**base_inputs, **kwargs}
res = await self.refine_llm_chain.apredict(callbacks=callbacks, **inputs)
refine_steps.append(res)
return self._construct_result(refine_steps, res)
def _construct_result(self, refine_steps: List[str], res: str) -> Tuple[str, dict]:
if self.return_intermediate_steps:
extra_return_dict = {"intermediate_steps": refine_steps}
else:
extra_return_dict = {}
return res, extra_return_dict
def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
return {
self.document_variable_name: format_document(doc, self.document_prompt),
self.initial_response_name: res,
}
def _construct_initial_inputs(
self, docs: List[Document], **kwargs: Any
) -> Dict[str, Any]:
base_info = {"page_content": docs[0].page_content}
base_info.update(docs[0].metadata)
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
base_inputs: dict = {
self.document_variable_name: self.document_prompt.format(**document_info)
}
inputs = {**base_inputs, **kwargs}
return inputs
@property
def _chain_type(self) -> str:
return "refine_documents_chain"
4.调用chain AttachCode
Python
def process(tool: BaseTool, prompt_type: QuestionTypeEnum, input_dict: dict,
run_manager: Optional[CallbackManagerForToolRun] = None):
# 获取模型实例
model_type = ModelTypeEnum.from_string(run_manager.metadata.get(constants.MODEL_TYPE))
model_instance = ModelAdapter.get_model_instance(model_type)
# 提示词模板集合
prompt_template = PromptSynchronizer.get_prompt_template(model_type=model_type, questionType=prompt_type)
prompt_map = json.loads(prompt_template)
refine_prompt = PromptTemplate.from_template(prompt_map["refine_template"], template_format="f-string")
question_prompt = PromptTemplate.from_template(prompt_map["prompt"])
logger.info("invoke tool input_dicts:{}",input_dict)
combine_docs_chain=load_summarize_chain(llm=model_instance,
chain_type="refine",
question_prompt=question_prompt,
refine_prompt=refine_prompt,
return_intermediate_steps=True,
input_key="text",
output_key="existing_answer",
verbose=True)
res = (tool.pre_handler
| download_summarize_chain(combine_docs_chain=combine_docs_chain,
text_splitter=model_instance.get_text_splitter(),
verbose=True,
input_key="input_document")
| tool.post_handler).invoke(input={"input_document": "", **input_dict}, config=RunnableConfig())
return res
5.过滤最后一次迭代 AttachCode
Python
async def get_stream_content_async(llm_adapter,
question: str,
runnable_config: RunnableConfig) -> AsyncGenerator[AIMessageChunk, None]:
last_doc_mark = None
async for event in llm_adapter.get_chain().astream_events({"input": question},
config=runnable_config,
version="v2"):
#多次迭代的标记
if event.get("event") == "on_custom_event" and event.get("name") == "last_doc_mark":
last_doc_mark = event["data"]["chunk"]
#最后一次迭代的结果
elif event.get("event") == "on_chat_model_stream" and (last_doc_mark is None or last_doc_mark):
chunk: AIMessageChunk = event["data"]["chunk"]
yield chunk
#缓存返回结果
elif event.get("event") == "on_custom_event" and event.get("name") == "cache_cached_return":
chunk: AIMessageChunk = event["data"]["chunk"]
yield chunk
此笔记由idea插件辅助生成
idea插件推荐 AnNote - IntelliJ IDEs Plugin | Marketplace 75 折折扣: MGRYF-TJW4N-WZMSJ-MZDLD-LVGJH BTKQ8-XZLPH-L3QH3-MPKBH-BP9RR
本文由博客群发一文多发等运营工具平台 OpenWrite 发布