使用Microsoft.SemanticKernel基于本地运行的Ollama大语言模型实现Agent调用函数

大语言模型的发展日新月异,记得在去年这个时候,函数调用还是gpt-4的专属。到今年本地运行的大模型无论是推理能力还是文本的输出质量都已经非常接近gpt-4了。而在去年gpt-4尚未发布函数调用时,智能体框架的开发者们依赖构建精巧的提示词实现了gpt-3.5的函数调用。目前在本机运行的大模型,基于这一套逻辑也可以实现函数式调用,今天我们就是用本地运行的大模型来实现这个需求。从测试的效果来看,本地大模型对于简单的函数调用成功率已经非常高了,但是受限于本地机器的性能,调用的时间还是比较长。如果有NVIDIA显卡的CUDA环境,质量应该会好很多,今天就以大家都比较熟悉的LLAMA生态作为起点,基于阿里云开源的千问7B模型的量化版作为基座通过C#和SemanticKernel来实现函数调用的功能。

基本调用逻辑参考这张图:

首先我们需要在本机(windows系统)安装Ollama作为LLM的API后端。访问https://ollama.com/,选择Download。选择你需要的版本即可,windows用户请选择Download for Windows。下载完成后,无脑点击下一步下一步即可安装完毕。

安装完毕后,打开我们的PowerShell即可运行大模型,第一次加载会下载模型文件到本地磁盘,会比较慢。运行起来后就可以通过控制台和模型进行简单的对话,这里我们以阿里发布的千问2:7b为例。执行以下命令即可运行起来:

复制代码
ollama run qwen2:7b

接着我们使用ctrl+D退出对话框,并执行ollama serve,看看服务器是否运行起来了,正常情况下会看到11434这个端口已经运行起来了。接下来我们就可以进入到编码阶段

首先我们创建一个.net8.0的的控制台,接着我们引入三个必要的包

复制代码
dotnet add package Microsoft.SemanticKernel --version 1.15.0
dotnet add package Newtonsoft.Json --version 13.0.3
dotnet add package OllamaSharp --version 2.0.1

SemanticKernel是我们主要的代理运行框架,OllamaSharp是一个简单的面向Ollama本地API服务的请求封装。避免我们手写httpclient来与本地服务器交互。我这里安装了Newtonsoft.Json来替代system.text.json,主要是用于后期需要一些序列化模型回调来使用,因为模型的回调json可能不是特别标准,使用system.text.json容易导致转义失败。

接下来就是编码阶段,首先我们定义一个函数,这个函数是后面LLM会用到的函数,简单的定义如下:

复制代码
public class FunctionTest
{
    [KernelFunction, Description("获取城市的天气状况")]
    public object GetWeather([Description("城市名称")] string CityName, [Description("查询时段,值可以是[白天,夜晚]")] string DayPart)
    {
        return new { CityName, DayPart, CurrentCondition = "多云", LaterCondition = "阴", MinTemperature = 19, MaxTemperature = 23 };
    }
}

这里的KernelFunction和Description特性都是必要的,用于SemanticKernel查询到对应的函数并封装处对应的元数据。

接着我们需要自定义一个继承自接口IChatCompletionService的实现,因为SemanticKernel是基于openai的gpt系列设计的框架,所以要和本地模型调用,我们需要设置独立的ChatCompletionService来让SemanticKernel和本机模型API交互。这里我们主要需要实现的函数是GetChatMessageContentsAsync。因为函数调用我们需要接收到模型完整的回调用于转换json,所以流式传输这里用不上。

复制代码
public class CustomChatCompletionService : IChatCompletionService
{
    public IReadOnlyDictionary<string, object?> Attributes => throw new NotImplementedException();

    public Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
    {
        throw new NotImplementedException();
    }

    public IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
    {
        throw new NotImplementedException();
    }
}

接下来我们需要定义一个SemanticKernel的实例,这个实例会伴随本次调用贯穿全程。SemanticKernel使用了简单的链式构建。基本代码如下:

