【LangChain】RunnableWithMessageHistory 完全指南(下):流式、截断与自定义

RunnableWithMessageHistory 完全指南(下):流式、截断与自定义

本系列文章:

在上篇中,我们掌握了 RunnableWithMessageHistory 的基础用法------如何包装链、管理 Session、以及底层的工作原理。但生产环境中的挑战远不止于此:

  • 用户要求实时看到 AI 的回复,流式输出怎么搞?
  • 对话历史越来越长,Token 费用爆炸怎么办?
  • 公司有审计合规要求,需要自定义存储逻辑?

本文将深入这三个核心场景,并给出一套可直接投入生产的完整代码模板。


一、流式输出(Streaming)深度解析

1.1 调用链时序图

流式输出是提升用户体验的关键,但 RunnableWithMessageHistory 的流式行为有一些"反直觉"的设计,理解它至关重要。

复制代码
用户请求 ──► [RunnableWithMessageHistory]
                │
                ├──► 1. 调用 get_session_history(session_id)
                │       └── 从 Redis/DB 加载历史消息
                │
                ├──► 2. 将历史注入 Prompt
                │
                ├──► 3. 调用底层链的 astream()
                │       └── Token 1 ──► Token 2 ──► ... ──► Token N
                │           ▲
                │           └── 实时推送给用户
                │
                └──► 4. 流结束后,将完整 AI 回复写入历史
                        └── 保存到 Redis/DB

关键洞察:历史写入的时机

python 复制代码
# 这是 RunnableWithMessageHistory 的核心逻辑(简化版)
async def astream(self, input, config):
    history = await self.get_session_history(config["configurable"]["session_id"])
    messages = history.messages + [HumanMessage(input)]

    # 流式输出------此时历史还未写入!
    full_response = []
    async for chunk in self.runnable.astream({"messages": messages}):
        full_response.append(chunk)
        yield chunk  # 实时推送给用户

    # ⚠️ 只有在流完全结束后,才写入历史
    history.add_messages([
        HumanMessage(input),
        AIMessage("".join(full_response))
    ])

为什么必须在流结束后写入?

因为流式输出的特点是逐字返回,在流完成前我们无法确定 AI 的完整回复内容。如果中途写入,可能出现:

  1. 消息不完整:用户只收到了"我认为这个方案...",但 AI 实际说完了"我认为这个方案不可行,因为..."
  2. 历史与用户体验不一致:用户看到的和保存的不匹配

1.2 中断场景处理

场景:用户中途关闭浏览器连接

python 复制代码
from fastapi import FastAPI, WebSocket
from langchain_core.runnables.history import RunnableWithMessageHistory

app = FastAPI()

@app.websocket("/chat/{session_id}")
async def websocket_chat(websocket: WebSocket, session_id: str):
    await websocket.accept()

    chain_with_history = RunnableWithMessageHistory(
        chain,
        get_session_history,
        input_messages_key="question",
        history_messages_key="history",
    )

    try:
        async for chunk in chain_with_history.astream(
            {"question": await websocket.receive_text()},
            config={"configurable": {"session_id": session_id}}
        ):
            await websocket.send_text(chunk.content)
    except WebSocketDisconnect:
        # ⚠️ 注意:此时历史不会被保存!
        # 因为 astream() 的 generator 被中断了,第4步(历史写入)永远不会执行
        logger.info(f"用户断开连接,session={session_id},历史未保存")

为什么这是对的?

假设用户问了一个问题,AI 回答到一半用户就离开了。如果保存这段不完整的历史

复制代码
用户:帮我写个 Python 函数计算斐波那契数列
AI:好的,这是一个递归实现(用户此时断开)

下次用户回来,历史里会有一条不完整的 AI 回复,导致后续对话质量下降。不保存是更安全的默认行为。

如果需要"断点续传",如何自定义?

如果你确实需要保存(比如客服场景,即使中断也要记录),可以包装流式输出:

