python AutoGen接入开源模型xLAM-7b-fc-r,测试function calling的功能

AutoGen主打的是多智能体,对话和写代码,但是教程方面没有langchain丰富,我这里抛砖引玉提供一个autogen接入开源function calling模型的教程,我这里使用的开源repo是:https://github.com/SalesforceAIResearch/xLAM

开源模型是:https://huggingface.co/Salesforce/xLAM-7b-fc-r

1b的模型效果有点差,推荐使用7b的模型。首先使用vllm运行:

vllm serve Salesforce/xLAM-8x7b-r --host 0.0.0.0 --port 8000 --tensor-parallel-size 4

然后autogen代码示例:

import re
import json
import random
import time 

from typing import Literal
from pydantic import BaseModel, Field
from typing_extensions import Annotated
import autogen
from autogen.cache import Cache
from openai.types.completion import Completion
import openai
from xLAM.client import xLAMChatCompletion, xLAMConfig
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage

local_llm_config={
    "config_list": [
        {
            "model": "/<your_path>/xLAM-7b-fc-r", # Same as in vLLM command
            "api_key": "NotRequired", # Not needed
            "model_client_cls": "CustomModelClient",
            "base_url": "http://localhost:8000/v1",  # Your vLLM URL, with '/v1' added
            "price": [0, 0],
        }
    ],
    "cache_seed": None # Turns off caching, useful for testing different models
}

TOOL_ENABLED = True

class CustomModelClient:
    def __init__(self, config, **kwargs):
        print(f"CustomModelClient config: {config}")
        gen_config_params = config.get("params", {})
        self.max_length = gen_config_params.get("max_length", 256)
        print(f"Loaded model {config['model']}")
        config = xLAMConfig(base_url=config["base_url"], model=config['model'])
        self.llm = xLAMChatCompletion.from_config(config)

    def create(self, params):
        if params.get("stream", False) and "messages" in params:
            raise NotImplementedError("Local models do not support streaming.")
        else:
            if "tools" in params:
                tools=[item['function'] for item in params["tools"]]

            response = self.llm.completion(params["messages"], tools=tools)
        if len(response['choices'][0]['message']['tool_calls'])>0:
            finish_reason='tool_calls'
            tool_results = response['choices'][0]['message']['tool_calls']
            if isinstance(tool_results, list) and isinstance(tool_results[0], list):
                tool_results = tool_results[0]
            tool_calls = []
            try:
                for tool_call in tool_results:
                    tool_calls.append(
                        ChatCompletionMessageToolCall(
                            id=str(random.randint(0,2500)),
                            function={"name": tool_call['name'], "arguments": json.dumps(tool_call["arguments"])},
                            type="function"
                        )
                    )
            except Exception as e:
                print("Tool parse error: {tool_results}")
                tool_calls=None
                finish_reason='stop'
        else:
            finish_reason='stop'
            tool_calls = None

        message  = ChatCompletionMessage(
            role="assistant",
            content=response['choices'][0]['message']['content'],
            function_call=None,
            tool_calls=tool_calls,
        )
        choices = [Choice(finish_reason=finish_reason, index=0, message=message)]
        response_oai = ChatCompletion(id=str(random.randint(0,25000)),
            model=params["model"],
            created=int(time.time()),
            object="chat.completion",
            choices=choices,
            usage=CompletionUsage(
                prompt_tokens=0,
                completion_tokens=0,
                total_tokens=0
            ),
            cost=0.0,

        )
        return response_oai

    def message_retrieval(self, response):
        """Retrieve the messages from the response."""
        choices = response.choices
        if isinstance(response, Completion):
            return [choice.text for choice in choices] 
        if TOOL_ENABLED:
            return [  # type: ignore [return-value]
                (
                    choice.message  # type: ignore [union-attr]
                    if choice.message.function_call is not None or choice.message.tool_calls is not None  # type: ignore [union-attr]
                    else choice.message.content
                )  # type: ignore [union-attr]
                for choice in choices
            ]
        else:
            return [  # type: ignore [return-value]
                choice.message if choice.message.function_call is not None else choice.message.content  # type: ignore [union-attr]
                for choice in choices
            ]

    def cost(self, response) -> float:
        """Calculate the cost of the response."""
        response.cost = 0
        return 0

    @staticmethod
    def get_usage(response):
        # returns a dict of prompt_tokens, completion_tokens, total_tokens, cost, model
        # if usage needs to be tracked, else None
        return {
            "prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
            "completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
            "total_tokens": (
                response.usage.prompt_tokens + response.usage.completion_tokens if response.usage is not None else 0
            ),
            "cost": response.cost if hasattr(response, "cost") else 0,
            "model": response.model,
        }


