简单看看langchain中的一点qwen源码

文章目录

前言

本文主要是继续深挖Tongyi类,并进一步探究详细的流程。个人理解不够全面,能够为大家给出的解释有限。

导入Tongyi类

Tongyi类是langchain_community.llms中的一个类。实际上,这个类是在langchain_community.llms文件夹下tongyi.py中的一个类,只不过因为langchain_community.llms文件夹下的__init__.py文件追加了一个方法:

python 复制代码
def _import_tongyi() -> Type[BaseLLM]:
  from langchain_community.llms.tongyi import Tongyi
  return Tongyi

最终,在一个包含了无数个if-else__getattr__方法中,会根据传入name的值判断执到底执行哪一个大模型的import方法。

这个意思就是说,我们假设现在新开发了一个大模型,叫做Ninedays(就当这个叫做九天吧),并存入ninedays.py。我们想要导入这个Ninedays大模型,也就可以通过from langchain_community.llms import Ninedays导入。

导入的过程将首先经过__init__.py方法中的__getattr__方法,用于访问没有直接定义出来的数据。此时,在__getattr__方法中增加:

python 复制代码
if name == "Ninedays":
  from langchain_community.llms.ninedays import Ninedays

这个意思就是,我经过__getattr__访问到了Ninedayes这个name,并且通过大量的if-else查询到了这个执行条件,于是开始导入大模型。

这种方法是一种懒加载的实现方法,非常方便。

配置Tongyi类

langchain_community.llms文件夹下的tongyi.py文件中,里面有这么几个属性:

python 复制代码
client: Any  #: :meta private:
model_name: str = "qwen-plus"