python 复制代码
class InterruptSafeHistory(RunnableWithMessageHistory):
    """支持中断保存的自定义实现"""

    async def astream(self, input, config, **kwargs):
        session_id = config["configurable"]["session_id"]
        history = await self.get_session_history(session_id)

        # 先保存用户消息
        human_msg = HumanMessage(content=input["question"])
        history.add_message(human_msg)

        full_response = []
        try:
            async for chunk in self.runnable.astream(input, config, **kwargs):
                full_response.append(chunk)
                yield chunk
        finally:
            # 无论是否中断,都保存已生成的内容
            if full_response:
                ai_content = "".join([c.content for c in full_response])
                history.add_message(AIMessage(content=ai_content))
                logger.info(f"中断保存:已生成 {len(ai_content)} 字符")

⚠️ 权衡:这种方案会保存不完整回复,需根据业务场景决定是否使用。

1.3 流式输出的性能优化

异步工厂函数

python 复制代码
import aioredis

async def get_session_history(session_id: str) -> BaseChatMessageHistory:
    """异步工厂:非阻塞加载历史"""
    redis = aioredis.from_url("redis://localhost")

    # 异步读取,不会阻塞事件循环
    raw_messages = await redis.lrange(f"chat:{session_id}", 0, -1)

    messages = []
    for raw in raw_messages:
        data = json.loads(raw)
        if data["type"] == "human":
            messages.append(HumanMessage(content=data["content"]))
        else:
            messages.append(AIMessage(content=data["content"]))

    return RedisChatMessageHistory(messages, redis, session_id)

非阻塞历史加载的关键点:

优化项 同步写法(❌) 异步写法(✅)
Redis 读取 redis.lrange() 阻塞 await redis.lrange() 非阻塞
数据库查询 session.query() 阻塞 await session.execute() 非阻塞
文件读取 open().read() 阻塞 aiofiles.open() 非阻塞

在高并发场景下,同步加载历史会导致事件循环阻塞,其他用户的请求被挂起。务必使用异步存储客户端。


二、历史消息截断(Trimming)实战

2.1 问题场景:Token 爆炸的代价

来看一个真实案例:

python 复制代码
# 假设每轮对话平均消耗:
# - 用户消息:100 tokens
# - AI 回复:300 tokens
# - 系统提示:200 tokens(固定)

# 10 轮对话后的 Token 消耗:
round_1  = 200 + 100 + 300 = 600   # 系统 + 用户 + AI
round_2  = 600 + 100 + 300 = 1000  # 历史 + 新消息
round_3  = 1000 + 100 + 300 = 1400
# ...
round_10 = 600 + 9*400 = 4200      # 线性增长!

# GPT-4 定价:$0.03 / 1K tokens(输入)
# 第 10 轮单次请求成本:4200 * 0.03 / 1000 = $0.126
# 如果每天 1000 次对话:$126/天,$3780/月 ------ 仅历史消息部分!

Token 消耗曲线:

复制代码
Tokens
  │
4k├                              ●───── 不截断(线性增长)
  │                         ●─────
3k├                    ●─────
  │               ●─────
2k├          ●─────              ○───── 截断到 2k(稳定)
  │     ●─────              ○─────
1k├●─────              ○─────
  │○─────         ○─────
  └────────────────────────────────────
   1   2   3   4   5   6   7   8   9   10  轮次

2.2 trim_messages 详解

LangChain 提供了 trim_messages 工具,但参数众多,容易踩坑。

python 复制代码
from langchain_core.messages import trim_messages
from langchain_openai import ChatOpenAI

# 基础用法
trimmer = trim_messages(
    max_tokens=2000,           # 保留的最大 token 数
    strategy="last",           # "last" 保留最新消息,"first" 保留最早
    include_system=True,       # 是否保留系统消息
    allow_partial=False,       # 是否允许截断单条消息
    start_on="human",          # 从哪类消息开始计数
    token_counter=ChatOpenAI(model="gpt-4").get_num_tokens_from_messages,
)

参数逐个拆解:

