自定义useChat管理AI会话

当前有许多ai对话hook 如@ai-sdk/react的useChat @ant-design/x-sdk的useXChat 但在我的使用场景中都有所不足 遂自定义hook useChat

功能如下:

  • 收集用户提供的的流式数据 不做任何更改
  • 维护用户消息/流式数据的状态 允许停止传输
  • 恢复整个session的会话 允许立刻进入流式传输状态
  • 更流畅的session创建 用户发送消息立刻切换到会话页面 而不需要在创建页面等待

useChat/useChatId.ts

useChat通过chatId判断是否重置会话 仅仅使用sessionId作为chatId会导致创建session时会话也重置

useChatId可以让用户在创建session时维持chatId不变

ts 复制代码
/* eslint-disable react-hooks/refs */
import { useRef } from 'react'
import { useMemoizedFn } from 'ahooks'
import { match, P } from 'ts-pattern'

export function useChatId(options: {
  sessionId?: string | null
  setSessionId?: (sid: string) => void
}) {
  const { sessionId, setSessionId } = options
  // 上次调用时传入的sessionId
  const prevSessionIdRef = useRef(sessionId)
  // 上次调用时使用的chatId
  const prevChatIdRef = useRef<string>(undefined)
  // sessionId由无到有 可能是创建session 也可能进入已有session
  // 保存新sessionId 控制chatId不变
  const newSessionIdRef = useRef<string>(undefined)
  // 用于引发会话刷新
  const chatId = match({ psid: prevSessionIdRef.current, sid: sessionId })
    .with({ psid: P.string, sid: P.string }, ({ psid, sid }) => {
      if (psid !== sid) {
        // session->另一session 切换id
        return sid
      } else {
        if (typeof prevChatIdRef.current !== 'string') {
          // 首次进入session prevChatIdRef尚未赋值
          return sid
        }
        // session一致 id不变
        return prevChatIdRef.current
      }
    })
    .with({ psid: P.string, sid: P.nullish }, () => {
      // session->会话创建 新id
      return crypto.randomUUID()
    })
    .with({ psid: P.nullish, sid: P.nullish }, () => {
      if (typeof prevChatIdRef.current !== 'string') {
        // 首次进入会话创建 prevChatIdRef尚未赋值
        return crypto.randomUUID()
      }
      // 始终处于新会话创建阶段 id不变
      return prevChatIdRef.current
    })
    .with({ psid: P.nullish, sid: P.string }, ({ sid }) => {
      const newSessionId = newSessionIdRef.current
      newSessionIdRef.current = undefined
      if (typeof newSessionId === 'string' && sid === newSessionId) {
        // 新sessionId与预设值相同
        // 新会话创建 id不变
        return prevChatIdRef.current
      }
      // 进入已有session
      return sid
    })
    .exhaustive()!
  prevSessionIdRef.current = sessionId
  prevChatIdRef.current = chatId
  /** 在新建session后 改为调用此函数设置sessionId 维持chatId不变 */
  const setNewSessionId = useMemoizedFn((sid: string) => {
    newSessionIdRef.current = sid
    setSessionId?.(sid)
  })
  return { chatId, setNewSessionId }
}

useChat/index.ts

要点如下

  • 使用useCallback配合闭包 确保不会让过期的setState执行
  • 允许用户同时resume已完成的历史消息和ai正在执行的消息
  • 为resuming/streaming/submitting单独设置state 避免单一status表意不清
  • 使用AbortController提供了停止能力 无论用户消息还是ai消息
  • 通过id唯一标识消息 让流式数据和常规消息正确合并
ts 复制代码
/* eslint-disable react-hooks/refs */
import type { Dispatch, SetStateAction } from 'react'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useLatest, useMemoizedFn } from 'ahooks'

export { useChatId } from './useChatId'

export type AIMessage<AIMessagePart> = {
  id: string
  role: 'ai'
  parts: AIMessagePart[]
  status: 'streaming' | 'done' | 'aborted' | 'error'
}

export type UserMessage<UserMessagePart> = {
  id: string
  role: 'user'
  parts: UserMessagePart[]
  status: 'submitting' | 'done' | 'aborted' | 'error'
}

export type ChatMessage<AIMessagePart, UserMessagePart> =
  | AIMessage<AIMessagePart>
  | UserMessage<UserMessagePart>

