ChatGLM3 langchain_demo 代码解析

ChatGLM3 langchain_demo 代码解析

  • [0. 背景](#0. 背景)
  • [1. 项目代码结构](#1. 项目代码结构)
  • [2. 代码解析](#2. 代码解析)

0. 背景

学习 ChatGLM3 的项目内容,过程中使用 AI 代码工具,对代码进行解释,帮助自己快速理解代码。这篇文章记录 ChatGLM3 langchain_demo 的代码解析内容。

1. 项目代码结构

2. 代码解析

2-1. utils.py

import os
import yaml


def tool_config_from_file(tool_name, directory="Tool/"):
    """search tool yaml and return json format"""
    for filename in os.listdir(directory):
        if filename.endswith('.yaml') and tool_name in filename:
            file_path = os.path.join(directory, filename)
            with open(file_path, encoding='utf-8') as f:
                return yaml.safe_load(f)
    return None

这段代码定义了一个函数 tool_config_from_file,用于从文件中加载工具的配置信息。

该函数接受两个参数:tool_name 表示要加载的工具名称,directory 表示存储工具配置文件的目录,默认为 "Tool/"。

在函数中,首先使用 os.listdir 函数获取指定目录下的所有文件名。然后,通过遍历文件名列表,找到以 ".yaml" 结尾且包含指定工具名称的文件。如果找到了匹配的文件,就构造文件的完整路径,并使用 open 函数打开文件。接着,使用 yaml.safe_load 函数加载文件内容,并将其转换为 JSON 格式的数据返回。

如果遍历完所有文件后仍未找到匹配的工具配置文件,则返回 None。

总体而言,这段代码定义了一个函数 tool_config_from_file,用于根据工具名称和文件目录获取工具的配置信息,并将其转换为 JSON 格式的数据返回。

2-2. ChatGLM3.py

import json
from langchain.llms.base import LLM
from transformers import AutoTokenizer, AutoModel, AutoConfig
from typing import List, Optional
from utils import tool_config_from_file

这段代码导入了一些模块和函数,并定义了一些类型注解。

首先,导入了 json 模块,用于处理 JSON 数据。 然后,导入了 LLM 类和 AutoTokenizer、AutoModel、AutoConfig 类,这些来自 langchain.llms.base 和 transformers 模块,用于构建和配置语言模型。 接下来,导入了 List 和 Optional 类型,用于类型注解。 最后,导入了 tool_config_from_file 函数,该函数来自 utils 模块,用于加载工具的配置信息。

总体而言,这段代码导入了所需的模块、类和函数,以及定义了一些类型注解。

class ChatGLM3(LLM):
    max_token: int = 8192
    do_sample: bool = False
    temperature: float = 0.8
    top_p = 0.8
    tokenizer: object = None
    model: object = None
    history: List = []
    tool_names: List = []
    has_search: bool = False

    def __init__(self):
        super().__init__()

    @property
    def _llm_type(self) -> str:
        return "ChatGLM3"

    def load_model(self, model_name_or_path=None):
        model_config = AutoConfig.from_pretrained(
            model_name_or_path,
            trust_remote_code=True
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path,
            trust_remote_code=True
        )
        self.model = AutoModel.from_pretrained(
            model_name_or_path, config=model_config, trust_remote_code=True
        ).half().cuda()

    def _tool_history(self, prompt: str):
        ans = []
        tool_prompts = prompt.split(
            "You have access to the following tools:\n\n")[1].split("\n\nUse a json blob")[0].split("\n")

        tool_names = [tool.split(":")[0] for tool in tool_prompts]
        self.tool_names = tool_names
        tools_json = []
        for i, tool in enumerate(tool_names):
            tool_config = tool_config_from_file(tool)
            if tool_config:
                tools_json.append(tool_config)
            else:
                ValueError(
                    f"Tool {tool} config not found! It's description is {tool_prompts[i]}"
                )

        ans.append({
            "role": "system",
            "content": "Answer the following questions as best as you can. You have access to the following tools:",
            "tools": tools_json
        })
        query = f"""{prompt.split("Human: ")[-1].strip()}"""
        return ans, query

    def _extract_observation(self, prompt: str):
        return_json = prompt.split("Observation: ")[-1].split("\nThought:")[0]
        self.history.append({
            "role": "observation",
            "content": return_json
        })
        return

    def _extract_tool(self):
        if len(self.history[-1]["metadata"]) > 0:
            metadata = self.history[-1]["metadata"]
            content = self.history[-1]["content"]
            if "tool_call" in content:
                for tool in self.tool_names:
                    if tool in metadata:
                        input_para = content.split("='")[-1].split("'")[0]
                        action_json = {
                            "action": tool,
                            "action_input": input_para
                        }
                        self.has_search = True
                        return f"""
Action: 

{json.dumps(action_json, ensure_ascii=False)}

""" 复制代码
        final_answer_json = {
            "action": "Final Answer",
            "action_input": self.history[-1]["content"]
        }
        self.has_search = False
        return f"""
Action: 

{json.dumps(final_answer_json, ensure_ascii=False)}

""" 复制代码
    def _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = ["<|user|>"]):
        print("======")
        print(prompt)
        print("======")
        if not self.has_search:
            self.history, query = self._tool_history(prompt)
        else:
            self._extract_observation(prompt)
            query = ""
        # print("======")
        # print(history)
        # print("======")
        _, self.history = self.model.chat(
            self.tokenizer,
            query,
            history=self.history,
            do_sample=self.do_sample,
            max_length=self.max_token,
            temperature=self.temperature,
        )
        response = self._extract_tool()
        history.append((prompt, response))
        return response

这段代码定义了一个名为 ChatGLM3 的类,该类继承自 LLM 类。

以下是这个类的成员变量和方法的详细解析:

  • max_token: int = 8192:最大令牌数,默认为 8192。

  • do_sample: bool = False:是否进行采样,默认为 False。

  • temperature: float = 0.8:采样温度,默认为 0.8。

  • top_p = 0.8:top-p 采样的概率阈值,默认为 0.8。

  • tokenizer: object = None:tokenizer 对象,默认为 None。

  • model: object = None:模型对象,默认为 None。

  • history: List = []:对话历史记录列表,默认为空列表。

  • tool_names: List = []:工具名称列表,默认为空列表。

  • has_search: bool = False:是否进行工具搜索的标志位,默认为 False。

  • init(self):类的构造函数,调用父类 LLM 的构造函数。

  • _llm_type(self) -> str:类属性,返回字符串 "ChatGLM3"。

  • load_model(self, model_name_or_path=None):加载模型的方法。根据模型的名称或路径,使用 AutoConfig、AutoTokenizer 和 AutoModel 类从预训练模型中加载配置、tokenizer 和模型,并将模型转化为半精度浮点数并放置在 CUDA 设备上。

  • _tool_history(self, prompt: str):提取工具历史记录的方法。根据提示字符串,从中提取出工具名称和工具配置信息,并将其存储到 tool_names 和 tools_json 中,最后将结果作为字典添加到 ans 列表中,并返回 ans 列表和查询字符串。

  • _extract_observation(self, prompt: str):提取观察信息的方法。从提示字符串中提取观察信息,将其添加到历史记录列表 history 中。

  • _extract_tool(self):提取工具信息的方法。根据最后一条历史记录的元数据和内容,判断是否存在工具调用。如果存在,遍历工具名称列表,如果某个工具名称出现在元数据中,提取出输入参数,并构造一个动作 JSON 对象。同时,将 has_search 设置为 True。如果不存在工具调用,构造一个最终回答的动作 JSON 对象,并将 has_search 设置为 False。最后,返回包含动作 JSON 的字符串。

  • _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = ["<|user|>"]):执行对话的方法。根据是否进行工具搜索的标志位,调用 _tool_history 方法或 _extract_observation 方法来获取查询字符串和更新历史记录。然后,使用模型的 chat 方法进行对话生成,并将结果传递给 _extract_tool 方法,提取工具信息。最后,将提示和响应添加到历史记录列表中,并返回响应。

总体而言,这段代码定义了一个名为ChatGLM3 的类,该类继承自 LLM 类。它包含了一些成员变量和方法,用于加载模型、提取工具历史记录、提取观察信息和执行对话。

在 load_model 方法中,模型的配置、tokenizer 和模型本身被加载,并存储在 tokenizer 和 model 成员变量中。

_tool_history 方法用于从提示字符串中提取工具历史记录。它首先将提示字符串按特定的分隔符切分,然后从切分结果中提取工具名称和相应的工具配置信息。这些信息被存储在 tool_names 和 tools_json 成员变量中,并作为字典添加到 ans 列表中。最后,返回 ans 列表和查询字符串。

_extract_observation 方法用于从提示字符串中提取观察信息。它将观察信息存储在 history 成员变量中。

_extract_tool 方法用于提取工具信息。它检查最后一条历史记录的元数据和内容,判断是否存在工具调用。如果存在工具调用,它将提取出工具名称和输入参数,并构造一个包含动作和输入参数的 JSON 对象。如果不存在工具调用,它将构造一个包含最终回答动作的 JSON 对象。最后,它将返回包含动作 JSON 的字符串。

_call 方法用于执行对话。它根据 has_search 标志位,选择调用 _tool_history 方法或 _extract_observation 方法,获取查询字符串和更新历史记录。然后,它使用模型的 chat 方法生成响应,并将结果传递给 _extract_tool 方法,提取工具信息。最后,它将提示和响应添加到历史记录列表中,并返回响应。

总体而言,这段代码定义了一个用于对话生成的类 ChatGLM3,它继承自 LLM 类,并提供了加载模型、提取工具历史记录、提取观察信息和执行对话的功能。

    def _tool_history(self, prompt: str):
        ans = []
        tool_prompts = prompt.split(
            "You have access to the following tools:\n\n")[1].split("\n\nUse a json blob")[0].split("\n")

        tool_names = [tool.split(":")[0] for tool in tool_prompts]
        self.tool_names = tool_names
        tools_json = []
        for i, tool in enumerate(tool_names):
            tool_config = tool_config_from_file(tool)
            if tool_config:
                tools_json.append(tool_config)
            else:
                ValueError(
                    f"Tool {tool} config not found! It's description is {tool_prompts[i]}"
                )

        ans.append({
            "role": "system",
            "content": "Answer the following questions as best as you can. You have access to the following tools:",
            "tools": tools_json
        })
        query = f"""{prompt.split("Human: ")[-1].strip()}"""
        return ans, query

这段代码定义了一个名为 _tool_history 的方法,它接受一个参数 prompt,该参数是一个字符串。

以下是这段代码的详细解析:

  • ans = []:创建一个空列表 ans,用于存储返回结果。

  • tool_prompts = prompt.split("You have access to the following tools:\n\n")[1].split("\n\nUse a json blob")[0].split("\n"):从 prompt 字符串中提取工具提示信息。它使用特定的分隔符将 prompt 字符串进行切分,提取包含工具提示信息的部分,并将其存储在 tool_prompts 列表中。

  • tool_names = [tool.split(":")[0] for tool in tool_prompts]:从 tool_prompts 列表中提取工具名称。它使用冒号进行切分,并将每个工具的名称存储在 tool_names 列表中。

  • self.tool_names = tool_names:将 tool_names 列表赋值给类的成员变量 tool_names。

  • tools_json = []:创建一个空列表 tools_json,用于存储工具的配置信息。

  • for i, tool in enumerate(tool_names)::对 tool_names 列表进行遍历,循环变量 tool 表示当前遍历的工具名称,循环变量 i 表示当前遍历的索引。

    • tool_config = tool_config_from_file(tool):调用 tool_config_from_file 函数,根据工具名称获取工具的配置信息,并将结果存储在 tool_config 变量中。

    • if tool_config::检查 tool_config 是否存在。

      • tools_json.append(tool_config):如果 tool_config 存在,则将其添加到 tools_json 列表中。
    • else::否则,如果 tool_config 不存在。

      • ValueError(...):抛出一个 ValueError 异常,提示工具配置未找到,并包含工具的描述信息。
  • ans.append({...}):将一个字典对象添加到 ans 列表中。字典包含以下键值对:

    • "role": "system":角色为 "system",表示系统的角色。
    • "content": "Answer the following questions as best as you can. You have access to the following tools:":内容为提示信息。
    • "tools": tools_json:工具列表为 tools_json。
  • query = f"""{prompt.split("Human: ")[-1].strip()}""":从 prompt 字符串中提取查询字符串。它首先根据特定的分隔符将 prompt 字符串进行切分,然后从切分结果中选择最后一个元素,并去除首尾空格。

  • return ans, query:返回 ans 列表和 query 字符串作为结果。

总体而言,这段代码定义了一个方法 _tool_history,它从 prompt 字符串中提取工具提示信息,并构造一个包含提示信息、工具列表和查询字符串的字典对象。

    def _extract_observation(self, prompt: str):
        return_json = prompt.split("Observation: ")[-1].split("\nThought:")[0]
        self.history.append({
            "role": "observation",
            "content": return_json
        })
        return

这段代码定义了一个名为 _extract_observation 的方法,它接受一个参数 prompt,该参数是一个字符串。

以下是这段代码的详细解析:

  • return_json = prompt.split("Observation: ")[-1].split("\nThought:")[0]:从 prompt 字符串中提取观察信息。它首先根据特定的分隔符将 prompt 字符串进行切分,然后选择切分结果中的最后一个元素,并再次根据特定的分隔符将其进行切分,最后选择切分结果中的第一个元素。

  • self.history.append({...}):将一个字典对象添加到类的成员变量 history 列表中。字典包含以下键值对:

    • "role": "observation":角色为 "observation",表示观察的角色。
    • "content": return_json:内容为观察信息。
  • return:没有指定返回值,默认返回 None。

总体而言,这段代码定义了一个方法 _extract_observation,它从 prompt 字符串中提取观察信息,并将观察信息添加到类的历史记录列表 history 中。

    def _extract_tool(self):
        if len(self.history[-1]["metadata"]) > 0:
            metadata = self.history[-1]["metadata"]
            content = self.history[-1]["content"]
            if "tool_call" in content:
                for tool in self.tool_names:
                    if tool in metadata:
                        input_para = content.split("='")[-1].split("'")[0]
                        action_json = {
                            "action": tool,
                            "action_input": input_para
                        }
                        self.has_search = True
                        return f"""
Action: 

{json.dumps(action_json, ensure_ascii=False)}

""" 复制代码
        final_answer_json = {
            "action": "Final Answer",
            "action_input": self.history[-1]["content"]
        }
        self.has_search = False
        return f"""
Action: 

{json.dumps(final_answer_json, ensure_ascii=False)}

""" 复制代码
    def _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = ["<|user|>"]):
        print("======")
        print(prompt)
        print("======")
        if not self.has_search:
            self.history, query = self._tool_history(prompt)
        else:
            self._extract_observation(prompt)
            query = ""
        # print("======")
        # print(history)
        # print("======")
        _, self.history = self.model.chat(
            self.tokenizer,
            query,
            history=self.history,
            do_sample=self.do_sample,
            max_length=self.max_token,
            temperature=self.temperature,
        )
        response = self._extract_tool()
        history.append((prompt, response))
        return response