复制代码
var builder = Kernel.CreateBuilder();
//这里我们需要增加刚才我们定义的实例CustomChatCompletionService,有点类似IOC的设计
builder.Services.AddKeyedSingleton<IChatCompletionService>("ollamaChat", new CustomChatCompletionService());
//这里我们需要插入之前定义的插件
builder.Plugins.AddFromType<FunctionTest>();
var kernel = builder.Build();

可以看到基本的构建链式调用代码部分还是比较简单的,接下来就是调用的部分,这里主要的部分就是将LLM可用的函数插入到系统提示词,来引导LLM去调用特定函数:

复制代码
//定义一个对话历史
ChatHistory history = [];
//获取刚才定义的插件函数的元数据,用于后续创建prompt
var plugins = kernel.Plugins.GetFunctionsMetadata();
//生成函数调用提示词,引导模型根据用户请求去调用函数
var functionsPrompt = CreateFunctionsMetaObject(plugins);
//创建系统提示词,插入刚才生成的提示词
var prompt = $"""
                  You have access to the following functions. Use them if required:
                  {functionsPrompt}
                  If function calls are used, ensure the output is in JSON format; otherwise, output should be in text format.
                  """;
//添加系统提示词
history.AddSystemMessage(prompt);
//创建一个对话服务实例
var chatCompletionService = kernel.GetRequiredService<IChatCompletionService>();
//添加用户的提问
history.AddUserMessage(question);
//链式执行kernel
var result = await chatCompletionService.GetChatMessageContentAsync(
    history,
    executionSettings: null,
    kernel: kernel);
//打印回调内容
Console.WriteLine($"Assistant> {result}");

在这里我们可以debug看看生成的系统提示词细节:

当代码执行到GetChatMessageContentAsync这里时,就会跳转到我们的CustomChatCompletionService的GetChatMessageContentsAsync函数,在这里我们需要进行ollama的调用来达成目的。

这里比较核心的部分就是将LLM回调的内容使用JSON序列化来检测是否涉及到函数调用,简单来讲由于类似qwen这样没有专门针对function calling专项微调过的(glm-4-9b原生支持function calling)模型,其function calling并不是每次都能准确的回调,所以这里我们需要对回调的内容进行反序列化和信息抽取,确保模型的调用符合回调函数的格式标准。具体代码如下

复制代码
public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
    GetDicSearchResult(kernel);
    var prompt = HistoryToText(chatHistory);
    var ollama = new OllamaApiClient("http://127.0.0.1:11434", "qwen2:7b");
    var chat = new Chat(ollama, _ => { });
    sw.Start();
    var history = (await chat.Send(prompt, CancellationToken.None)).ToArray();
    sw.Stop();
    Console.WriteLine($"调用耗时:{Math.Round(sw.Elapsed.TotalSeconds,2)}秒");
    var last = history.Last();
    var chatResponse = last.Content;
    try
    {
        JToken jToken = JToken.Parse(chatResponse);
        jToken = ConvertStringToJson(jToken);
        var searchs = DicSearchResult.Values.ToList();
        if (TryFindValues(jToken, ref searchs))
        {
            var firstFunc = searchs.First();
            var funcCallResult = await firstFunc.KernelFunction.InvokeAsync(kernel, firstFunc.FunctionParams);
            chatHistory.AddMessage(AuthorRole.Assistant, chatResponse);
            chatHistory.AddMessage(AuthorRole.Tool, funcCallResult.ToString());
            return await GetChatMessageContentsAsync(chatHistory, kernel: kernel);
        }
        else
        {

        }
    }
    catch(Exception e)
    {

    }
    return new List<ChatMessageContent> { new ChatMessageContent(AuthorRole.Assistant, chatResponse) };
}

这里我们首先使用SemanticKernel的kernel的函数元数据通过GetDicSearchResult构建了一个字典,这部分代码如下:

