ChatGLM3 langchain_demo 代码解析
- [0. 背景](#0. 背景)
- [1. 项目代码结构](#1. 项目代码结构)
- [2. 代码解析](#2. 代码解析)
-
- [2-1. utils.py](#2-1. utils.py)
- [2-2. ChatGLM3.py](#2-2. ChatGLM3.py)
- [2-3. Tool/Calculator.py](#2-3. Tool/Calculator.py)
- [2-4. Tool/Weather.py](#2-4. Tool/Weather.py)
- [2-5. main.py](#2-5. main.py)
0. 背景
学习 ChatGLM3 的项目内容,过程中使用 AI 代码工具,对代码进行解释,帮助自己快速理解代码。这篇文章记录 ChatGLM3 langchain_demo 的代码解析内容。
1. 项目代码结构
2. 代码解析
2-1. utils.py
import os
import yaml
def tool_config_from_file(tool_name, directory="Tool/"):
"""search tool yaml and return json format"""
for filename in os.listdir(directory):
if filename.endswith('.yaml') and tool_name in filename:
file_path = os.path.join(directory, filename)
with open(file_path, encoding='utf-8') as f:
return yaml.safe_load(f)
return None
这段代码定义了一个函数 tool_config_from_file,用于从文件中加载工具的配置信息。
该函数接受两个参数:tool_name 表示要加载的工具名称,directory 表示存储工具配置文件的目录,默认为 "Tool/"。
在函数中,首先使用 os.listdir 函数获取指定目录下的所有文件名。然后,通过遍历文件名列表,找到以 ".yaml" 结尾且包含指定工具名称的文件。如果找到了匹配的文件,就构造文件的完整路径,并使用 open 函数打开文件。接着,使用 yaml.safe_load 函数加载文件内容,并将其转换为 JSON 格式的数据返回。
如果遍历完所有文件后仍未找到匹配的工具配置文件,则返回 None。
总体而言,这段代码定义了一个函数 tool_config_from_file,用于根据工具名称和文件目录获取工具的配置信息,并将其转换为 JSON 格式的数据返回。
2-2. ChatGLM3.py
import json
from langchain.llms.base import LLM
from transformers import AutoTokenizer, AutoModel, AutoConfig
from typing import List, Optional
from utils import tool_config_from_file
这段代码导入了一些模块和函数,并定义了一些类型注解。
首先,导入了 json 模块,用于处理 JSON 数据。 然后,导入了 LLM 类和 AutoTokenizer、AutoModel、AutoConfig 类,这些来自 langchain.llms.base 和 transformers 模块,用于构建和配置语言模型。 接下来,导入了 List 和 Optional 类型,用于类型注解。 最后,导入了 tool_config_from_file 函数,该函数来自 utils 模块,用于加载工具的配置信息。
总体而言,这段代码导入了所需的模块、类和函数,以及定义了一些类型注解。
class ChatGLM3(LLM):
max_token: int = 8192
do_sample: bool = False
temperature: float = 0.8
top_p = 0.8
tokenizer: object = None
model: object = None
history: List = []
tool_names: List = []
has_search: bool = False
def __init__(self):
super().__init__()
@property
def _llm_type(self) -> str:
return "ChatGLM3"
def load_model(self, model_name_or_path=None):
model_config = AutoConfig.from_pretrained(
model_name_or_path,
trust_remote_code=True
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=True
)
self.model = AutoModel.from_pretrained(
model_name_or_path, config=model_config, trust_remote_code=True
).half().cuda()
def _tool_history(self, prompt: str):
ans = []
tool_prompts = prompt.split(
"You have access to the following tools:\n\n")[1].split("\n\nUse a json blob")[0].split("\n")
tool_names = [tool.split(":")[0] for tool in tool_prompts]
self.tool_names = tool_names
tools_json = []
for i, tool in enumerate(tool_names):
tool_config = tool_config_from_file(tool)
if tool_config:
tools_json.append(tool_config)
else:
ValueError(
f"Tool {tool} config not found! It's description is {tool_prompts[i]}"
)
ans.append({
"role": "system",
"content": "Answer the following questions as best as you can. You have access to the following tools:",
"tools": tools_json
})
query = f"""{prompt.split("Human: ")[-1].strip()}"""
return ans, query
def _extract_observation(self, prompt: str):
return_json = prompt.split("Observation: ")[-1].split("\nThought:")[0]
self.history.append({
"role": "observation",
"content": return_json
})
return
def _extract_tool(self):
if len(self.history[-1]["metadata"]) > 0:
metadata = self.history[-1]["metadata"]
content = self.history[-1]["content"]
if "tool_call" in content:
for tool in self.tool_names:
if tool in metadata:
input_para = content.split("='")[-1].split("'")[0]
action_json = {
"action": tool,
"action_input": input_para
}
self.has_search = True
return f"""
Action:
{json.dumps(action_json, ensure_ascii=False)}
"""
final_answer_json = {
"action": "Final Answer",
"action_input": self.history[-1]["content"]
}
self.has_search = False
return f"""
Action:
{json.dumps(final_answer_json, ensure_ascii=False)}
"""
def _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = ["<|user|>"]):
print("======")
print(prompt)
print("======")
if not self.has_search:
self.history, query = self._tool_history(prompt)
else:
self._extract_observation(prompt)
query = ""
# print("======")
# print(history)
# print("======")
_, self.history = self.model.chat(
self.tokenizer,
query,
history=self.history,
do_sample=self.do_sample,
max_length=self.max_token,
temperature=self.temperature,
)
response = self._extract_tool()
history.append((prompt, response))
return response
这段代码定义了一个名为 ChatGLM3 的类,该类继承自 LLM 类。
以下是这个类的成员变量和方法的详细解析:
-
max_token: int = 8192:最大令牌数,默认为 8192。
-
do_sample: bool = False:是否进行采样,默认为 False。
-
temperature: float = 0.8:采样温度,默认为 0.8。
-
top_p = 0.8:top-p 采样的概率阈值,默认为 0.8。
-
tokenizer: object = None:tokenizer 对象,默认为 None。
-
model: object = None:模型对象,默认为 None。
-
history: List = []:对话历史记录列表,默认为空列表。
-
tool_names: List = []:工具名称列表,默认为空列表。
-
has_search: bool = False:是否进行工具搜索的标志位,默认为 False。
-
init(self):类的构造函数,调用父类 LLM 的构造函数。
-
_llm_type(self) -> str:类属性,返回字符串 "ChatGLM3"。
-
load_model(self, model_name_or_path=None):加载模型的方法。根据模型的名称或路径,使用 AutoConfig、AutoTokenizer 和 AutoModel 类从预训练模型中加载配置、tokenizer 和模型,并将模型转化为半精度浮点数并放置在 CUDA 设备上。
-
_tool_history(self, prompt: str):提取工具历史记录的方法。根据提示字符串,从中提取出工具名称和工具配置信息,并将其存储到 tool_names 和 tools_json 中,最后将结果作为字典添加到 ans 列表中,并返回 ans 列表和查询字符串。
-
_extract_observation(self, prompt: str):提取观察信息的方法。从提示字符串中提取观察信息,将其添加到历史记录列表 history 中。
-
_extract_tool(self):提取工具信息的方法。根据最后一条历史记录的元数据和内容,判断是否存在工具调用。如果存在,遍历工具名称列表,如果某个工具名称出现在元数据中,提取出输入参数,并构造一个动作 JSON 对象。同时,将 has_search 设置为 True。如果不存在工具调用,构造一个最终回答的动作 JSON 对象,并将 has_search 设置为 False。最后,返回包含动作 JSON 的字符串。
-
_call(self, prompt: str, history: List = [], stop: Optional[List[str]] = ["<|user|>"]):执行对话的方法。根据是否进行工具搜索的标志位,调用 _tool_history 方法或 _extract_observation 方法来获取查询字符串和更新历史记录。然后,使用模型的 chat 方法进行对话生成,并将结果传递给 _extract_tool 方法,提取工具信息。最后,将提示和响应添加到历史记录列表中,并返回响应。
总体而言,这段代码定义了一个名为ChatGLM3 的类,该类继承自 LLM 类。它包含了一些成员变量和方法,用于加载模型、提取工具历史记录、提取观察信息和执行对话。
在 load_model 方法中,模型的配置、tokenizer 和模型本身被加载,并存储在 tokenizer 和 model 成员变量中。
_tool_history 方法用于从提示字符串中提取工具历史记录。它首先将提示字符串按特定的分隔符切分,然后从切分结果中提取工具名称和相应的工具配置信息。这些信息被存储在 tool_names 和 tools_json 成员变量中,并作为字典添加到 ans 列表中。最后,返回 ans 列表和查询字符串。
_extract_observation 方法用于从提示字符串中提取观察信息。它将观察信息存储在 history 成员变量中。
_extract_tool 方法用于提取工具信息。它检查最后一条历史记录的元数据和内容,判断是否存在工具调用。如果存在工具调用,它将提取出工具名称和输入参数,并构造一个包含动作和输入参数的 JSON 对象。如果不存在工具调用,它将构造一个包含最终回答动作的 JSON 对象。最后,它将返回包含动作 JSON 的字符串。
_call 方法用于执行对话。它根据 has_search 标志位,选择调用 _tool_history 方法或 _extract_observation 方法,获取查询字符串和更新历史记录。然后,它使用模型的 chat 方法生成响应,并将结果传递给 _extract_tool 方法,提取工具信息。最后,它将提示和响应添加到历史记录列表中,并返回响应。
总体而言,这段代码定义了一个用于对话生成的类 ChatGLM3,它继承自 LLM 类,并提供了加载模型、提取工具历史记录、提取观察信息和执行对话的功能。
def _tool_history(self, prompt: str):
ans = []
tool_prompts = prompt.split(
"You have access to the following tools:\n\n")[1].split("\n\nUse a json blob")[0].split("\n")
tool_names = [tool.split(":")[0] for tool in tool_prompts]
self.tool_names = tool_names
tools_json = []
for i, tool in enumerate(tool_names):
tool_config = tool_config_from_file(tool)
if tool_config:
tools_json.append(tool_config)
else:
ValueError(
f"Tool {tool} config not found! It's description is {tool_prompts[i]}"
)
ans.append({
"role": "system",
"content": "Answer the following questions as best as you can. You have access to the following tools:",
"tools": tools_json
})
query = f"""{prompt.split("Human: ")[-1].strip()}"""
return ans, query
这段代码定义了一个名为 _tool_history 的方法,它接受一个参数 prompt,该参数是一个字符串。
以下是这段代码的详细解析:
-
ans = []:创建一个空列表 ans,用于存储返回结果。
-
tool_prompts = prompt.split("You have access to the following tools:\n\n")[1].split("\n\nUse a json blob")[0].split("\n"):从 prompt 字符串中提取工具提示信息。它使用特定的分隔符将 prompt 字符串进行切分,提取包含工具提示信息的部分,并将其存储在 tool_prompts 列表中。
-
tool_names = [tool.split(":")[0] for tool in tool_prompts]:从 tool_prompts 列表中提取工具名称。它使用冒号进行切分,并将每个工具的名称存储在 tool_names 列表中。
-
self.tool_names = tool_names:将 tool_names 列表赋值给类的成员变量 tool_names。
-
tools_json = []:创建一个空列表 tools_json,用于存储工具的配置信息。
-
for i, tool in enumerate(tool_names)::对 tool_names 列表进行遍历,循环变量 tool 表示当前遍历的工具名称,循环变量 i 表示当前遍历的索引。
-
tool_config = tool_config_from_file(tool):调用 tool_config_from_file 函数,根据工具名称获取工具的配置信息,并将结果存储在 tool_config 变量中。
-
if tool_config::检查 tool_config 是否存在。
- tools_json.append(tool_config):如果 tool_config 存在,则将其添加到 tools_json 列表中。
-
else::否则,如果 tool_config 不存在。
- ValueError(...):抛出一个 ValueError 异常,提示工具配置未找到,并包含工具的描述信息。
-
-
ans.append({...}):将一个字典对象添加到 ans 列表中。字典包含以下键值对:
- "role": "system":角色为 "system",表示系统的角色。
- "content": "Answer the following questions as best as you can. You have access to the following tools:":内容为提示信息。
- "tools": tools_json:工具列表为 tools_json。
-
query = f"""{prompt.split("Human: ")[-1].strip()}""":从 prompt 字符串中提取查询字符串。它首先根据特定的分隔符将 prompt 字符串进行切分,然后从切分结果中选择最后一个元素,并去除首尾空格。
-
return ans, query:返回 ans 列表和 query 字符串作为结果。
总体而言,这段代码定义了一个方法 _tool_history,它从 prompt 字符串中提取工具提示信息,并构造一个包含提示信息、工具列表和查询字符串的字典对象。
def _extract_observation(self, prompt: str):
return_json = prompt.split("Observation: ")[-1].split("\nThought:")[0]
self.history.append({
"role": "observation",
"content": return_json
})
return
这段代码定义了一个名为 _extract_observation 的方法,它接受一个参数 prompt,该参数是一个字符串。
以下是这段代码的详细解析:
-
return_json = prompt.split("Observation: ")[-1].split("\nThought:")[0]:从 prompt 字符串中提取观察信息。它首先根据特定的分隔符将 prompt 字符串进行切分,然后选择切分结果中的最后一个元素,并再次根据特定的分隔符将其进行切分,最后选择切分结果中的第一个元素。
-
self.history.append({...}):将一个字典对象添加到类的成员变量 history 列表中。字典包含以下键值对:
- "role": "observation":角色为 "observation",表示观察的角色。
- "content": return_json:内容为观察信息。
-
return:没有指定返回值,默认返回 None。
总体而言,这段代码定义了一个方法 _extract_observation,它从 prompt 字符串中提取观察信息,并将观察信息添加到类的历史记录列表 history 中。
def _extract_tool(self):
if len(self.history[-1]["metadata"]) > 0:
metadata = self.history[-1]["metadata"]
content = self.history[-1]["content"]
if "tool_call" in content:
for tool in self.tool_names:
if tool in metadata:
input_para = content.split("='")[-1].split("'")[0]
action_json = {
"action": tool,
"action_input": input_para
}
self.has_search = True
return f"""
Action:
{json.dumps(action_json, ensure_ascii=False)}
"""
final_answer_json = {
"action": "Final Answer",
"action_input": self.history[-1]["content"]
}
self.has_search = False
return f"""
Action:
{json.dumps(final_answer_json, ensure_ascii=False)}
"""
def _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = ["<|user|>"]):
print("======")
print(prompt)
print("======")
if not self.has_search:
self.history, query = self._tool_history(prompt)
else:
self._extract_observation(prompt)
query = ""
# print("======")
# print(history)
# print("======")
_, self.history = self.model.chat(
self.tokenizer,
query,
history=self.history,
do_sample=self.do_sample,
max_length=self.max_token,
temperature=self.temperature,
)
response = self._extract_tool()
history.append((prompt, response))
return response
这段代码包含两个方法 _extract_tool 和 _call。
以下是这段代码的详细解析:
_extract_tool 方法:
-
if len(self.history[-1]["metadata"]) > 0::检查历史记录的最后一项是否包含元数据。如果包含元数据,则执行以下操作。
-
metadata = self.history[-1]["metadata"]:将元数据存储在变量 metadata 中。
-
content = self.history[-1]["content"]:将历史记录的最后一项的内容存储在变量 content 中。
-
if "tool_call" in content::检查内容中是否包含字符串 "tool_call"。如果包含,则执行以下操作。
-
for tool in self.tool_names::对工具名称列表进行遍历,循环变量 tool 表示当前遍历的工具名称。
-
if tool in metadata::检查工具名称是否存在于元数据中。如果存在,则执行以下操作。
-
input_para = content.split("='")[-1].split("'")[0]:从内容中提取输入参数。它首先根据特定的分隔符将内容进行切分,然后选择切分结果中的倒数第二个元素,并再次根据特定的分隔符将其进行切分,最后选择切分结果中的第一个元素。
-
action_json = {...}:构造一个包含动作和输入参数的字典对象。字典包含以下键值对:
- "action": tool:动作为工具名称。
- "action_input": input_para:动作输入参数为提取的输入参数。
-
self.has_search = True:将成员变量 has_search 设置为 True,表示存在搜索。
-
-
return f"...":返回一个包含动作 JSON 的字符串。字符串使用 Markdown 格式进行格式化,显示动作 JSON。
-
-
-
-
final_answer_json = {...}:构造一个包含最终回答动作的字典对象。字典包含以下键值对:
- "action": "Final Answer":动作为 "Final Answer"。
- "action_input": self.history[-1]["content"]:动作输入参数为历史记录的最后一项的内容。
-
self.has_search = False:将成员变量 has_search 设置为 False,表示不存在搜索。
-
return f"...":返回一个包含动作 JSON 的字符串。字符串使用 Markdown 格式进行格式化,显示动作 JSON。
_call 方法:
-
print("======"):打印分隔线。
-
print(prompt):打印传入的参数 prompt。
-
print("======"):打印分隔线。
-
if not self.has_search::检查成员变量 has_search 是否为 False。如果为 False,表示不存在搜索,执行以下操作。
- self.history, query = self._tool_history(prompt):调用 _tool_history 方法,获取历史记录和查询字符串。将返回的历史记录赋值给类的成员变量 history,将返回的查询字符串赋值给变量 query。
-
else::如果存在搜索
-
self._extract_observation(prompt):调用 _extract_observation 方法,从 prompt 中提取观察信息,并将其添加到类的历史记录列表 history 中。
-
query = "":将查询字符串设置为空字符串。
-
-
_, self.history = self.model.chat(...):调用 model.chat 方法进行对话。它使用模型、分词器和其他参数来生成对话的响应。返回的结果包含生成的响应和更新后的历史记录。使用 _ 忽略生成的响应,将更新后的历史记录赋值给类的成员变量 history。
-
response = self._extract_tool():调用 _extract_tool 方法,从更新后的历史记录中提取工具动作并生成响应。
-
history.append((prompt, response)):将元组 (prompt, response) 添加到 history 列表中,用于记录对话历史。
-
return response:返回生成的响应作为结果。
总体而言,这段代码定义了两个方法 _extract_tool 和 _call。_extract_tool 方法用于从历史记录中提取工具动作,生成相应的响应。_call 方法用于处理对话流程,包括提取观察信息、生成对话响应和记录对话历史。
2-3. Tool/Calculator.py
import abc
import math
from typing import Any
from langchain.tools import BaseTool
class Calculator(BaseTool, abc.ABC):
name = "Calculator"
description = "Useful for when you need to answer questions about math"
def __init__(self):
super().__init__()
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
# 用例中没有用到 arun 不予具体实现
pass
def _run(self, para: str) -> str:
para = para.replace("^", "**")
if "sqrt" in para:
para = para.replace("sqrt", "math.sqrt")
elif "log" in para:
para = para.replace("log", "math.log")
return eval(para)
if __name__ == "__main__":
calculator_tool = Calculator()
result = calculator_tool.run("sqrt(2) + 3")
print(result)
这段代码定义了一个名为 Calculator 的类,它继承自 BaseTool 类和 abc.ABC 抽象基类。Calculator 类是一个计算器工具,用于执行数学计算操作。
以下是这段代码的详细解析:
-
import abc:导入 abc 模块,用于定义抽象基类。
-
import math:导入 math 模块,用于执行数学计算操作。
-
from typing import Any:从 typing 模块导入 Any 类型,用于函数参数和返回值的类型注解。
-
from langchain.tools import BaseTool:从 langchain.tools 模块导入 BaseTool 类,用作 Calculator 类的父类。
-
class Calculator(BaseTool, abc.ABC)::定义了一个名为 Calculator 的类,它继承自 BaseTool 类和 abc.ABC 抽象基类。Calculator 类表示一个计算器工具。
-
name = "Calculator":类属性 name 被设置为字符串 "Calculator",表示工具的名称。
-
description = "Useful for when you need to answer questions about math":类属性 description 被设置为字符串 "Useful for when you need to answer questions about math",表示工具的描述信息。
-
def init (self)::构造函数,用于初始化 Calculator 类的实例。它调用父类的构造函数 super().init() 来完成初始化。
-
async def _arun(self, *args: Any, **kwargs: Any) -> Any::定义了一个异步方法 _arun,它接受任意数量的位置参数和关键字参数,并返回任意类型的值。在这段代码中,_arun 方法没有具体的实现,因为在示例中没有使用到它。
-
def _run(self, para: str) -> str::定义了一个方法 _run,它接受一个字符串参数 para,并返回一个字符串。在这个方法中,它首先对参数 para 进行替换操作,将字符串中的 "^" 替换为 "**"。然后,它检查参数 para 中是否包含 "sqrt",如果是,则将 "sqrt" 替换为 "math.sqrt";如果参数 para 中包含 "log",则将 "log" 替换为 "math.log"。最后,使用 eval 函数来执行参数 para 的计算操作,并返回计算结果。
-
请注意,这段代码中的 BaseTool 类没有给出具体的定义,因此我无法提供关于它的更多详细信息。如果您能提供 BaseTool 类的定义或相关代码,我将能够给出更准确的解释。
2-4. Tool/Weather.py
import os
from typing import Any
import requests
from langchain.tools import BaseTool
class Weather(BaseTool):
name = "weather"
description = "Use for searching weather at a specific location"
def __init__(self):
super().__init__()
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
# 用例中没有用到 arun 不予具体实现
pass
def get_weather(self, location):
api_key = os.environ["SENIVERSE_KEY"]
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
weather = {
"temperature": data["results"][0]["now"]["temperature"],
"description": data["results"][0]["now"]["text"],
}
return weather
else:
raise Exception(
f"Failed to retrieve weather: {response.status_code}")
def _run(self, para: str) -> str:
return self.get_weather(para)
if __name__ == "__main__":
weather_tool = Weather()
weather_info = weather_tool.run("成都")
print(weather_info)
这段代码定义了一个名为 Weather 的类,它继承自 BaseTool 类。Weather 类是一个天气工具,用于查询特定位置的天气信息。
以下是代码的详细解析:
-
import os:导入 os 模块,用于访问操作系统的功能,例如环境变量。
-
from typing import Any:从 typing 模块导入 Any 类型,用于函数参数和返回值的类型注解。
-
import requests:导入 requests 模块,用于发送 HTTP 请求。
-
from langchain.tools import BaseTool:从 langchain.tools 模块导入 BaseTool 类,用作 Weather 类的父类。
-
class Weather(BaseTool)::定义了一个名为 Weather 的类,它继承自 BaseTool 类。Weather 类表示一个天气工具。
-
name = "weather":类属性 name 被设置为字符串 "weather",表示工具的名称。
-
description = "Use for searching weather at a specific location":类属性 description 被设置为字符串 "Use for searching weather at a specific location",表示工具的描述信息。
-
def init (self)::构造函数,用于初始化 Weather 类的实例。它调用父类的构造函数 super().init() 来完成初始化。
-
async def _arun(self, *args: Any, **kwargs: Any) -> Any::定义了一个异步方法 _arun,它接受任意数量的位置参数和关键字参数,并返回任意类型的值。在这段代码中,_arun 方法没有具体的实现,因为在示例中没有使用到它。
-
def get_weather(self, location)::定义了一个方法 get_weather,它接受一个参数 location,表示要查询的位置。在这个方法中,它使用 os.environ 字典获取名为 "SENIVERSE_KEY" 的环境变量作为 API 密钥。然后,它构建一个 URL,使用 requests.get 方法发送 GET 请求到该 URL,并获取响应结果。如果响应的状态码为 200(表示请求成功),则解析响应的 JSON 数据,并返回包含温度和天气描述的字典。如果响应的状态码不为 200,则抛出异常。
-
def _run(self, para: str) -> str::定义了一个方法 _run,它接受一个字符串参数 para,并返回一个字符串。在这个方法中,它调用 get_weather 方法,传入参数 para(表示要查询的位置),并返回查询到的天气信息。
-
2-5. main.py
from typing import List
from ChatGLM3 import ChatGLM3
from langchain.agents import load_tools
from Tool.Weather import Weather
from Tool.Calculator import Calculator
from langchain.agents import initialize_agent
from langchain.agents import AgentType
def run_tool(tools, llm, prompt_chain: List[str]):
loaded_tolls = []
for tool in tools:
if isinstance(tool, str):
loaded_tolls.append(load_tools([tool], llm=llm)[0])
else:
loaded_tolls.append(tool)
agent = initialize_agent(
loaded_tolls, llm,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
handle_parsing_errors=True
)
for prompt in prompt_chain:
agent.run(prompt)
if __name__ == "__main__":
model_path = "THUDM/chatglm3-6b"
llm = ChatGLM3()
llm.load_model(model_name_or_path=model_path)
# arxiv: 单个工具调用示例 1
run_tool(["arxiv"], llm, [
"帮我查询GLM-130B相关工作"
])
# weather: 单个工具调用示例 2
# run_tool([Weather()], llm, [
# "今天北京天气怎么样?",
# "What's the weather like in Shanghai today",
# ])
# calculator: 单个工具调用示例 3
run_tool([Calculator()], llm, [
"12345679乘以54等于多少?",
"3.14的3.14次方等于多少?",
"根号2加上根号三等于多少?",
]),
# arxiv + weather + calculator: 多个工具结合调用
# run_tool([Calculator(), "arxiv", Weather()], llm, [
# "帮我检索GLM-130B相关论文",
# "今天北京天气怎么样?",
# "根号3减去根号二再加上4等于多少?",
# ])
这段代码是一个示例程序,演示了如何使用 langchain 库来调用不同的工具进行自然语言处理。
以下是代码的详细解析:
-
from typing import List:从 typing 模块导入 List 类型,用于函数参数的类型注解。
-
from ChatGLM3 import ChatGLM3:从 ChatGLM3 模块导入 ChatGLM3 类,用于创建语言模型。
-
from langchain.agents import load_tools:从 langchain.agents 模块导入 load_tools 函数,用于加载工具。
-
from Tool.Weather import Weather:从 Tool.Weather 模块导入 Weather 类,用于天气查询工具。
-
from Tool.Calculator import Calculator:从 Tool.Calculator 模块导入 Calculator 类,用于计算器工具。
-
from langchain.agents import initialize_agent:从 langchain.agents 模块导入 initialize_agent 函数,用于初始化代理。
-
from langchain.agents import AgentType:从 langchain.agents 模块导入 AgentType 枚举,用于指定代理类型。
-
def run_tool(tools, llm, prompt_chain: List[str])::定义了一个函数 run_tool,它接受三个参数:tools(要加载的工具列表),llm(语言模型实例),prompt_chain(要运行的提示列表)。
- loaded_tolls = []:创建一个空列表 loaded_tolls,用于存储加载后的工具。
- for tool in tools::对于 tools 列表中的每个工具:
- if isinstance(tool, str)::如果工具是字符串类型,表示需要加载的是一个工具模块:
- loaded_tolls.append(load_tools([tool], llm=llm)[0]):加载指定的工具模块,并将返回的工具实例添加到 loaded_tolls 列表中。
- if isinstance(tool, str)::如果工具是字符串类型,表示需要加载的是一个工具模块:
- else::如果工具不是字符串类型,表示已经是一个工具实例:
- loaded_tolls.append(tool):直接将工具实例添加到 loaded_tolls 列表中。
-
agent = initialize_agent(loaded_tolls, llm, agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True, handle_parsing_errors=True):使用加载后的工具和语言模型实例初始化代理。AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION 表示采用结构化对话的零次推理反应描述的代理类型。
- for prompt in prompt_chain::对于 prompt_chain 列表中的每个提示:
- agent.run(prompt):使用代理运行提示。
-
if name == "main"::如果当前模块被直接执行(而不是被导入为模块):
- model_path = "THUDM/chatglm3-6b":设置语言模型的路径。
- llm = ChatGLM3():创建 ChatGLM3 类的实例,用于创建语言模型。
-
llm.load_model(model_name_or_path=model_path):加载指定路径的语言模型。
-
run_tool(["arxiv"], llm, ["帮我查询GLM-130B相关工作"]):调用 run_tool 函数,使用 arxiv 工具和语言模型实例 llm,并传入一个提示列表 ["帮我查询GLM-130B相关工作"],以执行相关工作的查询。
-
run_tool([Calculator()], llm, ["12345679乘以54等于多少?", "3.14的3.14次方等于多少?", "根号2加上根号三等于多少?"]):调用 run_tool 函数,使用 Calculator 工具和语言模型实例 llm,并传入一个提示列表,以执行计算器工具的计算。
注释掉的代码块是其他工具的调用示例,包括天气查询工具和多个工具的结合调用。
这段代码演示了如何使用 langchain 库,加载不同的工具和语言模型,然后通过代理来运行自然语言提示,以执行不同的任务,例如查询相关工作、查询天气和进行计算等。
完结!