export type UseChatOptions<AIMessagePart, UserMessagePart> = {
  chatId: unknown
  resumeMessages: (signal: AbortSignal) =>
    | Promise<
        | {
            messages?: ChatMessage<AIMessagePart, UserMessagePart>[]
            stream?: AsyncIterable<AIMessagePart>
          }
        | void
        | undefined
        | null
      >
    | void
    | undefined
    | null
  sendUserMessage: (
    parts: UserMessagePart[],
    signal: AbortSignal,
  ) => Promise<AsyncIterable<AIMessagePart> | void | undefined | null> | void | undefined | null
  /**
   * 异常通知回调\
   * resumeMessages、sendUserMessage、AI 流消费中任何非 abort 异常都会触发。\
   * 仅作为通知,sendMessage 仍会向调用方 throw,调用方可自行决定是否再做处理。
   */
  onError?: (error: unknown) => void
}

/** chatId的占位符 */
const DefaultChatId = Symbol('DefaultChatId')

/**
 * 聊天会话 Hook 管理消息列表、发送消息、流式消费与中止。
 *
 * `chatId` 作为会话作用域 key 变更时内部状态(messages、submitting、streaming)会重置
 * 并重新触发 `resumeMessages`,可以同时恢复历史消息和流式传输。
 * *流式消息会自动续接在最后一条AI消息上。*
 *
 * ### 新建会话时保持 chatId 稳定
 *
 * 如果用sessionId作为chatId,那么 sessionId 从无到有(新会话创建 -> 服务端返回真实 sid)会导致丢失当前正在流式中的消息与状态。
 * 使用 {@link useChatId} 可在这一过渡期保持 chatId 不变
 * 仅在真正切换到其他已有会话时才更新。
 *
 * @example
 * ```tsx
 * const [sessionId, setSessionId] = useState<string | null>(null)
 * const { chatId, setNewSessionId } = useChatId({ sessionId, setSessionId })
 *
 * const { messages, sendMessage } = useChat({
 *   chatId,
 *   resumeMessages: async (signal) => {
 *     if (!sessionId) return { messages: [] }
 *     // 允许返回stream继续流式传输
 *     const { messages, stream } = await fetchHistory(sessionId, signal)
 *     return { messages, stream }
 *   },
 *   sendUserMessage: async (parts, signal) => {
 *     // 没有 sessionId 时先自行新建 session 用 setNewSessionId 写回 chatId 保持不变
 *     let sid = sessionId
 *     if (!sid) {
 *       sid = await createSession(signal)
 *       // 调用setNewSessionId确保chatId不变
 *       setNewSessionId(sid)
 *     }
 *     const { stream } = await postMessage({ sessionId: sid, parts }, signal)
 *     return stream
 *   },
 * })
 * ```
 */