这段代码包含两个方法 _extract_tool 和 _call。

以下是这段代码的详细解析:

_extract_tool 方法:

  • if len(self.history[-1]["metadata"]) > 0::检查历史记录的最后一项是否包含元数据。如果包含元数据,则执行以下操作。

    • metadata = self.history[-1]["metadata"]:将元数据存储在变量 metadata 中。

    • content = self.history[-1]["content"]:将历史记录的最后一项的内容存储在变量 content 中。

    • if "tool_call" in content::检查内容中是否包含字符串 "tool_call"。如果包含,则执行以下操作。

      • for tool in self.tool_names::对工具名称列表进行遍历,循环变量 tool 表示当前遍历的工具名称。

        • if tool in metadata::检查工具名称是否存在于元数据中。如果存在,则执行以下操作。

          • input_para = content.split("='")[-1].split("'")[0]:从内容中提取输入参数。它首先根据特定的分隔符将内容进行切分,然后选择切分结果中的倒数第二个元素,并再次根据特定的分隔符将其进行切分,最后选择切分结果中的第一个元素。

          • action_json = {...}:构造一个包含动作和输入参数的字典对象。字典包含以下键值对:

            • "action": tool:动作为工具名称。
            • "action_input": input_para:动作输入参数为提取的输入参数。
          • self.has_search = True:将成员变量 has_search 设置为 True,表示存在搜索。

        • return f"...":返回一个包含动作 JSON 的字符串。字符串使用 Markdown 格式进行格式化,显示动作 JSON。

  • final_answer_json = {...}:构造一个包含最终回答动作的字典对象。字典包含以下键值对:

    • "action": "Final Answer":动作为 "Final Answer"。
    • "action_input": self.history[-1]["content"]:动作输入参数为历史记录的最后一项的内容。
  • self.has_search = False:将成员变量 has_search 设置为 False,表示不存在搜索。

  • return f"...":返回一个包含动作 JSON 的字符串。字符串使用 Markdown 格式进行格式化,显示动作 JSON。