参数 取值 含义
strategy "last" 保留最新的消息,丢弃旧的(最常用)
"first" 保留最早的消息,丢弃新的(罕见场景)
include_system True 系统消息不计入 max_tokens,始终保留
False 系统消息参与截断,可能被丢弃(❌ 不推荐)
allow_partial False 不允许截断单条消息,保持消息完整性
True 允许截断单条消息(极少使用)
start_on "human" 从人类消息开始保留(跳过开头的 AI 消息)
"assistant" 从 AI 消息开始

include_system=True 的重要性:

python 复制代码
# ❌ 错误:系统消息被截断
trimmer_bad = trim_messages(
    max_tokens=100,
    strategy="last",
    include_system=False,  # 系统消息参与竞争
)

messages = [
    SystemMessage("你是一个专业的客服助手..."),  # 30 tokens
    HumanMessage("你好"),                           # 2 tokens
    AIMessage("您好!有什么可以帮您?"),              # 10 tokens
    # ... 很多轮后
]

# 当历史很长时,系统消息可能被挤掉!
# 结果:AI 忘记了自己的角色设定

# ✅ 正确:系统消息始终保留
trimmer_good = trim_messages(
    max_tokens=100,
    strategy="last",
    include_system=True,  # 系统消息"免死金牌"
)

token_counter 的坑:不同模型计数差异

python 复制代码
# 同一个文本,不同模型的 token 数可能不同!
text = "你好,世界"

# OpenAI GPT-4
gpt4_tokens = tiktoken.encoding_for_model("gpt-4").encode(text)
# 结果:6 tokens(中文每个字约 2 tokens)

# Claude
claude_tokens = anthropic.count_tokens(text)
# 结果:可能不同!

# 本地模型(如 Llama)
llama_tokens = tokenizer.encode(text)
# 结果:取决于 tokenizer

# ✅ 最佳实践:使用与目标模型一致的计数器
def get_token_counter(model_name: str):
    if "gpt" in model_name:
        enc = tiktoken.encoding_for_model(model_name)
        return lambda messages: sum(len(enc.encode(m.content)) for m in messages)
    elif "claude" in model_name:
        return lambda messages: anthropic.count_tokens(
            "".join(m.content for m in messages)
        )
    # ...

2.3 三种截断策略对比

策略 适用场景 代码示例 优缺点
按 Token 截断 通用首选,精确控制成本 trim_messages ✅ 精确控制成本 ❌ 需要 token 计数器
按消息数截断 简单场景,快速实现 history.messages[-10:] ✅ 简单直观 ❌ 消息长度不均,token 波动大
按时间截断 客服、咨询场景 过滤 24 小时内的消息 ✅ 符合业务逻辑 ❌ 需要额外时间戳存储

按消息数截断(简单场景):

python 复制代码
class SimpleTrimHistory(BaseChatMessageHistory):
    """简单按消息数截断的历史管理"""

    def __init__(self, max_messages: int = 10):
        self.messages = []
        self.max_messages = max_messages

    def add_message(self, message):
        self.messages.append(message)
        # 保留最后 N 条(每轮对话 = 2 条消息:human + AI)
        if len(self.messages) > self.max_messages * 2:
            self.messages = self.messages[-self.max_messages * 2:]

    def clear(self):
        self.messages = []

按时间截断(客服场景):

python 复制代码
from datetime import datetime, timedelta

class TimeBasedHistory(BaseChatMessageHistory):
    """按时间截断:只保留最近 24 小时的消息"""

    def __init__(self, ttl_hours: int = 24):
        self.messages = []
        self.timestamps = []
        self.ttl = timedelta(hours=ttl_hours)

    def add_message(self, message):
        now = datetime.now()
        # 清理过期消息
        cutoff = now - self.ttl
        valid_indices = [i for i, t in enumerate(self.timestamps) if t > cutoff]

        self.messages = [self.messages[i] for i in valid_indices]
        self.timestamps = [self.timestamps[i] for i in valid_indices]

        # 添加新消息
        self.messages.append(message)
        self.timestamps.append(now)

2.4 在 Chain 内部截断 vs 在 History 层截断

