用 Rust 构建高性能 LiteLLM 客户端:支持流式与非流式调用

在大模型应用开发中,LiteLLM 作为统一的 LLM 代理层,能让我们用标准化的接口调用不同厂商的大模型。今天我将分享如何用 Rust 构建一个高性能、类型安全的 LiteLLM 客户端,同时支持流式(Streaming)和非流式的聊天补全请求。

为什么选择 Rust 开发 LiteLLM 客户端?

  • 高性能:Rust 的零成本抽象和无 GC 特性,特别适合需要低延迟、高并发的大模型调用场景
  • 类型安全:编译期检查能提前发现大部分错误,避免线上运行时的 JSON 解析或参数错误
  • 异步生态:Tokio + reqwest 提供了强大的异步网络编程能力,完美适配流式响应处理

完整实现代码

1. 依赖配置(Cargo.toml)

toml 复制代码
[package]
name = "litellm-client"
version = "0.1.0"
edition = "2021"

[dependencies]
reqwest = { version = "0.12", features = ["json", "stream"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
tokio = { version = "1.0", features = ["full"] }
futures = "0.3"

2. 客户端核心实现(src/main.rs)

rust 复制代码
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;

/// 聊天客户端的错误类型定义
#[derive(Error, Debug)]
pub enum ChatClientError {
    #[error("未授权: {0}")]
    Unauthorized(String),
    #[error("请求错误: {0}")]
    RequestError(#[from] reqwest::Error),
    #[error("JSON 解析错误: {0}")]
    JsonError(#[from] serde_json::Error),
    #[error("无效的 UTF-8 编码")]
    InvalidUtf8,
}

impl From<std::string::FromUtf8Error> for ChatClientError {
    fn from(_: std::string::FromUtf8Error) -> Self {
        ChatClientError::InvalidUtf8
    }
}

/// 聊天补全的消息结构
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
    pub role: String,
    pub content: String,
}

/// 聊天补全请求参数
#[derive(Debug, Clone, Serialize)]
pub struct CompletionRequest {
    pub model: String,
    pub messages: Vec<Message>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub temperature: Option<f64>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub top_p: Option<f64>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub n: Option<u32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub max_tokens: Option<u32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub presence_penalty: Option<f64>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub frequency_penalty: Option<f64>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub user: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub stream: Option<bool>,
}

/// 聊天补全响应类型
pub type CompletionResponse = Value;

/// 用于与 LiteLLM 代理服务器交互的聊天客户端
pub struct ChatClient {
    base_url: String,
    api_key: Option<String>,
    client: Client,
}

impl ChatClient {
    /// 初始化 ChatClient 实例
    /// 
    /// # 参数
    /// * `base_url` - LiteLLM 代理服务器的基础 URL (例如: "http://localhost:8000")
    /// * `api_key` - 认证用的 API 密钥,如果提供会以 Bearer token 形式发送
    pub fn new(base_url: &str, api_key: Option<&str>) -> Self {
        ChatClient {
            base_url: base_url.trim_end_matches('/').to_string(),
            api_key: api_key.map(|s| s.to_string()),
            client: Client::new(),
        }
    }

    /// 自定义设置 reqwest Client
    pub fn set_client_val(&mut self, client: Client) -> &mut Self {
        self.client = client;
        self
    }

    /// 获取 API 请求的头部信息,包含认证信息(如果设置了 api_key)
    fn get_headers(&self) -> reqwest::header::HeaderMap {
        let mut headers = reqwest::header::HeaderMap::new();
        headers.insert(
            reqwest::header::CONTENT_TYPE,
            reqwest::header::HeaderValue::from_static("application/json"),
        );

        // 添加认证头部
        if let Some(api_key) = &self.api_key {
            headers.insert(
                reqwest::header::AUTHORIZATION,
                reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
                    .expect("无效的 API 密钥头部值"),
            );
        }

        headers
    }

    /// 创建非流式的聊天补全请求
    /// 
    /// # 参数
    /// * `model` - 用于生成补全的模型名称
    /// * `messages` - 用于生成补全的消息列表
    /// * `temperature` - 采样温度,范围 0-2
    /// * `top_p` - 核采样参数,范围 0-1
    /// * `n` - 要生成的补全结果数量
    /// * `max_tokens` - 生成的最大令牌数
    /// * `presence_penalty` - 存在惩罚,范围 -2.0 到 2.0
    /// * `frequency_penalty` - 频率惩罚,范围 -2.0 到 2.0
    /// * `user` - 终端用户的唯一标识符
    /// 
    /// # 返回值
    /// 服务器返回的补全响应
    /// 
    /// # 错误
    /// 请求失败时返回错误
    pub async fn completions(
        &self,
        model: &str,
        messages: Vec<Message>,
        temperature: Option<f64>,
        top_p: Option<f64>,
        n: Option<u32>,
        max_tokens: Option<u32>,
        presence_penalty: Option<f64>,
        frequency_penalty: Option<f64>,
        user: Option<&str>,
    ) -> Result<CompletionResponse, ChatClientError> {
        let url = format!("{}/chat/completions", self.base_url);

        // 构建请求数据
        let request_data = CompletionRequest {
            model: model.to_string(),
            messages,
            temperature,
            top_p,
            n,
            max_tokens,
            presence_penalty,
            frequency_penalty,
            user: user.map(|s| s.to_string()),
            stream: None,
        };

        // 发送 POST 请求
        let response = self
            .client
            .post(&url)
            .headers(self.get_headers())
            .json(&request_data)
            .send()
            .await?;

        // 处理未授权错误
        if response.status() == reqwest::StatusCode::UNAUTHORIZED {
            return Err(ChatClientError::Unauthorized(
                "未授权访问,请检查 API 密钥".to_string(),
            ));
        }

        // 解析 JSON 响应
        let json: Value = response.json().await?;
        Ok(json)
    }

    /// 创建流式的聊天补全请求
    /// 
    /// # 参数
    /// 与 completions 方法参数相同
    /// 
    /// # 返回值
    /// 服务器返回的流式响应块
    /// 
    /// # 错误
    /// 请求失败时返回错误
    pub async fn completions_stream(
        &self,
        model: &str,
        messages: Vec<Message>,
        temperature: Option<f64>,
        top_p: Option<f64>,
        n: Option<u32>,
        max_tokens: Option<u32>,
        presence_penalty: Option<f64>,
        frequency_penalty: Option<f64>,
        user: Option<&str>,
    ) -> Result<impl futures::Stream<Item = Result<Value, ChatClientError>>, ChatClientError> {
        let url = format!("{}/chat/completions", self.base_url);

        // 构建请求数据(开启流式)
        let request_data = CompletionRequest {
            model: model.to_string(),
            messages,
            temperature,
            top_p,
            n,
            max_tokens,
            presence_penalty,
            frequency_penalty,
            user: user.map(|s| s.to_string()),
            stream: Some(true),
        };

        // 发送 POST 请求
        let response = self
            .client
            .post(&url)
            .headers(self.get_headers())
            .json(&request_data)
            .send()
            .await?;

        // 处理未授权错误
        if response.status() == reqwest::StatusCode::UNAUTHORIZED {
            return Err(ChatClientError::Unauthorized(
                "未授权访问,请检查 API 密钥".to_string(),
            ));
        }

        use futures::StreamExt;
        
        // 处理流式响应(SSE 格式)
        Ok(response.bytes_stream().flat_map(|chunk_result| -> futures::stream::BoxStream<'static, Result<Value, ChatClientError>> {
            use futures::stream::iter;
            
            match chunk_result {
                Ok(bytes) => {
                    // 将字节转换为字符串
                    match String::from_utf8(bytes.to_vec()) {
                        Ok(text) => {
                            // 解析 SSE (Server-Sent Events) 格式
                            let chunks: Vec<Value> = text
                                .lines()
                                // 过滤出以 "data: " 开头的行
                                .filter(|line| line.starts_with("data: "))
                                .filter_map(|line| {
                                    // 移除 "data: " 前缀
                                    let data_str = &line[6..];
                                    // 忽略结束标记
                                    if data_str.trim() == "[DONE]" {
                                        None
                                    } else {
                                        // 解析 JSON 数据
                                        serde_json::from_str(data_str).ok()
                                    }
                                })
                                .collect();

                            // 转换为流式输出
                            Box::pin(iter(chunks.into_iter().map(Ok)))
                        }
                        Err(_) => Box::pin(iter(std::iter::once(Err(ChatClientError::InvalidUtf8)))),
                    }
                }
                Err(e) => Box::pin(iter(std::iter::once(Err(ChatClientError::RequestError(e))))),
            }
        }))
    }
}

/// 示例代码
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    // 1. 创建聊天客户端实例
    let client = ChatClient::new("https://open.bigmodel.cn/api/coding/paas/v4", Some("your-api-key-here"));

    // 2. 示例:非流式补全请求
    let messages = vec![
        Message {
            role: "user".to_string(),
            content: "你好,请介绍一下 Rust 语言的优势".to_string(),
        },
    ];

    match client
        .completions("glm-4.7", messages, Some(0.7), Some(0.9), None, Some(1024), None, None, None)
        .await
    {
        Ok(response) => println!("非流式响应: {}", serde_json::to_string_pretty(&response)?),
        Err(e) => println!("非流式请求错误: {}", e),
    }

    // 3. 示例:流式补全请求
    let messages = vec![
        Message {
            role: "user".to_string(),
            content: "用 Rust 写一个简单的 HTTP 服务器示例".to_string(),
        },
    ];

    match client
        .completions_stream("glm-4.7", messages, Some(0.7), Some(0.9), None, Some(1024), None, None, None)
        .await
    {
        Ok(stream) => {
            use futures::StreamExt;
            tokio::pin!(stream);
            println!("开始接收流式响应:");
            while let Some(chunk_result) = stream.next().await {
                match chunk_result {
                    Ok(chunk) => {
                        // 提取并打印内容(根据实际响应结构调整)
                        if let Some(content) = chunk["choices"][0]["delta"]["content"].as_str() {
                            print!("{}", content);
                        }
                    }
                    Err(e) => println!("\n流式响应错误: {}", e),
                }
            }
            println!("\n流式响应接收完成");
        }
        Err(e) => println!("流式请求错误: {}", e),
    }

    Ok(())
}

核心功能解析

1. 错误处理

使用 thiserror 定义了清晰的错误类型,涵盖了:

  • 未授权错误(Unauthorized)
  • 请求错误(RequestError)
  • JSON 解析错误(JsonError)
  • UTF-8 编码错误(InvalidUtf8)

通过 From trait 实现了自动转换,让错误处理更加优雅。

2. 数据结构设计

  • Message:表示单条聊天消息,包含角色(role)和内容(content)
  • CompletionRequest:封装了所有聊天补全请求参数,使用 serdeskip_serializing_if 特性,只序列化非空参数
  • ChatClient:核心客户端结构体,包含基础 URL、API 密钥和 HTTP 客户端

3. 核心方法

非流式调用(completions)

  • 构建完整的请求参数
  • 设置正确的请求头(包含认证信息)
  • 发送 POST 请求并处理响应
  • 解析 JSON 响应并返回

流式调用(completions_stream)

  • 开启 stream: true 参数
  • 使用 bytes_stream() 处理流式响应
  • 解析 SSE (Server-Sent Events) 格式数据
  • 将响应块转换为异步流,方便消费

使用注意事项

  1. 替换 API 密钥 :将示例中的 your-api-key-here 替换为实际的 API 密钥
  2. 模型名称适配 :根据 LiteLLM 配置的模型名称调整 model 参数
  3. 错误处理:实际使用时建议完善错误处理逻辑
  4. 超时设置 :可以通过 set_client_val 自定义 Client,添加超时、代理等配置

总结

这个 Rust 客户端实现了 LiteLLM 的核心调用能力,主要特点包括:

  1. 类型安全:通过 Rust 的类型系统和 serde 序列化框架,确保请求参数的正确性
  2. 异步高效:基于 Tokio 和 reqwest 实现异步请求,支持高并发
  3. 完整的流式处理:正确解析 SSE 格式的流式响应,便于实时展示大模型输出
  4. 优雅的错误处理:使用 thiserror 定义清晰的错误类型,方便调试和处理

该客户端可以直接集成到你的 Rust 项目中,用于与 LiteLLM 代理服务器交互,调用各种大模型服务。

相关推荐
DongLi013 天前
rustlings 学习笔记 -- exercises/05_vecs
rust
番茄灭世神3 天前
Rust学习笔记第2篇
rust·编程语言
shimly1234564 天前
(done) 速通 rustlings(20) 错误处理1 --- 不涉及Traits
rust
shimly1234564 天前
(done) 速通 rustlings(19) Option
rust
@atweiwei4 天前
rust所有权机制详解
开发语言·数据结构·后端·rust·内存·所有权
shimly1234564 天前
(done) 速通 rustlings(24) 错误处理2 --- 涉及Traits
rust
shimly1234564 天前
(done) 速通 rustlings(23) 特性 Traits
rust
shimly1234564 天前
(done) 速通 rustlings(17) 哈希表
rust
shimly1234564 天前
(done) 速通 rustlings(15) 字符串
rust
shimly1234564 天前
(done) 速通 rustlings(22) 泛型
rust