_call 方法:

  • print("======"):打印分隔线。

  • print(prompt):打印传入的参数 prompt。

  • print("======"):打印分隔线。

  • if not self.has_search::检查成员变量 has_search 是否为 False。如果为 False,表示不存在搜索,执行以下操作。

    • self.history, query = self._tool_history(prompt):调用 _tool_history 方法,获取历史记录和查询字符串。将返回的历史记录赋值给类的成员变量 history,将返回的查询字符串赋值给变量 query。
  • else::如果存在搜索

    • self._extract_observation(prompt):调用 _extract_observation 方法,从 prompt 中提取观察信息,并将其添加到类的历史记录列表 history 中。

    • query = "":将查询字符串设置为空字符串。

  • _, self.history = self.model.chat(...):调用 model.chat 方法进行对话。它使用模型、分词器和其他参数来生成对话的响应。返回的结果包含生成的响应和更新后的历史记录。使用 _ 忽略生成的响应,将更新后的历史记录赋值给类的成员变量 history。

  • response = self._extract_tool():调用 _extract_tool 方法,从更新后的历史记录中提取工具动作并生成响应。

  • history.append((prompt, response)):将元组 (prompt, response) 添加到 history 列表中,用于记录对话历史。

  • return response:返回生成的响应作为结果。