方案 A:在 History 层截断(传统方式)

python 复制代码
class TrimmingHistory(BaseChatMessageHistory):
    def __init__(self):
        self._messages = []
        self.trimmer = trim_messages(max_tokens=2000, strategy="last")

    @property
    def messages(self):
        # 每次读取时都截断
        return self.trimmer.invoke(self._messages)

    def add_message(self, message):
        self._messages.append(message)

问题 :每次读取历史都要序列化所有消息计算 token截断,性能开销大。

方案 B:LCEL 风格的优雅方案(推荐)

python 复制代码
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter

# 在 Chain 内部截断,少一次序列化
chain = (
    {
        "history": itemgetter("history") | trimmer,  # 截断在这里发生
        "question": itemgetter("question"),
    }
    | prompt
    | model
)

chain_with_history = RunnableWithMessageHistory(
    chain,
    get_session_history,
    input_messages_key="question",
    history_messages_key="history",
)

性能对比:

方案 消息序列化次数 Token 计算次数 适用场景
History 层截断 2 次(读 + 写) 2 次 简单应用
Chain 内部截断 1 次 1 次 高性能要求

三、自定义 ChatMessageHistory

3.1 什么时候需要自定义?

标准的历史管理(ChatMessageHistoryRedisChatMessageHistory)在以下场景不够用:

  1. 按用户等级分配不同 token 上限

    • 免费用户:保留 2k tokens
    • 付费用户:保留 8k tokens
    • VIP 用户:保留 16k tokens
  2. 敏感消息审计日志

    • 记录谁、在什么时候、说了什么
    • 合规要求(金融、医疗行业)
  3. 多租户数据隔离

    • SaaS 产品:不同企业的数据物理隔离
    • 防止跨租户数据泄露

3.2 完整实现代码

python 复制代码
from typing import List, Optional
from datetime import datetime
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain_core.messages import trim_messages
from langchain_openai import ChatOpenAI
import threading
import json
import redis

