用 Rust 构建高性能 LiteLLM 客户端:支持流式与非流式调用

在大模型应用开发中,LiteLLM 作为统一的 LLM 代理层,能让我们用标准化的接口调用不同厂商的大模型。今天我将分享如何用 Rust 构建一个高性能、类型安全的 LiteLLM 客户端,同时支持流式(Streaming)和非流式的聊天补全请求。

为什么选择 Rust 开发 LiteLLM 客户端?

  • 高性能:Rust 的零成本抽象和无 GC 特性,特别适合需要低延迟、高并发的大模型调用场景
  • 类型安全:编译期检查能提前发现大部分错误,避免线上运行时的 JSON 解析或参数错误
  • 异步生态:Tokio + reqwest 提供了强大的异步网络编程能力,完美适配流式响应处理

完整实现代码

1. 依赖配置(Cargo.toml)

toml 复制代码
[package]
name = "litellm-client"
version = "0.1.0"
edition = "2021"

[dependencies]
reqwest = { version = "0.12", features = ["json", "stream"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
tokio = { version = "1.0", features = ["full"] }
futures = "0.3"

2. 客户端核心实现(src/main.rs)

rust 复制代码
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;

/// 聊天客户端的错误类型定义
#[derive(Error, Debug)]
pub enum ChatClientError {
    #[error("未授权: {0}")]
    Unauthorized(String),
    #[error("请求错误: {0}")]
    RequestError(#[from] reqwest::Error),
    #[error("JSON 解析错误: {0}")]
    JsonError(#[from] serde_json::Error),
    #[error("无效的 UTF-8 编码")]
    InvalidUtf8,
}

impl From<std::string::FromUtf8Error> for ChatClientError {
    fn from(_: std::string::FromUtf8Error) -> Self {
        ChatClientError::InvalidUtf8
    }
}

/// 聊天补全的消息结构
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
    pub role: String,
    pub content: String,
}

/// 聊天补全请求参数
#[derive(Debug, Clone, Serialize)]
pub struct CompletionRequest {
    pub model: String,
    pub messages: Vec<Message>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub temperature: Option<f64>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub top_p: Option<f64>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub n: Option<u32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub max_tokens: Option<u32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub presence_penalty: Option<f64>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub frequency_penalty: Option<f64>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub user: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub stream: Option<bool>,
}

/// 聊天补全响应类型
pub type CompletionResponse = Value;

/// 用于与 LiteLLM 代理服务器交互的聊天客户端
pub struct ChatClient {
    base_url: String,
    api_key: Option<String>,
    client: Client,
}

impl ChatClient {
    /// 初始化 ChatClient 实例
    /// 
    /// # 参数
    /// * `base_url` - LiteLLM 代理服务器的基础 URL (例如: "http://localhost:8000")
    /// * `api_key` - 认证用的 API 密钥,如果提供会以 Bearer token 形式发送
    pub fn new(base_url: &str, api_key: Option<&str>) -> Self {
        ChatClient {
            base_url: base_url.trim_end_matches('/').to_string(),
            api_key: api_key.map(|s| s.to_string()),
            client: Client::new(),
        }
    }

    /// 自定义设置 reqwest Client
    pub fn set_client_val(&mut self, client: Client) -> &mut Self {
        self.client = client;
        self
    }

    /// 获取 API 请求的头部信息,包含认证信息(如果设置了 api_key)
    fn get_headers(&self) -> reqwest::header::HeaderMap {
        let mut headers = reqwest::header::HeaderMap::new();
        headers.insert(
            reqwest::header::CONTENT_TYPE,
            reqwest::header::HeaderValue::from_static("application/json"),
        );

        // 添加认证头部
        if let Some(api_key) = &self.api_key {
            headers.insert(
                reqwest::header::AUTHORIZATION,
                reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
                    .expect("无效的 API 密钥头部值"),
            );
        }

        headers
    }

    /// 创建非流式的聊天补全请求
    /// 
    /// # 参数
    /// * `model` - 用于生成补全的模型名称
    /// * `messages` - 用于生成补全的消息列表
    /// * `temperature` - 采样温度,范围 0-2
    /// * `top_p` - 核采样参数,范围 0-1
    /// * `n` - 要生成的补全结果数量
    /// * `max_tokens` - 生成的最大令牌数
    /// * `presence_penalty` - 存在惩罚,范围 -2.0 到 2.0
    /// * `frequency_penalty` - 频率惩罚,范围 -2.0 到 2.0
    /// * `user` - 终端用户的唯一标识符
    /// 
    /// # 返回值
    /// 服务器返回的补全响应
    /// 
    /// # 错误
    /// 请求失败时返回错误
    pub async fn completions(
        &self,
        model: &str,
        messages: Vec<Message>,
        temperature: Option<f64>,
        top_p: Option<f64>,
        n: Option<u32>,
        max_tokens: Option<u32>,
        presence_penalty: Option<f64>,
        frequency_penalty: Option<f64>,
        user: Option<&str>,
    ) -> Result<CompletionResponse, ChatClientError> {
        let url = format!("{}/chat/completions", self.base_url);

        // 构建请求数据
        let request_data = CompletionRequest {
            model: model.to_string(),
            messages,
            temperature,
            top_p,
            n,
            max_tokens,
            presence_penalty,
            frequency_penalty,
            user: user.map(|s| s.to_string()),
            stream: None,
        };

        // 发送 POST 请求
        let response = self
            .client
            .post(&url)
            .headers(self.get_headers())
            .json(&request_data)
            .send()
            .await?;

        // 处理未授权错误
        if response.status() == reqwest::StatusCode::UNAUTHORIZED {
            return Err(ChatClientError::Unauthorized(
                "未授权访问,请检查 API 密钥".to_string(),
            ));
        }

        // 解析 JSON 响应
        let json: Value = response.json().await?;
        Ok(json)
    }

    /// 创建流式的聊天补全请求
    /// 
    /// # 参数
    /// 与 completions 方法参数相同
    /// 
    /// # 返回值
    /// 服务器返回的流式响应块
    /// 
    /// # 错误
    /// 请求失败时返回错误
    pub async fn completions_stream(
        &self,
        model: &str,
        messages: Vec<Message>,
        temperature: Option<f64>,
        top_p: Option<f64>,
        n: Option<u32>,
        max_tokens: Option<u32>,
        presence_penalty: Option<f64>,
        frequency_penalty: Option<f64>,
        user: Option<&str>,
    ) -> Result<impl futures::Stream<Item = Result<Value, ChatClientError>>, ChatClientError> {
        let url = format!("{}/chat/completions", self.base_url);

        // 构建请求数据(开启流式)
        let request_data = CompletionRequest {
            model: model.to_string(),
            messages,
            temperature,
            top_p,
            n,
            max_tokens,
            presence_penalty,
            frequency_penalty,
            user: user.map(|s| s.to_string()),
            stream: Some(true),
        };

        // 发送 POST 请求
        let response = self
            .client
            .post(&url)
            .headers(self.get_headers())
            .json(&request_data)
            .send()
            .await?;

        // 处理未授权错误
        if response.status() == reqwest::StatusCode::UNAUTHORIZED {
            return Err(ChatClientError::Unauthorized(
                "未授权访问,请检查 API 密钥".to_string(),
            ));
        }

        use futures::StreamExt;
        
        // 处理流式响应(SSE 格式)
        Ok(response.bytes_stream().flat_map(|chunk_result| -> futures::stream::BoxStream<'static, Result<Value, ChatClientError>> {
            use futures::stream::iter;
            
            match chunk_result {
                Ok(bytes) => {
                    // 将字节转换为字符串
                    match String::from_utf8(bytes.to_vec()) {
                        Ok(text) => {
                            // 解析 SSE (Server-Sent Events) 格式
                            let chunks: Vec<Value> = text
                                .lines()
                                // 过滤出以 "data: " 开头的行
                                .filter(|line| line.starts_with("data: "))
                                .filter_map(|line| {
                                    // 移除 "data: " 前缀
                                    let data_str = &line[6..];
                                    // 忽略结束标记
                                    if data_str.trim() == "[DONE]" {
                                        None
                                    } else {
                                        // 解析 JSON 数据
                                        serde_json::from_str(data_str).ok()
                                    }
                                })
                                .collect();

                            // 转换为流式输出
                            Box::pin(iter(chunks.into_iter().map(Ok)))
                        }
                        Err(_) => Box::pin(iter(std::iter::once(Err(ChatClientError::InvalidUtf8)))),
                    }
                }
                Err(e) => Box::pin(iter(std::iter::once(Err(ChatClientError::RequestError(e))))),
            }
        }))
    }
}

/// 示例代码
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    // 1. 创建聊天客户端实例
    let client = ChatClient::new("https://open.bigmodel.cn/api/coding/paas/v4", Some("your-api-key-here"));

    // 2. 示例:非流式补全请求
    let messages = vec![
        Message {
            role: "user".to_string(),
            content: "你好,请介绍一下 Rust 语言的优势".to_string(),
        },
    ];

    match client
        .completions("glm-4.7", messages, Some(0.7), Some(0.9), None, Some(1024), None, None, None)
        .await
    {
        Ok(response) => println!("非流式响应: {}", serde_json::to_string_pretty(&response)?),
        Err(e) => println!("非流式请求错误: {}", e),
    }

    // 3. 示例:流式补全请求
    let messages = vec![
        Message {
            role: "user".to_string(),
            content: "用 Rust 写一个简单的 HTTP 服务器示例".to_string(),
        },
    ];

    match client
        .completions_stream("glm-4.7", messages, Some(0.7), Some(0.9), None, Some(1024), None, None, None)
        .await
    {
        Ok(stream) => {
            use futures::StreamExt;
            tokio::pin!(stream);
            println!("开始接收流式响应:");
            while let Some(chunk_result) = stream.next().await {
                match chunk_result {
                    Ok(chunk) => {
                        // 提取并打印内容(根据实际响应结构调整)
                        if let Some(content) = chunk["choices"][0]["delta"]["content"].as_str() {
                            print!("{}", content);
                        }
                    }
                    Err(e) => println!("\n流式响应错误: {}", e),
                }
            }
            println!("\n流式响应接收完成");
        }
        Err(e) => println!("流式请求错误: {}", e),
    }

    Ok(())
}

核心功能解析

1. 错误处理

使用 thiserror 定义了清晰的错误类型,涵盖了:

  • 未授权错误(Unauthorized)
  • 请求错误(RequestError)
  • JSON 解析错误(JsonError)
  • UTF-8 编码错误(InvalidUtf8)

通过 From trait 实现了自动转换,让错误处理更加优雅。

2. 数据结构设计

  • Message:表示单条聊天消息,包含角色(role)和内容(content)
  • CompletionRequest:封装了所有聊天补全请求参数,使用 serdeskip_serializing_if 特性,只序列化非空参数
  • ChatClient:核心客户端结构体,包含基础 URL、API 密钥和 HTTP 客户端

3. 核心方法

非流式调用(completions)

  • 构建完整的请求参数
  • 设置正确的请求头(包含认证信息)
  • 发送 POST 请求并处理响应
  • 解析 JSON 响应并返回

流式调用(completions_stream)

  • 开启 stream: true 参数
  • 使用 bytes_stream() 处理流式响应
  • 解析 SSE (Server-Sent Events) 格式数据
  • 将响应块转换为异步流,方便消费

使用注意事项

  1. 替换 API 密钥 :将示例中的 your-api-key-here 替换为实际的 API 密钥
  2. 模型名称适配 :根据 LiteLLM 配置的模型名称调整 model 参数
  3. 错误处理:实际使用时建议完善错误处理逻辑
  4. 超时设置 :可以通过 set_client_val 自定义 Client,添加超时、代理等配置

总结

这个 Rust 客户端实现了 LiteLLM 的核心调用能力,主要特点包括:

  1. 类型安全:通过 Rust 的类型系统和 serde 序列化框架,确保请求参数的正确性
  2. 异步高效:基于 Tokio 和 reqwest 实现异步请求,支持高并发
  3. 完整的流式处理:正确解析 SSE 格式的流式响应,便于实时展示大模型输出
  4. 优雅的错误处理:使用 thiserror 定义清晰的错误类型,方便调试和处理

该客户端可以直接集成到你的 Rust 项目中,用于与 LiteLLM 代理服务器交互,调用各种大模型服务。

相关推荐
魔力军9 小时前
Rust学习Day3: 3个小demo实现
java·学习·rust
Smart-Space9 小时前
htmlbuilder - rust灵活构建html
rust·html
魔力军10 小时前
Rust学习Day2: 变量与可变性、数据类型和函数和控制流
开发语言·学习·rust
暴躁小师兄数据学院1 天前
【WEB3.0零基础转行笔记】Rust编程篇-第一讲:课程简介
rust·web3·区块链·智能合约
Hello.Reader1 天前
Rocket Fairings 实战把全局能力做成“结构化中间件”
中间件·rust·rocket
Andrew_Ryan1 天前
rust arena 内存分配
rust
Andrew_Ryan1 天前
深入理解 Rust 内存管理:基于 typed_arena 的指针操作实践
rust
微小冷2 天前
Rust异步编程详解
开发语言·rust·async·await·异步编程·tokio
鸿乃江边鸟2 天前
Spark Datafusion Comet 向量化Rust Native--CometShuffleExchangeExec怎么控制读写
大数据·rust·spark·native