总体而言,这段代码定义了两个方法 _extract_tool 和 _call。_extract_tool 方法用于从历史记录中提取工具动作,生成相应的响应。_call 方法用于处理对话流程,包括提取观察信息、生成对话响应和记录对话历史。

2-3. Tool/Calculator.py

import abc
import math
from typing import Any

from langchain.tools import BaseTool


class Calculator(BaseTool, abc.ABC):
    name = "Calculator"
    description = "Useful for when you need to answer questions about math"

    def __init__(self):
        super().__init__()

    async def _arun(self, *args: Any, **kwargs: Any) -> Any:
        # 用例中没有用到 arun 不予具体实现
        pass


    def _run(self, para: str) -> str:
        para = para.replace("^", "**")
        if "sqrt" in para:
            para = para.replace("sqrt", "math.sqrt")
        elif "log" in para:
            para = para.replace("log", "math.log")
        return eval(para)


if __name__ == "__main__":
    calculator_tool = Calculator()
    result = calculator_tool.run("sqrt(2) + 3")
    print(result)

这段代码定义了一个名为 Calculator 的类,它继承自 BaseTool 类和 abc.ABC 抽象基类。Calculator 类是一个计算器工具,用于执行数学计算操作。

以下是这段代码的详细解析:

  • import abc:导入 abc 模块,用于定义抽象基类。

  • import math:导入 math 模块,用于执行数学计算操作。

  • from typing import Any:从 typing 模块导入 Any 类型,用于函数参数和返回值的类型注解。

  • from langchain.tools import BaseTool:从 langchain.tools 模块导入 BaseTool 类,用作 Calculator 类的父类。

  • class Calculator(BaseTool, abc.ABC)::定义了一个名为 Calculator 的类,它继承自 BaseTool 类和 abc.ABC 抽象基类。Calculator 类表示一个计算器工具。

    • name = "Calculator":类属性 name 被设置为字符串 "Calculator",表示工具的名称。

    • description = "Useful for when you need to answer questions about math":类属性 description 被设置为字符串 "Useful for when you need to answer questions about math",表示工具的描述信息。

    • def init (self)::构造函数,用于初始化 Calculator 类的实例。它调用父类的构造函数 super().init() 来完成初始化。

    • async def _arun(self, *args: Any, **kwargs: Any) -> Any::定义了一个异步方法 _arun,它接受任意数量的位置参数和关键字参数,并返回任意类型的值。在这段代码中,_arun 方法没有具体的实现,因为在示例中没有使用到它。

    • def _run(self, para: str) -> str::定义了一个方法 _run,它接受一个字符串参数 para,并返回一个字符串。在这个方法中,它首先对参数 para 进行替换操作,将字符串中的 "^" 替换为 "**"。然后,它检查参数 para 中是否包含 "sqrt",如果是,则将 "sqrt" 替换为 "math.sqrt";如果参数 para 中包含 "log",则将 "log" 替换为 "math.log"。最后,使用 eval 函数来执行参数 para 的计算操作,并返回计算结果。