"""Model name to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)

top_p: float = 0.8
"""Total probability mass of tokens to consider at each step."""

dashscope_api_key: Optional[str] = None
"""Dashscope api key provide by Alibaba Cloud."""

streaming: bool = False
"""Whether to stream the results or not."""

max_retries: int = 10
"""Maximum number of retries to make when generating."""

其中:

  • client并不确定是什么,没有相关定义,但是会按照llm.client.call执行,其中llmTongyi类的实例。
  • model_nameqwen-plus,表示默认的模型名称。
  • model_kwargs是空字典,表示默认的模型参数。
  • top_p是0.8,是model_kwargs中的top_p参数。
  • dashscope_api_key是通义千问的api-key
  • streaming表示最终的输出是否是流式输出。
  • max_retries是10,表示最多允许的重试次数。

我们初始化的过程中,往往也是直接自定义这些参数:

python 复制代码
llm = Tongyi(
  dashscope_api_key="sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
)

读取api-key

读取的方式有很多种,包括yaml配置、xml配置、txt配置、properties配置乃至数据库配置等等,Python也为每一种配置都有特定的工具库,非常方便。下面仅介绍3种推荐配置方法。

os配置

Tongyi初始化的过程中,会执行一个validate_environment方法,将检查环境是否满足要求:

python 复制代码
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
  """Validate that api key and python package exists in environment."""
  values["dashscope_api_key"] = get_from_dict_or_env(
    values, "dashscope_api_key", "DASHSCOPE_API_KEY"
  )
  try:
    import dashscope
  except ImportError:
    raise ImportError(
      "Could not import dashscope python package. "
      "Please install it with `pip install dashscope`."
    )
  try:
    values["client"] = dashscope.Generation
  except AttributeError:
    raise ValueError(
      "`dashscope` has no `Generation` attribute, this is likely "
      "due to an old version of the dashscope package. Try upgrading it "
      "with `pip install --upgrade dashscope`."
    )
  return values

首先一上来就是执行get_from_dict_or_env,这个方法将会在values中获取dashscope_api_key,如果获取不到,则从os.environ中获取DASHSCOPE_API_KEY

既然知道这个,那就好办了:

python 复制代码
import os
os.environ["DASHSCOPE_API_KEY"] = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"

这样Tongyi就知道你的api-key了。

env配置

当然,方法是有很多的。除了单纯的使用os显性地配置以外,官方提出,我们最好不要在代码中明码显示我们的api-key,这将会带来一定的风险。

当然,官方也推荐了一种方法:使用.env文件存储api-key

比较取巧的是,.env文件可以通过load_dotenv加载:

python 复制代码
from dotenv import load_dotenv
load_dotenv()

同样也非常方便。

P.S. :我不太确定是在哪里出了问题,现在.env配置失效了,只能使用os.environ的方式配置。

streamlit配置

我们在OpenAI的教程中能够更多地看到streamlit的存在。主要是因为,streamlit能够非常方便地给出一个demo界面,从而为用户更快地测试与迭代。

当然,streamlit也有独特的部分,比如streamlit能够自动的将api-key存储在secrets.toml文件中,从而实现api-key的隐藏。

secrets.toml文件样例为:

toml 复制代码
DASHSCOPE_API_KEY = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"

P.S. :需要注意的是,toml文件必须要为字符串增加双引号,否则将会抛出编码错误。

为了读取api-keystreamlit提供了secrets成员变量,其最初定义是这样的:

python 复制代码
class Secrets(Mapping[str, Any]):
  """A dict-like class that stores secrets.
  Parses secrets.toml on-demand. Cannot be externally mutated.

  Safe to use from multiple threads.
  """

  def __init__(self, file_paths: list[str]):
    # Our secrets dict.
    self._secrets: Mapping[str, Any] | None = None
    self._lock = threading.RLock()
    self._file_watchers_installed = False
    self._file_paths = file_paths

    self.file_change_listener = Signal(
      doc="Emitted when a `secrets.toml` file has been changed."
    )

可以看到,secrets表现类似Java中的ConcurrentHashMap一般,既有字典的一部分,又有线程安全的一部分。

而在定义secrets的过程中,通过上锁的方式,在streamlit的生命周期中实现了单例模式:

python 复制代码
# 根据配置文件`secrets.toml`获取锁
SECRETS_FILE_LOCS: Final[list[str]] = [
    file_util.get_streamlit_file_path("secrets.toml"),
    # NOTE: The order here is important! Project-level secrets should overwrite global
    # secrets.
    file_util.get_project_streamlit_file_path("secrets.toml"),
]
# 创建单例
secrets_singleton: Final = Secrets(SECRETS_FILE_LOCS)
# streamlit的成员对象
secrets = _secrets_singleton

这也就意味着,无论我们使用多少个Agent协同进行,都将读取同一个api-key

最终在使用的时候,也就能够直接这么做:

python 复制代码
import streamlit as st
llm = Tongyi(
  dashscope_api_key = st.secrets["DASHSCOPE_API_KEY"]
)

同样也是非常的方便。

PromptTemplate

当我们构造了Tongyi实例之后,我们就需要完善一个prompt模板了,在实际使用中,就会用到PromptTempate。这是个啥呢?是langchain_core.prompts中的一个类。在他的构造函数中,他接受这么几个参数:

  • template -> 模板,也就是我们规定其中固定的部分,然后留下灵活填入的部分,叫做模板。例如:Hello, {name},这串字符串就是模板,需要灵活填入的就是这个name
  • input_variables -> 输入变量,输入变量是一个列表,无论输入一个两个都是一个列表,按照模板中需要填入的变量顺序进行匹配
  • partial_variables -> 预置变量,输入变量是一个dict字典,键是变量名,值是变量值。这个参数主要是为了将变量值直接预先填入prompt模板中,而不像input_variables需要渲染。
  • template_format -> 模板格式,默认是f-string,但是也可以是jinja2,还可以是mustache
  • validate_template -> 是否验证模板,默认是True

P.S. :关于partial_variables,这个参数是langchain提供的一个功能,可以预置变量值,而不用每次都渲染。举个简单的例子:

python 复制代码
template_str = """
    Hello, {name}. Welcome to {place}. Today's weather is {weather}.