export function useChat<AIMessagePart, UserMessagePart>(
  options: UseChatOptions<AIMessagePart, UserMessagePart>,
) {
  const { chatId } = options
  const optionsRef = useLatest(options)
  const abortCtrlRef = useRef(new AbortController())

  const stop = useMemoizedFn(() => {
    abortCtrlRef.current.abort()
    abortCtrlRef.current = new AbortController()
  })

  const [messages, setMessages, messageActions] = useMessages<AIMessagePart, UserMessagePart>(
    chatId,
  )

  /** 是否正在恢复会话(不包括读流的过程) */
  const [resuming, setResuming] = useSafeState(true, chatId)

  /** 用户已提交但尚未收到响应的消息数量 */
  const [submittingCount, setSubmittingCount] = useSafeState(0, chatId)
  const invokeSendUserMessage = useCallback(
    async (id: string, parts: UserMessagePart[], signal: AbortSignal) => {
      setSubmittingCount((prev) => prev + 1)
      try {
        const stream = await optionsRef.current.sendUserMessage(parts, signal)
        // 调用方可能未监听 signal
        if (signal.aborted) throw new DOMException('Aborted', 'AbortError')
        messageActions.updateMessage({ id, role: 'user', status: 'done' })
        return stream
      } catch (err) {
        const aborted = signal.aborted
        messageActions.updateMessage({
          id,
          role: 'user',
          status: aborted ? 'aborted' : 'error',
        })
        throw err
      } finally {
        setSubmittingCount((prev) => prev - 1)
      }
    },
    [messageActions, optionsRef, setSubmittingCount],
  )

  /** 正在被读取的 AI 流数量 */
  const [streamingCount, setStreamingCount] = useSafeState(0, chatId)
  const consumeAIStream = useCallback(
    async (id: string, stream: AsyncIterable<AIMessagePart>, signal: AbortSignal) => {
      setStreamingCount((prev) => prev + 1)
      try {
        for await (const part of stream) {
          // 调用方可能未监听 signal
          if (signal.aborted) throw new DOMException('Aborted', 'AbortError')
          messageActions.updateMessage({ id, role: 'ai', parts: [part] })
        }
        messageActions.updateMessage({ id, role: 'ai', status: 'done' })
      } catch (err) {
        // 主动 stop 触发的中止标记为 aborted 真实异常才标记 error
        const aborted = signal.aborted
        messageActions.updateMessage({
          id,
          role: 'ai',
          status: aborted ? 'aborted' : 'error',
        })
        throw err
      } finally {
        setStreamingCount((prev) => prev - 1)
      }
    },
    [messageActions, setStreamingCount],
  )

  // 防止effect意外调用 例如hmr
  const latestChatId = useRef<unknown>(DefaultChatId)
  const isChatIdChange = latestChatId.current !== chatId
  latestChatId.current = chatId
  useEffect(() => {
    // 组件卸载时无效
    const cleanup = () => {
      if (latestChatId.current !== chatId) stop()
    }
    if (!isChatIdChange) return cleanup
    const signal = abortCtrlRef.current.signal
    ;(async () => {
      try {
        const res = await optionsRef.current.resumeMessages(signal)
        // 调用方可能未监听 signal
        if (signal.aborted) throw new DOMException('Aborted', 'AbortError')
        if (!res) return
        const { stream } = res
        let { messages: resumed = [] } = res
        let lastAiMessageId: string | undefined
        const lastMessage = resumed.at(-1)
        if (stream) {
          if (lastMessage?.role === 'ai') {
            lastAiMessageId = lastMessage.id
          } else {
            lastAiMessageId = crypto.randomUUID()
            resumed = [
              ...resumed,
              { id: lastAiMessageId, role: 'ai', parts: [], status: 'streaming' },
            ]
          }
        }
        if (resumed.length) {
          messageActions.addMessages(resumed, 'pre')
        }
        if (lastAiMessageId && stream) {
          setResuming(false)
          await consumeAIStream(lastAiMessageId, stream, signal)
        }
      } catch (err) {
        if (!signal.aborted) optionsRef.current.onError?.(err)
      } finally {
        setResuming(false)
      }
    })()
    return cleanup
  }, [chatId, consumeAIStream, isChatIdChange, messageActions, optionsRef, setResuming, stop])
  // 组件卸载时清理
  useEffect(() => {
    return stop
  }, [stop])

  const sendMessage = useCallback(
    async (parts: UserMessagePart[]) => {
      const signal = abortCtrlRef.current.signal
      const userMessageId = crypto.randomUUID()
      messageActions.addMessages({
        id: userMessageId,
        role: 'user',
        parts,
        status: 'submitting',
      })
      try {
        const stream = await invokeSendUserMessage(userMessageId, parts, signal)
        if (!stream) return
        const aiMessageId = crypto.randomUUID()
        messageActions.addMessages({
          id: aiMessageId,
          role: 'ai',
          parts: [],
          status: 'streaming',
        })
        await consumeAIStream(aiMessageId, stream, signal)
      } catch (err) {
        if (!signal.aborted) optionsRef.current.onError?.(err)
        throw err
      }
    },
    [consumeAIStream, invokeSendUserMessage, messageActions, optionsRef],
  )

  return {
    messages,
    sendMessage,
    /** 是否正在恢复会话(不含读流阶段) */
    resuming,
    /** 是否有 AI 流正在读取中 */
    streaming: streamingCount > 0,
    /** 是否有用户消息已发出但尚未收到响应 */
    submitting: submittingCount > 0,
    stop,
    setMessages,
    messageActions,
    consumeAIStream,
  }
}

/**
 * 与作用域 key 绑定的 state\
 * - value 用 ref 维护 key 变更当帧同步重置 避免 render 中 setValue 触发二次渲染\
 * - setter 闭包在 useCallback 中捕获 render 当时的 key 异步回来时若 key 已变更 则跳过更新
 */