请注意,这段代码中的 BaseTool 类没有给出具体的定义,因此我无法提供关于它的更多详细信息。如果您能提供 BaseTool 类的定义或相关代码,我将能够给出更准确的解释。

2-4. Tool/Weather.py

import os
from typing import Any

import requests
from langchain.tools import BaseTool


class Weather(BaseTool):
    name = "weather"
    description = "Use for searching weather at a specific location"

    def __init__(self):
        super().__init__()

    async def _arun(self, *args: Any, **kwargs: Any) -> Any:
        # 用例中没有用到 arun 不予具体实现
        pass

    def get_weather(self, location):
        api_key = os.environ["SENIVERSE_KEY"]
        url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
        response = requests.get(url)
        if response.status_code == 200:
            data = response.json()
            weather = {
                "temperature": data["results"][0]["now"]["temperature"],
                "description": data["results"][0]["now"]["text"],
            }
            return weather
        else:
            raise Exception(
                f"Failed to retrieve weather: {response.status_code}")

    def _run(self, para: str) -> str:
        return self.get_weather(para)


if __name__ == "__main__":
    weather_tool = Weather()
    weather_info = weather_tool.run("成都")
    print(weather_info)

这段代码定义了一个名为 Weather 的类,它继承自 BaseTool 类。Weather 类是一个天气工具,用于查询特定位置的天气信息。