chatbot = autogen.AssistantAgent(
    name="chatbot",
    system_message="For currency exchange tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.",
    llm_config=local_llm_config,
)


# create a UserProxyAgent instance named "user_proxy"
user_proxy = autogen.UserProxyAgent(
    name="user_proxy",
    is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
    human_input_mode="NEVER",
    max_consecutive_auto_reply=5,
)

CurrencySymbol = Literal["USD", "EUR"]

def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:
    if base_currency == quote_currency:
        return 1.0
    elif base_currency == "USD" and quote_currency == "EUR":
        return 1 / 1.1
    elif base_currency == "EUR" and quote_currency == "USD":
        return 1.1
    else:
        raise ValueError(f"Unknown currencies {base_currency}, {quote_currency}")


@user_proxy.register_for_execution()
@chatbot.register_for_llm(description="Currency exchange calculator.")
def currency_calculator(
    base_amount: Annotated[float, "Amount of currency in base_currency"],
    base_currency: Annotated[CurrencySymbol, "Base currency"] = "USD",
    quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR",
) -> str:
    quote_amount = exchange_rate(base_currency, quote_currency) * base_amount
    return f"{quote_amount} {quote_currency}"


print(chatbot.llm_config["tools"])
chatbot.register_model_client(model_client_cls=CustomModelClient)
query = "How much is 123.45 USD in EUR?"
# query = "What's the weather like in New York in fahrenheit?"
res = user_proxy.initiate_chat(
        chatbot, message=query,
        max_round=5,
        )
print("Chat history:", res.chat_history)

运行示例结果:

user_proxy (to chatbot):

How much is 123.45 USD in EUR?

--------------------------------------------------------------------------------
chatbot (to user_proxy):


***** Suggested tool call (507): currency_calculator *****
Arguments:
{"base_amount": 123.45, "base_currency": "USD", "quote_currency": "EUR"}
**********************************************************

--------------------------------------------------------------------------------

>>>>>>>> EXECUTING FUNCTION currency_calculator...
user_proxy (to chatbot):

user_proxy (to chatbot):

***** Response from calling tool (507) *****
112.22727272727272 EUR
********************************************

--------------------------------------------------------------------------------
chatbot (to user_proxy):

The currency calculator returned 112.23 EUR.

--------------------------------------------------------------------------------
user_proxy (to chatbot):
相关推荐
mqiqe21 分钟前
Python MySQL通过Binlog 获取变更记录 恢复数据
开发语言·python·mysql
AttackingLin23 分钟前
2024强网杯--babyheap house of apple2解法
linux·开发语言·python
哭泣的眼泪40837 分钟前
解析粗糙度仪在工业制造及材料科学和建筑工程领域的重要性
python·算法·django·virtualenv·pygame
Ysjt | 深1 小时前
C++多线程编程入门教程(优质版)
java·开发语言·jvm·c++
ephemerals__1 小时前
【c++丨STL】list模拟实现(附源码)
开发语言·c++·list
码农飞飞1 小时前
深入理解Rust的模式匹配
开发语言·后端·rust·模式匹配·解构·结构体和枚举
一个小坑货1 小时前
Rust 的简介
开发语言·后端·rust
湫ccc1 小时前
《Python基础》之基本数据类型
开发语言·python