class EnterpriseChatMessageHistory(BaseChatMessageHistory):
    """
    企业级自定义历史管理器
    特性:
    - 按用户等级动态截断
    - 审计日志
    - 线程安全
    - 敏感词过滤
    """

    # 用户等级对应的 token 上限
    TIER_LIMITS = {
        "free": 2000,
        "pro": 8000,
        "enterprise": 16000,
    }

    def __init__(
        self,
        session_id: str,
        user_id: str,
        user_tier: str = "free",
        redis_client: Optional[redis.Redis] = None,
        audit_callback: Optional[callable] = None,
        sensitive_words: Optional[List[str]] = None,
    ):
        self.session_id = session_id
        self.user_id = user_id
        self.user_tier = user_tier
        self.redis = redis_client
        self.audit_callback = audit_callback or self._default_audit
        self.sensitive_words = sensitive_words or []
        self._lock = threading.RLock()

        # 动态 trimmer
        max_tokens = self.TIER_LIMITS.get(user_tier, 2000)
        self.trimmer = trim_messages(
            max_tokens=max_tokens,
            strategy="last",
            include_system=True,
            token_counter=ChatOpenAI(model="gpt-4").get_num_tokens_from_messages,
        )

        # 内存缓存(生产环境应使用 Redis)
        self._message_cache: List[BaseMessage] = []
        self._load_from_storage()

    def _load_from_storage(self):
        """从持久化存储加载"""
        if self.redis:
            key = f"chat_history:{self.session_id}"
            raw_data = self.redis.get(key)
            if raw_data:
                self._message_cache = self._deserialize(raw_data)

    def _save_to_storage(self):
        """保存到持久化存储"""
        if self.redis:
            key = f"chat_history:{self.session_id}"
            self.redis.set(key, self._serialize(self._message_cache))

    def _serialize(self, messages: List[BaseMessage]) -> str:
        """序列化消息"""
        data = []
        for msg in messages:
            data.append({
                "type": msg.type,
                "content": msg.content,
                "timestamp": datetime.now().isoformat(),
            })
        return json.dumps(data)

    def _deserialize(self, raw: str) -> List[BaseMessage]:
        """反序列化消息"""
        data = json.loads(raw)
        messages = []
        for item in data:
            if item["type"] == "human":
                messages.append(HumanMessage(content=item["content"]))
            elif item["type"] == "ai":
                messages.append(AIMessage(content=item["content"]))
            elif item["type"] == "system":
                messages.append(SystemMessage(content=item["content"]))
        return messages

    def _default_audit(self, action: str, message: BaseMessage):
        """默认审计日志"""
        audit_entry = {
            "timestamp": datetime.now().isoformat(),
            "session_id": self.session_id,
            "user_id": self.user_id,
            "action": action,
            "message_type": message.type,
            "content_preview": message.content[:100] + "..." if len(message.content) > 100 else message.content,
        }
        print(f"[AUDIT] {json.dumps(audit_entry)}")

    def _check_sensitive(self, content: str) -> str:
        """敏感词过滤"""
        filtered = content
        for word in self.sensitive_words:
            filtered = filtered.replace(word, "*" * len(word))
        return filtered

    @property
    def messages(self) -> List[BaseMessage]:
        """读取时自动截断"""
        with self._lock:
            # 动态截断:根据用户等级
            return self.trimmer.invoke(self._message_cache)

    def add_message(self, message: BaseMessage) -> None:
        """添加消息(带审计和过滤)"""
        with self._lock:
            # 敏感词过滤
            if isinstance(message, (HumanMessage, AIMessage)):
                message.content = self._check_sensitive(message.content)

            # 写入审计日志
            self.audit_callback("add", message)

            # 添加到缓存
            self._message_cache.append(message)

            # 持久化
            self._save_to_storage()

    def add_messages(self, messages: List[BaseMessage]) -> None:
        """批量添加"""
        for msg in messages:
            self.add_message(msg)

    def clear(self) -> None:
        """清空历史"""
        with self._lock:
            self.audit_callback("clear", SystemMessage(content="HISTORY_CLEARED"))
            self._message_cache = []
            if self.redis:
                self.redis.delete(f"chat_history:{self.session_id}")

    def get_stats(self) -> dict:
        """获取统计信息"""
        return {
            "session_id": self.session_id,
            "user_tier": self.user_tier,
            "token_limit": self.TIER_LIMITS.get(self.user_tier),
            "total_messages": len(self._message_cache),
            "storage_key": f"chat_history:{self.session_id}",
        }

使用示例:

python 复制代码
# 工厂函数
def get_enterprise_history(session_id: str, user_id: str) -> EnterpriseChatMessageHistory:
    # 实际应从数据库查询用户等级
    user_tier = get_user_tier_from_db(user_id)

    return EnterpriseChatMessageHistory(
        session_id=session_id,
        user_id=user_id,
        user_tier=user_tier,
        redis_client=redis.Redis(host="localhost", port=6379),
        sensitive_words=["密码", "身份证号", "银行卡"],
    )

# 使用
chain_with_history = RunnableWithMessageHistory(
    chain,
    get_enterprise_history,
    input_messages_key="question",
    history_messages_key="history",
)

3.3 与外部系统集成

接入公司内部权限系统:

python 复制代码
class RBACChatMessageHistory(EnterpriseChatMessageHistory):
    """基于角色的访问控制历史管理"""

    def __init__(self, *args, **kwargs):
        self.permission_service = kwargs.pop("permission_service")
        super().__init__(*args, **kwargs)

    @property
    def messages(self) -> List[BaseMessage]:
        # 检查用户是否有权查看历史
        if not self.permission_service.can_read_history(self.user_id, self.session_id):
            raise PermissionError("无权访问该会话历史")

        return super().messages

    def add_message(self, message: BaseMessage) -> None:
        # 检查写入权限
        if not self.permission_service.can_write_history(self.user_id):
            raise PermissionError("无权写入历史")

        super().add_message(message)

消息加密存储:

python 复制代码
from cryptography.fernet import Fernet