以下是代码的详细解析:

  • import os:导入 os 模块,用于访问操作系统的功能,例如环境变量。

  • from typing import Any:从 typing 模块导入 Any 类型,用于函数参数和返回值的类型注解。

  • import requests:导入 requests 模块,用于发送 HTTP 请求。

  • from langchain.tools import BaseTool:从 langchain.tools 模块导入 BaseTool 类,用作 Weather 类的父类。

  • class Weather(BaseTool)::定义了一个名为 Weather 的类,它继承自 BaseTool 类。Weather 类表示一个天气工具。

    • name = "weather":类属性 name 被设置为字符串 "weather",表示工具的名称。

    • description = "Use for searching weather at a specific location":类属性 description 被设置为字符串 "Use for searching weather at a specific location",表示工具的描述信息。

    • def init (self)::构造函数,用于初始化 Weather 类的实例。它调用父类的构造函数 super().init() 来完成初始化。

    • async def _arun(self, *args: Any, **kwargs: Any) -> Any::定义了一个异步方法 _arun,它接受任意数量的位置参数和关键字参数,并返回任意类型的值。在这段代码中,_arun 方法没有具体的实现,因为在示例中没有使用到它。

    • def get_weather(self, location)::定义了一个方法 get_weather,它接受一个参数 location,表示要查询的位置。在这个方法中,它使用 os.environ 字典获取名为 "SENIVERSE_KEY" 的环境变量作为 API 密钥。然后,它构建一个 URL,使用 requests.get 方法发送 GET 请求到该 URL,并获取响应结果。如果响应的状态码为 200(表示请求成功),则解析响应的 JSON 数据,并返回包含温度和天气描述的字典。如果响应的状态码不为 200,则抛出异常。

    • def _run(self, para: str) -> str::定义了一个方法 _run,它接受一个字符串参数 para,并返回一个字符串。在这个方法中,它调用 get_weather 方法,传入参数 para(表示要查询的位置),并返回查询到的天气信息。