复制代码
public static Dictionary<string, SearchResult> DicSearchResult = new Dictionary<string, SearchResult>();
public static void GetDicSearchResult(Kernel kernel)
{
    DicSearchResult = new Dictionary<string, SearchResult>();
    foreach (var functionMetaData in kernel.Plugins.GetFunctionsMetadata())
    {
        string functionName = functionMetaData.Name;
        if (DicSearchResult.ContainsKey(functionName))
            continue;
        var searchResult = new SearchResult
        {
            FunctionName = functionName,
            KernelFunction = kernel.Plugins.GetFunction(null, functionName)
        };
        functionMetaData.Parameters.ToList().ForEach(x => searchResult.FunctionParams.Add(x.Name, null));
        DicSearchResult.Add(functionName, searchResult);
    }
}

接着使用HistoryToText将历史对话信息组装成一个单一的prompt发送给模型,大概会组装成如下内容,其实就是系统提示词+用户提示词组合成一个单一文本:

接着我们使用OllamaSharp的SDK提供的OllamaApiClient发送信息给模型,等待模型回调后,从模型回调的内容中抽取chatResponse,接着我们需要通过一个try catch来处理,当chatResponse可以被正确的解析成标准JToken后,说明模型的回调是一段json,否则会抛出异常,代表模型输出的是一段文本。如果是文本,我们就直接返回模型输出的内容,如果是json则继续向下处理,通过一个TryFindValues函数从模型中抽取我们所需要的回调函数名、参数,并赋值到一个临时变量中。最后通过SemanticKernel的KernelFunction的InvokeAsync进行真正的函数调用,获取到函数的回调内容,接着我们需要将模型的原始输出和回调内容一同添加到chatHistory后,再度递归发起GetChatMessageContentsAsync调用,这一次模型就会拿到前一次回调的城市天气内容来进行回答了。

第二次回调前的prompt如下,可以看到模型的输出虽然是json,但是并没有规范的格式,不过使用我们的抽取函数还是获取到了需要的信息,从而正确的构建了底部的回调:

通过这一轮回调再次喂给llm,llm就可以正确的输出结果了:

以上就是整个文章的内容了,可以看到在这个过程中我们主要做的工作就是通过系统提示词诱导模型输出回调函数json,解析json获取参数,调用本地的函数后再次回调给模型,这个过程其实有点类似的RAG,只不过RAG是通过用户的提示词直接进行近似度搜索获取到近似度相关的文本组合到系统提示词,而函数调用给了模型更大的自由度,可以让模型自行决策是否调用函数,从而使本地Agent代理可以实现诸如帮你操控电脑,打印文件,编写邮件等等助手性质的功能。

下面是核心部分的代码,请大家自取

program.cs:

复制代码
using ConsoleApp4;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel;
using Microsoft.Extensions.DependencyInjection;
using System.ComponentModel;
using Newtonsoft.Json.Linq;




await Ollama("我想知道西安今天晚上的天气情况");