class EncryptedChatMessageHistory(EnterpriseChatMessageHistory):
    """加密存储的历史管理"""

    def __init__(self, *args, **kwargs):
        self.cipher = Fernet(kwargs.pop("encryption_key"))
        super().__init__(*args, **kwargs)

    def _serialize(self, messages: List[BaseMessage]) -> str:
        """加密序列化"""
        plain = super()._serialize(messages)
        return self.cipher.encrypt(plain.encode()).decode()

    def _deserialize(self, raw: str) -> List[BaseMessage]:
        """解密反序列化"""
        plain = self.cipher.decrypt(raw.encode()).decode()
        return super()._deserialize(plain)

四、生产环境 Checklist

在将 RunnableWithMessageHistory 部署到生产环境前,逐项检查:

Session ID 生成策略

python 复制代码
import uuid
from hashlib import sha256

# ❌ 不推荐:纯随机 UUID,无法追溯
session_id = str(uuid.uuid4())

# ✅ 推荐:业务相关 + 随机盐,可追踪且防枚举
def generate_session_id(user_id: str, conversation_type: str = "chat") -> str:
    salt = uuid.uuid4().hex[:8]
    return sha256(f"{user_id}:{conversation_type}:{salt}".encode()).hexdigest()[:16]

# 结果:"a3f7b2c9d1e4f5a8" ------ 无法反推用户 ID,但相同用户 + 类型可关联分析
策略 示例 适用场景
纯 UUID 550e8400-e29b-41d4-a716-446655440000 匿名场景
业务 ID + 随机 user_1234:a3f7b2c9 需要关联分析
哈希值 a3f7b2c9d1e4f5a8 隐私保护 + 可追踪

存储连接池配置

python 复制代码
# Redis 连接池(生产必配)
redis_pool = redis.ConnectionPool(
    host="redis.internal",
    port=6379,
    db=0,
    max_connections=100,      # 最大连接数
    socket_connect_timeout=5,  # 连接超时
    socket_timeout=5,          # 读写超时
    health_check_interval=30,  # 健康检查
)

redis_client = redis.Redis(connection_pool=redis_pool)

历史消息 GDPR 合规删除

python 复制代码
class GDPRCompliantHistory(EnterpriseChatMessageHistory):
    """GDPR 合规的历史管理"""

    def delete_user_data(self, user_id: str) -> int:
        """响应用户"删除我的数据"请求(Right to be forgotten)"""
        # 1. 删除该用户的所有会话
        pattern = f"chat_history:*"
        deleted = 0
        for key in self.redis.scan_iter(match=pattern):
            session_data = self.redis.get(key)
            if session_data and user_id in session_data.decode():
                self.redis.delete(key)
                deleted += 1

        # 2. 记录删除审计日志
        self.audit_callback("gdpr_delete", SystemMessage(
            content=f"Deleted {deleted} sessions for user {user_id}"
        ))

        return deleted

    def export_user_data(self, user_id: str) -> dict:
        """响应用户"导出我的数据"请求(Right to data portability)"""
        # 收集该用户的所有数据
        data = {"sessions": [], "export_date": datetime.now().isoformat()}
        # ... 实现导出逻辑
        return data

监控指标

python 复制代码
from prometheus_client import Histogram, Counter, Gauge

# 定义监控指标
history_load_latency = Histogram(
    "chat_history_load_seconds",
    "Time spent loading chat history",
    buckets=[0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0],
)

history_size = Gauge(
    "chat_history_messages_total",
    "Number of messages in history",
    ["session_id"],
)

history_operations = Counter(
    "chat_history_operations_total",
    "Total history operations",
    ["operation_type"],  # add, read, clear, trim
)

# 在代码中埋点
class MonitoredHistory(EnterpriseChatMessageHistory):
    @property
    def messages(self):
        with history_load_latency.time():
            msgs = super().messages
        history_size.labels(session_id=self.session_id).set(len(msgs))
        history_operations.labels(operation_type="read").inc()
        return msgs

    def add_message(self, message):
        super().add_message(message)
        history_operations.labels(operation_type="add").inc()