2-5. main.py

from typing import List
from ChatGLM3 import ChatGLM3

from langchain.agents import load_tools
from Tool.Weather import Weather
from Tool.Calculator import Calculator

from langchain.agents import initialize_agent
from langchain.agents import AgentType


def run_tool(tools, llm, prompt_chain: List[str]):
    loaded_tolls = []
    for tool in tools:
        if isinstance(tool, str):
            loaded_tolls.append(load_tools([tool], llm=llm)[0])
        else:
            loaded_tolls.append(tool)
    agent = initialize_agent(
        loaded_tolls, llm,
        agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
        verbose=True,
        handle_parsing_errors=True
    )
    for prompt in prompt_chain:
        agent.run(prompt)


if __name__ == "__main__":
    model_path = "THUDM/chatglm3-6b"
    llm = ChatGLM3()
    llm.load_model(model_name_or_path=model_path)

    # arxiv: 单个工具调用示例 1
    run_tool(["arxiv"], llm, [
        "帮我查询GLM-130B相关工作"
    ])

    # weather: 单个工具调用示例 2
    # run_tool([Weather()], llm, [
    #     "今天北京天气怎么样?",
    #     "What's the weather like in Shanghai today",
    # ])

    # calculator: 单个工具调用示例 3
    run_tool([Calculator()], llm, [
        "12345679乘以54等于多少?",
        "3.14的3.14次方等于多少?",
        "根号2加上根号三等于多少?",
    ]),

    # arxiv + weather + calculator: 多个工具结合调用
    # run_tool([Calculator(), "arxiv", Weather()], llm, [
    #     "帮我检索GLM-130B相关论文",
    #     "今天北京天气怎么样?",
    #     "根号3减去根号二再加上4等于多少?",
    # ])

这段代码是一个示例程序,演示了如何使用 langchain 库来调用不同的工具进行自然语言处理。