async Task Ollama(string question)
{
    Console.WriteLine($"User> {question}");
    var builder = Kernel.CreateBuilder();
    //这里我们需要增加刚才我们定义的实例CustomChatCompletionService,有点类似IOC的设计
    builder.Services.AddKeyedSingleton<IChatCompletionService>("ollamaChat", new CustomChatCompletionService());
    //这里我们需要插入之前定义的插件
    builder.Plugins.AddFromType<FunctionTest>();
    var kernel = builder.Build();
    //定义一个对话历史
    ChatHistory history = [];
    //获取刚才定义的插件函数的元数据,用于后续创建prompt
    var plugins = kernel.Plugins.GetFunctionsMetadata();
    //生成函数调用提示词,引导模型根据用户请求去调用函数
    var functionsPrompt = CreateFunctionsMetaObject(plugins);
    //创建系统提示词,插入刚才生成的提示词
    var prompt = $"""
                      You have access to the following functions. Use them if required:
                      {functionsPrompt}
                      If function calls are used, ensure the output is in JSON format; otherwise, output should be in text format.
                      """;
    //添加系统提示词
    history.AddSystemMessage(prompt);
    //创建一个对话服务实例
    var chatCompletionService = kernel.GetRequiredService<IChatCompletionService>();
    //添加用户的提问
    history.AddUserMessage(question);
    //链式执行kernel
    var result = await chatCompletionService.GetChatMessageContentAsync(
        history,
        executionSettings: null,
        kernel: kernel);
    //打印回调内容
    Console.WriteLine($"Assistant> {result}");
}
static JToken? CreateFunctionsMetaObject(IList<KernelFunctionMetadata> plugins)
{
    if (plugins.Count < 1) return null;
    if (plugins.Count == 1) return CreateFunctionMetaObject(plugins[0]);

    JArray promptFunctions = [];
    foreach (var plugin in plugins)
    {
        var pluginFunctionWrapper = CreateFunctionMetaObject(plugin);
        promptFunctions.Add(pluginFunctionWrapper);
    }

    return promptFunctions;
}
static JObject CreateFunctionMetaObject(KernelFunctionMetadata plugin)
{
    var pluginFunctionWrapper = new JObject()
        {
            { "type", "function" },
        };

    var pluginFunction = new JObject()
        {
            { "name", plugin.Name },
            { "description", plugin.Description },
        };

    var pluginFunctionParameters = new JObject()
        {
            { "type", "object" },
        };
    var pluginProperties = new JObject();
    foreach (var parameter in plugin.Parameters)
    {
        var property = new JObject()
            {
                { "type", parameter.ParameterType?.ToString() },
                { "description", parameter.Description },
            };

        pluginProperties.Add(parameter.Name, property);
    }

    pluginFunctionParameters.Add("properties", pluginProperties);
    pluginFunction.Add("parameters", pluginFunctionParameters);
    pluginFunctionWrapper.Add("function", pluginFunction);

    return pluginFunctionWrapper;
}
public class FunctionTest
{
    [KernelFunction, Description("获取城市的天气状况")]
    public object GetWeather([Description("城市名称")] string CityName, [Description("查询时段,值可以是[白天,夜晚]")] string DayPart)
    {
        return new { CityName, DayPart, CurrentCondition = "多云", LaterCondition = "阴", MinTemperature = 19, MaxTemperature = 23 };
    }
}

CustomChatCompletionService.cs:

复制代码
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Newtonsoft.Json.Linq;
using OllamaSharp;
using System.Diagnostics;
using System.Text;

