1. 项目说明与整体架构
1.1 项目说明
本项目是一套基于 C++17 开发的工业级跨厂商大模型接入 SDK + 完整前后端对话系统,从底层大模型 API 适配、数据持久化、会话全生命周期管理,到上层 HTTP 服务封装、前端交互落地,形成了完整的端到端业务闭环。项目全程遵循面向对象设计原则与工业级 C++ 开发规范,兼顾了扩展性、稳定性、并发性能与可维护性,同时配套完整的测试体系与部署方案,既可以作为 SDK 嵌入业务系统,也可以直接部署为独立的对话服务使用。
1.2 项目结构图

1.3 实现思路
- 分层设计:采用四层解耦模块化架构,落地门面 + 策略模式,层间职责单一,支持新厂商无侵入扩展,严格遵循开闭原则。
- 核心实现:基于 C++17 实现跨厂商大模型接入、会话管理等核心能力,通过双层存储、细化锁控、RAII 机制优化性能与内存安全。
- 全链路流程:覆盖 SDK 初始化到会话管理全场景,形成端到端闭环,极简接口屏蔽底层差异,大幅降低大模型接入成本。
- 稳定性保障:通过全链路校验、幂等设计、异常兜底保障运行稳定,基于配置驱动与模块化设计保障项目可扩展、可维护、易测试。
2. SDK层的实现(本项目核心)
2.1 接口定义层
2.1.1 common.h
这个文件是整个项目的根基,定义了所有模块共用的数据结构以及类型规范,设计核心是:用最精简的结构,承载全链路的核心数据,同时保证扩展性。
代码总览
cpp
#pragma once
#include <string>
#include <ctime>
#include <vector>
namespace ai_chat_sdk
{
// 消息结构
struct Message
{
std::string _messageId; // 消息ID
std::string _role; // 角色,如user、assistant等
std::string _content; // 消息内容
std::time_t _timestamp; // 消息发送时间戳
// 构造函数
Message(const std::string &role = "", const std::string &content = "")
: _role(role), _content(content), _timestamp(0)
{
}
};
// 模型的公共配置信息
struct Config
{
std::string _modelName; // 模型名称
double _temperature = 0.7; // 温度参数,用于控制生成文本的随机性
int _maxTokens = 2048; // 最大生成令牌数
virtual ~Config() = default; // 添加虚函数主要是为了实现:向下转型时的安全性
};
// 通过API方式接入云端模型
struct APIConfig : public Config
{
std::string _apiKey; // API密钥
std::string _endpoint; // 【新增】模型API基础地址(可选,不填则使用Provider默认地址)
};
// LLM信息
struct ModelInfo
{
std::string _modelName; // 模型名称
std::string _modelDesc; // 模型描述
std::string _provider; // 模型提供者
std::string _endpoint; // 模型API endpoint base url
bool _isAvailable = false; // 模型是否可用
ModelInfo(const std::string &modelName = "", const std::string &modelDesc = "", const std::string &provider = "", const std::string &endpoint = "")
: _modelName(modelName), _modelDesc(modelDesc), _provider(provider), _endpoint(endpoint)
{
}
};
// 会话信息
struct Session
{
std::string _sessionId; // 会话ID
std::string _modelName; // 会话使用的模型名称
std::vector<Message> _messages; // 会话中的消息列表
std::time_t _createdAt; // 会话创建时间戳
std::time_t _updatedAt; // 会话最后更新时间戳
// 构造函数
Session(const std::string &modelName = "")
: _modelName(modelName)
{
}
};
} // end ai_chat_sdk
1. Message 结构体:对话消息的最小单元
cpp
struct Message
{
std::string _messageId; // 消息ID
std::string _role; // 角色,如user、assistant等
std::string _content; // 消息内容
std::time_t _timestamp; // 消息发送时间戳
// 构造函数
Message(const std::string &role = "", const std::string &content = "")
: _role(role), _content(content), _timestamp(0)
{
}
};
设计思路:
_role(角色)、_content(内容)字段是主流大模型API的通用字段,保证了数据结构的通用性,后续新增厂商无需修改消息结构。_messageId(消息唯一 ID)、_timestamp(时间戳)是业务层常用的扩展字段,用于消息追踪、会话历史排序、数据持久化,虽然当前可能没完全用到,但提前预留了扩展性。。- 提供了带默认参数的构造函数,既可以方便地创建空消息,也可以快速初始化
role和content,提升开发效率。
2. Config 与 APIConfig:配置的分层设计,支持多类型模型接入
cpp
// 模型的公共配置信息
struct Config
{
std::string _modelName; // 模型名称
double _temperature = 0.7; // 温度参数,用于控制生成文本的随机性
int _maxTokens = 2048; // 最大生成令牌数
virtual ~Config() = default; // 添加虚函数主要是为了实现:向下转型时的安全性
};
// 通过API方式接入云端模型
struct APIConfig : public Config
{
std::string _apiKey; // API密钥
std::string _endpoint; // 【新增】模型API基础地址(可选,不填则使用Provider默认地址)
};
设计思路:
1. 分层继承设计,支持多类型模型扩展:基类Config定义了所有模型通用的核心配置,子类APIConfig专门针对云端API模型,扩展了_apikey、_endpoint两个成员变量。
2. 虚函数的关键作用:这里将基类Config中的析构函数设计为虚函数 ,就是为了当使用基类指针删除子类对象的时候,能够正确地调用子类的析构函数,以避免内存泄漏的问题。
3. ModelInfo 结构体:对外暴露的模型信息
cpp
struct ModelInfo
{
std::string _modelName; // 模型名称
std::string _modelDesc; // 模型描述
std::string _provider; // 模型提供者
std::string _endpoint; // 模型API endpoint base url
bool _isAvailable = false; // 模型是否可用
ModelInfo(const std::string &modelName = "", const std::string &modelDesc = "", const std::string &provider = "", const std::string &endpoint = "")
: _modelName(modelName), _modelDesc(modelDesc), _provider(provider), _endpoint(endpoint)
{
}
};
设计思路:
**1. 面向调用方的信息封装:**这个结构体式专门给SDK调用方看的,让调用方知道模型的具体信息,而不需要暴露内部的Config等敏感信息。
**2. _isAvailable状态管理:**标记模型是否可用(比如初始化失败、API Key 无效时设为 false),调用方可以根据这个字段做容错处理。
4. Session 结构体:会话的完整生命周期数据
设计思路:
1. 会话的核心数据承载:
_sessionId:会话的唯一标识,用于创建、获取、删除会话,是调用方操作会话的 "句柄"。
_modelName:会话绑定的模型,保证同一个会话里的所有消息都用同一个模型处理,避免模型切换导致的上下文混乱。
_messages:完整的对话历史,是实现多轮对话的核心 ------ 每次发送新消息时,会把历史消息一起发给大模型,保证上下文连贯性。
2. 时间戳的业务价值
_createdAt、_updatedAt用于会话列表的排序、会话的生命周期管理(比如清理长期未使用的会话),也是数据持久化时的必要字段。
2.1.2 LLMProvider.h
这是整个 SDK扩展性设计的灵魂,用抽象基类定义了所有大模型厂商的统一接口规范,不同厂商只需要实现这个接口,就能无缝接入 SDK,完全符合开闭原则(对扩展开放,对修改关闭)。
代码总览
cpp
#pragma once
#include <functional>
#include <string>
#include <map>
#include <vector>
#include "common.h"
namespace ai_chat_sdk
{
// LLMProvider 类(抽象类)
class LLMProvider
{
public:
// 初始化模型
virtual bool initModel(const std::map<std::string, std::string> &modelConfig) = 0;
// 检测模型是否有效
virtual bool isAvailable() const = 0;
// 获取模型名称
virtual std::string getModelName() const = 0;
// 获取模型描述
virtual std::string getModelDesc() const = 0;
// 发送消息 - 全量返回
virtual std::string sendMessage(const std::vector<Message> &messages, const std::map<std::string, std::string> &requestParam) = 0;
// 发送消息 - 增量返回 - 流式响应
virtual std::string sendMessageStream(const std::vector<Message> &messages,
const std::map<std::string, std::string> &requestParam,
std::function<void(const std::string &, bool)> callback) = 0; // callback: 对模型返回的增量数据如何处理,第一个参数为增量数据,第二个参数为是否为最后一个增量数据
protected:
bool _isAvailable = false; // 标记模型是否有效
std::string _apiKey; // API密钥
std::string _endpoint; // 模型API endpoint base url
};
} // end ai_chat_sdk
抽象基类的核心设计
cpp
class LLMProvider
{
public:
// 初始化模型
virtual bool initModel(const std::map<std::string, std::string> &modelConfig) = 0;
// 检测模型是否有效
virtual bool isAvailable() const = 0;
// 获取模型名称
virtual std::string getModelName() const = 0;
// 获取模型描述
virtual std::string getModelDesc() const = 0;
// 发送消息 - 全量返回
virtual std::string sendMessage(const std::vector<Message> &messages, const std::map<std::string, std::string> &requestParam) = 0;
// 发送消息 - 增量返回 - 流式响应
virtual std::string sendMessageStream(const std::vector<Message> &messages,
const std::map<std::string, std::string> &requestParam,
std::function<void(const std::string &, bool)> callback) = 0;
protected:
bool _isAvailable = false; // 标记模型是否有效
std::string _apiKey; // API密钥
std::string _endpoint; // 模型API endpoint base url
};
设计思路与相关亮点
1. 为什么用抽象基类(纯虚函数),而不是普通基类?
这里其实就是策略模式的核心所在了,LLManager是一个规范定义者,而不是一个具体实现者,这里设计所有函数都是纯虚函数就是为了让子类去实现这些函数接口,否则无法实例化,所以这就是在语法层面上保证了规范的强制执行,避免遗漏实现。
2. 接口设计的统一性:屏蔽厂商差异,对内提供一致体验
所有接口的参数和返回值都是统一的,上层调用方只需要和LLMProvider抽象基类交互,不需要知道底层是豆包、DeepSeek 还是其他厂商,完全解耦。
3. 流式返回的核心设计
sendMessageStream的第三个参数std::function<void(const std::string &, bool)> callback,是实现流式对话的关键:
这里就有一个问题,为什么不是直接返回,而必须使用回调机制呢?
这里关键就在于我们需要实现的是流式返回,流式返回是边生成边返回,不是一次性返回完整结果,用回调函数可以在每收到一段增量数据的时候,就立即通知上层处理,而不需要等待完整结果的生成,大大提升了用户体验。
回调参数的设计:第一个参数表示增量数据("你好"、"我是"、"豆包"),第二个参数bool是"是否为最后一段数据",上层可以根据这个标志判断流式返回是否结束,做收尾处理。
2.1.3 ChatSDK.h
这是SDK对外暴露的唯一统一接口,用门面模式吧内部所有复杂的模块封装起来,对外只提供极简的接口,让调用方"用最少的代码,就能完成所有操作"。
代码总览
cpp
#pragma once
#include <memory>
#include <string>
#include <map>
#include <vector>
#include <functional>
#include <unordered_map>
#include "LLMManager.h"
#include "SessionManager.h"
#include "common.h"
namespace ai_chat_sdk {
// ChatSDK 类
class ChatSDK {
public:
// 初始化模型
bool initModels(const std::vector<std::shared_ptr<Config>>& configs);
// 创建会话
std::string createSession(const std::string& modelName);
// 获取指定会话
std::shared_ptr<Session> getSession(const std::string& sessionId);
// 获取所有会话列表
std::vector<std::string> getSessionLists() const;
// 删除指定会话
bool deleteSession(const std::string& sessionId);
// 获取可用的模型信息
std::vector<ModelInfo> getAvailableModels() const;
// 发送消息 - 全量返回
std::string sendMessage(const std::string& sessionId, const std::string& message);
// 发送消息 - 流式返回
std::string sendMessageStream(
const std::string& sessionId,
const std::string& message,
std::function<void(const std::string&, bool)> callback);
private:
// 注册支持的模型(改成无参数)
void registerAllProvider();
// 初始化模型
void initProviders(const std::vector<std::shared_ptr<Config>>& configs);
// 初始化 API 模型
bool initAPIModelProviders(
const std::string& modelName,
const std::shared_ptr<APIConfig>& apiConfig);
private:
bool _initialized = false;
// key: 模型名 value: 配置
std::unordered_map<std::string, std::shared_ptr<Config>> _modelConfigs;
LLMManager _llmManager;
public:
SessionManager _sessionManager;
};
} // namespace ai_chat_sdk
门面模式的核心设计:对外极简,对内复杂
cpp
class ChatSDK {
public:
// 初始化模型
bool initModels(const std::vector<std::shared_ptr<Config>>& configs);
// 创建会话
std::string createSession(const std::string& modelName);
// 获取指定会话
std::shared_ptr<Session> getSession(const std::string& sessionId);
// 获取所有会话列表
std::vector<std::string> getSessionLists() const;
// 删除指定会话
bool deleteSession(const std::string& sessionId);
// 获取可用的模型信息
std::vector<ModelInfo> getAvailableModels() const;
// 发送消息 - 全量返回
std::string sendMessage(const std::string& sessionId, const std::string& message);
// 发送消息 - 流式返回
std::string sendMessageStream(
const std::string& sessionId,
const std::string& message,
std::function<void(const std::string&, bool)> callback);
private:
// 注册支持的模型(改成无参数)
void registerAllProvider();
// 初始化模型
void initProviders(const std::vector<std::shared_ptr<Config>>& configs);
// 初始化 API 模型
bool initAPIModelProviders(
const std::string& modelName,
const std::shared_ptr<APIConfig>& apiConfig);
private:
bool _initialized = false;
// key: 模型名 value: 配置
std::unordered_map<std::string, std::shared_ptr<Config>> _modelConfigs;
LLMManager _llmManager;
public:
SessionManager _sessionManager;
};
设计思路与相关亮点
1. 门面模式的应用:为什么需要这个类?
如果没有ChatSDK,调用方需要直接与LLManager、SessionManager、多个LLMProvider子类进行交互,这样一来,使用门槛就会提高很多。
而ChatSDK作为门面,可以屏蔽底层的复杂逻辑,调用方只需要:
调用initModels初始化模型。
调用createSession创建会话。
调用sendMessage或sendMessageStream发送消息。
三步就可以完成完整的对话,大大降低了使用的门槛,这就是门面模式的核心价值。
2. public 接口的设计:完全贴合调用方的使用流程
初始化阶段:initModels(初始化所有配置的模型)、getAvailableModels(查询有哪些模型可用)。
会话管理阶段:createSession(创建会话)、getSession(获取会话)、getSessionLists(获取所有会话)、deleteSession(删除会话)。
sendMessage(全量返回)、sendMessageStream(流式返回)。
3. private 成员函数的设计:内部逻辑的封装与解耦
registerAllProvider、initProviders、initAPIModelProviders是内部实现细节,对外隐藏,保证了接口的稳定性 ------ 后续内部初始化逻辑变了,只要 public 接口不变,调用方的代码就不需要修改。
同时,这三个函数的职责拆分也很清晰:
registerAllProvider:注册所有支持的厂商Provider,这里就是DeepSeek和豆包,这个函数会把这两个Provider注册到LLManager中。
initProviders:遍历调用方传入的配置,逐个初始化对应的 Provider。
initAPIModelProviders:专门初始化 API 类型的模型,和后续可能的本地模型初始化逻辑解耦,这里说明一下,受限于硬件条件,本项目没有接入本地模型。
4. 成员变量的设计:核心模块的组合与状态管理
_initialized:初始化状态标志,防止重复初始化,保证SDK的状态安全。
_modelConfigs:用unorder_map来存储模型名到配置信息的映射,可以快速查找某个模型的配置,哈希表的查询效率是O(1)。
_llmManager:大模型管理模块,负责管理所有 Provider 的生命周期、请求分发。
_sessionManager:会话管理模块,负责会话的创建、获取、删除、历史消息维护。
注意:这里用的组合而不是继承,这是一个比较关键的思路,也符合"优先使用组合而不是继承"的设计原则,让ChatSDK可以灵活地组合不同的管理模块,避免了被继承关系所束缚。
小结
我们把这三个文件串起来,理清楚整个 SDK 的核心交互链路:
调用方传入Config配置->ChatSDK::initModels->内部调用LLMManager初始化对应的LLMProvider子类。
调用方调用ChatSDK::createSession->内部调用SessionManager创建会话,返回sessionId。
调用方调用ChatSDK::sendMessage->内部通过sessionId从SessionManager获取会话历史->调用LLMManager分发请求给对应的LLMProvider->返回结果给调用方。
2.2 核心实现层
2.2.1 ChatSDK.cpp
整体代码结构概览:
cpp
#include "../include/ChatSDK.h"
#include "../include/DeepSeekProvider.h"
#include "../include/DouBaoProvider.h"
#include "../include/util/myLog.h"
#include <memory>
namespace ai_chat_sdk {
// =============================
// 初始化模型
// =============================
bool ChatSDK::initModels(const std::vector<std::shared_ptr<Config>>& configs)
{
if (configs.empty()) {
ERR("ChatSDK::initModels: configs 为空,初始化失败");
return false;
}
registerAllProvider();
initProviders(configs);
_initialized = true;
return true;
}
// =============================
// 注册 Provider(只做一次)
// =============================
void ChatSDK::registerAllProvider()
{
// DeepSeek
if (!_llmManager.isModelAvailable("deepseek-chat")) {
auto deepseekProvider = std::make_unique<DeepSeekProvider>();
_llmManager.registerProvider("deepseek-chat", std::move(deepseekProvider));
INFO("deepseek-chat provider registered succeeded");
}
// DouBao
if (!_llmManager.isModelAvailable("doubao-seed-2-0-lite-260215")) {
auto doubaoProvider = std::make_unique<DouBaoProvider>();
_llmManager.registerProvider("doubao-seed-2-0-lite-260215", std::move(doubaoProvider));
INFO("doubao-seed-2-0-lite provider registered succeeded");
}
}
// =============================
// 初始化 Provider
// =============================
void ChatSDK::initProviders(const std::vector<std::shared_ptr<Config>>& configs)
{
for (const auto& config : configs)
{
if (!config) {
WARN("ChatSDK::initProviders: 跳过空的配置指针");
continue;
}
if (config->_modelName.empty()) {
WARN("ChatSDK::initProviders: 跳过模型名为空的配置");
continue;
}
auto apiConfig = std::dynamic_pointer_cast<APIConfig>(config);
if (!apiConfig) {
ERR("Config for model {} is not APIConfig", config->_modelName);
continue;
}
if (apiConfig->_modelName != "deepseek-chat" &&
apiConfig->_modelName != "doubao-seed-2-0-lite-260215") {
ERR("Model {} is not supported", apiConfig->_modelName);
continue;
}
initAPIModelProviders(apiConfig->_modelName, apiConfig);
}
}
// =============================
// 初始化 API 模型
// =============================
bool ChatSDK::initAPIModelProviders(const std::string& modelName,
const std::shared_ptr<APIConfig>& apiConfig)
{
if (modelName.empty()) {
ERR("ChatSDK::initAPIModelProviders: modelName is empty");
return false;
}
if (!apiConfig || apiConfig->_apiKey.empty()) {
ERR("ChatSDK::initAPIModelProviders: apiKey is empty");
return false;
}
if (_modelConfigs.find(modelName) != _modelConfigs.end()) {
INFO("Model {} already initialized", modelName);
return true;
}
std::map<std::string, std::string> modelParams;
modelParams["api_key"] = apiConfig->_apiKey;
modelParams["endpoint_id"] = "ep-m-20260218004248-cf9qd";
if (!apiConfig->_endpoint.empty()) {
modelParams["endpoint"] = apiConfig->_endpoint;
}
if (!_llmManager.initModel(modelName, modelParams)) {
ERR("Init model {} failed", modelName);
return false;
}
_modelConfigs[modelName] = apiConfig;
INFO("Model {} init succeeded", modelName);
return true;
}
// =============================
// 创建会话
// =============================
std::string ChatSDK::createSession(const std::string& modelName)
{
if (!_initialized) {
ERR("SDK not initialized");
return "";
}
if (modelName.empty()) {
ERR("modelName is empty");
return "";
}
return _sessionManager.createSession(modelName);
}
// =============================
// 获取会话
// =============================
std::shared_ptr<Session> ChatSDK::getSession(const std::string& sessionId)
{
if (!_initialized || sessionId.empty())
return nullptr;
return _sessionManager.getSession(sessionId);
}
// =============================
// 删除会话
// =============================
bool ChatSDK::deleteSession(const std::string& sessionId)
{
if (!_initialized || sessionId.empty())
return false;
return _sessionManager.deleteSession(sessionId);
}
// =============================
// 会话列表
// =============================
std::vector<std::string> ChatSDK::getSessionLists() const
{
if (!_initialized)
return {};
return _sessionManager.getSessionLists();
}
// =============================
// 可用模型
// =============================
std::vector<ModelInfo> ChatSDK::getAvailableModels() const
{
return _llmManager.getAvailableModels();
}
// =============================
// 普通发送
// =============================
std::string ChatSDK::sendMessage(const std::string& sessionId,
const std::string& message)
{
if (!_initialized || sessionId.empty() || message.empty())
return "";
auto session = _sessionManager.getSession(sessionId);
if (!session)
return "";
Message userMsg("user", message);
_sessionManager.addMessage(sessionId, userMsg);
auto history = _sessionManager.getHistroyMessages(sessionId);
auto it = _modelConfigs.find(session->_modelName);
if (it == _modelConfigs.end())
return "";
std::map<std::string, std::string> requestParam;
requestParam["temperature"] = std::to_string(it->second->_temperature);
requestParam["max_tokens"] = std::to_string(it->second->_maxTokens);
auto response = _llmManager.sendMessage(session->_modelName,
history,
requestParam);
if (response.empty())
return "";
Message assistantMsg("assistant", response);
_sessionManager.addMessage(sessionId, assistantMsg);
_sessionManager.updateSessionTimestamp(sessionId);
return response;
}
// =============================
// 流式发送
// =============================
std::string ChatSDK::sendMessageStream(
const std::string& sessionId,
const std::string& message,
std::function<void(const std::string&, bool)> callback)
{
if (!_initialized || sessionId.empty() || message.empty())
return "";
auto session = _sessionManager.getSession(sessionId);
if (!session)
return "";
Message userMsg("user", message);
_sessionManager.addMessage(sessionId, userMsg);
auto history = _sessionManager.getHistroyMessages(sessionId);
auto it = _modelConfigs.find(session->_modelName);
if (it == _modelConfigs.end())
return "";
std::map<std::string, std::string> requestParam;
requestParam["temperature"] = std::to_string(it->second->_temperature);
requestParam["max_tokens"] = std::to_string(it->second->_maxTokens);
auto response = _llmManager.sendMessageStream(
session->_modelName,
history,
requestParam,
callback);
// 只存储有效的助手回复(非空且不包含错误关键词)
if (!response.empty() &&
response.find("error") == std::string::npos &&
response != "(模型返回空回复)") {
Message assistantMsg("assistant", response);
_sessionManager.addMessage(sessionId, assistantMsg);
_sessionManager.updateSessionTimestamp(sessionId);
} else {
WARN("ChatSDK: 丢弃无效的助手回复: {}", response);
}
return response;
}
} // namespace ai_chat_sdk
这部分按照功能可以分为4个核心模块:
- **初始化模块:**initModels、registerAllProvider、initProviders、initAPIModelProviders
- **会话管理模块:**createSession、getSession、deleteSession、getSessionLists
- **对话交互模块:**sendMessage(全量)、sendMessageStream(流式)
- **辅助查询模块:**getAvailableModels
下面我们来依次分析每个模块;
初始化模块:SDK 的启动流程,工程化细节的集中体现
设计思路:
1. initModels:对外统一初始化入口
cpp
bool ChatSDK::initModels(const std::vector<std::shared_ptr<Config>>& configs)
{
// 【工程化亮点1:前置参数校验】
// 入口处先做最基础的参数检查,快速失败,避免后续无效逻辑
if (configs.empty()) {
ERR("ChatSDK::initModels: configs 为空,初始化失败"); // 【工程化亮点2:日志分级记录】
return false;
}
// 【内部调用1:注册所有支持的厂商Provider】
registerAllProvider();
// 【内部调用2:根据配置初始化对应模型】
initProviders(configs);
// 【工程化亮点3:状态标记】
_initialized = true;
return true;
}
这里涉及了前置参数的校验,是防御性编程的核心思想,将错误的参数在函数入口处就被拦截;其次就是日志分级,这里项目中统一封装了日志功能,出问题时便于快速检索问题。
2. registerAllProvider:注册所有支持的厂商(策略模式的落地)
cpp
void ChatSDK::registerAllProvider()
{
// DeepSeek 注册
// 【工程化亮点1:幂等性检查】
// 先检查模型是否已经注册过,避免重复注册,保证函数的幂等性
if (!_llmManager.isModelAvailable("deepseek-chat")) {
// 【工程化亮点2:智能指针 + 移动语义】
// 1. 用 std::make_unique 创建独占智能指针,自动管理内存,避免内存泄漏
auto deepseekProvider = std::make_unique<DeepSeekProvider>();
// 2. 用 std::move 把智能指针的所有权转移给 LLMManager,避免拷贝,提升性能
_llmManager.registerProvider("deepseek-chat", std::move(deepseekProvider));
INFO("deepseek-chat provider registered succeeded");
}
// DouBao 注册(逻辑完全一致)
if (!_llmManager.isModelAvailable("doubao-seed-2-0-lite-260215")) {
auto doubaoProvider = std::make_unique<DouBaoProvider>();
_llmManager.registerProvider("doubao-seed-2-0-lite-260215", std::move(doubaoProvider));
INFO("doubao-seed-2-0-lite provider registered succeeded");
}
}
- 函数幂等性检查
这里做幂等性检查,就是为了防止调用方重复调用initModels,导致重复注册 Provider,浪费内存,甚至出现状态混乱。
- unique_ptr+move的常见组合
首先是unique_ptr,它是独占智能指针,是RAII(资源获取即初始化)的典型应用,可以很好地避免内存泄漏的问题,同一时间智能有一个unique_ptr指向同一个对象,避免了多个指针共享所有权带来的混乱。
其次就是move函数的使用,它的作用在于告诉编译器,某个资源是可以被移动的,因为unique_ptr是禁止拷贝的,所以这里只能进行所有权的转移,移动语义只转移对象的 "控制权"(比如指针的指向),不存在任何拷贝,效率是非常理想的。
3. initProviders:遍历配置,逐个初始化模型
cpp
void ChatSDK::initProviders(const std::vector<std::shared_ptr<Config>>& configs)
{
// 遍历调用方传入的所有配置
for (const auto& config : configs)
{
// 【工程化亮点1:多层级的参数校验】
// 每一步都做检查,跳过无效配置,保证程序的健壮性
if (!config) {
WARN("ChatSDK::initProviders: 跳过空的配置指针");
continue;
}
if (config->_modelName.empty()) {
WARN("ChatSDK::initProviders: 跳过模型名为空的配置");
continue;
}
// 【核心:安全的向下转型】
// 用 std::dynamic_pointer_cast 把基类 Config 安全地向下转型为子类 APIConfig
auto apiConfig = std::dynamic_pointer_cast<APIConfig>(config);
if (!apiConfig) {
ERR("Config for model {} is not APIConfig", config->_modelName);
continue;
}
// 【工程化亮点2:白名单校验】
// 只支持注册过的模型,避免非法模型名,保证安全性
if (apiConfig->_modelName != "deepseek-chat" &&
apiConfig->_modelName != "doubao-seed-2-0-lite-260215") {
ERR("Model {} is not supported", apiConfig->_modelName);
continue;
}
// 【内部调用:初始化具体的API模型】
initAPIModelProviders(apiConfig->_modelName, apiConfig);
}
}
这里值得注意的一个点就是std::dynamic_pointer_cast的使用,这是C++智能指针的安全向下转型工具,它会在运行时检查类型是否匹配,如果匹配则返回转型后的智能指针;否则返回nullptr,整体是很安全的。
4. initAPIModelProviders:初始化具体的 API 模型
cpp
bool ChatSDK::initAPIModelProviders(const std::string& modelName,
const std::shared_ptr<APIConfig>& apiConfig)
{
// 【工程化亮点:继续前置参数校验】
if (modelName.empty()) {
ERR("ChatSDK::initAPIModelProviders: modelName is empty");
return false;
}
if (!apiConfig || apiConfig->_apiKey.empty()) {
ERR("ChatSDK::initAPIModelProviders: apiKey is empty");
return false;
}
// 【幂等性检查:避免重复初始化同一个模型】
if (_modelConfigs.find(modelName) != _modelConfigs.end()) {
INFO("Model {} already initialized", modelName);
return true;
}
// 【核心:把结构化的APIConfig,转成Provider需要的map格式】
// 这里是为了适配LLMProvider::initModel的接口(用map传递配置,更灵活)
std::map<std::string, std::string> modelParams;
modelParams["api_key"] = apiConfig->_apiKey;
modelParams["endpoint_id"] = "ep-m-20260218004248-cf9qd"; // 豆包的endpoint_id
// 【扩展性设计:支持自定义endpoint】
// 如果调用方传了自定义的endpoint,就用调用方的;否则用Provider的默认值
if (!apiConfig->_endpoint.empty()) {
modelParams["endpoint"] = apiConfig->_endpoint;
}
// 【内部调用:让LLMManager完成具体的模型初始化】
if (!_llmManager.initModel(modelName, modelParams)) {
ERR("Init model {} failed", modelName);
return false;
}
// 【核心:缓存模型配置,方便后续快速查找】
// 用std::unordered_map存储,查找时间复杂度O(1)
_modelConfigs[modelName] = apiConfig;
INFO("Model {} init succeeded", modelName);
return true;
}
- 为什么要把APIConfig转成std::map<std::string, std::string>?
这里是为了灵活性与扩展性,如果后续有新的厂商,那么直接添加键值对即可,不需要修改LLMProvider的接口,符合开闭原则。
会话管理模块:门面模式的典型应用,完全复用 SessionManager 的能力
核心代码示例:
cpp
// 创建会话
std::string ChatSDK::createSession(const std::string& modelName)
{
// 【工程化亮点:所有接口都先检查初始化状态】
if (!_initialized) {
ERR("SDK not initialized");
return "";
}
if (modelName.empty()) {
ERR("modelName is empty");
return "";
}
// 【内部调用:直接复用SessionManager的能力】
return _sessionManager.createSession(modelName);
}
// 发送消息(全量)里的会话操作
std::string ChatSDK::sendMessage(const std::string& sessionId, const std::string& message)
{
// 前置校验
if (!_initialized || sessionId.empty() || message.empty())
return "";
// 【内部调用1:获取会话】
auto session = _sessionManager.getSession(sessionId);
if (!session)
return "";
// 【内部调用2:添加用户消息到会话历史】
Message userMsg("user", message);
_sessionManager.addMessage(sessionId, userMsg);
// 【内部调用3:获取完整的历史消息】
auto history = _sessionManager.getHistroyMessages(sessionId);
// ... 中间是LLM调用逻辑 ...
// 【内部调用4:添加助手回复到会话历史】
Message assistantMsg("assistant", response);
_sessionManager.addMessage(sessionId, assistantMsg);
// 【内部调用5:更新会话的最后活跃时间】
_sessionManager.updateSessionTimestamp(sessionId);
return response;
}
这部分代码非常简洁诶,完全体现了门面模式与组合的核心价值:ChatSDK不会去做会话管理的逻辑,只是**做前置状态检查,**然后直接调用_sessionManager的方法,对外提供统一的入口。
设计思路总结:
ChatSDK在会话管理上,完全是一个"转发者+校验者"的角色:
先做统一的前置校验,然后调用_sessionManager封装好的方法,完成具体的对话操作,对外只暴露自己的极简接口,完全屏蔽了SessionManager的存在。
对话交互模块:全链路的核心,组合两个管理模块的能力
这是SDK对外提供的核心能力,代码逻辑是比较清晰的,完全体现了组合的思想:ChatSDK组合了SessionManager和LLMManager的能力,完成端到端的对话。
1. sendMessage:全量返回(非流式)
cpp
std::string ChatSDK::sendMessage(const std::string& sessionId, const std::string& message)
{
// 【步骤1:前置统一校验】
if (!_initialized || sessionId.empty() || message.empty())
return "";
// 【步骤2:从SessionManager获取会话和历史消息】
auto session = _sessionManager.getSession(sessionId);
if (!session)
return "";
// 【步骤3:把用户的新消息加入会话历史】
Message userMsg("user", message);
_sessionManager.addMessage(sessionId, userMsg);
// 【步骤4:获取完整的对话历史,用于多轮对话】
auto history = _sessionManager.getHistroyMessages(sessionId);
// 【步骤5:从缓存里获取模型的配置(temperature、max_tokens)】
auto it = _modelConfigs.find(session->_modelName);
if (it == _modelConfigs.end())
return "";
// 【步骤6:把结构化配置转成map,适配LLMProvider的接口】
std::map<std::string, std::string> requestParam;
requestParam["temperature"] = std::to_string(it->second->_temperature);
requestParam["max_tokens"] = std::to_string(it->second->_maxTokens);
// 【步骤7:内部调用LLMManager,分发请求给对应厂商】
auto response = _llmManager.sendMessage(session->_modelName, history, requestParam);
// 【步骤8:把大模型的回复加入会话历史,更新上下文】
if (response.empty())
return "";
Message assistantMsg("assistant", response);
_sessionManager.addMessage(sessionId, assistantMsg);
_sessionManager.updateSessionTimestamp(sessionId);
// 【步骤9:返回结果给调用方】
return response;
}
2. sendMessageStream:流式返回(核心亮点在于对大模型特性的理解)
cpp
std::string ChatSDK::sendMessageStream(
const std::string& sessionId,
const std::string& message,
std::function<void(const std::string&, bool)> callback)
{
// 【步骤1-7:和全量返回完全一致,前置校验、获取历史、准备参数】
if (!_initialized || sessionId.empty() || message.empty())
return "";
auto session = _sessionManager.getSession(sessionId);
if (!session)
return "";
Message userMsg("user", message);
_sessionManager.addMessage(sessionId, userMsg);
auto history = _sessionManager.getHistroyMessages(sessionId);
auto it = _modelConfigs.find(session->_modelName);
if (it == _modelConfigs.end())
return "";
std::map<std::string, std::string> requestParam;
requestParam["temperature"] = std::to_string(it->second->_temperature);
requestParam["max_tokens"] = std::to_string(it->second->_maxTokens);
// 【核心亮点:流式调用,传入回调函数】
// 把调用方传入的callback,直接透传给LLMManager,再透传给具体的Provider
// Provider每收到一段增量数据,就调用callback通知上层,实现打字机效果
auto response = _llmManager.sendMessageStream(
session->_modelName,
history,
requestParam,
callback);
// 【工程化亮点:对最终结果做有效性校验,只存储有效回复】
// 避免把错误信息、空回复加入会话历史,污染上下文
if (!response.empty() &&
response.find("error") == std::string::npos &&
response != "(模型返回空回复)") {
Message assistantMsg("assistant", response);
_sessionManager.addMessage(sessionId, assistantMsg);
_sessionManager.updateSessionTimestamp(sessionId);
} else {
WARN("ChatSDK: 丢弃无效的助手回复: {}", response);
}
return response;
}
- 为什么流式返回要用std::function回调?
这个问题我们在前面提到过,其实本质就是流式返回的时候是"边生成边返回",这也是现在所有主流大模型的统一回复方式,用回调函数,就是可以在每收到一段增量数据的时候,就立即通知上层处理,而不需要等到完整结果生成,大大提升了用户的体验。
2.2.2 SessionManager.h/.cpp
代码总览
- SessionManager.h
cpp
#pragma once
#include <atomic>
#include <unordered_map>
#include <mutex>
#include <memory>
#include "DataManager.h"
#include "common.h"
namespace ai_chat_sdk {
class SessionManager{
public:
SessionManager(const std::string& dbName = "chatDB.db");
// 创建会话,提供模型名称
std::string createSession(const std::string& modelName);
// 通过会话ID获取会话信息
std::shared_ptr<Session> getSession(const std::string& sessionId);
// 往某个会话中添加消息
bool addMessage(const std::string& sessionId, const Message& message);
// 获取某个会话的所有历史消息
std::vector<Message> getHistroyMessages(const std::string& sessionId)const;
// 更新会话时间戳
void updateSessionTimestamp(const std::string& sessionId);
// 获取会话所有会话列表
std::vector<std::string> getSessionLists()const;
// 删除某个会话
bool deleteSession(const std::string& sessionId);
// 清空所有会话
void clearAllSessions();
// 获取会话总数
size_t getSessionCount()const;
private:
std::string generateSessionId();
std::string generateMessageId(size_t messageCounter);
private:
// 管理所有会话信息,key: 会话ID,value: 会话信息
std::unordered_map<std::string, std::shared_ptr<Session>> _sessions;
mutable std::mutex _mutex;
std::atomic<int64_t> _sessionCounter = {0}; // 记录所有会话总数
DataManager _dataManager;
};
} // end ai_chat_sdk
- SessionManager.cpp
cpp
#include "../include/SessionManager.h"
#include <iomanip>
#include <sstream>
#include <vector>
#include "../include/util/myLog.h"
namespace ai_chat_sdk {
SessionManager::SessionManager(const std::string& dbName)
: _dataManager(dbName)
{
auto sessions = _dataManager.getAllSessions();
for(auto& session : sessions){
_sessions[session->_sessionId] = session;
}
}
std::string SessionManager::generateSessionId(){
_sessionCounter.fetch_add(1);
std::time_t time = std::time(nullptr);
std::ostringstream os;
os<<"session_"<<time<<"_"<<std::setw(8)<<std::setfill('0')<<_sessionCounter;
return os.str();
}
std::string SessionManager::generateMessageId(size_t messageCounter){
messageCounter++;
std::time_t time = std::time(nullptr);
std::ostringstream os;
os<<"msg_"<<time<<"_"<<std::setw(8)<<std::setfill('0')<<messageCounter;
return os.str();
}
std::string SessionManager::createSession(const std::string& modelName){
std::string sessionId;
{
std::lock_guard<std::mutex> lock(_mutex);
sessionId = generateSessionId();
auto session = std::make_shared<Session>(modelName);
session->_sessionId = sessionId;
session->_createdAt = std::time(nullptr);
session->_updatedAt = session->_createdAt;
_sessions[sessionId] = session;
}
// 数据库操作不加锁
auto session = std::make_shared<Session>(modelName);
session->_sessionId = sessionId;
session->_createdAt = std::time(nullptr);
session->_updatedAt = session->_createdAt;
_dataManager.insertSession(*session);
return sessionId;
}
std::shared_ptr<Session> SessionManager::getSession(const std::string& sessionId){
// 先在内存中查找
{
std::unique_lock<std::mutex> lock(_mutex);
auto it = _sessions.find(sessionId);
if(it != _sessions.end()){
auto session = it->second;
lock.unlock();
session->_messages = _dataManager.getSessionMessages(sessionId);
return session;
}
}
auto session = _dataManager.getSession(sessionId);
if(session){
{
std::lock_guard<std::mutex> lock(_mutex);
auto it = _sessions.find(sessionId);
if(it == _sessions.end()){
_sessions[sessionId] = session;
}
}
session->_messages = _dataManager.getSessionMessages(sessionId);
return session;
}
WARN("sessionId = {} not found", sessionId);
return nullptr;
}
bool SessionManager::addMessage(const std::string& sessionId, const Message& message){
std::shared_ptr<Session> session;
{
std::lock_guard<std::mutex> lock(_mutex);
auto it = _sessions.find(sessionId);
if(it == _sessions.end()){
return false;
}
session = it->second;
}
Message msg(message._role, message._content);
msg._messageId = generateMessageId(session->_messages.size());
msg._timestamp = std::time(nullptr);
INFO("message Info: content {} timestamap {}", msg._content, msg._timestamp);
{
std::lock_guard<std::mutex> lock(_mutex);
session->_messages.push_back(msg);
session->_updatedAt = std::time(nullptr);
}
INFO("add message success, sessionId = {}, message.content = {}", sessionId, msg._content);
_dataManager.insertMessage(sessionId, msg);
return true;
}
// ==================== 修改点:过滤错误消息 ====================
std::vector<Message> SessionManager::getHistroyMessages(const std::string& sessionId) const {
std::vector<Message> rawMessages;
{
std::lock_guard<std::mutex> lock(_mutex);
auto it = _sessions.find(sessionId);
if(it != _sessions.end()){
rawMessages = it->second->_messages;
} else {
rawMessages = _dataManager.getSessionMessages(sessionId);
}
}
std::vector<Message> filtered;
for (const auto& msg : rawMessages) {
// 如果是用户消息,无条件保留
if (msg._role == "user") {
filtered.push_back(msg);
continue;
}
// 如果是助手消息,进行过滤
if (msg._role == "assistant") {
// 检查内容是否为空或错误
if (msg._content.empty() ||
msg._content == "(模型返回空回复)" ||
msg._content.find("Failed to parse") != std::string::npos ||
msg._content.find("control character") != std::string::npos ||
msg._content.find("error") != std::string::npos ||
msg._content.find("fail") != std::string::npos ||
msg._content.find("invalid") != std::string::npos ||
msg._content.find("denied") != std::string::npos ||
msg._content.find("unavailable") != std::string::npos ||
msg._content.find("Authentication Fails") != std::string::npos ||
msg._content.find("rate limit") != std::string::npos) {
WARN("过滤掉错误消息: {}", msg._content);
continue;
}
// 如果内容以 '{' 开头(可能是 JSON 错误),也过滤
if (!msg._content.empty() && msg._content[0] == '{') {
WARN("过滤掉疑似 JSON 错误的消息: {}", msg._content);
continue;
}
}
filtered.push_back(msg);
}
return filtered;
}
// ==================== 结束修改 ====================
void SessionManager::updateSessionTimestamp(const std::string& sessionId){
std::time_t timestamp;
{
std::lock_guard<std::mutex> lock(_mutex);
auto it = _sessions.find(sessionId);
if(it != _sessions.end()){
it->second->_updatedAt = std::time(nullptr);
timestamp = it->second->_updatedAt;
} else {
return;
}
}
_dataManager.updateSessionTimestamp(sessionId, timestamp);
}
std::vector<std::string> SessionManager::getSessionLists() const {
auto sessions = _dataManager.getAllSessions();
std::lock_guard<std::mutex> lock(_mutex);
std::vector<std::pair<std::time_t, std::shared_ptr<Session>>> temp;
temp.reserve(_sessions.size() + sessions.size());
for(const auto& pair : _sessions){
temp.emplace_back(pair.second->_updatedAt, pair.second);
}
for(const auto& session : sessions){
if(_sessions.find(session->_sessionId) == _sessions.end()){
temp.emplace_back(session->_updatedAt, session);
}
}
std::sort(temp.begin(), temp.end(), [](const auto& a, const auto& b){
return a.first > b.first;
});
std::vector<std::string> sessionIds;
for(const auto& pair : temp){
sessionIds.push_back(pair.second->_sessionId);
}
return sessionIds;
}
bool SessionManager::deleteSession(const std::string& sessionId){
{
std::lock_guard<std::mutex> lock(_mutex);
auto it = _sessions.find(sessionId);
if(it == _sessions.end()){
return false;
}
_sessions.erase(it);
}
_dataManager.deleteSession(sessionId);
return true;
}
void SessionManager::clearAllSessions(){
{
std::lock_guard<std::mutex> lock(_mutex);
_sessions.clear();
}
_dataManager.clearAllSessions();
}
size_t SessionManager::getSessionCount() const {
std::lock_guard<std::mutex> lock(_mutex);
return _sessions.size();
}
} // end ai_chat_sdk
头文件(SessionManager.h):核心设计亮点先看
头文件中定义的成员变量,直接体现了这个模块的核心实现思路;
cpp
class SessionManager{
private:
// 【核心设计1:内存+数据库双层存储】
// 内存层:std::unordered_map,O(1)查找,高性能
std::unordered_map<std::string, std::shared_ptr<Session>> _sessions;
// 【核心设计2:线程安全三件套】
mutable std::mutex _mutex; // 互斥锁,保护内存操作
std::atomic<int64_t> _sessionCounter; // 原子计数器,生成唯一ID
// 【核心设计3:数据持久化】
DataManager _dataManager; // 数据库管理模块,负责落地存储
};
1. 为什么使用"内存+数据库"双层存储?
- 内存层(_sessions):使用unordered_map,增删查改的效率都是非常高的,用于高频热数据访问。
- 数据库层(_dataManager):用于持久化存储,保证重启后数据不会丢失,用于数据备份。
2. mutable的作用
-
首先需要明确一点,就是const成员函数里是不允许修改成员变量的;但是在这个场景下,后面的成员函数都是需要加锁的(修改_mutex的状态),而它们又是const成员函数,所以这里就产生了冲突。
-
那么mutable的作用就是**允许const成员函数去修改被mutable修饰的成员变量,**这样一来就可以解决这个矛盾了。
3. std::atomic的作用
- std::atomic是C++11中的原子类型,对它的操作是原子的,不可竞争的,不需要加锁,atomic的效率比锁要高不少,_sessionCounter只是一个简单的计数器,完全没必要去加锁保护。
构造函数:程序启动时的数据恢复
cpp
SessionManager::SessionManager(const std::string& dbName)
: _dataManager(dbName) // 【初始化列表:先初始化数据库管理模块】
{
// 【核心:启动时从数据库恢复所有会话到内存】
auto sessions = _dataManager.getAllSessions();
for(auto& session : sessions){
_sessions[session->_sessionId] = session;
}
}
设计思路:
- 程序启动时,先把数据库中所有会话加载到内存中,这样后续访问历史会话就非常高效,因为都是内存操作,不涉及访问数据库的操作。
- 这是"冷启动预热"的典型代表。
ID 生成器:唯一 ID 的生成逻辑
cpp
std::string SessionManager::generateSessionId(){
// 【原子操作:无锁自增,线程安全】
_sessionCounter.fetch_add(1);
// 【ID格式:session_时间戳_8位计数器】
std::time_t time = std::time(nullptr);
std::ostringstream os;
os<<"session_"<<time<<"_"<<std::setw(8)<<std::setfill('0')<<_sessionCounter;
return os.str();
}
- fetch_add(1)是atomic的原子自增操作,线程安全、无锁、性能拉满,这里就需要区别于++_sessionCounter,我们要清楚一个点,++操作并不是原子的,在多线程的情况下会存在数据不一致问题,而atomic这里提供的自增方法在底层规避了这个问题。
核心功能函数:线程安全 + 锁粒度控制的典范
createSession:创建会话(锁粒度控制的典范)
cpp
std::string SessionManager::createSession(const std::string& modelName){
std::string sessionId;
// 【锁的范围1:只保护内存操作,锁粒度极小】
{
std::lock_guard<std::mutex> lock(_mutex); // RAII锁,离开作用域自动释放
sessionId = generateSessionId();
auto session = std::make_shared<Session>(modelName);
session->_sessionId = sessionId;
session->_createdAt = std::time(nullptr);
session->_updatedAt = session->_createdAt;
_sessions[sessionId] = session;
} // 锁在这里自动释放了!
// 【关键:数据库操作不加锁!】
// 数据库操作是慢IO,不应该在锁里做,否则会严重影响并发性能
auto session = std::make_shared<Session>(modelName);
session->_sessionId = sessionId;
session->_createdAt = std::time(nullptr);
session->_updatedAt = session->_createdAt;
_dataManager.insertSession(*session);
return sessionId;
}
1. 为什么用std::lock_guard?
这是一个经典的RAII(资源获取即初始化)应用,std::lock_guard在构造的时候自动加锁,在离开作用域时自动解锁,这样就可以避免忘记解锁导致死锁的问题。
2. 为什么锁的范围这么小?为什么数据库操作不加锁?
这里就是锁粒度优化的核心思想:
- **锁的范围越小,并发性能越高:**只把必须保护的内存操作放在锁里,其他操作都放在锁外面。
- **数据库操作是慢IO:**一次数据库写入需要几毫秒到几十毫秒,这样的操作如果放在锁里面,其他线程都会被阻塞,并发性能就会被严重影响。
addMessage:添加消息(同样的锁粒度控制)
cpp
bool SessionManager::addMessage(const std::string& sessionId, const Message& message){
std::shared_ptr<Session> session;
// 【锁1:只保护查找会话的内存操作】
{
std::lock_guard<std::mutex> lock(_mutex);
auto it = _sessions.find(sessionId);
if(it == _sessions.end()){
return false;
}
session = it->second; // 拿到shared_ptr,引用计数+1,延长生命周期
} // 锁释放
// 【无锁区域:生成消息ID、打日志,都不需要加锁】
Message msg(message._role, message._content);
msg._messageId = generateMessageId(session->_messages.size());
msg._timestamp = std::time(nullptr);
INFO("message Info: content {} timestamap {}", msg._content, msg._timestamp);
// 【锁2:只保护修改内存消息列表的操作】
{
std::lock_guard<std::mutex> lock(_mutex);
session->_messages.push_back(msg);
session->_updatedAt = std::time(nullptr);
} // 锁释放
// 【无锁区域:数据库慢IO操作,不加锁】
INFO("add message success, sessionId = {}, message.content = {}", sessionId, msg._content);
_dataManager.insertMessage(sessionId, msg);
return true;
}
设计思路总结:
这个函数将锁粒度的控制做得比较到位;
- 锁1:只保护查找会话的操作,拿到shared_ptr就释放锁;
- 无锁:生成消息 ID、打日志,这些操作不需要保护;
- 锁2:只保护修改内存消息列表的操作。
- 无锁:数据库慢IO,完全不加锁。
getHistroyMessages:获取历史消息(带错误过滤的业务逻辑)
cpp
std::vector<Message> SessionManager::getHistroyMessages(const std::string& sessionId) const {
std::vector<Message> rawMessages;
// 【锁:保护内存查找操作】
{
std::lock_guard<std::mutex> lock(_mutex);
auto it = _sessions.find(sessionId);
if(it != _sessions.end()){
rawMessages = it->second->_messages;
} else {
rawMessages = _dataManager.getSessionMessages(sessionId);
}
}
// 【核心业务逻辑:过滤错误消息,保证多轮对话质量】
std::vector<Message> filtered;
for (const auto& msg : rawMessages) {
// 用户消息:无条件保留
if (msg._role == "user") {
filtered.push_back(msg);
continue;
}
// 助手消息:严格过滤
if (msg._role == "assistant") {
// 过滤空回复、错误关键词、JSON错误
if (msg._content.empty() ||
msg._content == "(模型返回空回复)" ||
msg._content.find("error") != std::string::npos ||
msg._content[0] == '{') {
WARN("过滤掉错误消息: {}", msg._content);
continue;
}
}
filtered.push_back(msg);
}
return filtered;
}
这个函数是业务质量保障的核心:
- 大模型可能返回错误信息、空回复、JSON格式的错误响应。
- 如果这些错误信息被加入对话历史,会污染多轮对话的上下文,导致后续对话质量无法保证。
- 这里消息经过过滤,只会留下有效的、正常的消息,保证多轮对话的上下文质量。
2.2.3 DataManager.h/.cpp
代码总览:
- DataManager.h
cpp
#pragma once
#include <memory>
#include <sqlite3.h>
#include <string>
#include <mutex>
#include "common.h"
namespace ai_chat_sdk {
class DataManager{
public:
DataManager(const std::string& dbName);
~DataManager();
// Session相关操作
// 插入新会话
bool insertSession(const Session& session);
// 获取指定会话信息
std::shared_ptr<Session> getSession(const std::string& sessionId)const;
// 更新指定会话的时间戳
bool updateSessionTimestamp(const std::string& sessionId, std::time_t timestamp);
// 删除指定会话--注意:删除会话时,也需要删除该会话中管理的所有的消息
bool deleteSession(const std::string& sessionId);
// 获取所有会话id
std::vector<std::string> getAllSessionIds()const;
// 获取所有会话信息
std::vector<std::shared_ptr<Session>> getAllSessions()const;
// 删除所有会话
bool clearAllSessions();
// 获取会话总数
size_t getSessionCount()const;
////////////////////////////////////////////////////////////////////
// Message相关操作
// 插入新消息--注意:插入消息时,需要更新会话的时间戳
bool insertMessage(const std::string& sessionId, const Message& message);
// 获取指定会话的历史消息
std::vector<Message> getSessionMessages(const std::string& sessionId)const;
// 删除指定会话的所有消息
bool deleteSessionMessages(const std::string& sessionId);
private:
// 初始化数据库 -- 创建数据库表
bool initDataBase();
// 执行SQL语句的工具函数
bool executeSQL(const std::string& sql);
private:
// 【核心修改】交换了 _dbName 和 _db 的声明顺序
// 让声明顺序和构造函数初始化列表的顺序保持一致
std::string _dbName;
sqlite3* _db = nullptr;
mutable std::mutex _mutex;
};
}// end ai_chat_sdk
- DataManager.cpp
cpp
#include "../include/DataManager.h"
#include "../include/util/myLog.h"
#include <memory>
#include <mutex>
#include <vector>
namespace ai_chat_sdk {
DataManager::DataManager(const std::string& dbName)
: _dbName(dbName)
, _db(nullptr)
{
// 创建并打开数据库
int rc = sqlite3_open(dbName.c_str(), &_db);
if(rc != SQLITE_OK){
ERR("打开数据库失败:{}", sqlite3_errmsg(_db));
}
INFO("打开数据库成功:{}", dbName);
// 初始化数据库表 - 创建会话表和消息表
if(!initDataBase()){
sqlite3_close(_db);
_db = nullptr;
ERR("初始化数据库表失败");
}
}
DataManager::~DataManager(){
if(_db){
sqlite3_close(_db);
}
}
// 初始化数据库
bool DataManager::initDataBase(){
// 创建会话表
std::string createSessionTable = R"(
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
model_name TEXT NOT NULL,
create_time INTEGER NOT NULL,
update_time INTEGER NOT NULL
);
)";
// 执行创建Sessions表的SQL语句
if(!executeSQL(createSessionTable)){
return false;
}
// 创建消息表
std::string createMessageTable = R"(
CREATE TABLE IF NOT EXISTS messages (
message_id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp INTEGER NOT NULL,
FOREIGN KEY (session_id) REFERENCES sessions(session_id) ON DELETE CASCADE
);
)";
// 执行创建Messages表的SQL语句
if(!executeSQL(createMessageTable)){
return false;
}
return true;
}
bool DataManager::executeSQL(const std::string& sql){
if(!_db){
ERR("数据库未初始化");
return false;
}
char* errMsg = nullptr;
int rc = sqlite3_exec(_db, sql.c_str(), nullptr, nullptr, &errMsg);
if(rc != SQLITE_OK){
ERR("执行SQL语句失败: {}", errMsg);
sqlite3_free(errMsg);
return false;
}
return true;
}
// 插入会话
bool DataManager::insertSession(const Session& session){
std::lock_guard<std::mutex> lock(_mutex);
// 构建SQL语句
std::string insertSQL = R"(
INSERT INTO sessions (session_id, model_name, create_time, update_time)
VALUES (?, ?, ?, ?);
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, insertSQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
ERR("insertSession - 准备语句失败:{}", sqlite3_errmsg(_db));
return false;
}
// 绑定参数
sqlite3_bind_text(stmt, 1, session._sessionId.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 2, session._modelName.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_int64(stmt, 3, static_cast<int64_t>(session._createdAt));
sqlite3_bind_int64(stmt, 4, static_cast<int64_t>(session._updatedAt));
// 执行SQL语句
rc = sqlite3_step(stmt);
if(rc != SQLITE_DONE){
ERR("insertSession - 执行语句失败:{}", sqlite3_errmsg(_db));
sqlite3_finalize(stmt);
return false;
}
// 释放语句
sqlite3_finalize(stmt);
INFO("insertSession - 插入会话成功:{}", session._sessionId);
return true;
}
// 获取指定sessionId的会话信息
std::shared_ptr<Session> DataManager::getSession(const std::string& sessionId)const{
std::lock_guard<std::mutex> lock(_mutex);
// 构建SQL语句
std::string selectSQL = R"(
SELECT model_name, create_time, update_time FROM sessions WHERE session_id = ?;
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, selectSQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
ERR("getSession - 准备语句失败:{}", sqlite3_errmsg(_db));
return nullptr;
}
// 绑定参数
sqlite3_bind_text(stmt, 1, sessionId.c_str(), -1, SQLITE_TRANSIENT);
// 执行SQL语句
rc = sqlite3_step(stmt);
if(rc != SQLITE_DONE){
ERR("getSession - 执行语句失败:{}", sqlite3_errmsg(_db));
sqlite3_finalize(stmt);
return nullptr;
}
// 从结果集中提取数据
std::string modelName = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0));
int64_t createTime = sqlite3_column_int64(stmt, 1);
int64_t updateTime = sqlite3_column_int64(stmt, 2);
// 创建会话对象
auto session = std::make_shared<Session>(modelName);
session->_sessionId = sessionId;
session->_createdAt = static_cast<std::time_t>(createTime);
session->_updatedAt = static_cast<std::time_t>(updateTime);
// 释放语句
sqlite3_finalize(stmt);
INFO("getSession - 获取会话成功:{}", sessionId);
// 获取该会话的所有消息
session->_messages = getSessionMessages(sessionId);
return session;
}
// 更新指定会话的时间戳
bool DataManager::updateSessionTimestamp(const std::string& sessionId, std::time_t timestamp){
std::lock_guard<std::mutex> lock(_mutex);
// 构建SQL语句
std::string updateSQL = R"(
UPDATE sessions SET update_time = ? WHERE session_id = ?;
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, updateSQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
ERR("updateSessionTimestamp - 准备语句失败:{}", sqlite3_errmsg(_db));
return false;
}
// 绑定参数
sqlite3_bind_int64(stmt, 1, static_cast<int64_t>(timestamp));
sqlite3_bind_text(stmt, 2, sessionId.c_str(), -1, SQLITE_TRANSIENT);
// 执行SQL语句
rc = sqlite3_step(stmt);
if(rc != SQLITE_DONE){
ERR("updateSessionTimestamp - 执行语句失败:{}", sqlite3_errmsg(_db));
sqlite3_finalize(stmt);
return false;
}
// 释放语句
sqlite3_finalize(stmt);
INFO("updateSessionTimestamp - 更新会话时间戳成功:{}", sessionId);
return true;
}
// 删除指定会话
bool DataManager::deleteSession(const std::string& sessionId){
std::lock_guard<std::mutex> lock(_mutex);
// 构建SQL语句
std::string deleteSQL = R"(
DELETE FROM sessions WHERE session_id = ?;
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, deleteSQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
ERR("deleteSession - 准备语句失败:{}", sqlite3_errmsg(_db));
return false;
}
// 绑定参数
sqlite3_bind_text(stmt, 1, sessionId.c_str(), -1, SQLITE_TRANSIENT);
// 执行SQL语句
rc = sqlite3_step(stmt);
if(rc != SQLITE_DONE){
ERR("deleteSession - 执行语句失败:{}", sqlite3_errmsg(_db));
sqlite3_finalize(stmt);
return false;
}
// 释放语句
sqlite3_finalize(stmt);
INFO("deleteSession - 删除会话成功:{}", sessionId);
return true;
}
// 获取所有会话id
std::vector<std::string> DataManager::getAllSessionIds()const{
std::lock_guard<std::mutex> lock(_mutex);
// 构建SQL语句
std::string selectSQL = R"(
SELECT session_id FROM sessions ORDER BY update_time DESC;
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, selectSQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
ERR("getAllSessionIds - 准备语句失败:{}", sqlite3_errmsg(_db));
return {};
}
std::vector<std::string> sessionIds;
while(sqlite3_step(stmt) == SQLITE_ROW){
std::string sessionId = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0));
sessionIds.push_back(sessionId);
}
// 释放语句
sqlite3_finalize(stmt);
INFO("getAllSessionIds - 获取所有会话id成功, 会话总数:{}", sessionIds.size());
return sessionIds;
}
// 获取所有session信息,并按照更新时间降序排列
std::vector<std::shared_ptr<Session>> DataManager::getAllSessions()const{
std::lock_guard<std::mutex> lock(_mutex);
// 构建SQL语句
std::string selectSQL = R"(
SELECT session_id, model_name, create_time, update_time FROM sessions ORDER BY update_time DESC;
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, selectSQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
ERR("getAllSessionIds - 准备语句失败:{}", sqlite3_errmsg(_db));
return {};
}
std::vector<std::shared_ptr<Session>> sessions;
while(sqlite3_step(stmt) == SQLITE_ROW){
std::string sessionId = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0));
std::string modelName = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
int64_t createTime = sqlite3_column_int64(stmt, 2);
int64_t updateTime = sqlite3_column_int64(stmt, 3);
auto session = std::make_shared<Session>(modelName);
session->_sessionId = sessionId;
session->_createdAt = static_cast<std::time_t>(createTime);
session->_updatedAt = static_cast<std::time_t>(updateTime);
sessions.push_back(session);
// 历史消息暂时不获取,需要时再通过会话id来进行获取
}
// 释放语句
sqlite3_finalize(stmt);
INFO("getAllSessions - 获取所有会话信息成功, 会话总数:{}", sessions.size());
return sessions;
}
// 获取会话总数
size_t DataManager::getSessionCount()const
{
std::lock_guard<std::mutex> lock(_mutex);
// 准备SQL语句
std::string selectSQL = R"(
SELECT COUNT(*) FROM sessions;
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, selectSQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
ERR("getSessionCount - 准备语句失败:{}", sqlite3_errmsg(_db));
return 0;
}
// 执行SQL语句
rc = sqlite3_step(stmt);
if(rc != SQLITE_ROW){
ERR("getSessionCount - 执行语句失败:{}", sqlite3_errmsg(_db));
sqlite3_finalize(stmt);
return 0;
}
// 获取会话总数
size_t count = sqlite3_column_int64(stmt, 0);
// 释放语句
sqlite3_finalize(stmt);
INFO("getSessionCount - 获取会话总数成功:{}", count);
return count;
}
// 删除所有会话
bool DataManager::clearAllSessions(){
std::lock_guard<std::mutex> lock(_mutex);
// 构建SQL语句
std::string deleteSQL = R"(
DELETE FROM sessions;
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, deleteSQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
ERR("clearAllSessions - 准备语句失败:{}", sqlite3_errmsg(_db));
return false;
}
// 执行SQL语句
rc = sqlite3_step(stmt);
if(rc != SQLITE_DONE){
ERR("clearAllSessions - 执行语句失败:{}", sqlite3_errmsg(_db));
sqlite3_finalize(stmt);
return false;
}
// 释放语句
sqlite3_finalize(stmt);
INFO("clearAllSessions - 删除所有会话成功");
return true;
}
/////////////////////////////////////////////////////////////Messages///////////////////////////////////////
// 插入新消息--注意:插入消息时,需要更新会话的时间戳
bool DataManager::insertMessage(const std::string& sessionId, const Message& message){
std::lock_guard<std::mutex> lock(_mutex);
// 构建SQL语句
std::string insertSQL = R"(
INSERT INTO messages (message_id, session_id, role, content, timestamp)
VALUES (?, ?, ?, ?, ?);
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, insertSQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
ERR("insertMessage - 准备语句失败:{}", sqlite3_errmsg(_db));
return false;
}
// 绑定参数
sqlite3_bind_text(stmt, 1, message._messageId.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 2, sessionId.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 3, message._role.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 4, message._content.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_int64(stmt, 5, static_cast<int64_t>(message._timestamp));
// 执行SQL语句
rc = sqlite3_step(stmt);
if(rc != SQLITE_DONE){
ERR("insertMessage - 执行语句失败:{}", sqlite3_errmsg(_db));
sqlite3_finalize(stmt);
return false;
}
// 同时更新session的update_time
std::string updateSQL = R"(
UPDATE sessions SET update_time = ? WHERE session_id = ?;
)";
// 准备SQL语句
sqlite3_stmt* updateStmt;
rc = sqlite3_prepare_v2(_db, updateSQL.c_str(), -1, &updateStmt, nullptr);
if(rc != SQLITE_OK){
ERR("insertMessage - 准备语句失败:{}", sqlite3_errmsg(_db));
sqlite3_finalize(stmt);
return false;
}
// 绑定参数
sqlite3_bind_int64(updateStmt, 1, static_cast<int64_t>(message._timestamp));
sqlite3_bind_text(updateStmt, 2, sessionId.c_str(), -1, SQLITE_TRANSIENT);
// 执行SQL语句
rc = sqlite3_step(updateStmt);
if(rc != SQLITE_DONE){
ERR("insertMessage - 执行语句失败:{}", sqlite3_errmsg(_db));
sqlite3_finalize(updateStmt);
return false;
}
// 释放语句
sqlite3_finalize(stmt);
sqlite3_finalize(updateStmt);
INFO("insertMessage - 插入消息成功:{}", message._messageId);
return true;
}
// 获取会话中的所有消息
std::vector<Message> DataManager::getSessionMessages(const std::string& sessionId)const
{
std::lock_guard<std::mutex> lock(_mutex);
// 准备SQL语句
std::string selectSQL = R"(
SELECT message_id, role, content, timestamp FROM messages WHERE session_id = ?;
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, selectSQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
ERR("getSessionMessages - 准备语句失败:{}", sqlite3_errmsg(_db));
return {};
}
// 绑定参数
sqlite3_bind_text(stmt, 1, sessionId.c_str(), -1, SQLITE_TRANSIENT);
// 执行SQL语句
std::vector<Message> messages;
while((rc = sqlite3_step(stmt)) == SQLITE_ROW){
Message message;
message._messageId = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0));
message._role = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
message._content = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2));
message._timestamp = static_cast<std::time_t>(sqlite3_column_int64(stmt, 3));
messages.push_back(message);
}
if(rc != SQLITE_DONE){
ERR("getSessionMessages - 执行语句失败:{}", sqlite3_errmsg(_db));
sqlite3_finalize(stmt);
return {};
}
// 释放语句
sqlite3_finalize(stmt);
return messages;
}
// 删除制定会话的历史消息
bool DataManager::deleteSessionMessages(const std::string& sessionId){
std::lock_guard<std::mutex> lock(_mutex);
// 构建SQL语句
std::string deleteSQL = R"(
DELETE FROM messages WHERE session_id = ?;
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, deleteSQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
ERR("deleteSessionMessages - 准备语句失败:{}", sqlite3_errmsg(_db));
return false;
}
// 绑定参数
sqlite3_bind_text(stmt, 1, sessionId.c_str(), -1, SQLITE_TRANSIENT);
// 执行SQL语句
rc = sqlite3_step(stmt);
if(rc != SQLITE_DONE){
ERR("deleteSessionMessages - 执行语句失败:{}", sqlite3_errmsg(_db));
sqlite3_finalize(stmt);
return false;
}
// 释放语句
sqlite3_finalize(stmt);
INFO("deleteSessionMessages - 删除会话消息成功:{}", sessionId);
return true;
}
} // end ai_chat_sdk
这个模块的核心作用就在于消息的持久化存储,是整个SDK层的数据底座(大家可以看一下上面的项目结构图)。提供完整的 CRUD 操作,支持会话 / 消息的生命周期管理,保证程序重启后数据不丢失;同时通过外键约束、参数化查询、线程安全控制,保障数据的一致性、安全性与并发访问稳定性。
头文件设计:职责清晰的接口分层
cpp
class DataManager{
public:
// 【Session相关操作】:会话元信息的增删改查
bool insertSession(const Session& session);
std::shared_ptr<Session> getSession(const std::string& sessionId)const;
bool updateSessionTimestamp(const std::string& sessionId, std::time_t timestamp);
bool deleteSession(const std::string& sessionId);
std::vector<std::string> getAllSessionIds()const;
std::vector<std::shared_ptr<Session>> getAllSessions()const;
bool clearAllSessions();
size_t getSessionCount()const;
// 【Message相关操作】:会话消息的增删改查
bool insertMessage(const std::string& sessionId, const Message& message);
std::vector<Message> getSessionMessages(const std::string& sessionId)const;
bool deleteSessionMessages(const std::string& sessionId);
private:
// 【工具函数】:初始化表、执行SQL
bool initDataBase();
bool executeSQL(const std::string& sql);
private:
// 【核心成员】
std::string _dbName; // 数据库文件名
sqlite3* _db = nullptr; // SQLite数据库连接指针
mutable std::mutex _mutex; // 互斥锁(mutable保证const函数也能加锁)
};
核心实现拆解
构造与析构:RAII 管理数据库连接
cpp
// 构造函数:打开数据库 + 初始化表
DataManager::DataManager(const std::string& dbName)
: _dbName(dbName), _db(nullptr)
{
// 打开数据库
int rc = sqlite3_open(dbName.c_str(), &_db);
if(rc != SQLITE_OK){
ERR("打开数据库失败:{}", sqlite3_errmsg(_db));
}
// 初始化表(如果不存在则创建)
if(!initDataBase()){
sqlite3_close(_db);
_db = nullptr;
ERR("初始化数据库表失败");
}
}
// 析构函数:自动关闭数据库连接
DataManager::~DataManager(){
if(_db){
sqlite3_close(_db);
}
}
**设计亮点:**用 RAII(资源获取即初始化)管理数据库连接,构造时打开,析构时自动关闭,完全避免资源泄漏。
数据库表设计:外键约束保证数据一致性
cpp
// 会话表:存储会话元信息
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY, // 会话唯一标识
model_name TEXT NOT NULL, // 绑定的模型名
create_time INTEGER NOT NULL, // 创建时间戳
update_time INTEGER NOT NULL // 更新时间戳
);
// 消息表:存储会话消息
CREATE TABLE IF NOT EXISTS messages (
message_id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp INTEGER NOT NULL,
// 【核心亮点】外键约束 + 级联删除
FOREIGN KEY (session_id) REFERENCES sessions(session_id) ON DELETE CASCADE
);
ON DELETE CASCADE:删除会话时,自动删除该会话下的所有消息,无需手动清理,保证数据一致性;
时间戳字段:支持会话按活跃时间排序,提升用户体验。
核心函数:参数化查询防 SQL 注入
以insertSession为例,所有 SQL 操作都使用参数化查询,而非字符串拼接:
cpp
bool DataManager::insertSession(const Session& session){
std::lock_guard<std::mutex> lock(_mutex);
// 【参数化查询】用?占位符,而非字符串拼接
std::string insertSQL = R"(
INSERT INTO sessions (session_id, model_name, create_time, update_time)
VALUES (?, ?, ?, ?);
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, insertSQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK) return false;
// 【绑定参数】安全替换占位符
sqlite3_bind_text(stmt, 1, session._sessionId.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 2, session._modelName.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_int64(stmt, 3, static_cast<int64_t>(session._createdAt));
sqlite3_bind_int64(stmt, 4, static_cast<int64_t>(session._updatedAt));
// 执行SQL + 释放语句
rc = sqlite3_step(stmt);
sqlite3_finalize(stmt);
return rc == SQLITE_DONE;
}
设计思路:
- 参数化查询:完全防止了SQL注入,保证数据安全。
- SQLITE_TRANSIENT:SQLite 自动管理绑定的字符串内存,无需手动释放。
线程安全:mutable mutex 保证 const 正确性
所有数据库操作都用std::lock_guard加锁,且_mutex声明为mutable:
cpp
// const成员函数也能加锁
std::vector<std::string> DataManager::getAllSessionIds()const{
std::lock_guard<std::mutex> lock(_mutex); // mutable mutex允许在const函数中修改
// ... 执行SQL ...
}
设计思路:
- mutable:可以让const成员函数修改成员变量_mutex(进行加锁),保证线程安全。
- std::lock_guard:RAII自动管理锁,自动释放,避免死锁。
2.2.4 LLMManager.h/.cpp
代码总览:
LLMManager.h
cpp
#pragma once
#include <map>
#include <memory>
#include <mutex>
#include "LLMProvider.h"
namespace ai_chat_sdk
{
class LLMManager
{
public:
// 注册LLM提供者
bool registerProvider(const std::string &modelName, std::unique_ptr<LLMProvider> provider);
// 初始化指定模型
bool initModel(const std::string &modelName, const std::map<std::string, std::string> &modelParam);
// 获取可用模型
std::vector<ModelInfo> getAvailableModels() const;
// 检查模型是否可用
bool isModelAvailable(const std::string &modelName) const;
// 发送消息给指定模型
std::string sendMessage(const std::string &modelName, const std::vector<Message> &messages, const std::map<std::string, std::string> &requestParam);
// 发送消息流给指定模型
std::string sendMessageStream(const std::string &modelName, const std::vector<Message> &messages, const std::map<std::string, std::string> &requestParam, std::function<void(const std::string &, bool)> &callback);
private:
mutable std::mutex _mutex; // 新增互斥锁
// key: 模型名称 value: 模型提供器
std::map<std::string, std::unique_ptr<LLMProvider>> _providers;
// key: 模型名称 value: 模型信息
std::map<std::string, ModelInfo> _modelInfos;
};
}
LLMManager.cpp
cpp
#include "../include/LLMManager.h"
#include "../include/util/myLog.h"
#include "../include/common.h"
namespace ai_chat_sdk
{
bool LLMManager::registerProvider(const std::string &modelName, std::unique_ptr<LLMProvider> provider)
{
if (!provider) {
ERR("cannot register nullptr provider, modelName = {}", modelName);
return false;
}
std::lock_guard<std::mutex> lock(_mutex);
_providers[modelName] = std::move(provider);
_modelInfos[modelName] = ModelInfo(modelName);
INFO("register provider success, modelName = {}", modelName);
return true;
}
bool LLMManager::initModel(const std::string &modelName, const std::map<std::string, std::string> &modelParam)
{
std::lock_guard<std::mutex> lock(_mutex);
auto it = _providers.find(modelName);
if (it == _providers.end()) {
ERR("model provider not found, modelName = {}", modelName);
return false;
}
bool isSuccess = it->second->initModel(modelParam);
if (!isSuccess) {
ERR("init model failed, modelName = {}", modelName);
} else {
INFO("init model success, modelName = {}", modelName);
_modelInfos[modelName]._modelDesc = it->second->getModelDesc();
_modelInfos[modelName]._isAvailable = true;
}
return isSuccess;
}
std::vector<ModelInfo> LLMManager::getAvailableModels() const
{
std::lock_guard<std::mutex> lock(_mutex);
std::vector<ModelInfo> models;
for (const auto &pair : _modelInfos) {
if (pair.second._isAvailable) {
models.push_back(pair.second);
}
}
return models;
}
bool LLMManager::isModelAvailable(const std::string &modelName) const
{
std::lock_guard<std::mutex> lock(_mutex);
auto it = _modelInfos.find(modelName);
if (it == _modelInfos.end()) return false;
return it->second._isAvailable;
}
std::string LLMManager::sendMessage(const std::string &modelName, const std::vector<Message> &messages, const std::map<std::string, std::string> &requestParam)
{
LLMProvider* provider = nullptr;
{
std::lock_guard<std::mutex> lock(_mutex);
auto it = _providers.find(modelName);
if (it == _providers.end()) {
ERR("model provider not found, modelName = {}", modelName);
return "";
}
if (!it->second->isAvailable()) {
ERR("model not available, modelName = {}", modelName);
return "";
}
provider = it->second.get();
}
return provider->sendMessage(messages, requestParam);
}
std::string LLMManager::sendMessageStream(const std::string &modelName, const std::vector<Message> &messages, const std::map<std::string, std::string> &requestParam, std::function<void(const std::string &, bool)> &callback)
{
LLMProvider* provider = nullptr;
{
std::lock_guard<std::mutex> lock(_mutex);
auto it = _providers.find(modelName);
if (it == _providers.end()) {
ERR("model provider not found, modelName = {}", modelName);
return "";
}
if (!it->second->isAvailable()) {
ERR("model not available, modelName = {}", modelName);
return "";
}
provider = it->second.get();
}
return provider->sendMessageStream(messages, requestParam, callback);
}
}
核心作用:
作为整个 SDK 的策略模式容器与请求分发核心,负责管理所有LLMProvider子类实例,根据模型名动态路由请求到对应的 Provider 实现;同时通过双 map 设计、极致锁粒度控制,兼顾扩展性、性能与线程安全。
头文件设计:职责分离的双 map 存储
cpp
class LLMManager
{
public:
// 【Provider管理】注册、初始化模型
bool registerProvider(const std::string &modelName, std::unique_ptr<LLMProvider> provider);
bool initModel(const std::string &modelName, const std::map<std::string, std::string> &modelParam);
// 【查询接口】获取可用模型、检查状态
std::vector<ModelInfo> getAvailableModels() const;
bool isModelAvailable(const std::string &modelName) const;
// 【请求分发】发送消息(全量/流式)
std::string sendMessage(const std::string &modelName, const std::vector<Message> &messages, const std::map<std::string, std::string> &requestParam);
std::string sendMessageStream(const std::string &modelName, const std::vector<Message> &messages, const std::map<std::string, std::string> &requestParam, std::function<void(const std::string &, bool)> &callback);
private:
// 【线程安全】mutable保证const函数也能加锁
mutable std::mutex _mutex;
// 【双map职责分离】
std::map<std::string, std::unique_ptr<LLMProvider>> _providers; // 存储"干活的人"
std::map<std::string, ModelInfo> _modelInfos; // 存储"人的信息"
};
设计思路:
- 双map职责分离:_providers存Provider实例(负责API调用),_modelInfos存模型元信息(方便快速查询)。
核心实现拆解
registerProvider:智能指针 + 移动语义的典范
cpp
bool LLMManager::registerProvider(const std::string &modelName, std::unique_ptr<LLMProvider> provider)
{
// 前置校验:防止注册空指针
if (!provider) {
ERR("cannot register nullptr provider, modelName = {}", modelName);
return false;
}
// 加锁:保护双map的修改
std::lock_guard<std::mutex> lock(_mutex);
// 【核心】std::move转移unique_ptr所有权
// 把Provider的所有权从调用方转移到LLMManager的_providers map里
_providers[modelName] = std::move(provider);
// 同步初始化ModelInfo
_modelInfos[modelName] = ModelInfo(modelName);
return true;
}
设计思路:
- unique_ptr禁止拷贝,使用std::move转移所有权,不涉及拷贝,性能极高。
- 双 map 同步更新:注册 Provider 时同步初始化 ModelInfo,保证数据一致性。
initModel:多态调用的典范
cpp
bool LLMManager::initModel(const std::string &modelName, const std::map<std::string, std::string> &modelParam)
{
std::lock_guard<std::mutex> lock(_mutex);
// 步骤1:根据模型名找到对应的Provider
auto it = _providers.find(modelName);
if (it == _providers.end()) return false;
// 【核心】多态调用Provider的initModel
// it->second是unique_ptr<LLMProvider>(基类指针)
// 实际调用的是子类实现(DouBaoProvider::initModel或DeepSeekProvider::initModel)
bool isSuccess = it->second->initModel(modelParam);
// 步骤3:根据结果更新ModelInfo
if (isSuccess) {
_modelInfos[modelName]._modelDesc = it->second->getModelDesc();
_modelInfos[modelName]._isAvailable = true;
}
return isSuccess;
}
设计思路:
- 多态调用:通过基类指针调用虚函数,运行时动态选择子类实现,完美落地策略模式;
sendMessage/sendMessageStream:极致锁粒度控制的典范
两个函数的逻辑是一致的,这里我们以sendMessage为例;
cpp
std::string LLMManager::sendMessage(const std::string &modelName,
const std::vector<Message> &messages,
const std::map<std::string, std::string> &requestParam)
{
LLMProvider* provider = nullptr;
// 【锁的范围:极小!只保护查找Provider和检查可用性】
{
std::lock_guard<std::mutex> lock(_mutex);
// 步骤1:找到Provider
auto it = _providers.find(modelName);
if (it == _providers.end()) return "";
// 步骤2:检查可用性
if (!it->second->isAvailable()) return "";
// 步骤3:拿到Provider的裸指针(所有权不转移)
provider = it->second.get();
} // 锁在这里释放了!
// 【关键:实际的API请求调用,完全不加锁!】
// 大模型API是慢IO(可能几秒),绝对不能在锁里做
// Provider本身无共享状态(或已自己加锁),可以安全调用
return provider->sendMessage(messages, requestParam);
}
设计思路:
- 极致锁粒度:只保护查找Provider、检查可用性的内存操作,API请求属于慢IO,不加锁,这样就提高的并发性能。
2.2.5 Provider 实现
代码总览:
DeepSeekProvider.cpp
cpp
#include "../include/DeepSeekProvider.h"
#include "../include/util/myLog.h"
#include <string>
#include <map>
#include <vector>
#include <functional>
#include <sstream>
#include <algorithm>
#include <curl/curl.h>
#include <jsoncpp/json/json.h>
#include <jsoncpp/json/reader.h>
namespace ai_chat_sdk {
// 符合 JSON 规范的字符串转义
static std::string escapeJsonString(const std::string &str) {
std::string result;
for (char c : str) {
switch (c) {
case '"': result += "\\\""; break;
case '\\': result += "\\\\"; break;
case '/': result += "\\/"; break;
case '\b': result += "\\b"; break;
case '\f': result += "\\f"; break;
case '\n': result += "\\n"; break;
case '\r': result += "\\r"; break;
case '\t': result += "\\t"; break;
default:
if (c >= 0 && c < 32) {
char hex[7];
snprintf(hex, sizeof(hex), "\\u%04x", (unsigned char)c);
result += hex;
} else {
result += c;
}
break;
}
}
return result;
}
bool DeepSeekProvider::initModel(const std::map<std::string, std::string> &modelParam) {
auto it = modelParam.find("api_key");
if (it == modelParam.end() || it->second.empty()) {
ERR("DeepSeekProvider initModel: api_key为空");
return false;
}
_apiKey = it->second;
it = modelParam.find("endpoint");
_endpoint = (it == modelParam.end() || it->second.empty()) ? "https://api.deepseek.com" : it->second;
if (!_endpoint.empty() && _endpoint.back() == '/')
_endpoint.pop_back();
_isAvailable = true;
INFO("DeepSeekProvider initModel成功,endpoint: {}", _endpoint);
return true;
}
bool DeepSeekProvider::isAvailable() const { return _isAvailable; }
std::string DeepSeekProvider::getModelName() const { return "deepseek-chat"; }
std::string DeepSeekProvider::getModelDesc() const { return "DeepSeek大模型(官方API接入)"; }
// 全量发送(请保留你原有的实现)
std::string DeepSeekProvider::sendMessage(const std::vector<Message> &messages, const std::map<std::string, std::string> &requestParam) {
// ... 你的实现 ...
return "";
}
// libcurl 写回调:将解析出的纯文本内容传递给上层回调,并累积纯文本到 full_response
struct WriteContext {
std::function<void(const std::string&, bool)>* callback;
std::string* full_text; // 累积纯文本内容,用于返回
std::string buffer; // 行缓冲
bool has_sent_content;
};
static size_t WriteCallback(void *contents, size_t size, size_t nmemb, void *userp) {
size_t total = size * nmemb;
WriteContext* ctx = static_cast<WriteContext*>(userp);
std::string chunk(static_cast<char*>(contents), total);
// 按行处理
ctx->buffer += chunk;
size_t pos;
while ((pos = ctx->buffer.find('\n')) != std::string::npos) {
std::string line = ctx->buffer.substr(0, pos);
ctx->buffer.erase(0, pos + 1);
// 去除行尾可能有的 \r
if (!line.empty() && line.back() == '\r') line.pop_back();
if (line.substr(0, 6) == "data: ") {
std::string data = line.substr(6);
if (data == "[DONE]") {
// 流结束,但由主函数发送结束标记,这里不重复发送
continue;
}
if (!data.empty()) {
Json::Value root;
Json::CharReaderBuilder reader;
std::string errs;
std::istringstream s(data);
if (Json::parseFromStream(reader, s, &root, &errs)) {
if (root.isMember("error")) {
std::string errorMsg = root["error"].toStyledString();
ERR("DeepSeek API 返回错误: {}", errorMsg);
(*(ctx->callback))("{\"error\":\"API 错误:" + root["error"]["message"].asString() + "\"}", true);
return 0; // 终止传输
}
if (root.isMember("choices") && root["choices"].isArray() && !root["choices"].empty()) {
auto choice = root["choices"][0];
if (choice.isMember("delta")) {
std::string delta_content;
if (choice["delta"].isObject() && choice["delta"].isMember("content")) {
delta_content = choice["delta"]["content"].asString();
} else if (choice["delta"].isString()) {
delta_content = choice["delta"].asString();
} else if (choice["delta"].isInt()) {
delta_content = std::to_string(choice["delta"].asInt());
WARN("DeepSeekProvider sendMessageStream: delta 为整数, 转换为字符串: {}", delta_content);
}
if (!delta_content.empty()) {
ctx->has_sent_content = true;
// 追加到 full_text
*(ctx->full_text) += delta_content;
// 回调给前端
(*(ctx->callback))(delta_content, false);
INFO("回调内容: {}", delta_content);
}
}
}
} else {
WARN("DeepSeekProvider sendMessageStream: JSON解析失败, 错误: {}, 原始数据: {}", errs, data);
if (data.find("error") != std::string::npos || data.find("failed") != std::string::npos) {
(*(ctx->callback))("{\"error\":\"API 响应异常,请稍后重试\"}", true);
return 0;
}
}
}
} else {
// 非 data: 行,可能是错误响应,记录并忽略
WARN("DeepSeekProvider sendMessageStream: 收到非标准行: {}", line);
}
}
return total;
}
// 流式发送(libcurl 版本)
std::string DeepSeekProvider::sendMessageStream(
const std::vector<Message> &messages,
const std::map<std::string, std::string> &requestParam,
std::function<void(const std::string &, bool)> callback)
{
if (!_isAvailable) {
callback("{\"error\":\"模型未初始化\"}", true);
return "";
}
if (messages.empty()) {
callback("{\"error\":\"消息列表为空\"}", true);
return "";
}
if (!callback) {
ERR("callback为空");
return "";
}
double temperature = 0.7;
int maxTokens = 2048;
auto it = requestParam.find("temperature");
if (it != requestParam.end()) temperature = std::stod(it->second);
it = requestParam.find("max_tokens");
if (it != requestParam.end()) maxTokens = std::stoi(it->second);
// 构造 JSON 请求体
Json::Value body;
body["model"] = getModelName();
Json::Value msgs(Json::arrayValue);
for (const auto& msg : messages) {
Json::Value m;
m["role"] = msg._role;
m["content"] = msg._content;
msgs.append(m);
}
body["messages"] = msgs;
body["stream"] = true;
body["temperature"] = temperature;
body["max_tokens"] = maxTokens;
Json::StreamWriterBuilder writer;
std::string body_str = Json::writeString(writer, body);
// 初始化 libcurl
CURL* curl = curl_easy_init();
if (!curl) {
callback("{\"error\":\"libcurl init failed\"}", true);
return "";
}
struct curl_slist* headers = nullptr;
headers = curl_slist_append(headers, "Content-Type: application/json");
headers = curl_slist_append(headers, ("Authorization: Bearer " + _apiKey).c_str());
headers = curl_slist_append(headers, "Accept: text/event-stream");
std::string url = _endpoint + "/v1/chat/completions";
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_POST, 1L);
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body_str.c_str());
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, 300L);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
WriteContext ctx;
ctx.callback = &callback;
std::string full_text; // 存储纯文本
ctx.full_text = &full_text;
ctx.has_sent_content = false;
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &ctx);
CURLcode res = curl_easy_perform(curl);
long http_code = 0;
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code);
curl_slist_free_all(headers);
curl_easy_cleanup(curl);
if (res != CURLE_OK) {
ERR("curl_easy_perform failed: {}", curl_easy_strerror(res));
callback("{\"error\":\"网络请求失败\"}", true);
return "";
}
if (http_code != 200) {
ERR("HTTP error: {}", http_code);
callback("{\"error\":\"HTTP " + std::to_string(http_code) + "\"}", true);
return "";
}
// 无论是否收到 [DONE],都发送结束标记
callback("", true);
INFO("DeepSeekProvider sendMessageStream: 流式请求成功,收到内容长度: {}", full_text.length());
return full_text;
}
} // namespace ai_chat_sdk
DouBaoProvider.cpp
cpp
#include "../include/DouBaoProvider.h"
#include "../include/util/myLog.h"
#include <string>
#include <map>
#include <vector>
#include <functional>
#include <sstream>
#include <algorithm>
#include <curl/curl.h>
#include <jsoncpp/json/json.h>
#include <jsoncpp/json/reader.h>
namespace ai_chat_sdk {
static std::string escapeJsonString(const std::string &str) {
std::string result;
for (char c : str) {
switch (c) {
case '"': result += "\\\""; break;
case '\\': result += "\\\\"; break;
case '/': result += "\\/"; break;
case '\b': result += "\\b"; break;
case '\f': result += "\\f"; break;
case '\n': result += "\\n"; break;
case '\r': result += "\\r"; break;
case '\t': result += "\\t"; break;
default:
if (c >= 0 && c < 32) {
char hex[7];
snprintf(hex, sizeof(hex), "\\u%04x", (unsigned char)c);
result += hex;
} else {
result += c;
}
break;
}
}
return result;
}
bool DouBaoProvider::initModel(const std::map<std::string, std::string>& modelParam) {
auto it = modelParam.find("api_key");
if (it == modelParam.end() || it->second.empty()) {
ERR("DouBaoProvider initModel: api_key为空");
return false;
}
_apiKey = it->second;
it = modelParam.find("endpoint");
if (it == modelParam.end() || it->second.empty()) {
_endpoint = "https://ark.cn-beijing.volces.com";
} else {
_endpoint = it->second;
}
if (!_endpoint.empty() && _endpoint.back() == '/')
_endpoint.pop_back();
_isAvailable = true;
INFO("DouBaoProvider initModel成功,endpoint: {}", _endpoint);
return true;
}
bool DouBaoProvider::isAvailable() const { return _isAvailable; }
std::string DouBaoProvider::getModelName() const { return "doubao-seed-2-0-lite-260215"; }
std::string DouBaoProvider::getModelDesc() const { return "字节跳动豆包大模型(火山方舟API接入)"; }
std::string DouBaoProvider::sendMessage(const std::vector<Message>& messages, const std::map<std::string, std::string>& requestParam) {
// ... 你的实现 ...
return "";
}
struct WriteContext {
std::function<void(const std::string&, bool)>* callback;
std::string* full_text;
std::string buffer;
bool has_sent_content;
};
static size_t WriteCallback(void *contents, size_t size, size_t nmemb, void *userp) {
size_t total = size * nmemb;
WriteContext* ctx = static_cast<WriteContext*>(userp);
std::string chunk(static_cast<char*>(contents), total);
ctx->buffer += chunk;
size_t pos;
while ((pos = ctx->buffer.find('\n')) != std::string::npos) {
std::string line = ctx->buffer.substr(0, pos);
ctx->buffer.erase(0, pos + 1);
if (!line.empty() && line.back() == '\r') line.pop_back();
if (line.substr(0, 6) == "data: ") {
std::string data = line.substr(6);
if (data == "[DONE]") {
continue;
}
if (!data.empty()) {
Json::Value root;
Json::CharReaderBuilder reader;
std::string errs;
std::istringstream s(data);
if (Json::parseFromStream(reader, s, &root, &errs)) {
if (root.isMember("error")) {
std::string errorMsg = root["error"].toStyledString();
ERR("DouBao API 返回错误: {}", errorMsg);
(*(ctx->callback))("{\"error\":\"API 错误:" + root["error"]["message"].asString() + "\"}", true);
return 0;
}
if (root.isMember("choices") && root["choices"].isArray() && !root["choices"].empty()) {
auto choice = root["choices"][0];
if (choice.isMember("delta") && choice["delta"].isMember("content")) {
std::string delta_content = choice["delta"]["content"].asString();
if (!delta_content.empty()) {
ctx->has_sent_content = true;
*(ctx->full_text) += delta_content;
(*(ctx->callback))(delta_content, false);
}
}
}
} else {
WARN("DouBaoProvider sendMessageStream: JSON解析失败, 错误: {}, 原始数据: {}", errs, data);
if (data.find("error") != std::string::npos || data.find("failed") != std::string::npos) {
(*(ctx->callback))("{\"error\":\"API 响应异常,请稍后重试\"}", true);
return 0;
}
}
}
} else {
WARN("DouBaoProvider sendMessageStream: 收到非标准行: {}", line);
}
}
return total;
}
std::string DouBaoProvider::sendMessageStream(
const std::vector<Message>& messages,
const std::map<std::string, std::string>& requestParam,
std::function<void(const std::string&, bool)> callback)
{
if (!_isAvailable) {
callback("{\"error\":\"模型未初始化\"}", true);
return "";
}
if (messages.empty()) {
callback("{\"error\":\"消息列表为空\"}", true);
return "";
}
if (!callback) {
ERR("callback为空");
return "";
}
double temperature = 0.7;
int maxTokens = 2048;
auto it = requestParam.find("temperature");
if (it != requestParam.end()) temperature = std::stod(it->second);
it = requestParam.find("max_tokens");
if (it != requestParam.end()) maxTokens = std::stoi(it->second);
Json::Value body;
// 注意:豆包模型名称可能需要从配置中获取,这里使用硬编码
body["model"] = "ep-m-20260218004248-cf9qd";
Json::Value msgs(Json::arrayValue);
for (const auto& msg : messages) {
Json::Value m;
m["role"] = msg._role;
m["content"] = msg._content;
msgs.append(m);
}
body["messages"] = msgs;
body["stream"] = true;
body["temperature"] = temperature;
body["max_tokens"] = maxTokens;
Json::StreamWriterBuilder writer;
std::string body_str = Json::writeString(writer, body);
CURL* curl = curl_easy_init();
if (!curl) {
callback("{\"error\":\"libcurl init failed\"}", true);
return "";
}
struct curl_slist* headers = nullptr;
headers = curl_slist_append(headers, "Content-Type: application/json");
headers = curl_slist_append(headers, ("Authorization: Bearer " + _apiKey).c_str());
headers = curl_slist_append(headers, "Accept: text/event-stream");
std::string url = _endpoint + "/api/v3/chat/completions";
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_POST, 1L);
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body_str.c_str());
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, 300L);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
WriteContext ctx;
ctx.callback = &callback;
std::string full_text;
ctx.full_text = &full_text;
ctx.has_sent_content = false;
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &ctx);
CURLcode res = curl_easy_perform(curl);
long http_code = 0;
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code);
curl_slist_free_all(headers);
curl_easy_cleanup(curl);
if (res != CURLE_OK) {
ERR("curl_easy_perform failed: {}", curl_easy_strerror(res));
callback("{\"error\":\"网络请求失败\"}", true);
return "";
}
if (http_code != 200) {
ERR("HTTP error: {}", http_code);
callback("{\"error\":\"HTTP " + std::to_string(http_code) + "\"}", true);
return "";
}
callback("", true);
INFO("DouBaoProvider sendMessageStream: 流式请求成功,收到内容长度: {}", full_text.length());
return full_text;
}
} // namespace ai_chat_sdk
这里两个Provider的结构基本类似,只有3个核心差异:
1. 默认 endpoint 不同:
2. API路径不同:
- DeepSeek:/v1/chat/completions
- 豆包:/api/v3/chat/completions
3. 模型名称不同:
- DeepSeek:deepseek-chat
- 豆包:ep-m-20260218004248-cf9qd(火山方舟的endpoint ID)
下面我们以DeepSeek为例来拆解一下核心实现:
initModel:配置解析与初始化
cpp
bool DeepSeekProvider::initModel(const std::map<std::string, std::string> &modelParam) {
// 步骤1:提取api_key(必填)
auto it = modelParam.find("api_key");
if (it == modelParam.end() || it->second.empty()) return false;
_apiKey = it->second;
// 步骤2:提取endpoint(选填,有默认值)
it = modelParam.find("endpoint");
_endpoint = (it == modelParam.end() || it->second.empty()) ? "https://api.deepseek.com" : it->second;
if (!_endpoint.empty() && _endpoint.back() == '/') _endpoint.pop_back(); // 去除末尾斜杠
// 步骤3:标记为可用
_isAvailable = true;
return true;
}
设计思路:配置驱动,支持自定义 endpoint,方便测试与私有化部署。
sendMessageStream:核心函数
这个函数是整个 SDK 最复杂的部分,分为 3 个阶段:
- **构造请求:**JSON 序列化请求体,设置 libcurl 参数;
- **执行请求:**通过 libcurl 发送 POST 请求,注册写回调函数;
- **解析流式响应:**在写回调中按行解析 SSE 协议,提取增量内容并回调给上层。
阶段 1:构造请求(JSON 序列化 + libcurl 初始化)
cpp
std::string DeepSeekProvider::sendMessageStream(...) {
// ... 前置校验(略)...
// 步骤1:JSON序列化请求体
Json::Value body;
body["model"] = getModelName();
body["stream"] = true; // 开启流式
// ... 填充messages、temperature、max_tokens(略)...
Json::StreamWriterBuilder writer;
std::string body_str = Json::writeString(writer, body);
// 步骤2:初始化libcurl
CURL* curl = curl_easy_init();
struct curl_slist* headers = nullptr;
headers = curl_slist_append(headers, "Content-Type: application/json");
headers = curl_slist_append(headers, ("Authorization: Bearer " + _apiKey).c_str());
headers = curl_slist_append(headers, "Accept: text/event-stream"); // SSE协议头
// 步骤3:设置libcurl参数
std::string url = _endpoint + "/v1/chat/completions";
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_POST, 1L);
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body_str.c_str());
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); // 注册写回调
// ... 执行请求(略)...
}
阶段 2&3:写回调函数(SSE 协议解析 + JSON 反序列化 + 增量回调)
这是最核心的部分,WriteCallback会被 libcurl 多次调用,每次传入一部分响应数据:
cpp
struct WriteContext {
std::function<void(const std::string&, bool)>* callback; // 上层回调
std::string* full_text; // 累积完整响应
std::string buffer; // 行缓冲(处理数据分片)
bool has_sent_content;
};
static size_t WriteCallback(void *contents, size_t size, size_t nmemb, void *userp) {
size_t total = size * nmemb;
WriteContext* ctx = static_cast<WriteContext*>(userp);
std::string chunk(static_cast<char*>(contents), total);
// 【关键1】行缓冲:处理数据分片(一次回调可能只收到半行)
ctx->buffer += chunk;
size_t pos;
while ((pos = ctx->buffer.find('\n')) != std::string::npos) {
std::string line = ctx->buffer.substr(0, pos);
ctx->buffer.erase(0, pos + 1);
if (!line.empty() && line.back() == '\r') line.pop_back(); // 去除\r
// 【关键2】SSE协议解析:只处理"data: "开头的行
if (line.substr(0, 6) == "data: ") {
std::string data = line.substr(6);
if (data == "[DONE]") continue; // 流结束标记
// 【关键3】JSON反序列化:提取delta增量内容
Json::Value root;
Json::CharReaderBuilder reader;
std::istringstream s(data);
if (Json::parseFromStream(reader, s, &root, &errs)) {
if (root.isMember("choices") && !root["choices"].empty()) {
auto choice = root["choices"][0];
if (choice.isMember("delta") && choice["delta"].isMember("content")) {
std::string delta_content = choice["delta"]["content"].asString();
if (!delta_content.empty()) {
// 【关键4】增量回调:把内容传给上层
*(ctx->full_text) += delta_content;
(*(ctx->callback))(delta_content, false);
}
}
}
}
}
}
return total;
}
设计思路:
- **行缓冲处理分片:**网络数据是分片到达的,一次回调只能接受到半行,用buffer累积直到收到完整的\n;
- **SSE协议解析:**严格遵循 SSE 规范,只处理
data:开头的行,忽略其他行; - **增量回调:**解析出delta.content后立即通过callback传给上层,实现"打字机"的效果,用户体验更好。
- **完整响应累积:**同时把所有增量内容累积到full_text,最后返回给调用方,方便需要完整响应的场景。
3. 测试层
代码总览
testSQLite3.cpp
cpp
#include <sqlite3.h>
#include <string>
#include <iostream>
struct StudentInfo{
std::string name;
std::string gender;
int age;
double gap;
StudentInfo(const std::string& name, const std::string& gender, int age, double gap)
:name(name), gender(gender), age(age), gap(gap)
{}
};
class StudentDB{
public:
StudentDB(const std::string& dbName){
// 创建并打开数据库
int rc = sqlite3_open(dbName.c_str(), &_db);
if(rc != SQLITE_OK){
std::cerr<<"打开数据库失败:"<<sqlite3_errmsg(_db)<<std::endl;
sqlite3_close(_db);
}
// 初始化数据库表 - 创建学生信息表
if(!initDataBase()){
sqlite3_close(_db);
}
}
~StudentDB(){
if(_db != nullptr){
sqlite3_close(_db);
}
}
bool insertStudentInfo(const StudentInfo& studentInfo){
// 插入学生信息
std::string insertSQL = R"(
INSERT INTO Student (name, gender, age, gap)
VALUES (?, ?, ?, ?);
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, insertSQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
std::cerr<<"准备语句失败:"<<sqlite3_errmsg(_db)<<std::endl;
return false;
}
// 绑定参数
sqlite3_bind_text(stmt, 1, studentInfo.name.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 2, studentInfo.gender.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_int(stmt, 3, studentInfo.age);
sqlite3_bind_double(stmt, 4, studentInfo.gap);
// 执行SQL语句
rc = sqlite3_step(stmt);
if(rc != SQLITE_DONE){
std::cerr<<"执行语句失败:"<<sqlite3_errmsg(_db)<<std::endl;
sqlite3_finalize(stmt);
return false;
}
// 清理
sqlite3_finalize(stmt);
return true;
}
bool queryStudentInfo(const std::string& name){
// 查询学生信息
std::string querySQL = R"(
SELECT stuid, name, gender, age, gap FROM Student WHERE name = ?;
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, querySQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
std::cerr<<"准备语句失败:"<<sqlite3_errmsg(_db)<<std::endl;
return false;
}
// 绑定参数
sqlite3_bind_text(stmt, 1, name.c_str(), -1, SQLITE_TRANSIENT);
// 执行SQL语句
rc = sqlite3_step(stmt);
if(rc != SQLITE_ROW){
std::cerr<<"执行语句失败:"<<sqlite3_errmsg(_db)<<std::endl;
sqlite3_finalize(stmt);
return false;
}
// 提取结果
int stuid = sqlite3_column_int(stmt, 0);
std::string queryName = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
std::string queryGender = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2));
int queryAge = sqlite3_column_int(stmt, 3);
double queryGap = sqlite3_column_double(stmt, 4);
// 打印结果
std::cout<<"查询到学生信息:"<<std::endl;
std::cout<<"stuid: "<<stuid<<std::endl;
std::cout<<"name: "<<queryName<<std::endl;
std::cout<<"gender: "<<queryGender<<std::endl;
std::cout<<"age: "<<queryAge<<std::endl;
std::cout<<"gap: "<<queryGap<<std::endl;
// 清理
sqlite3_finalize(stmt);
return true;
}
bool queryAllStudentInfo(){
// 查询所有学生信息
std::string querySQL = R"(
SELECT * FROM Student;
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, querySQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
std::cerr<<"准备语句失败:"<<sqlite3_errmsg(_db)<<std::endl;
return false;
}
// 执行SQL语句
rc = sqlite3_step(stmt);
if(rc != SQLITE_ROW && rc != SQLITE_DONE){
std::cerr<<"执行语句失败:"<<sqlite3_errmsg(_db)<<std::endl;
sqlite3_finalize(stmt);
return false;
}
// 提取结果
std::cout<<"-------------所有学生信息-----------------"<<std::endl;
while(rc == SQLITE_ROW){
int stuid = sqlite3_column_int(stmt, 0);
std::string queryName = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
std::string queryGender = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2));
int queryAge = sqlite3_column_int(stmt, 3);
double queryGap = sqlite3_column_double(stmt, 4);
// 打印结果
std::cout<<"查询到学生信息:"<<std::endl;
std::cout<<"stuid: "<<stuid<<" "
<<"name: "<<queryName<<" "
<<"gender: "<<queryGender<<" "
<<"age: "<<queryAge<<" "
<<"gap: "<<queryGap<<std::endl;
// 继续提取下一行
rc = sqlite3_step(stmt);
}
// 检查是否还有更多行
if(rc != SQLITE_DONE){
std::cerr<<"提取结果失败:"<<sqlite3_errmsg(_db)<<std::endl;
sqlite3_finalize(stmt);
return false;
}
// 清理
sqlite3_finalize(stmt);
return true;
}
// 修改学生信息
bool updateStudentInfo(const std::string& name, const StudentInfo& info){
// 更新学生信息
std::string updateSQL = R"(
UPDATE Student SET gender = ?, age = ?, gap = ? WHERE name = ?;
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, updateSQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
std::cerr<<"准备语句失败:"<<sqlite3_errmsg(_db)<<std::endl;
return false;
}
// 绑定参数
sqlite3_bind_text(stmt, 1, info.gender.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_int(stmt, 2, info.age);
sqlite3_bind_double(stmt, 3, info.gap);
sqlite3_bind_text(stmt, 4, name.c_str(), -1, SQLITE_TRANSIENT);
// 执行SQL语句
rc = sqlite3_step(stmt);
if(rc != SQLITE_DONE && rc != SQLITE_ROW){
std::cerr<<"执行语句失败:"<<sqlite3_errmsg(_db)<<std::endl;
sqlite3_finalize(stmt);
return false;
}
// 清理
sqlite3_finalize(stmt);
return true;
}
// 删除学生信息
bool deleteStudentInfo(const std::string& name){
// 删除学生信息
std::string deleteSQL = R"(
DELETE FROM Student WHERE name = ?;
)";
// 准备SQL语句
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(_db, deleteSQL.c_str(), -1, &stmt, nullptr);
if(rc != SQLITE_OK){
std::cerr<<"准备语句失败:"<<sqlite3_errmsg(_db)<<std::endl;
return false;
}
// 绑定参数
sqlite3_bind_text(stmt, 1, name.c_str(), -1, SQLITE_TRANSIENT);
// 执行SQL语句
rc = sqlite3_step(stmt);
if(rc != SQLITE_DONE && rc != SQLITE_ROW){
std::cerr<<"执行语句失败:"<<sqlite3_errmsg(_db)<<std::endl;
sqlite3_finalize(stmt);
return false;
}
// 清理
sqlite3_finalize(stmt);
return true;
}
private:
bool initDataBase(){
const std::string createTableSQL = R"(
CREATE TABLE IF NOT EXISTS Student (
stuid INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
gender TEXT,
age INTEGER,
gap REAL
);
)";
int rc = sqlite3_exec(_db, createTableSQL.c_str(), nullptr, nullptr, nullptr);
if(rc != SQLITE_OK){
std::cerr<<"创建表失败:"<<sqlite3_errmsg(_db)<<std::endl;
return false;
}
return true;
}
private:
sqlite3* _db = nullptr;
};
int main()
{
StudentInfo info1 = {"张三", "男", 18, 3.5};
StudentInfo info2 = {"李四", "女", 19, 3.8};
StudentInfo info3 = {"王五", "男", 20, 4.0};
StudentInfo info4 = {"赵六", "女", 21, 4.2};
StudentDB db("studentDB.db");
db.insertStudentInfo(info1);
db.insertStudentInfo(info2);
db.insertStudentInfo(info3);
db.insertStudentInfo(info4);
// 查询所有学生信息
db.queryAllStudentInfo();
info3.gap = 4.5;
db.updateStudentInfo(info3.name, info3);
db.queryStudentInfo(info3.name);
// 删除学生信息
db.deleteStudentInfo(info4.name);
db.queryAllStudentInfo();
return 0;
}
testLLM.cpp
cpp
#include <gtest/gtest.h>
#include <istream>
#include <memory>
#include <spdlog/common.h>
#include "../sdk/include/DeepSeekProvider.h"
#include "../sdk/include/DouBaoProvider.h"
#include "../sdk/include/util/myLog.h"
#include "../sdk/include/ChatSDK.h"
#include <iostream>
#include <string>
#include <vector>
// 全局变量:用于流式测试的完整回复收集
std::string g_full_stream_response;
// 流式输出回调函数(通用,适配DeepSeek和豆包)
void streamCallback(const std::string& content, bool is_finished) {
if (is_finished) {
std::cout << "\n>>> 流式输出结束" << std::endl;
} else if (!content.empty()) {
std::cout << content; // 实时打印流式内容
std::cout.flush(); // 刷新输出缓冲区,确保实时显示
g_full_stream_response += content; // 收集完整回复用于验证
}
}
// 测试ChatSDK(非流式,保留原有逻辑)
TEST(ChatSDKTest, sendMessage_NonStream) {
auto sdk = std::make_shared<ai_chat_sdk::ChatSDK>();
ASSERT_TRUE(sdk != nullptr);
// 1. 配置DeepSeek模型
auto deepseekConfig = std::make_shared<ai_chat_sdk::APIConfig>();
ASSERT_TRUE(deepseekConfig != nullptr);
deepseekConfig->_modelName = "deepseek-chat";
deepseekConfig->_apiKey = std::getenv("deepseek_apikey");
ASSERT_FALSE(deepseekConfig->_apiKey.empty()) << "请设置环境变量 deepseek_apikey";
deepseekConfig->_temperature = 0.7;
deepseekConfig->_maxTokens = 2048;
// 2. 配置豆包DouBao模型
auto doubaoConfig = std::make_shared<ai_chat_sdk::APIConfig>();
ASSERT_TRUE(doubaoConfig != nullptr);
doubaoConfig->_modelName = "doubao-seed-2-0-lite-260215";
doubaoConfig->_apiKey = std::getenv("doubao_apikey");
ASSERT_FALSE(doubaoConfig->_apiKey.empty()) << "请设置环境变量 doubao_apikey";
doubaoConfig->_temperature = 0.7;
doubaoConfig->_maxTokens = 2048;
// 3. 初始化SDK
std::vector<std::shared_ptr<ai_chat_sdk::Config>> modelConfigs = {
deepseekConfig, doubaoConfig
};
sdk->initModels(modelConfigs);
// 4. 测试DeepSeek非流式
std::cout << "=== 测试DeepSeek模型(非流式) ===" << std::endl;
auto deepseekSessionId = sdk->createSession(deepseekConfig->_modelName);
ASSERT_FALSE(deepseekSessionId.empty()) << "DeepSeek创建会话失败";
std::string message = "你好,请用一句话介绍自己";
auto response = sdk->sendMessage(deepseekSessionId, message);
ASSERT_FALSE(response.empty()) << "DeepSeek返回空响应";
std::cout << ">>> DeepSeek回复:" << response << std::endl;
// 5. 测试豆包非流式
std::cout << "\n=== 测试豆包模型(非流式) ===" << std::endl;
auto doubaoSessionId = sdk->createSession("doubao-seed-2-0-lite-260215");
ASSERT_FALSE(doubaoSessionId.empty()) << "豆包创建会话失败";
message = "你好,请用一句话介绍自己";
response = sdk->sendMessage(doubaoSessionId, message);
ASSERT_FALSE(response.empty()) << "豆包返回空响应";
std::cout << ">>> 豆包回复:" << response << std::endl;
// 6. 验证会话历史
auto deepseekSession = sdk->getSession(deepseekSessionId);
ASSERT_TRUE(deepseekSession != nullptr) << "DeepSeek会话不存在";
auto deepseekMessages = deepseekSession->_messages;
ASSERT_FALSE(deepseekMessages.empty()) << "DeepSeek会话历史为空";
auto doubaoSession = sdk->getSession(doubaoSessionId);
ASSERT_TRUE(doubaoSession != nullptr) << "豆包会话不存在";
auto doubaoMessages = doubaoSession->_messages;
ASSERT_FALSE(doubaoMessages.empty()) << "豆包会话历史为空";
}
// 测试ChatSDK(流式输出,新增)
TEST(ChatSDKTest, sendMessage_Stream) {
auto sdk = std::make_shared<ai_chat_sdk::ChatSDK>();
ASSERT_TRUE(sdk != nullptr);
// 1. 配置模型(复用非流式的配置)
auto deepseekConfig = std::make_shared<ai_chat_sdk::APIConfig>();
deepseekConfig->_modelName = "deepseek-chat";
deepseekConfig->_apiKey = std::getenv("deepseek_apikey");
ASSERT_FALSE(deepseekConfig->_apiKey.empty()) << "请设置环境变量 deepseek_apikey";
deepseekConfig->_temperature = 0.7;
deepseekConfig->_maxTokens = 2048;
auto doubaoConfig = std::make_shared<ai_chat_sdk::APIConfig>();
doubaoConfig->_modelName = "doubao-seed-2-0-lite-260215";
doubaoConfig->_apiKey = std::getenv("doubao_apikey");
ASSERT_FALSE(doubaoConfig->_apiKey.empty()) << "请设置环境变量 doubao_apikey";
doubaoConfig->_temperature = 0.7;
doubaoConfig->_maxTokens = 2048;
// 2. 初始化SDK
std::vector<std::shared_ptr<ai_chat_sdk::Config>> modelConfigs = {
deepseekConfig, doubaoConfig
};
sdk->initModels(modelConfigs);
// 3. 测试DeepSeek流式输出
std::cout << "\n=== 测试DeepSeek模型(流式输出) ===" << std::endl;
auto deepseekSessionId = sdk->createSession(deepseekConfig->_modelName);
ASSERT_FALSE(deepseekSessionId.empty()) << "DeepSeek创建会话失败";
g_full_stream_response.clear(); // 清空全局变量
std::string message = "请用3句话介绍一下人工智能";
std::cout << ">>> 模型正在回复(流式):" << std::endl;
auto stream_response = sdk->sendMessageStream(deepseekSessionId, message, streamCallback);
ASSERT_FALSE(g_full_stream_response.empty()) << "DeepSeek流式返回空响应";
std::cout << ">>> 完整流式回复:" << g_full_stream_response << std::endl;
// 4. 测试豆包流式输出
std::cout << "\n=== 测试豆包模型(流式输出) ===" << std::endl;
auto doubaoSessionId = sdk->createSession("doubao-seed-2-0-lite-260215");
ASSERT_FALSE(doubaoSessionId.empty()) << "豆包创建会话失败";
g_full_stream_response.clear(); // 清空全局变量
message = "请用3句话介绍一下字节跳动";
std::cout << ">>> 模型正在回复(流式):" << std::endl;
stream_response = sdk->sendMessageStream(doubaoSessionId, message, streamCallback);
ASSERT_FALSE(g_full_stream_response.empty()) << "豆包流式返回空响应";
std::cout << ">>> 完整流式回复:" << g_full_stream_response << std::endl;
// 5. 验证会话历史
auto deepseekSession = sdk->getSession(deepseekSessionId);
ASSERT_TRUE(deepseekSession != nullptr) << "DeepSeek会话不存在";
auto deepseekMessages = deepseekSession->_messages;
ASSERT_FALSE(deepseekMessages.empty()) << "DeepSeek会话历史为空";
auto doubaoSession = sdk->getSession(doubaoSessionId);
ASSERT_TRUE(doubaoSession != nullptr) << "豆包会话不存在";
auto doubaoMessages = doubaoSession->_messages;
ASSERT_FALSE(doubaoMessages.empty()) << "豆包会话历史为空";
}
int main(int argc, char **argv) {
// 初始化spdlog日志库
bite::Logger::initLogger("testLLM", "stdout", spdlog::level::debug);
// 初始化gtest库
testing::InitGoogleTest(&argc, argv);
// 执行所有的测试用例
return RUN_ALL_TESTS();
}
核心作用
本层主要涉及两个测试文件:
- 前置测试阶段(testSQLite3.cpp):这个文件主要就是测试数据库能否被正常使用,在开发DataManager之前,先通过一个独立的 StudentDB demo 验证 SQLite 的 CRUD 操作是否正确,降低后续开发风险;
- 集成测试阶段(testLLM.cpp):基于GTest框架,对整个SDK进行端到端的集成测试,覆盖非流式对话、流式对话两大核心场景,同时验证双厂商的适配能力,保证SDK的稳定性和可用性。
下面我们主要来关注集成测试阶段,也就是testLLM.cpp模块;
这部分是整个测试层的核心,我们首先来分析几个前置问题:
- GTest/Google Test: 业界主流的C++单元测试框架,提供
TEST宏、断言(ASSERT_TRUE、ASSERT_FALSE等)、测试用例组织等能力; - **环境变量管理API_KEY:**通过std::getenv从环境变量读取deepseek_apikey和doubao_apikey,避免把 API Key 硬编码在代码里,安全性极高;
- **spdlog 日志:**复用 SDK 的日志库,方便调试测试用例。
下面我们来分析这部分的两个核心测试用例:
测试用例 1:sendMessage_NonStream(非流式对话测试)
cpp
TEST(ChatSDKTest, sendMessage_NonStream) {
// 步骤1:创建ChatSDK实例
auto sdk = std::make_shared<ai_chat_sdk::ChatSDK>();
ASSERT_TRUE(sdk != nullptr);
// 步骤2:配置双厂商模型(DeepSeek + 豆包)
auto deepseekConfig = std::make_shared<ai_chat_sdk::APIConfig>();
deepseekConfig->_modelName = "deepseek-chat";
deepseekConfig->_apiKey = std::getenv("deepseek_apikey");
ASSERT_FALSE(deepseekConfig->_apiKey.empty()) << "请设置环境变量 deepseek_apikey";
auto doubaoConfig = std::make_shared<ai_chat_sdk::APIConfig>();
doubaoConfig->_modelName = "doubao-seed-2-0-lite-260215";
doubaoConfig->_apiKey = std::getenv("doubao_apikey");
ASSERT_FALSE(doubaoConfig->_apiKey.empty()) << "请设置环境变量 doubao_apikey";
// 步骤3:初始化SDK(注册Provider + 初始化模型)
std::vector<std::shared_ptr<ai_chat_sdk::Config>> modelConfigs = {deepseekConfig, doubaoConfig};
sdk->initModels(modelConfigs);
// 步骤4:测试DeepSeek非流式对话
auto deepseekSessionId = sdk->createSession(deepseekConfig->_modelName);
ASSERT_FALSE(deepseekSessionId.empty()) << "DeepSeek创建会话失败";
std::string response = sdk->sendMessage(deepseekSessionId, "你好,请用一句话介绍自己");
ASSERT_FALSE(response.empty()) << "DeepSeek返回空响应";
// 步骤5:测试豆包非流式对话(同理,略)
// 步骤6:验证会话历史(DataManager的持久化能力)
auto deepseekSession = sdk->getSession(deepseekSessionId);
ASSERT_TRUE(deepseekSession != nullptr) << "DeepSeek会话不存在";
auto deepseekMessages = deepseekSession->_messages;
ASSERT_FALSE(deepseekMessages.empty()) << "DeepSeek会话历史为空";
}
测试覆盖点:
- ChatSDK初始化流程;
- LLMProvider注册与模型初始化;
- SessionManager的会话创建与管理;
- DataManager的会话/消息持久化;
- DeepSeekProvider 与 DouBaoProvider 的非流式 API 调用。
测试用例 2:sendMessage_Stream(流式对话测试)
这是测试的亮点,专门验证 SDK 的流式返回能力:同时流式返回也是我们最关心的部分,直接影响用户的使用体验感。
cpp
// 全局变量:收集完整的流式响应,用于验证
std::string g_full_stream_response;
// 流式回调函数:实时打印 + 收集完整响应
void streamCallback(const std::string& content, bool is_finished) {
if (is_finished) {
std::cout << "\n>>> 流式输出结束" << std::endl;
} else if (!content.empty()) {
std::cout << content;
std::cout.flush(); // 刷新缓冲区,保证实时显示
g_full_stream_response += content;
}
}
TEST(ChatSDKTest, sendMessage_Stream) {
// ... 配置模型、初始化SDK(复用非流式的逻辑,略)...
// 测试DeepSeek流式对话
auto deepseekSessionId = sdk->createSession(deepseekConfig->_modelName);
ASSERT_FALSE(deepseekSessionId.empty());
g_full_stream_response.clear(); // 清空全局变量
std::cout << ">>> 模型正在回复(流式):" << std::endl;
auto stream_response = sdk->sendMessageStream(deepseekSessionId, "请用3句话介绍一下人工智能", streamCallback);
// 验证:完整响应不为空
ASSERT_FALSE(g_full_stream_response.empty()) << "DeepSeek流式返回空响应";
std::cout << ">>> 完整流式回复:" << g_full_stream_response << std::endl;
// ... 测试豆包流式对话(同理,略)...
}
测试覆盖点:
- SDK 的流式回调机制;
- SSE 协议的解析能力;
- 增量内容的实时传递;
- 完整响应的累积与返回。
4. 应用层
核心作用
这部分作为整个项目的应用层入口,其职责如下:
- **协议适配:**将底层 C++ SDK 的能力,封装成标准的 RESTful HTTP API,供前端页面调用;
- **流程串联:**把 SDK 的初始化、会话管理、消息发送、数据持久化等能力,串联成完整的端到端对话流程;
- **服务化落地:**通过httplib启动HTTP服务,挂载前端静态资源,实现"后端服务+前端页面"的完整产品闭环。
- **流式响应适配:**将 SDK 的 SSE 流式回调,适配成 HTTP 标准的 SSE 协议响应,实现前端 "打字机" 效果。
代码总览
ChatServer.h
cpp
#pragma once
#include <httplib.h>
#include <memory>
#include <atomic>
#include <string>
#include "ChatSDK.h"
namespace ai_chat_server{
// 服务器配置信息
struct ServerConfig {
std::string host; // 监听地址
int port; // 监听端口
std::string logLevel; // 日志级别
std::string logFile; // 日志文件路径(如果有)
double temperature; // 生成温度
int maxTokens; // 最大Token数
std::string deepseekAPIKey;// DeepSeek API密钥
std::string doubaoAPIKey; // 豆包 API密钥
// 【核心新增】两个模型名称成员(和main.cpp对应)
std::string deepseekModelName;
std::string doubaoModelName;
};
class ChatServer{
public:
ChatServer(const ServerConfig& config);
bool start(); // 启动服务器
void stop(); // 停止服务器
bool isRunning()const; // 是否正在运行
private:
// 构造统一响应
std::string buildResponse(const std::string& message, bool success = false);
// 接口处理函数
void handleCreateSessionRequest(const httplib::Request& request, httplib::Response& response);
void handleGetSessionListsRequest(const httplib::Request& request, httplib::Response& response);
void handleGetModelListsRequest(const httplib::Request& request, httplib::Response& response);
void handleDeleteSessionRequest(const httplib::Request& request, httplib::Response& response);
void handleGetHistoryMessagesRequest(const httplib::Request& request, httplib::Response& response);
void handleSendMessageRequest(const httplib::Request& request, httplib::Response& response);
void handleSendMessageStreamRequest(const httplib::Request& request, httplib::Response& response);
// 设置HTTP路由规则
void setHttpRoutes();
private:
ServerConfig _config; // 服务器配置
std::unique_ptr<httplib::Server> _chatServer = nullptr; // HTTP服务器
std::shared_ptr<ai_chat_sdk::ChatSDK> _chatSDK = nullptr; // 聊天SDK
std::atomic<bool> _isRunning = {false}; // 服务运行状态
};
} // end ai_chat_server
ChatServer.cpp
cpp
#include "ChatServer.h"
#include "util/myLog.h"
#include <jsoncpp/json/value.h>
#include <jsoncpp/json/reader.h>
#include <jsoncpp/json/writer.h>
#include <thread>
namespace ai_chat_server
{
ChatServer::ChatServer(const ServerConfig &config) : _config(config)
{
// 初始化日志系统(适配你的bite命名空间日志)
spdlog::level::level_enum logLevel = spdlog::level::info;
if (_config.logLevel == "DEBUG")
logLevel = spdlog::level::debug;
if (_config.logLevel == "WARN")
logLevel = spdlog::level::warn;
if (_config.logLevel == "ERROR")
logLevel = spdlog::level::err;
bite::Logger::initLogger("ChatServer", _config.logFile, logLevel);
_chatSDK = std::make_shared<ai_chat_sdk::ChatSDK>();
// ==================== 1. DeepSeek 模型配置 ====================
auto deepseekConfig = std::make_shared<ai_chat_sdk::APIConfig>();
deepseekConfig->_modelName = "deepseek-chat";
deepseekConfig->_apiKey = _config.deepseekAPIKey;
deepseekConfig->_temperature = _config.temperature;
deepseekConfig->_maxTokens = _config.maxTokens;
// ==================== 2. 豆包(火山方舟)模型配置 ====================
auto doubaoConfig = std::make_shared<ai_chat_sdk::APIConfig>();
doubaoConfig->_modelName = _config.doubaoModelName;
doubaoConfig->_apiKey = _config.doubaoAPIKey;
doubaoConfig->_temperature = _config.temperature;
doubaoConfig->_maxTokens = _config.maxTokens;
// 组装模型配置列表(仅保留两个国内稳定模型)
std::vector<std::shared_ptr<ai_chat_sdk::Config>> modelConfigs = {
deepseekConfig,
doubaoConfig};
INFO("start init ChatSDK models...");
if (!_chatSDK->initModels(modelConfigs))
{
ERR("ChatSDK init Failed!!!");
return;
}
INFO("ChatSDK models init success!!!");
// 创建HTTP服务器实例
_chatServer = std::make_unique<httplib::Server>();
if (!_chatServer)
{
ERR("ChatServer init Failed!!!");
return;
}
}
bool ChatServer::start()
{
if (_isRunning.load())
{
ERR("ChatServer is already running!!!");
return false;
}
// 注册HTTP路由
setHttpRoutes();
// 【关键】设置前端静态资源目录(www目录和可执行文件同级)
_chatServer->set_mount_point("/", "../www");
// 服务器在独立线程运行,不阻塞主线程
std::thread serverThread([this]()
{
_chatServer->listen(_config.host, _config.port);
INFO("ChatServer start on {}:{}", _config.host, _config.port); });
serverThread.detach();
_isRunning.store(true);
INFO("ChatServer start success!!!");
return true;
}
void ChatServer::stop()
{
if (!_isRunning.load())
{
ERR("ChatServer is not running!!!");
return;
}
if (_chatServer)
{
_chatServer->stop();
}
_isRunning.store(false);
INFO("ChatServer stop success!!!");
}
bool ChatServer::isRunning() const
{
return _isRunning.load();
}
// 构造统一的JSON响应
std::string ChatServer::buildResponse(const std::string &message, bool success)
{
Json::Value responseJson;
responseJson["success"] = success;
responseJson["message"] = message;
Json::StreamWriterBuilder writerBuilder;
return Json::writeString(writerBuilder, responseJson);
}
// 处理创建会话请求
void ChatServer::handleCreateSessionRequest(const httplib::Request &request, httplib::Response &response)
{
Json::Value requestJson;
Json::Reader reader;
if (!reader.parse(request.body, requestJson))
{
std::string errorStr = buildResponse("parse request body failed, json format error");
response.status = 400;
// ========== 新增:跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
response.set_content(errorStr, "application/json");
return;
}
// 解析模型名称,默认用deepseek-chat
std::string modelName = requestJson.get("model", "deepseek-chat").asString();
// 调用SDK创建会话
std::string sessionID = _chatSDK->createSession(modelName);
if (sessionID.empty())
{
std::string errorStr = buildResponse("create session failed");
response.status = 500;
// ========== 新增:跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
response.set_content(errorStr, "application/json");
return;
}
// 构造成功响应
Json::Value dataJson;
dataJson["session_id"] = sessionID;
dataJson["model"] = modelName;
Json::Value responseJson;
responseJson["success"] = true;
responseJson["message"] = "create session success";
responseJson["data"] = dataJson;
Json::StreamWriterBuilder writerBuilder;
std::string responseStr = Json::writeString(writerBuilder, responseJson);
response.status = 200;
// ========== 新增:跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
response.set_content(responseStr, "application/json");
}
// 处理获取会话列表请求
void ChatServer::handleGetSessionListsRequest(const httplib::Request &request, httplib::Response &response)
{
std::vector<std::string> sessionIDs = _chatSDK->getSessionLists();
Json::Value dataArray(Json::arrayValue);
for (const auto &sessionID : sessionIDs)
{
auto session = _chatSDK->getSession(sessionID);
if (session)
{
Json::Value sessionJson;
sessionJson["id"] = session->_sessionId;
sessionJson["model"] = session->_modelName;
sessionJson["created_at"] = static_cast<int64_t>(session->_createdAt);
sessionJson["updated_at"] = static_cast<int64_t>(session->_updatedAt);
sessionJson["message_count"] = session->_messages.size();
if (!session->_messages.empty())
{
sessionJson["first_user_message"] = session->_messages.front()._content;
}
dataArray.append(sessionJson);
}
}
Json::Value responseJson;
responseJson["success"] = true;
responseJson["message"] = "get session lists success";
responseJson["data"] = dataArray;
Json::StreamWriterBuilder writerBuilder;
std::string responseStr = Json::writeString(writerBuilder, responseJson);
response.status = 200;
// ========== 新增:跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
response.set_content(responseStr, "application/json");
}
// 处理获取模型列表请求
void ChatServer::handleGetModelListsRequest(const httplib::Request &request, httplib::Response &response)
{
auto modelLists = _chatSDK->getAvailableModels();
Json::Value dataArray(Json::arrayValue);
for (const auto &modelInfo : modelLists)
{
Json::Value modelJson;
modelJson["name"] = modelInfo._modelName;
modelJson["desc"] = modelInfo._modelDesc;
dataArray.append(modelJson);
}
Json::Value responseJson;
responseJson["success"] = true;
responseJson["message"] = "get model lists success";
responseJson["data"] = dataArray;
Json::StreamWriterBuilder writerBuilder;
std::string responseStr = Json::writeString(writerBuilder, responseJson);
response.status = 200;
// ========== 新增:跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
response.set_content(responseStr, "application/json");
}
// 处理删除会话请求
void ChatServer::handleDeleteSessionRequest(const httplib::Request &request, httplib::Response &response)
{
// 从路径中获取会话ID
std::string sessionId = request.matches[1];
bool ret = _chatSDK->deleteSession(sessionId);
if (ret)
{
std::string successStr = buildResponse("delete session success", true);
response.status = 200;
// ========== 新增:跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
response.set_content(successStr, "application/json");
}
else
{
std::string errorStr = buildResponse("delete session failed, session not found");
response.status = 404;
// ========== 新增:跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
response.set_content(errorStr, "application/json");
}
}
// 处理获取历史消息请求
void ChatServer::handleGetHistoryMessagesRequest(const httplib::Request &request, httplib::Response &response)
{
std::string sessionId = request.matches[1];
auto session = _chatSDK->getSession(sessionId);
if (!session)
{
std::string errorStr = buildResponse("session not found");
response.status = 404;
// ========== 新增:跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
response.set_content(errorStr, "application/json");
return;
}
Json::Value dataArray(Json::arrayValue);
for (const auto &message : session->_messages)
{
Json::Value messageJson;
messageJson["id"] = message._messageId;
messageJson["role"] = message._role;
messageJson["content"] = message._content;
messageJson["timestamp"] = static_cast<int64_t>(message._timestamp);
dataArray.append(messageJson);
}
Json::Value responseJson;
responseJson["success"] = true;
responseJson["message"] = "get history messages success";
responseJson["data"] = dataArray;
Json::StreamWriterBuilder writerBuilder;
std::string responseStr = Json::writeString(writerBuilder, responseJson);
response.status = 200;
// ========== 新增:跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
response.set_content(responseStr, "application/json");
}
// 处理发送消息请求-全量返回
void ChatServer::handleSendMessageRequest(const httplib::Request &request, httplib::Response &response)
{
Json::Value requestJson;
Json::Reader reader;
if (!reader.parse(request.body, requestJson))
{
std::string errorStr = buildResponse("parse request body failed, json format error");
response.status = 400;
// ========== 新增:跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
response.set_content(errorStr, "application/json");
return;
}
std::string sessionId = requestJson["session_id"].asString();
std::string message = requestJson["message"].asString();
if (sessionId.empty() || message.empty())
{
std::string errorStr = buildResponse("session_id or message is empty");
response.status = 400;
// ========== 新增:跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
response.set_content(errorStr, "application/json");
return;
}
// 调用SDK发送消息,全量获取回复
std::string assistantMessage = _chatSDK->sendMessage(sessionId, message);
if (assistantMessage.empty())
{
std::string errorStr = buildResponse("Failed to get AI response");
response.status = 500;
// ========== 新增:跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
response.set_content(errorStr, "application/json");
return;
}
Json::Value dataJson;
dataJson["session_id"] = sessionId;
dataJson["response"] = assistantMessage;
dataJson["assistant_message"] = assistantMessage;
Json::Value responseJson;
responseJson["success"] = true;
responseJson["message"] = "send message success";
responseJson["data"] = dataJson;
Json::StreamWriterBuilder writerBuilder;
std::string responseStr = Json::writeString(writerBuilder, responseJson);
response.status = 200;
// ========== 新增:跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
response.set_content(responseStr, "application/json");
}
// 处理发送消息请求-流式增量返回(SSE协议)
void ChatServer::handleSendMessageStreamRequest(const httplib::Request &request, httplib::Response &response)
{
Json::Value requestJson;
Json::Reader reader;
if (!reader.parse(request.body, requestJson))
{
std::string errorStr = buildResponse("parse request body failed, json format error");
response.status = 400;
// ========== 新增:跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
response.set_content(errorStr, "application/json");
return;
}
std::string sessionId = requestJson["session_id"].asString();
std::string message = requestJson["message"].asString();
if (sessionId.empty() || message.empty())
{
std::string errorStr = buildResponse("session_id or message is empty");
response.status = 400;
// ========== 新增:跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
response.set_content(errorStr, "application/json");
return;
}
// 配置SSE流式响应头
response.status = 200;
response.set_header("Cache-Control", "no-cache");
response.set_header("Connection", "keep-alive");
// ========== 补充完整跨域头 ==========
response.set_header("Access-Control-Allow-Origin", "*");
response.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
// 分块流式响应
response.set_chunked_content_provider("text/event-stream", [this, sessionId, message](size_t offset, httplib::DataSink &dataSink) -> bool
{
// 流式写入回调
auto writeChunk = [&](const std::string& chunk, bool last){
std::string sseData = "data: " + Json::valueToQuotedString(chunk.c_str()) + "\n\n";
dataSink.write(sseData.c_str(), sseData.size());
if(last){
// 流式结束标记
std::string doneData = "data: [DONE]\n\n";
dataSink.write(doneData.c_str(), doneData.size());
dataSink.done();
return false;
}
return true;
};
// 先发送空包,避免客户端超时
if (!writeChunk("", false)) {
return false;
}
// 调用SDK流式发送消息
_chatSDK->sendMessageStream(sessionId, message, writeChunk);
return false; });
}
// 注册所有HTTP路由
void ChatServer::setHttpRoutes()
{
// ========== 删掉不兼容的Use和Options中间件 ==========
// 处理OPTIONS预检请求(浏览器跨域必备)
_chatServer->Options(".*", [](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", "*");
res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
res.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
res.status = 200; });
// 会话管理接口
_chatServer->Post("/api/session", [this](const httplib::Request &req, httplib::Response &res)
{ handleCreateSessionRequest(req, res); });
_chatServer->Get("/api/sessions", [this](const httplib::Request &req, httplib::Response &res)
{ handleGetSessionListsRequest(req, res); });
_chatServer->Delete("/api/session/(.*)", [this](const httplib::Request &req, httplib::Response &res)
{ handleDeleteSessionRequest(req, res); });
_chatServer->Get("/api/session/(.*)/history", [this](const httplib::Request &req, httplib::Response &res)
{ handleGetHistoryMessagesRequest(req, res); });
// 模型接口
_chatServer->Get("/api/models", [this](const httplib::Request &req, httplib::Response &res)
{ handleGetModelListsRequest(req, res); });
// 聊天消息接口
_chatServer->Post("/api/message", [this](const httplib::Request &req, httplib::Response &res)
{ handleSendMessageRequest(req, res); });
_chatServer->Post("/api/message/async", [this](const httplib::Request &req, httplib::Response &res)
{ handleSendMessageStreamRequest(req, res); });
}
} // end ai_chat_server
main.cpp
cpp
#include "ChatServer.h"
#include <gflags/gflags.h>
#include <iostream>
#include <fstream>
#include <cstdlib>
#include <string>
#include <stdexcept>
#include <chrono>
#include <thread>
#include <spdlog/common.h>
// 匹配你项目的日志头文件
#include <util/myLog.h>
// 核心配置参数 - 补充豆包模型名称配置
DEFINE_string(host, "0.0.0.0", "服务器绑定地址,腾讯云部署请保持0.0.0.0确保外网可访问");
DEFINE_int32(port, 8080, "服务器绑定端口,需与前端API_BASE_URL的端口保持一致");
DEFINE_string(log_level, "INFO", "日志级别:TRACE/DEBUG/INFO/WARN/ERROR/CRITICAL");
DEFINE_double(temperature, 0.7, "AI生成温度,范围0.0-2.0");
DEFINE_int32(max_tokens, 2048, "AI生成最大Token数");
DEFINE_string(config_file, "./ChatServer.conf", "配置文件路径");
// 【核心新增】豆包模型名称配置(前端展示/选择的核心标识)
DEFINE_string(doubao_model_name, "doubao-seed-2-0-lite-260215", "豆包模型名称,前端展示的模型标识");
// 【核心新增】DeepSeek模型名称(统一标识,避免硬编码)
DEFINE_string(deepseek_model_name, "deepseek-chat", "DeepSeek模型名称,前端展示的模型标识");
// 版本号
const std::string VERSION = "1.0.0";
// 从环境变量获取API密钥 - 补充豆包
std::string getEnvVar(const std::string& key) {
char* value = std::getenv(key.c_str());
return value ? std::string(value) : "";
}
// 配置参数校验 - 支持 DeepSeek 或 豆包 二选一
bool validateConfig(ai_chat_server::ServerConfig& config) {
// 校验温度值范围
if (config.temperature < 0.0 || config.temperature > 2.0) {
ERR("配置错误:温度值必须在0.0-2.0之间,当前值:{}", config.temperature);
return false;
}
// 校验最大Token数
if (config.maxTokens <= 0) {
ERR("配置错误:最大Token数必须为正数,当前值:{}", config.maxTokens);
return false;
}
// 校验:DeepSeek 或 豆包 至少有一个API密钥
if (config.deepseekAPIKey.empty() && config.doubaoAPIKey.empty()) {
ERR("配置错误:请至少配置一个API密钥(DeepSeek 或 豆包)");
return false;
}
return true;
}
// 帮助与接口说明
void showAPIInfo() {
std::cout << "\n========================================\n";
std::cout << " ChatServer API 接口说明\n";
std::cout << "========================================\n";
std::cout << " POST /api/session - 创建新会话(传model参数指定模型)\n";
std::cout << " GET /api/sessions - 获取所有会话列表\n";
std::cout << " GET /api/models - 获取可用模型列表(DeepSeek+豆包)\n";
std::cout << " DELETE /api/session/{session_id} - 删除指定会话\n";
std::cout << " GET /api/session/{session_id}/history - 获取会话历史\n";
std::cout << " POST /api/message - 发送消息(全量返回)\n";
std::cout << " POST /api/message/async - 发送消息(流式返回)\n";
std::cout << "\n【模型使用说明】\n";
std::cout << " 创建会话时传model参数:\n";
std::cout << " - DeepSeek:model=deepseek-chat\n";
std::cout << " - 豆包:model=doubao-chat\n";
std::cout << "\n【腾讯云部署必看】\n";
std::cout << " 1. 已默认绑定0.0.0.0,无需修改host参数\n";
std::cout << " 2. 请在腾讯云控制台安全组放开8080端口入站权限\n";
std::cout << " 3. 前端API_BASE_URL需修改为:http://你的腾讯云公网IP:8080\n";
std::cout << " 4. 启动前需配置至少一个API密钥:\n";
std::cout << " - DeepSeek:export deepseek_apikey=\"你的真实密钥\"\n";
std::cout << " - 豆包:export doubao_apikey=\"你的真实密钥\"\n";
std::cout << "========================================\n\n";
}
// 主函数入口
int main(int argc, char** argv) {
try {
// 帮助信息处理
if (argc == 2 && (std::string(argv[1]) == "-h" || std::string(argv[1]) == "--help")) {
showAPIInfo();
return 0;
}
// 解析gflags命令行参数
gflags::SetUsageMessage("AIChatServer - DeepSeek & 豆包 AI聊天服务器\n用法: ./AIChatServer [options]");
gflags::ParseCommandLineFlags(&argc, &argv, true);
gflags::SetVersionString(VERSION);
// 加载配置文件(如果存在)
std::ifstream configFile(FLAGS_config_file);
if (configFile) {
gflags::SetCommandLineOption("flagfile", FLAGS_config_file.c_str());
INFO("已加载配置文件:{}", FLAGS_config_file);
}
// 构建服务配置对象 - 补充豆包模型名称
ai_chat_server::ServerConfig config;
config.host = FLAGS_host;
config.port = FLAGS_port;
config.logLevel = FLAGS_log_level;
config.temperature = FLAGS_temperature;
config.maxTokens = FLAGS_max_tokens;
// 获取 DeepSeek 和 豆包 的API密钥
config.deepseekAPIKey = getEnvVar("deepseek_apikey");
config.doubaoAPIKey = getEnvVar("doubao_apikey");
// 【核心新增】赋值模型名称(ChatServer需要这个值来注册模型)
config.deepseekModelName = FLAGS_deepseek_model_name;
config.doubaoModelName = FLAGS_doubao_model_name;
// 配置校验
if (!validateConfig(config)) {
ERR("配置校验失败,服务终止启动");
return 1;
}
// 日志级别初始化
spdlog::level::level_enum logLevel = spdlog::level::info;
if (config.logLevel == "TRACE") logLevel = spdlog::level::trace;
else if (config.logLevel == "DEBUG") logLevel = spdlog::level::debug;
else if (config.logLevel == "INFO") logLevel = spdlog::level::info;
else if (config.logLevel == "WARN" || config.logLevel == "WARNING") logLevel = spdlog::level::warn;
else if (config.logLevel == "ERROR") logLevel = spdlog::level::err;
else if (config.logLevel == "CRITICAL") logLevel = spdlog::level::critical;
// 初始化日志组件,匹配你项目原有逻辑
bite::Logger::initLogger("ChatServer", "stdout", logLevel);
// 打印启动配置,方便排查问题
INFO("========================================");
INFO(" ChatServer 启动配置");
INFO("========================================");
INFO(" 服务版本:{}", VERSION);
INFO(" 监听地址:{}", config.host);
INFO(" 监听端口:{}", config.port);
INFO(" 日志级别:{}", config.logLevel);
INFO(" 生成温度:{}", config.temperature);
INFO(" 最大Token:{}", config.maxTokens);
INFO(" DeepSeek 模型名称:{}", config.deepseekModelName);
INFO(" DeepSeek API密钥:{}", config.deepseekAPIKey.empty() ? "未设置" : "已配置");
INFO(" 豆包 模型名称:{}", config.doubaoModelName);
INFO(" 豆包 API密钥:{}", config.doubaoAPIKey.empty() ? "未设置" : "已配置");
INFO("========================================");
// 启动ChatServer服务
ai_chat_server::ChatServer server(config);
if (server.start()) {
INFO("ChatServer 启动成功!");
INFO("服务访问地址:http://{}:{}", config.host, config.port);
INFO("前端请配置API_BASE_URL为:http://你的腾讯云公网IP:{}", config.port);
INFO("前端可选择模型:{}(DeepSeek)、{}(豆包)", config.deepseekModelName, config.doubaoModelName);
// 主线程挂起,保持服务持续运行
while (server.isRunning()) {
std::this_thread::sleep_for(std::chrono::seconds(10));
}
} else {
ERR("ChatServer 启动失败!请检查端口是否被占用,可执行 lsof -i :8080 查看");
return 1;
}
return 0;
} catch (const std::exception& e) {
ERR("服务运行异常:{}", e.what());
return 1;
} catch (...) {
ERR("服务发生未知异常");
return 1;
}
}
下面我们来对本部分代码进行深度解析:
头文件设计:ChatServer.h(服务生命周期与接口定义)
cpp
// 服务器配置结构体:统一管理所有配置
struct ServerConfig {
std::string host; // 监听地址
int port; // 监听端口
std::string logLevel; // 日志级别
std::string logFile; // 日志文件
double temperature; // 生成温度
int maxTokens; // 最大Token数
std::string deepseekAPIKey;// DeepSeek API密钥
std::string doubaoAPIKey; // 豆包 API密钥
std::string deepseekModelName; // DeepSeek模型名称
std::string doubaoModelName; // 豆包模型名称
};
class ChatServer{
public:
ChatServer(const ServerConfig& config); // 构造:初始化SDK+HTTP服务
bool start(); // 启动服务
void stop(); // 停止服务
bool isRunning()const; // 检查状态
private:
// 【接口处理函数】对应每个HTTP API
void handleCreateSessionRequest(...); // 创建会话
void handleGetSessionListsRequest(...); // 获取会话列表
void handleGetModelListsRequest(...); // 获取模型列表
void handleDeleteSessionRequest(...); // 删除会话
void handleGetHistoryMessagesRequest(...); // 获取历史消息
void handleSendMessageRequest(...); // 非流式消息
void handleSendMessageStreamRequest(...); // 流式消息(核心)
// 【路由注册】设置HTTP路由规则
void setHttpRoutes();
private:
ServerConfig _config; // 服务配置
std::unique_ptr<httplib::Server> _chatServer; // HTTP服务器(httplib)
std::shared_ptr<ai_chat_sdk::ChatSDK> _chatSDK; // 【核心】ChatSDK智能指针
std::atomic<bool> _isRunning; // 服务运行状态(原子变量,线程安全)
};
设计思路:
- **配置与逻辑分离:**ServerConfig结构体统一管理所有配置,避免硬编码;
- 职责清晰的接口分层:每个 HTTP API 对应一个
handle函数,代码可读性极高; - **与 SDK 的核心联系:**通过std::shared_ptr<ChatSDK> _chatSDK持有 SDK 实例,所有核心能力都通过 SDK 实现,ChatServer 只做协议适配。
核心实现拆解:ChatServer.cpp(与 SDK 的深度协作)
这里我们重点来看一下与SDK紧密联系的部分:
构造函数:SDK 的初始化(ChatServer 与 SDK 的第一次握手)
cpp
ChatServer::ChatServer(const ServerConfig &config) : _config(config)
{
// 步骤1:初始化日志系统(复用SDK的日志库)
bite::Logger::initLogger("ChatServer", _config.logFile, logLevel);
// 步骤2:【核心】创建ChatSDK实例
_chatSDK = std::make_shared<ai_chat_sdk::ChatSDK>();
// 步骤3:配置双厂商模型(DeepSeek + 豆包)
auto deepseekConfig = std::make_shared<ai_chat_sdk::APIConfig>();
deepseekConfig->_modelName = "deepseek-chat";
deepseekConfig->_apiKey = _config.deepseekAPIKey;
deepseekConfig->_temperature = _config.temperature;
deepseekConfig->_maxTokens = _config.maxTokens;
auto doubaoConfig = std::make_shared<ai_chat_sdk::APIConfig>();
doubaoConfig->_modelName = _config.doubaoModelName;
doubaoConfig->_apiKey = _config.doubaoAPIKey;
doubaoConfig->_temperature = _config.temperature;
doubaoConfig->_maxTokens = _config.maxTokens;
// 步骤4:【核心】调用SDK的initModels初始化所有模型
std::vector<std::shared_ptr<ai_chat_sdk::Config>> modelConfigs = {deepseekConfig, doubaoConfig};
if (!_chatSDK->initModels(modelConfigs))
{
ERR("ChatSDK init Failed!!!");
return;
}
// 步骤5:创建httplib HTTP服务器实例
_chatServer = std::make_unique<httplib::Server>();
}
与SDK的联系:
- 完全复用SDK中的ChatSDK、APIConfig类,无需重复定义;
- 通过
_chatSDK->initModels(modelConfigs)完成 SDK 的核心初始化,ChatServer 不直接接触 Provider、DataManager 等底层细节。
流式消息接口:handleSendMessageStreamRequest(核心中的核心)
这是 ChatServer 最复杂的部分,完美实现了 "SDK 流式回调 → HTTP SSE 协议" 的适配:
cpp
void ChatServer::handleSendMessageStreamRequest(...)
{
// ... 前置校验(解析session_id、message,略)...
// 【关键1】设置SSE流式响应的HTTP头
response.status = 200;
response.set_header("Cache-Control", "no-cache");
response.set_header("Connection", "keep-alive");
response.set_header("Access-Control-Allow-Origin", "*"); // 跨域
// 【关键2】httplib的分块内容提供者:实现流式写入
response.set_chunked_content_provider("text/event-stream",
[this, sessionId, message](size_t offset, httplib::DataSink &dataSink) -> bool {
// 定义流式写入回调:把SDK的回调内容,封装成SSE协议格式
auto writeChunk = [&](const std::string& chunk, bool last){
// SSE协议格式:data: 内容\n\n
std::string sseData = "data: " + Json::valueToQuotedString(chunk.c_str()) + "\n\n";
dataSink.write(sseData.c_str(), sseData.size());
if(last){
// SSE结束标记:data: [DONE]\n\n
std::string doneData = "data: [DONE]\n\n";
dataSink.write(doneData.c_str(), doneData.size());
dataSink.done();
return false;
}
return true;
};
// 先发送空包,避免客户端超时
if (!writeChunk("", false)) return false;
// 【关键3】调用SDK的sendMessageStream,传入writeChunk作为回调
_chatSDK->sendMessageStream(sessionId, message, writeChunk);
return false;
});
}
与 SDK 的联系:
- SDK 的
sendMessageStream需要一个std::function<void(const std::string&, bool)>回调; - ChatServer 在这里定义了writeChunk作为回调,把 SDK 返回的增量内容,封装成 SSE 协议格式(data: 内容\n\n),通过dataSink.write发送给前端;
- 完美实现了 "SDK 内部流式解析 → ChatServer 协议适配 → 前端实时展示" 的全链路流式响应。
路由注册:setHttpRoutes(RESTful API 设计)
cpp
void ChatServer::setHttpRoutes()
{
// 【跨域处理】OPTIONS预检请求(浏览器跨域必备)
_chatServer->Options(".*", [](const httplib::Request &req, httplib::Response &res) {
res.set_header("Access-Control-Allow-Origin", "*");
res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
res.set_header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS");
res.status = 200;
});
// 【会话管理接口】
_chatServer->Post("/api/session", ...); // 创建会话
_chatServer->Get("/api/sessions", ...); // 获取会话列表
_chatServer->Delete("/api/session/(.*)", ...); // 删除会话
_chatServer->Get("/api/session/(.*)/history", ...); // 获取历史消息
// 【模型接口】
_chatServer->Get("/api/models", ...); // 获取可用模型列表
// 【聊天消息接口】
_chatServer->Post("/api/message", ...); // 非流式
_chatServer->Post("/api/message/async", ...); // 流式
}
设计思路:
- 标准的 RESTful API 设计:GET 查询、POST 创建、DELETE 删除;
- 完善的跨域处理:OPTIONS 请求单独处理,所有响应都加跨域头,确保前端可以正常访问;
- 静态资源挂载:在
start方法里通过_chatServer->set_mount_point("/", "../www")挂载前端页面,访问根路径即可加载前端。
main.cpp:服务启动入口(配置管理与生命周期控制)
cpp
int main(int argc, char** argv) {
// 步骤1:解析gflags命令行参数(支持--help、--config_file等)
gflags::ParseCommandLineFlags(&argc, &argv, true);
// 步骤2:从环境变量获取API密钥(安全,不硬编码)
config.deepseekAPIKey = getEnvVar("deepseek_apikey");
config.doubaoAPIKey = getEnvVar("doubao_apikey");
// 步骤3:配置校验(温度范围、Token数、至少一个API密钥)
if (!validateConfig(config)) return 1;
// 步骤4:初始化日志
bite::Logger::initLogger("ChatServer", "stdout", logLevel);
// 步骤5:创建ChatServer实例并启动
ai_chat_server::ChatServer server(config);
if (server.start()) {
// 主线程挂起,保持服务持续运行
while (server.isRunning()) {
std::this_thread::sleep_for(std::chrono::seconds(10));
}
}
}
设计思路:
- gflags 配置管理:支持命令行参数、配置文件,灵活度极高;
- 环境变量管理敏感信息:API Key 从环境变量读取,不硬编码在代码里,符合企业级安全规范;
- 完善的配置校验:启动前校验参数合法性,避免服务启动后出错;
- 主线程挂起 :HTTP 服务在独立线程运行,主线程通过
while循环挂起,保持服务持续运行。
5. 项目亮点与总结
架构设计:分层解耦,设计模式完美落地
项目采用 "对外接口层 - 核心控制层 - 厂商适配层 - 基础支撑层" 四层模块化结构,层间职责单一,完全解耦,是面向对象的经典实践:
- 门面模式:通过ChatSDK统一门面类对外提供极简接口,3 行核心代码即可完成大模型接入,完全屏蔽内部实现复杂度;
- 策略模式:通过LLMProvider抽象基类定义统一接口,不同厂商大模型做独立子类实现,新增厂商仅需新增子类、无需修改上层任何代码,严格遵循开闭原则;
- 职责分离:会话管理、数据持久化、模型调度、厂商适配完全独立,单个模块的修改不会影响其他模块,可维护性与可扩展性拉满。
工程化:工业级 C++ 开发规范,细节拉满
项目全程遵循企业级 C++ 开发规范,规避了各类常见的内存、线程安全问题,核心工程化亮点如下:
- RAII 资源管理:通过构造 / 析构函数管理数据库连接、锁、HTTP 句柄等资源,自动申请与释放,完全杜绝资源泄漏;
- **智能指针的正确使用:**用std::unique_ptr管理 Provider、HTTP 服务等独占资源,通过移动语义转移所有权,无额外性能开销;用std::shared_ptr管理 SDK、会话等共享生命周期的对象,保证内存安全;
- const 正确性与线程安全:用
mutable修饰互斥锁,保证 const 只读函数也能加锁实现线程安全;采用极致精细化的锁粒度控制,仅在内存操作时加锁,大模型 API 等慢 IO 操作完全不加锁,最大化并发性能; - 配置驱动与安全设计:通过 gflags 支持命令行参数、配置文件双模式配置,API 密钥等敏感信息通过环境变量读取,完全杜绝硬编码带来的安全风险;
- 完善的错误处理:全链路前置参数校验、异常场景兜底、分级日志体系,问题可追溯、可排查,鲁棒性极强。
核心能力:全场景覆盖,性能与体验兼顾
项目实现了大模型接入的全场景核心能力,兼顾了功能完整性与用户体验:
- 跨厂商统一接入:无缝支持 DeepSeek、字节豆包等主流大模型,一套接口适配所有厂商,彻底解决业务层多模型重复适配的痛点;
- 全量 + 流式双模式对话:既支持同步全量返回,也基于 SSE 协议实现流式增量返回,完美实现前端 "打字机" 效果,大幅提升用户体验;
- 会话全生命周期管理:支持多会话创建、查询、删除、历史消息持久化,通过外键级联删除保证数据一致性,通过参数化查询完全杜绝 SQL 注入风险;
- 数据持久化:基于 SQLite 实现本地数据存储,程序重启后会话、历史消息不丢失,同时通过内存 + 数据库双层存储优化高频访问性能。
质量保障:完整的测试体系
项目配套完整的测试体系,从开发前置验证到全流程集成测试,形成了完整的质量闭环。
- 前置验证:开发数据持久化层前,通过独立的 SQLite Demo 完成 CRUD 操作验证,降低核心代码开发风险;
- 全流程集成测试:基于 GTest 框架设计测试用例,覆盖非流式对话、流式对话两大核心场景,同时验证双厂商模型的适配能力、会话管理、数据持久化全链路;
- 回归测试能力:测试用例可重复执行,每次代码修改后均可快速验证,杜绝新代码引入线上 bug;
- 边界场景全覆盖:测试用例覆盖空参数、空响应、API 异常、模型不可用等各类边界场景,保证 SDK 在异常场景下的稳定性。
完整落地:从 SDK 到可部署产品的全链路闭环
项目不止于 SDK 组件开发,通过ChatServer应用层完成了服务化落地,形成了可直接部署使用的完整产品:
- RESTful API 标准化封装:将 SDK 能力封装为标准的 HTTP 接口,支持会话管理、模型查询、消息发送全流程接口,适配前端与第三方业务系统接入;
- **HTTP 流式响应适配:**通过分块内容提供者,将 SDK 内部的流式回调适配为标准 SSE 协议,完美支持前端流式展示;
- 完善的跨域处理:兼容浏览器 OPTIONS 预检请求,全接口支持跨域访问,前端接入无阻碍;
- 一键部署能力:支持静态前端页面挂载,配置好 API 密钥后即可一键启动服务,开箱即用。
各模块极简回顾
整个项目的核心模块职责清晰、环环相扣,形成了完整的技术链路:
- 门面入口层 :
ChatSDK,对外提供极简接入接口,对内串联所有核心模块; - 会话管理层 :
SessionManager,负责多会话全生命周期管理、对话上下文维护、无效消息过滤; - 数据持久化层 :
DataManager,基于 SQLite 实现会话与消息的增删改查,保证数据安全与一致性; - 调度分发层 :
LLMManager,策略模式容器,负责 Provider 的注册管理,根据模型名动态路由请求; - 厂商适配层 :
LLMProvider子类,屏蔽不同厂商 API 的差异,实现 HTTP 请求、SSE 解析、JSON 序列化 / 反序列化; - 质量保障层:基于 GTest 的集成测试体系,全流程验证 SDK 能力;
- 服务应用层 :
ChatServer,将 SDK 能力 HTTP 服务化,挂载前端页面,形成完整的对话系统。
后续项目可扩展方向
- 新增更多主流大模型厂商适配(通义千问、文心一言、Claude 等);
- 支持函数调用、RAG 检索增强生成等高级能力;
- 新增 Prometheus 监控指标、健康检查接口,适配线上容器化部署;
- 支持多节点负载均衡,满足高并发业务场景需求。
最后附上本项目的源码链接:gitee源码链接