以下是代码的详细解析:

  • from typing import List:从 typing 模块导入 List 类型,用于函数参数的类型注解。

  • from ChatGLM3 import ChatGLM3:从 ChatGLM3 模块导入 ChatGLM3 类,用于创建语言模型。

  • from langchain.agents import load_tools:从 langchain.agents 模块导入 load_tools 函数,用于加载工具。

  • from Tool.Weather import Weather:从 Tool.Weather 模块导入 Weather 类,用于天气查询工具。

  • from Tool.Calculator import Calculator:从 Tool.Calculator 模块导入 Calculator 类,用于计算器工具。

  • from langchain.agents import initialize_agent:从 langchain.agents 模块导入 initialize_agent 函数,用于初始化代理。

  • from langchain.agents import AgentType:从 langchain.agents 模块导入 AgentType 枚举,用于指定代理类型。

  • def run_tool(tools, llm, prompt_chain: List[str])::定义了一个函数 run_tool,它接受三个参数:tools(要加载的工具列表),llm(语言模型实例),prompt_chain(要运行的提示列表)。

    • loaded_tolls = []:创建一个空列表 loaded_tolls,用于存储加载后的工具。
    • for tool in tools::对于 tools 列表中的每个工具:
      • if isinstance(tool, str)::如果工具是字符串类型,表示需要加载的是一个工具模块:
        • loaded_tolls.append(load_tools([tool], llm=llm)[0]):加载指定的工具模块,并将返回的工具实例添加到 loaded_tolls 列表中。
    • else::如果工具不是字符串类型,表示已经是一个工具实例:
      • loaded_tolls.append(tool):直接将工具实例添加到 loaded_tolls 列表中。
  • agent = initialize_agent(loaded_tolls, llm, agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True, handle_parsing_errors=True):使用加载后的工具和语言模型实例初始化代理。AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION 表示采用结构化对话的零次推理反应描述的代理类型。

    • for prompt in prompt_chain::对于 prompt_chain 列表中的每个提示:
    • agent.run(prompt):使用代理运行提示。
  • if name == "main"::如果当前模块被直接执行(而不是被导入为模块):

    • model_path = "THUDM/chatglm3-6b":设置语言模型的路径。
    • llm = ChatGLM3():创建 ChatGLM3 类的实例,用于创建语言模型。
  • llm.load_model(model_name_or_path=model_path):加载指定路径的语言模型。

  • run_tool(["arxiv"], llm, ["帮我查询GLM-130B相关工作"]):调用 run_tool 函数,使用 arxiv 工具和语言模型实例 llm,并传入一个提示列表 ["帮我查询GLM-130B相关工作"],以执行相关工作的查询。

  • run_tool([Calculator()], llm, ["12345679乘以54等于多少?", "3.14的3.14次方等于多少?", "根号2加上根号三等于多少?"]):调用 run_tool 函数,使用 Calculator 工具和语言模型实例 llm,并传入一个提示列表,以执行计算器工具的计算。

注释掉的代码块是其他工具的调用示例,包括天气查询工具和多个工具的结合调用。

这段代码演示了如何使用 langchain 库,加载不同的工具和语言模型,然后通过代理来运行自然语言提示,以执行不同的任务,例如查询相关工作、查询天气和进行计算等。

完结!

相关推荐
waiting不是违停1 天前
LangChain Ollama实战文献检索助手(二)少样本提示FewShotPromptTemplate示例选择器
langchain·llm·ollama
Y24834908911 天前
05LangChain实战课 - 提示工程与FewShotPromptTemplate的应用
人工智能·langchain
科研小达人2 天前
Langchain调用模型使用FAISS
python·chatgpt·langchain·faiss
小陈phd4 天前
大语言模型及LangChain介绍
人工智能·语言模型·langchain
写程序的小火箭5 天前
如何评估一个RAG系统(RAGas评测框架)-下篇
人工智能·gpt·语言模型·chatgpt·langchain
Stitch .5 天前
小北的字节跳动青训营与 LangChain 实战课:探索 AI 技术的新边界(持续更新中~~~)
人工智能·python·gpt·ai·语言模型·chatgpt·langchain
黑金IT5 天前
掌握AI Prompt的艺术:如何有效引导智能助手
人工智能·langchain·prompt·ai编程
科研小达人6 天前
langchain调用chatgpt对文本进行编码
服务器·langchain
智兔唯新6 天前
【AIGC】COT思维链:让AI学会拆解问题,像人一样思考
人工智能·python·langchain·prompt·aigc
wyh_1116 天前
windows下xinference无法加载本地大模型问题解决
langchain·xinference