namespace ConsoleApp4
{
    public class CustomChatCompletionService : IChatCompletionService
    {
        public static Dictionary<string, SearchResult> DicSearchResult = new Dictionary<string, SearchResult>();
        public static void GetDicSearchResult(Kernel kernel)
        {
            DicSearchResult = new Dictionary<string, SearchResult>();
            foreach (var functionMetaData in kernel.Plugins.GetFunctionsMetadata())
            {
                string functionName = functionMetaData.Name;
                if (DicSearchResult.ContainsKey(functionName))
                    continue;
                var searchResult = new SearchResult
                {
                    FunctionName = functionName,
                    KernelFunction = kernel.Plugins.GetFunction(null, functionName)
                };
                functionMetaData.Parameters.ToList().ForEach(x => searchResult.FunctionParams.Add(x.Name, null));
                DicSearchResult.Add(functionName, searchResult);
            }
        }
        public IReadOnlyDictionary<string, object?> Attributes => throw new NotImplementedException();
        static Stopwatch sw = new Stopwatch();
        public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
        {
            GetDicSearchResult(kernel);
            var prompt = HistoryToText(chatHistory);
            var ollama = new OllamaApiClient("http://127.0.0.1:11434", "qwen2:7b");
            var chat = new Chat(ollama, _ => { });
            sw.Start();
            var history = (await chat.Send(prompt, CancellationToken.None)).ToArray();
            sw.Stop();
            Console.WriteLine($"调用耗时:{Math.Round(sw.Elapsed.TotalSeconds,2)}秒");
            var last = history.Last();
            var chatResponse = last.Content;
            try
            {
                JToken jToken = JToken.Parse(chatResponse);
                jToken = ConvertStringToJson(jToken);
                var searchs = DicSearchResult.Values.ToList();
                if (TryFindValues(jToken, ref searchs))
                {
                    var firstFunc = searchs.First();
                    var funcCallResult = await firstFunc.KernelFunction.InvokeAsync(kernel, firstFunc.FunctionParams);
                    chatHistory.AddMessage(AuthorRole.Assistant, chatResponse);
                    chatHistory.AddMessage(AuthorRole.Tool, funcCallResult.ToString());
                    return await GetChatMessageContentsAsync(chatHistory, kernel: kernel);
                }
                else
                {

                }
            }
            catch(Exception e)
            {

            }
            return new List<ChatMessageContent> { new ChatMessageContent(AuthorRole.Assistant, chatResponse) };
        }
        JToken ConvertStringToJson(JToken token)
        {
            if (token.Type == JTokenType.Object)
            {
                // 遍历对象的每个属性
                JObject obj = new JObject();
                foreach (JProperty prop in token.Children<JProperty>())
                {
                    obj.Add(prop.Name, ConvertStringToJson(prop.Value));
                }
                return obj;
            }
            else if (token.Type == JTokenType.Array)
            {
                // 遍历数组的每个元素
                JArray array = new JArray();
                foreach (JToken item in token.Children())
                {
                    array.Add(ConvertStringToJson(item));
                }
                return array;
            }
            else if (token.Type == JTokenType.String)
            {
                // 尝试将字符串解析为 JSON
                string value = token.ToString();
                try
                {
                    return JToken.Parse(value);
                }
                catch (Exception)
                {
                    // 解析失败时返回原始字符串
                    return token;
                }
            }
            else
            {
                // 其他类型直接返回
                return token;
            }
        }
        bool TryFindValues(JToken token, ref List<SearchResult> searches)
        {
            if (token.Type == JTokenType.Object)
            {
                foreach (var child in token.Children<JProperty>())
                {
                    foreach (var search in searches)
                    {
                        if (child.Value.ToString().ToLower().Equals(search.FunctionName.ToLower()) && search.SearchFunctionNameSucc != true)
                            search.SearchFunctionNameSucc = true;
                        foreach (var par in search.FunctionParams)
                        {
                            if (child.Name.ToLower().Equals(par.Key.ToLower()) && par.Value == null)
                                search.FunctionParams[par.Key] = child.Value.ToString().ToLower();
                        }
                    }
                    if (searches.Any(x => x.SearchFunctionNameSucc == false || x.FunctionParams.Any(x => x.Value == null)))
                        TryFindValues(child.Value, ref searches);
                }
            }
            else if (token.Type == JTokenType.Array)
            {
                foreach (var item in token.Children())
                {
                    if (searches.Any(x => x.SearchFunctionNameSucc == false || x.FunctionParams.Any(x => x.Value == null)))
                        TryFindValues(item, ref searches);
                }
            }
            return searches.Any(x => x.SearchFunctionNameSucc && x.FunctionParams.All(x => x.Value != null));
        }
        public virtual string HistoryToText(ChatHistory history)
        {
            StringBuilder sb = new();
            foreach (var message in history)
            {
                if (message.Role == AuthorRole.User)
                {
                    sb.AppendLine($"User: {message.Content}");
                }
                else if (message.Role == AuthorRole.System)
                {
                    sb.AppendLine($"System: {message.Content}");
                }
                else if (message.Role == AuthorRole.Assistant)
                {
                    sb.AppendLine($"Assistant: {message.Content}");
                }
                else if (message.Role == AuthorRole.Tool)
                {
                    sb.AppendLine($"Tool: {message.Content}");
                }
            }
            return sb.ToString();
        }
        public IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
        {
            throw new NotImplementedException();
        }
    }
    public class SearchResult
    {
        public string FunctionName { get; set; }
        public bool SearchFunctionNameSucc { get; set; }
        public KernelArguments FunctionParams { get; set; } = new KernelArguments();
        public KernelFunction KernelFunction { get; set; }
    }
}