function useSafeState<T>(initial: T, key: unknown) {
  const [, forceRender] = useState({})
  const valueRef = useRef(initial)
  const prevKeyRef = useRef(key)
  const latestKeyRef = useRef(key)
  if (!Object.is(prevKeyRef.current, key)) {
    prevKeyRef.current = key
    valueRef.current = initial
  }
  latestKeyRef.current = key
  // 不能使用useMemorizedFn 需要闭包里的旧key与最新的key对比
  const safeSetValue: Dispatch<SetStateAction<T>> = useCallback(
    (next) => {
      if (!Object.is(latestKeyRef.current, key)) return
      valueRef.current =
        typeof next === 'function' ? (next as (prev: T) => T)(valueRef.current) : next
      forceRender({})
    },
    [key],
  )
  return [valueRef.current, safeSetValue] as const
}

/** 分配式的 PartialExcept 保留联合类型各分支 使 role 能作为 discriminator narrow 到对应分支 */
type PartialExcept<T, K extends PropertyKey> = T extends unknown
  ? Pick<T, Extract<keyof T, K>> & Partial<Omit<T, K>>
  : never

function useMessages<AIMessagePart, UserMessagePart>(chatId: unknown) {
  type Message = ChatMessage<AIMessagePart, UserMessagePart>
  const [messages, setMessages] = useSafeState<Message[]>([], chatId)
  const actions = useMemo(() => {
    const addMessages = (messageOrList: Message | Message[], position: 'pre' | 'end' = 'end') => {
      const list = Array.isArray(messageOrList) ? messageOrList : [messageOrList]
      if (!list.length) return
      setMessages((prev) => (position === 'end' ? [...prev, ...list] : [...list, ...prev]))
    }

    const updateMessage = (patch: PartialExcept<Message, 'id' | 'role'>) => {
      if (!patch.parts?.length && patch.status === undefined) return
      setMessages((prev) => {
        const index = prev.findIndex((item) => item.id === patch.id && item.role === patch.role)
        if (index === -1) return prev
        const target = prev[index]
        const merged = mergeMessage(target, patch)
        if (merged === target) return prev
        return prev.map((item, i) => (i === index ? merged : item))
      })
    }
    return {
      addMessages,
      updateMessage,
    }
  }, [setMessages])

  return [messages, setMessages, actions] as const
}

function mergeMessage<AIMessagePart, UserMessagePart>(
  target: ChatMessage<AIMessagePart, UserMessagePart>,
  patch: PartialExcept<ChatMessage<AIMessagePart, UserMessagePart>, 'id' | 'role'>,
): ChatMessage<AIMessagePart, UserMessagePart> {
  // role 已在调用方通过 findIndex 匹配 下面仅通过窄化恢复 TS 联合类型收敛
  if (target.role === 'ai' && patch.role === 'ai') {
    return {
      ...target,
      ...(patch.status !== undefined && { status: patch.status }),
      ...(patch.parts?.length && { parts: [...target.parts, ...patch.parts] }),
    }
  }
  if (target.role === 'user' && patch.role === 'user') {
    return {
      ...target,
      ...(patch.status !== undefined && { status: patch.status }),
      ...(patch.parts?.length && { parts: [...target.parts, ...patch.parts] }),
    }
  }
  return target
}
相关推荐
阿钱真强道1 小时前
21 ComfyUI 实战:IP-Adapter + ControlNet 实现人物表情编辑,为什么降权重后更容易“笑出来”
aigc·stable-diffusion·controlnet·comfyui·softedge·ip-adapter·人物表情编辑
悟空和大王1 小时前
内网环境: vue3中使用 iconify 的在线图标
前端
福大大架构师每日一题1 小时前
openclaw v2026.4.21 更新:图像生成、权限安全、插件修复、Slack 线程、浏览器与 npm 安装全面优化
前端·安全·npm
小赵同学WoW1 小时前
call(), appy(),bind() 之间的区别与使用方法,自己实现这三个函数
前端
t***5442 小时前
如何在 Dev-C++ 中设置 MinGW 和 Clang 的路径
java·前端·c++
拜托啦!狮子2 小时前
安装EnsDb.Hsapiens.v86
java·服务器·前端
金玉满堂@bj2 小时前
playwright使用教程总结
前端
scheduleTTe2 小时前
Nginx
服务器·前端·nginx
techdashen2 小时前
不开端口,不配 DNS,用树莓派在家搭一个公网可访问的 Web 服务
前端·网络·智能路由器