"""

其中,假设place = "Wonderland"。如果placepartial_variables,那么langchain会自动将place的值填入模板中。在PromptTemplate渲染的过程中,我们渲染的并不是 字符串"Hello, {name}. Welcome to {place}. Today's weather is {weather}."而是字符串:

python 复制代码
template_str = """
    Hello, {name}. Welcome to Wonderland. Today's weather is {weather}.
"""

P.P.S. :关于模板,需要说明的是,langchain默认模板为f-string模板,并强烈建议不要使用非受信的jinja2模板。这主要是因为jinja2的功能过于强大,它可以在prompt中执行Python代码。举个比较简单的例子:

python 复制代码
template_str = """
Hello, {{ name }}
{% if 2 + 2 == 4 %}
  2 + 2 is indeed 4
{% endif %}
{% set dangerous_code = '__import__("os").system("ls")' %}
{{ eval(dangerous_code) }}
"""

看到了吗,jinja2模板能够直接执行Python代码,从而直接执行了linuxls命令。如果说写jinja2模板的人将更具备攻击性一点,则可以将服务器中的数据直接全部传出去,或者植入挖矿程序等。

每次使用的时候都这么构造的话,会不会太麻烦了?没关系,官方已经给你想好了。默认情况下,PromptTemplate将使用f-string模板,所以可以直接这么写:

python 复制代码
prompt = PromptTemplate.from_template("Hello, {name}")

P.S.: :直接用PromptTemplate构造函数与PromptTemplate.from_template构造基本没有区别,from_template只是为了方便使用者仅需要考虑输入f-string模板的prompt而设计的。

当然,你可能会疑惑,这样子设置,他怎么知道自己的input_variables呢?答案就在源码中的get_template_variables方法,它能够将传递的prompt模板中的变量解析出来,并返回一个经过排序的列表。

最终,PromptTemplate无论是使用from_template还是直接使用构造函数,最终都会返回一个PromptTemplate对象。

LLMChain

LLMChain的构造函数主要接受两个参数,一个是LLM实例,一个是prompt模板。就像这样:

python 复制代码
llm_chain = LLMChain(llm=llm, prompt=prompt)

本质上,LLMChain在获得LLM实例与prompt模板后,会通过LLMChain_call方法执行:

python 复制代码
def _call(
  self,
  inputs: Dict[str, Any],
  run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
  response = self.generate([inputs], run_manager=run_manager)
  return self.create_outputs(response)[0]

这里主要调用了两个部分,分别是generatecreate_outputs两个。

generate主要是通过调用prep_prompt方法构造提示,然后根据llm实例的类型生成具体的LLMResult类。

最终,也就由LLMChain给出结果。

更换掉LLMChain(可选)

比较尴尬的是,LLMChain有一个注解(或者说你比较喜欢叫他装饰器也行):

python 复制代码
@deprecated(
  since="0.1.17",
  alternative="RunnableSequence, e.g., `prompt | llm`",
  removal="0.3.0",
)

从 0.1.17 0.1.17 0.1.17开始支持,从 0.3.0 0.3.0 0.3.0开始取消支持。如果有需要的话,可以提前修改为TransformChain

修改起来也还算比较简单。原来我们使用:

python 复制代码
llm_chain = LLMChain(llm=llm, prompt=prompt)

需要修改的也就只有这里,PromptTemplate没变,StuffDocumentsChain也没变,变的只有构造chain的方法:

python 复制代码
llm_chain = TransformChain(
  llm=llm,
  input_prompt=prompt,
  output_key="text",
  transform = lambda x: {"text": x}
)

当然,上面这是一个偷懒的方法。实际上,我们需要额外定义一个func方法,然后在构造函数中定义transform = func

其中:

  • llm没有变化,还是以前的llm对象;
  • input_prompt还是以前的PromptTemplate对象,只是名字改了;
  • output_keyllm的输出字段名,这个是新东西,最终标记输出内容的键;
  • transform,原名transform_cbtransform只是transform_cb的一个别名,为该参数赋值的时候是都可以使用的。transform是一个callback函数,将prompt的输出包装为字典,键需要与output_key一致,内容可以原样输出,也可以稍做处理;

当然,考虑到目前最新的版本只有 0.2.1 0.2.1 0.2.1,如果并没有考虑升级事项,LLMChain是完全可以直接使用的。但如果考虑到长期维护,可能就需要更换为TransformChain

StuffDocumentsChain

StuffDocumentsChain的构造函数主要接受两个参数,一个是LLMChain实例,一个是document_prompt模板中代表植入文本的变量:

python 复制代码
stf_chain = StuffDocumentsChain(
  llm_chain = LLMChain(
    llm=Tongyi(),
    prompt=prompt,
    verbose=True,
    memory=memory
  ),
  document_variable_name="text"
)

PromptTemplate类似,通过document_variable_name将文件内容添加到prompt模板中。

这里面需要注意的是,document_variable_name需要与prompt中指定的植入文本的内容一致。

比如,在最开始的时候,我们定义的prompt模板是:

python 复制代码
prompt = """"
{human_input} {chat_history} {text}
"""

那么我们就需要将document_variable_name设置为text

当然,你可能在某些教程上看到是这样的:

python 复制代码
stf_chain = StuffDocumentsChain(
  llm_chain = LLMChain(
    llm=Tongyi(),
    prompt=prompt,
    verbose=True,
    memory=memory,
    output_key="text"
  ),
  document_variable_name="text"
)

这个时候,出现了 2 2 2个text!是不是意味着这两个text必须要对应起来呢?

还记得上文提到的TransformChain吗?一样的,output_key实际上仅代表输出时候的键值,而document_variable_name代表prompt模板中需要植入文本的变量,前者是输出的变量,后者是输入的变量,这两者在意义上是完全不一致的。

使用StuffDocumentsChain包装了prompt后,就需要交付大模型并生成解答了。

目前,使用的方法是run方法:

python 复制代码
response = stf_chain.run(
  human_input=message,
  chat_history="",
  input_documents=load_documents("http://120.26.106.143:5000/disease")
)

在这里,必要的字段包括三个部分,分别是hunman_inputchat_historyinput_documents。这三个部分都比较好理解,分别是用户最新一次输入、聊天记录以及输入文件内容。

其中,human_inputinput_documents就是原样输入,不需要额外处理。

chat_history需要一段数组,说明以往的交互历史。如果明码传输的话,需要这么做:

python 复制代码
chat_history = [
  {"role": "user", "content": "我想了解一下糖尿病的症状。"},
  {"role": "assistant", "content": "糖尿病的主要症状包括频尿、口渴、饥饿感增加及体重下降等。"},
  {"role": "user", "content": "那它的治疗方法呢?"},
  {"role": "assistant", "content": "治疗方法通常涉及饮食控制、规律运动、监测血糖以及可能需要的药物治疗。"}
]

如果你使用的是langchainConversationBufferMemory,并且利用streamlitsession_state作为全局变量存储,那么可以直接利用起来:

python 复制代码
# 初始化会话状态
if "messages" not in st.session_state:
  st.session_state.messages = [{
    "role": "assistant",
    "content": "你好,我是医学千问机器人GraphAcademy。有什么需要帮忙的吗?"
  }]
# 获取交互历史 - 针对多线程场景的懒加载提供`if`容错
chat_history = st.session_state.messages if st.session_state.messages else ""

ConversationBufferMemory

ConversationBufferMemorylangchain提供的一个会话记忆类,用于存储会话历史。ConversationBufferMemory继承自BaseChatMemory类,而BaseChatMemory中有一个chat_memory变量,其类型通过Field方法的default_factory参数指定了默认类型InMemoryChatMessageHistory。看来找到到答案了。继续追溯,发现InMemoryChatMessageHistory拥有成员变量messages,同样通过Field方法的default_factory参数指定了默认类型为list

到头来,ConversationBufferMemory能够存储历史对话信息,究其根源就是因为它本身就是一个列表。也正是上文提到的,使用列表逐行存储历史信息。而追加存储信息也是触发了InMemoryChatMessageHistory类中的add_message方法。

那么,这个数组里面能够存储什么东西呢?按照ConversationBufferMemoryget_buffer_string方法的定义,我们可以看到这么一些内容:

python 复制代码
def get_buffer_string(
    messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
  """Convert a sequence of Messages to strings and concatenate them into one string.

  Args:
      messages: Messages to be converted to strings.
      human_prefix: The prefix to prepend to contents of HumanMessages.
      ai_prefix: THe prefix to prepend to contents of AIMessages.

  Returns:
      A single string concatenation of all input messages.

  Example:
      .. code-block:: python

          from langchain_core import AIMessage, HumanMessage

          messages = [
              HumanMessage(content="Hi, how are you?"),
              AIMessage(content="Good, how are you?"),
          ]
          get_buffer_string(messages)
          # -> "Human: Hi, how are you?\nAI: Good, how are you?"
  """
  string_messages = []
  for m in messages:
    if isinstance(m, HumanMessage):
      role = human_prefix
    elif isinstance(m, AIMessage):
      role = ai_prefix
    elif isinstance(m, SystemMessage):
      role = "System"
    elif isinstance(m, FunctionMessage):
      role = "Function"
    elif isinstance(m, ToolMessage):
      role = "Tool"
    elif isinstance(m, ChatMessage):
      role = m.role
    else:
      raise ValueError(f"Got unsupported message type: {m}")
    message = f"{role}: {m.content}"
    if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
      message += f"{m.additional_kwargs['function_call']}"
    string_messages.append(message)

  return "\n".join(string_messages)

意思就是,如果ConversationBufferMemory如果在某一时刻需要将存储的内容输出出来,就会把存储的HumanMessageAIMessageSystemMessageFunctionMessageToolMessage以及ChatMessage中的角色提取出来,然后拼接在一起,而其他类型就会报错说类型不受支持。

所以,ConversationBufferMemory也就只能够存储HumanMessageAIMessageSystemMessageFunctionMessageToolMessage以及ChatMessage在内的一条或者多条信息。这些也都是langchain为我们提供的消息接口,用于封装各种各样的消息。

这些消息最终也将通过memory参数传入LLMChain中,在后续植入prompt知识库的时候,也会进一步通过chat_history输入StuffDocumentsChain中。

其实,比较神奇的是,虽然比较推荐传入ConversationBufferMemory,其实无论是LLMChainmemory还是StuffDocumentsChainchat_history,直接传入数组也都是一点问题没有的。因为本质上都是数组,只需要保证里面的内容是使用langchain所提供的各种Message对象即可。

相关推荐
豌豆花下猫几秒前
Python 潮流周刊#78:async/await 是糟糕的设计(摘要)
后端·python·ai
只因在人海中多看了你一眼3 分钟前
python语言基础
开发语言·python
小技与小术11 分钟前
数据结构之树与二叉树
开发语言·数据结构·python
hummhumm37 分钟前
第 25 章 - Golang 项目结构
java·开发语言·前端·后端·python·elasticsearch·golang
杜小满41 分钟前
周志华深度森林deep forest(deep-forest)最新可安装教程,仅需在pycharm中完成,超简单安装教程
python·随机森林·pycharm·集成学习
databook2 小时前
『玩转Streamlit』--布局与容器组件
python·机器学习·数据分析
nuclear20112 小时前
使用Python 在Excel中创建和取消数据分组 - 详解
python·excel数据分组·创建excel分组·excel分类汇总·excel嵌套分组·excel大纲级别·取消excel分组
Lucky小小吴3 小时前
有关django、python版本、sqlite3版本冲突问题
python·django·sqlite
GIS 数据栈3 小时前
每日一书 《基于ArcGIS的Python编程秘笈》
开发语言·python·arcgis
爱分享的码瑞哥3 小时前
Python爬虫中的IP封禁问题及其解决方案
爬虫·python·tcp/ip