五、完整生产级代码模板

整合上下篇所有最佳实践,一键可用的 starter template:

python 复制代码
"""
RunnableWithMessageHistory 生产级模板
整合:流式输出、动态截断、自定义历史、监控、GDPR 合规
"""

import os
import json
import uuid
import asyncio
import redis.asyncio as aioredis
from datetime import datetime, timedelta
from typing import List, Optional, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import (
    BaseMessage, HumanMessage, AIMessage, SystemMessage, trim_messages
)
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter

# ==================== 配置 ====================
@dataclass
class AppConfig:
    """应用配置"""
    redis_url: str = "redis://localhost:6379"
    openai_api_key: str = os.getenv("OPENAI_API_KEY")
    model_name: str = "gpt-4"
    default_max_tokens: int = 4000
    session_ttl_hours: int = 24 * 7  # 7 天过期

    # 用户等级配置
    tier_limits = {
        "free": 2000,
        "pro": 8000,
        "enterprise": 16000,
    }

CONFIG = AppConfig()

# ==================== 自定义历史管理器 ====================
class ProductionChatMessageHistory(BaseChatMessageHistory):
    """生产级历史管理器"""

    def __init__(
        self,
        session_id: str,
        user_id: str,
        user_tier: str = "free",
        redis: Optional[aioredis.Redis] = None,
    ):
        self.session_id = session_id
        self.user_id = user_id
        self.user_tier = user_tier
        self.redis = redis
        self._cache: List[BaseMessage] = []

        # 动态截断器
        max_tokens = CONFIG.tier_limits.get(user_tier, CONFIG.default_max_tokens)
        self.trimmer = trim_messages(
            max_tokens=max_tokens,
            strategy="last",
            include_system=True,
            token_counter=ChatOpenAI(model=CONFIG.model_name).get_num_tokens_from_messages,
        )

    async def _load(self):
        """异步加载"""
        if self.redis:
            key = f"chat:{self.session_id}"
            data = await self.redis.get(key)
            if data:
                self._cache = self._deserialize(data)

    async def _save(self):
        """异步保存"""
        if self.redis:
            key = f"chat:{self.session_id}"
            await self.redis.setex(
                key,
                timedelta(hours=CONFIG.session_ttl_hours),
                self._serialize(self._cache),
            )

    def _serialize(self, messages: List[BaseMessage]) -> str:
        return json.dumps([
            {"type": m.type, "content": m.content}
            for m in messages
        ])

    def _deserialize(self, raw: str) -> List[BaseMessage]:
        data = json.loads(raw)
        type_map = {
            "human": HumanMessage,
            "ai": AIMessage,
            "system": SystemMessage,
        }
        return [type_map[d["type"]](content=d["content"]) for d in data]

    @property
    def messages(self) -> List[BaseMessage]:
        return self.trimmer.invoke(self._cache)

    def add_message(self, message: BaseMessage) -> None:
        self._cache.append(message)
        # 异步保存由调用方处理

    def add_messages(self, messages: List[BaseMessage]) -> None:
        for m in messages:
            self.add_message(m)

    def clear(self) -> None:
        self._cache = []

# ==================== 异步工厂函数 ====================
async def get_session_history(
    session_id: str,
    user_id: str = "anonymous",
    user_tier: str = "free",
) -> ProductionChatMessageHistory:
    """异步工厂"""
    redis = aioredis.from_url(CONFIG.redis_url)
    history = ProductionChatMessageHistory(
        session_id=session_id,
        user_id=user_id,
        user_tier=user_tier,
        redis=redis,
    )
    await history._load()
    return history

# ==================== 构建 Chain ====================
def create_chain():
    """创建带历史管理的对话链"""

    prompt = ChatPromptTemplate.from_messages([
        ("system", "你是一个专业、友好的 AI 助手。请用中文回答。"),
        MessagesPlaceholder(variable_name="history"),
        ("human", "{question}"),
    ])

    model = ChatOpenAI(
        model=CONFIG.model_name,
        api_key=CONFIG.openai_api_key,
        streaming=True,
    )

    # LCEL 风格:在 Chain 内部截断
    chain = (
        {
            "history": itemgetter("history"),
            "question": itemgetter("question"),
        }
        | prompt
        | model
        | StrOutputParser()
    )

    return RunnableWithMessageHistory(
        chain,
        lambda session_id: get_session_history(session_id, user_id="user_123", user_tier="pro"),
        input_messages_key="question",
        history_messages_key="history",
    )

# ==================== FastAPI 应用 ====================
app = FastAPI(title="Chat API with History")

@app.websocket("/chat/{session_id}")
async def websocket_chat(websocket: WebSocket, session_id: str):
    """WebSocket 流式对话"""
    await websocket.accept()

    chain = create_chain()
    history = await get_session_history(session_id)

    try:
        while True:
            question = await websocket.receive_text()

            # 流式输出
            full_response = []
            async for chunk in chain.astream(
                {"question": question},
                config={"configurable": {"session_id": session_id}},
            ):
                await websocket.send_json({
                    "type": "token",
                    "content": chunk,
                })
                full_response.append(chunk)

            # 保存完整对话到历史
            history.add_messages([
                HumanMessage(content=question),
                AIMessage(content="".join(full_response)),
            ])
            await history._save()

            await websocket.send_json({"type": "done"})

    except WebSocketDisconnect:
        print(f"Client disconnected: {session_id}")
    except Exception as e:
        await websocket.send_json({"type": "error", "message": str(e)})
    finally:
        await history.redis.close()

@app.delete("/sessions/{session_id}")
async def delete_session(session_id: str):
    """GDPR:删除会话"""
    redis = aioredis.from_url(CONFIG.redis_url)
    await redis.delete(f"chat:{session_id}")
    return {"status": "deleted", "session_id": session_id}

@app.get("/health")
async def health_check():
    """健康检查"""
    return {"status": "ok", "timestamp": datetime.now().isoformat()}

# ==================== 启动 ====================
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

六、总结与延伸阅读

核心要点回顾

主题 关键决策
流式输出 默认流结束后写历史;中断不保存是正确行为;需要断点续传可自定义
历史截断 trim_messages + include_system=True;在 Chain 内部截断性能更好
自定义历史 继承 BaseChatMessageHistory;关注线程安全、审计、加密
生产部署 Session ID 策略、连接池、GDPR、监控缺一不可

延伸阅读

  1. LangChain 官方文档

  2. 性能优化

  3. 合规与安全


📌 下篇预告 :我们将深入 LCEL 的高级组合模式 ------如何用 RunnableParallelRunnableBranch 构建复杂的多步骤 AI 工作流。

相关推荐
小碗羊肉3 小时前
【Agent笔记 | 第五篇】LangChain&LangGraph
笔记·langchain
abigale033 小时前
LangChain 多轮对话记忆:基于 session_id 实现多会话隔离
typescript·langchain·uuid·session_id
草莓熊Lotso4 小时前
【LangChain】聊天模型实战:结构化输出完全指南(从原理到落地)
数据库·python·langchain·软件工程
遇见火星4 小时前
LangChain 系列(二):LangChain vs DeepAgent
langchain
swipe14 小时前
Neo4j + Graph RAG 医疗知识图谱工程实践:患者教育问答真正需要的是“关系可追溯”
后端·langchain·llm
CC大煊15 小时前
一个Javaer的AI转型笔记(1):入坑LangChain,我的第一个hello world
笔记·langchain
Mr.Daozhi17 小时前
RAG 进阶实战:跑通 Demo 后我连续翻了 6 次车,逐一修复才真正可用(含 Gradio Web 版)
前端·数据库·langchain·大模型·gradio·rag·科研工具
swipe20 小时前
混合检索 RAG 的工程化实践:不是多查几路,而是把召回、重排和上下文预算管好
后端·langchain·llm
啊哈哈哈哈哈啊哈哈21 小时前
LangChain 与 LlamaIndex 实现 RAG:代码知识点总结
langchain