Chainlit快速实现AI对话应用将聊天数据的持久化到postgres关系数据库中

概述

默认情况下,Chainlit 应用不会保留其生成的聊天和元素。即网页一刷新,所有的聊天记录,页面上的所有聊天记录都会消失。但是,存储和利用这些数据的能力可能是您的项目或组织的重要组成部分。

之前写过一篇文章《Chainlit快速实现AI对话应用并将聊天数据的持久化到sqllite本地数据库中》,这个技术方案的优点是,不需要自己在安装数据库,创建表结构等操作,缺点是,只适合用户量比较少的情况。使用postgres数据库可以解决中等规模的用户访问聊天记录访问问题。

教程

1. 安装chainlit依赖

bash 复制代码
pip install chainlit psycopg2 aiohttp aiofiles sqlalchemy

2. 配置环境变量

在项目根目录下,创建.env文件,内容如下:

bash 复制代码
OPENAI_BASE_URL="https://dashscope.aliyuncs.com/compatible-mode/v1"
OPENAI_API_KEY="your api_key"
  • 由于国内无法访问open aichatgpt,所以需要配置 OPENAI_BASE_URL的代理地址,如果使用国内的LLM大模型接口,可以使用兼容open ai的接口地址

安装postgres 数据库

可以参考这篇文章 《windows 安装PostgresSQL数据库简单教程》,安装postgres数据库后,使用navicat等数据管理工具,创建一个数据库,例如,名为chain_lit的数据库,然后导入一下创建表结构的sql命令:

sql 复制代码
CREATE TABLE users (
    "id" UUID PRIMARY KEY,
    "identifier" TEXT NOT NULL UNIQUE,
    "metadata" JSONB NOT NULL,
    "createdAt" TEXT
);

CREATE TABLE IF NOT EXISTS threads (
    "id" UUID PRIMARY KEY,
    "createdAt" TEXT,
    "name" TEXT,
    "userId" UUID,
    "userIdentifier" TEXT,
    "tags" TEXT[],
    "metadata" JSONB,
    FOREIGN KEY ("userId") REFERENCES users("id") ON DELETE CASCADE
);

CREATE TABLE IF NOT EXISTS steps (
    "id" UUID PRIMARY KEY,
    "name" TEXT NOT NULL,
    "type" TEXT NOT NULL,
    "threadId" UUID NOT NULL,
    "parentId" UUID,
    "disableFeedback" BOOLEAN NOT NULL,
    "streaming" BOOLEAN NOT NULL,
    "waitForAnswer" BOOLEAN,
    "isError" BOOLEAN,
    "metadata" JSONB,
    "tags" TEXT[],
    "input" TEXT,
    "output" TEXT,
    "createdAt" TEXT,
    "start" TEXT,
    "end" TEXT,
    "generation" JSONB,
    "showInput" TEXT,
    "language" TEXT,
    "indent" INT
);

CREATE TABLE IF NOT EXISTS elements (
    "id" UUID PRIMARY KEY,
    "threadId" UUID,
    "type" TEXT,
    "url" TEXT,
    "chainlitKey" TEXT,
    "name" TEXT NOT NULL,
    "display" TEXT,
    "objectKey" TEXT,
    "size" TEXT,
    "page" INT,
    "language" TEXT,
    "forId" UUID,
    "mime" TEXT
);

CREATE TABLE IF NOT EXISTS feedbacks (
    "id" UUID PRIMARY KEY,
    "forId" UUID NOT NULL,
    "threadId" UUID NOT NULL,
    "value" INT NOT NULL,
    "comment" TEXT
);

3. 创建代码

在项目艮目录下,创建postgres_client.py文件,代码如下:

python 复制代码
from typing import TYPE_CHECKING, Dict, Union, Any

import psycopg2  # type: ignore
from chainlit.data import BaseStorageClient
from chainlit.logger import logger
from psycopg2.extras import RealDictCursor

if TYPE_CHECKING:
    from psycopg2.extensions import connection, cursor


class PostgresStorageClient(BaseStorageClient):
    """
    Class to enable storage in a PostgreSQL database.

    parms:
        host: Hostname or IP address of the PostgreSQL server.
        dbname: Name of the database to connect to.
        user: User name used to authenticate.
        password: Password used to authenticate.
        port: Port number to connect to (default: 5432).
    """

    def __init__(self, host: str, dbname: str, user: str, password: str, port: int = 5432):
        try:
            self.conn: connection = psycopg2.connect(
                host=host,
                dbname=dbname,
                user=user,
                password=password,
                port=port
            )
            self.cursor: cursor = self.conn.cursor(cursor_factory=RealDictCursor)
            logger.info("PostgresStorageClient initialized")
        except Exception as e:
            logger.warn(f"PostgresStorageClient initialization error: {e}")

    async def upload_file(self, object_key: str, data: Union[bytes, str], mime: str = 'application/octet-stream',
                          overwrite: bool = True) -> Dict[str, Any]:
        try:
            # Assuming the table is called files and has columns id, object_key, data, and mime
            query = """
                INSERT INTO files (object_key, data, mime)
                VALUES (%s, %s, %s)
                ON CONFLICT (object_key)
                DO UPDATE SET data = EXCLUDED.data, mime = EXCLUDED.mime;
            """
            self.cursor.execute(query, (object_key, psycopg2.Binary(data) if isinstance(data, bytes) else data, mime))
            self.conn.commit()
            url = f"http://example.com/download/{object_key}"
            return {"object_key": object_key, "url": url}
        except Exception as e:
            logger.warn(f"PostgresStorageClient, upload_file error: {e}")
            return {}

在项目艮目录下,创建postgres_data.py文件,代码如下:

python 复制代码
import json
import ssl
import uuid
from dataclasses import asdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from literalai.helper import utc_now

import aiofiles
import aiohttp
from chainlit.context import context
from chainlit.data import BaseDataLayer, BaseStorageClient, queue_until_user_message
from chainlit.logger import logger
from chainlit.step import StepDict
from chainlit.types import (
    Feedback,
    FeedbackDict,
    PageInfo,
    PaginatedResponse,
    Pagination,
    ThreadDict,
    ThreadFilter,
)
from chainlit.user import PersistedUser, User
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker

if TYPE_CHECKING:
    from chainlit.element import Element, ElementDict
    from chainlit.step import StepDict


class PostgresDataLayer(BaseDataLayer):
    def __init__(
        self,
        conninfo: str,
        ssl_require: bool = False,
        storage_provider: Optional[BaseStorageClient] = None,
        user_thread_limit: Optional[int] = 1000,
        show_logger: Optional[bool] = False,
    ):
        self._conninfo = conninfo
        self.user_thread_limit = user_thread_limit
        self.show_logger = show_logger
        ssl_args = {}
        if ssl_require:
            # Create an SSL context to require an SSL connection
            ssl_context = ssl.create_default_context()
            ssl_context.check_hostname = False
            ssl_context.verify_mode = ssl.CERT_NONE
            ssl_args["ssl"] = ssl_context
        self.engine: AsyncEngine = create_async_engine(
            self._conninfo, connect_args=ssl_args
        )
        self.async_session = sessionmaker(bind=self.engine, expire_on_commit=False, class_=AsyncSession)  # type: ignore
        if storage_provider:
            self.storage_provider: Optional[BaseStorageClient] = storage_provider
            if self.show_logger:
                logger.info("SQLAlchemyDataLayer storage client initialized")
        else:
            self.storage_provider = None
            logger.warn(
                "SQLAlchemyDataLayer storage client is not initialized and elements will not be persisted!"
            )

    async def build_debug_url(self) -> str:
        return ""

    ###### SQL Helpers ######
    async def execute_sql(
        self, query: str, parameters: dict
    ) -> Union[List[Dict[str, Any]], int, None]:
        parameterized_query = text(query)
        async with self.async_session() as session:
            try:
                await session.begin()
                result = await session.execute(parameterized_query, parameters)
                await session.commit()
                if result.returns_rows:
                    json_result = [dict(row._mapping) for row in result.fetchall()]
                    clean_json_result = self.clean_result(json_result)
                    return clean_json_result
                else:
                    return result.rowcount
            except SQLAlchemyError as e:
                await session.rollback()
                logger.warn(f"An error occurred: {e}")
                return None
            except Exception as e:
                await session.rollback()
                logger.warn(f"An unexpected error occurred: {e}")
                return None

    async def get_current_timestamp(self) -> str:
        return utc_now()

    def clean_result(self, obj):
        """Recursively change UUID -> str and serialize dictionaries"""
        if isinstance(obj, dict):
            return {k: self.clean_result(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [self.clean_result(item) for item in obj]
        elif isinstance(obj, uuid.UUID):
            return str(obj)
        return obj

    ###### User ######
    async def get_user(self, identifier: str) -> Optional[PersistedUser]:
        if self.show_logger:
            logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
        query = "SELECT * FROM users WHERE identifier = :identifier"
        parameters = {"identifier": identifier}
        result = await self.execute_sql(query=query, parameters=parameters)
        if result and isinstance(result, list):
            user_data = result[0]
            return PersistedUser(**user_data)
        return None

    async def create_user(self, user: User) -> Optional[PersistedUser]:
        if self.show_logger:
            logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
        existing_user: Optional["PersistedUser"] = await self.get_user(user.identifier)
        user_dict: Dict[str, Any] = {
            "identifier": str(user.identifier),
            "metadata": json.dumps(user.metadata) or {},
        }
        if not existing_user:  # create the user
            if self.show_logger:
                logger.info("SQLAlchemy: create_user, creating the user")
            user_dict["id"] = str(uuid.uuid4())
            user_dict["createdAt"] = await self.get_current_timestamp()
            query = """INSERT INTO users ("id", "identifier", "createdAt", "metadata") VALUES (:id, :identifier, :createdAt, :metadata)"""
            await self.execute_sql(query=query, parameters=user_dict)
        else:  # update the user
            if self.show_logger:
                logger.info("SQLAlchemy: update user metadata")
            query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier"""
            await self.execute_sql(
                query=query, parameters=user_dict
            )  # We want to update the metadata
        return await self.get_user(user.identifier)

    ###### Threads ######
    async def get_thread_author(self, thread_id: str) -> str:
        if self.show_logger:
            logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
        query = """SELECT "userIdentifier" FROM threads WHERE "id" = :id"""
        parameters = {"id": thread_id}
        result = await self.execute_sql(query=query, parameters=parameters)
        if isinstance(result, list) and result:
            author_identifier = result[0].get("userIdentifier")
            if author_identifier is not None:
                return author_identifier
        raise ValueError(f"Author not found for thread_id {thread_id}")

    async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
        if self.show_logger:
            logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
        user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(
            thread_id=thread_id
        )
        if user_threads:
            return user_threads[0]
        else:
            return None

    async def update_thread(
        self,
        thread_id: str,
        name: Optional[str] = None,
        user_id: Optional[str] = None,
        metadata: Optional[Dict] = None,
        tags: Optional[List[str]] = None,
    ):
        if self.show_logger:
            logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
        if context.session.user is not None:
            user_identifier = context.session.user.identifier
        else:
            raise ValueError("User not found in session context")
        data = {
            "id": thread_id,
            "createdAt": (
                await self.get_current_timestamp() if metadata is None else None
            ),
            "name": (
                name
                if name is not None
                else (metadata.get("name") if metadata and "name" in metadata else None)
            ),
            "userId": user_id,
            "userIdentifier": user_identifier,
            "tags": tags,
            "metadata": json.dumps(metadata) if metadata else None,
        }
        parameters = {
            key: value for key, value in data.items() if value is not None
        }  # Remove keys with None values
        columns = ", ".join(f'"{key}"' for key in parameters.keys())
        values = ", ".join(f":{key}" for key in parameters.keys())
        updates = ", ".join(
            f'"{key}" = EXCLUDED."{key}"' for key in parameters.keys() if key != "id"
        )
        query = f"""
            INSERT INTO threads ({columns})
            VALUES ({values})
            ON CONFLICT ("id") DO UPDATE
            SET {updates};
        """
        await self.execute_sql(query=query, parameters=parameters)

    async def delete_thread(self, thread_id: str):
        if self.show_logger:
            logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")
        # Delete feedbacks/elements/steps/thread
        feedbacks_query = """DELETE FROM feedbacks WHERE "forId" IN (SELECT "id" FROM steps WHERE "threadId" = :id)"""
        elements_query = """DELETE FROM elements WHERE "threadId" = :id"""
        steps_query = """DELETE FROM steps WHERE "threadId" = :id"""
        thread_query = """DELETE FROM threads WHERE "id" = :id"""
        parameters = {"id": thread_id}
        await self.execute_sql(query=feedbacks_query, parameters=parameters)
        await self.execute_sql(query=elements_query, parameters=parameters)
        await self.execute_sql(query=steps_query, parameters=parameters)
        await self.execute_sql(query=thread_query, parameters=parameters)

    async def list_threads(
        self, pagination: Pagination, filters: ThreadFilter
    ) -> PaginatedResponse:
        if self.show_logger:
            logger.info(
                f"SQLAlchemy: list_threads, pagination={pagination}, filters={filters}"
            )
        if not filters.userId:
            raise ValueError("userId is required")
        all_user_threads: List[ThreadDict] = (
            await self.get_all_user_threads(user_id=filters.userId) or []
        )

        search_keyword = filters.search.lower() if filters.search else None
        feedback_value = int(filters.feedback) if filters.feedback else None

        filtered_threads = []
        for thread in all_user_threads:
            keyword_match = True
            feedback_match = True
            if search_keyword or feedback_value is not None:
                if search_keyword:
                    keyword_match = any(
                        search_keyword in step["output"].lower()
                        for step in thread["steps"]
                        if "output" in step
                    )
                if feedback_value is not None:
                    feedback_match = False  # Assume no match until found
                    for step in thread["steps"]:
                        feedback = step.get("feedback")
                        if feedback and feedback.get("value") == feedback_value:
                            feedback_match = True
                            break
            if keyword_match and feedback_match:
                filtered_threads.append(thread)

        start = 0
        if pagination.cursor:
            for i, thread in enumerate(filtered_threads):
                if (
                    thread["id"] == pagination.cursor
                ):  # Find the start index using pagination.cursor
                    start = i + 1
                    break
        end = start + pagination.first
        paginated_threads = filtered_threads[start:end] or []
        has_next_page = len(filtered_threads) > end
        start_cursor = paginated_threads[0]["id"] if paginated_threads else None
        end_cursor = paginated_threads[-1]["id"] if paginated_threads else None
        return PaginatedResponse(
            pageInfo=PageInfo(
                hasNextPage=has_next_page,
                startCursor=start_cursor,
                endCursor=end_cursor,
            ),
            data=paginated_threads,
        )

    ###### Steps ######
    @queue_until_user_message()
    async def create_step(self, step_dict: "StepDict"):
        if self.show_logger:
            logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
        if not getattr(context.session.user, "id", None):
            raise ValueError("No authenticated user in context")
        step_dict["showInput"] = (
            str(step_dict.get("showInput", "")).lower()
            if "showInput" in step_dict
            else None
        )
        parameters = {
            key: value
            for key, value in step_dict.items()
            if value is not None and not (isinstance(value, dict) and not value)
        }
        parameters["metadata"] = json.dumps(step_dict.get("metadata", {}))
        parameters["generation"] = json.dumps(step_dict.get("generation", {}))
        columns = ", ".join(f'"{key}"' for key in parameters.keys())
        values = ", ".join(f":{key}" for key in parameters.keys())
        updates = ", ".join(
            f'"{key}" = :{key}' for key in parameters.keys() if key != "id"
        )
        query = f"""
            INSERT INTO steps ({columns})
            VALUES ({values})
            ON CONFLICT (id) DO UPDATE
            SET {updates};
        """
        await self.execute_sql(query=query, parameters=parameters)

    @queue_until_user_message()
    async def update_step(self, step_dict: "StepDict"):
        if self.show_logger:
            logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
        await self.create_step(step_dict)

    @queue_until_user_message()
    async def delete_step(self, step_id: str):
        if self.show_logger:
            logger.info(f"SQLAlchemy: delete_step, step_id={step_id}")
        # Delete feedbacks/elements/steps
        feedbacks_query = """DELETE FROM feedbacks WHERE "forId" = :id"""
        elements_query = """DELETE FROM elements WHERE "forId" = :id"""
        steps_query = """DELETE FROM steps WHERE "id" = :id"""
        parameters = {"id": step_id}
        await self.execute_sql(query=feedbacks_query, parameters=parameters)
        await self.execute_sql(query=elements_query, parameters=parameters)
        await self.execute_sql(query=steps_query, parameters=parameters)

    ###### Feedback ######
    async def upsert_feedback(self, feedback: Feedback) -> str:
        if self.show_logger:
            logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
        feedback.id = feedback.id or str(uuid.uuid4())
        feedback_dict = asdict(feedback)
        parameters = {
            key: value for key, value in feedback_dict.items() if value is not None
        }

        columns = ", ".join(f'"{key}"' for key in parameters.keys())
        values = ", ".join(f":{key}" for key in parameters.keys())
        updates = ", ".join(
            f'"{key}" = :{key}' for key in parameters.keys() if key != "id"
        )
        query = f"""
            INSERT INTO feedbacks ({columns})
            VALUES ({values})
            ON CONFLICT (id) DO UPDATE
            SET {updates};
        """
        await self.execute_sql(query=query, parameters=parameters)
        return feedback.id

    async def delete_feedback(self, feedback_id: str) -> bool:
        if self.show_logger:
            logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
        query = """DELETE FROM feedbacks WHERE "id" = :feedback_id"""
        parameters = {"feedback_id": feedback_id}
        await self.execute_sql(query=query, parameters=parameters)
        return True

    ###### Elements ######
    @queue_until_user_message()
    async def create_element(self, element: "Element"):
        if self.show_logger:
            logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
        if not getattr(context.session.user, "id", None):
            raise ValueError("No authenticated user in context")
        if not self.storage_provider:
            logger.warn(
                f"SQLAlchemy: create_element error. No blob_storage_client is configured!"
            )
            return
        if not element.for_id:
            return

        content: Optional[Union[bytes, str]] = None

        if element.path:
            async with aiofiles.open(element.path, "rb") as f:
                content = await f.read()
        elif element.url:
            async with aiohttp.ClientSession() as session:
                async with session.get(element.url) as response:
                    if response.status == 200:
                        content = await response.read()
                    else:
                        content = None
        elif element.content:
            content = element.content
        else:
            raise ValueError("Element url, path or content must be provided")
        if content is None:
            raise ValueError("Content is None, cannot upload file")

        context_user = context.session.user

        user_folder = getattr(context_user, "id", "unknown")
        file_object_key = f"{user_folder}/{element.id}" + (
            f"/{element.name}" if element.name else ""
        )

        if not element.mime:
            element.mime = "application/octet-stream"

        uploaded_file = await self.storage_provider.upload_file(
            object_key=file_object_key, data=content, mime=element.mime, overwrite=True
        )
        if not uploaded_file:
            raise ValueError(
                "SQLAlchemy Error: create_element, Failed to persist data in storage_provider"
            )

        element_dict: ElementDict = element.to_dict()

        element_dict["url"] = uploaded_file.get("url")
        element_dict["objectKey"] = uploaded_file.get("object_key")
        element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None}

        columns = ", ".join(f'"{column}"' for column in element_dict_cleaned.keys())
        placeholders = ", ".join(f":{column}" for column in element_dict_cleaned.keys())
        query = f"INSERT INTO elements ({columns}) VALUES ({placeholders})"
        await self.execute_sql(query=query, parameters=element_dict_cleaned)

    @queue_until_user_message()
    async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
        if self.show_logger:
            logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")
        query = """DELETE FROM elements WHERE "id" = :id"""
        parameters = {"id": element_id}
        await self.execute_sql(query=query, parameters=parameters)

    async def delete_user_session(self, id: str) -> bool:
        return False  # Not sure why documentation wants this

    async def get_all_user_threads(
        self, user_id: Optional[str] = None, thread_id: Optional[str] = None
    ) -> Optional[List[ThreadDict]]:
        """Fetch all user threads up to self.user_thread_limit, or one thread by id if thread_id is provided."""
        if self.show_logger:
            logger.info(f"SQLAlchemy: get_all_user_threads")
        user_threads_query = """
            SELECT
                "id" AS thread_id,
                "createdAt" AS thread_createdat,
                "name" AS thread_name,
                "userId" AS user_id,
                "userIdentifier" AS user_identifier,
                "tags" AS thread_tags,
                "metadata" AS thread_metadata
            FROM threads
            WHERE "userId" = :user_id OR "id" = :thread_id
            ORDER BY "createdAt" DESC
            LIMIT :limit
        """
        user_threads = await self.execute_sql(
            query=user_threads_query,
            parameters={
                "user_id": user_id,
                "limit": self.user_thread_limit,
                "thread_id": thread_id,
            },
        )
        if not isinstance(user_threads, list):
            return None
        if not user_threads:
            return []
        else:
            thread_ids = (
                "('"
                + "','".join(map(str, [thread["thread_id"] for thread in user_threads]))
                + "')"
            )

        steps_feedbacks_query = f"""
            SELECT
                s."id" AS step_id,
                s."name" AS step_name,
                s."type" AS step_type,
                s."threadId" AS step_threadid,
                s."parentId" AS step_parentid,
                s."streaming" AS step_streaming,
                s."waitForAnswer" AS step_waitforanswer,
                s."isError" AS step_iserror,
                s."metadata" AS step_metadata,
                s."tags" AS step_tags,
                s."input" AS step_input,
                s."output" AS step_output,
                s."createdAt" AS step_createdat,
                s."start" AS step_start,
                s."end" AS step_end,
                s."generation" AS step_generation,
                s."showInput" AS step_showinput,
                s."language" AS step_language,
                s."indent" AS step_indent,
                f."value" AS feedback_value,
                f."comment" AS feedback_comment
            FROM steps s LEFT JOIN feedbacks f ON s."id" = f."forId"
            WHERE s."threadId" IN {thread_ids}
            ORDER BY s."createdAt" ASC
        """
        steps_feedbacks = await self.execute_sql(
            query=steps_feedbacks_query, parameters={}
        )

        elements_query = f"""
            SELECT
                e."id" AS element_id,
                e."threadId" as element_threadid,
                e."type" AS element_type,
                e."chainlitKey" AS element_chainlitkey,
                e."url" AS element_url,
                e."objectKey" as element_objectkey,
                e."name" AS element_name,
                e."display" AS element_display,
                e."size" AS element_size,
                e."language" AS element_language,
                e."page" AS element_page,
                e."forId" AS element_forid,
                e."mime" AS element_mime
            FROM elements e
            WHERE e."threadId" IN {thread_ids}
        """
        elements = await self.execute_sql(query=elements_query, parameters={})

        thread_dicts = {}
        for thread in user_threads:
            thread_id = thread["thread_id"]
            if thread_id is not None:
                thread_dicts[thread_id] = ThreadDict(
                    id=thread_id,
                    createdAt=thread["thread_createdat"],
                    name=thread["thread_name"],
                    userId=thread["user_id"],
                    userIdentifier=thread["user_identifier"],
                    tags=thread["thread_tags"],
                    metadata=thread["thread_metadata"],
                    steps=[],
                    elements=[],
                )
        # Process steps_feedbacks to populate the steps in the corresponding ThreadDict
        if isinstance(steps_feedbacks, list):
            for step_feedback in steps_feedbacks:
                thread_id = step_feedback["step_threadid"]
                if thread_id is not None:
                    feedback = None
                    if step_feedback["feedback_value"] is not None:
                        feedback = FeedbackDict(
                            forId=step_feedback["step_id"],
                            id=step_feedback.get("feedback_id"),
                            value=step_feedback["feedback_value"],
                            comment=step_feedback.get("feedback_comment"),
                        )
                    step_dict = StepDict(
                        id=step_feedback["step_id"],
                        name=step_feedback["step_name"],
                        type=step_feedback["step_type"],
                        threadId=thread_id,
                        parentId=step_feedback.get("step_parentid"),
                        streaming=step_feedback.get("step_streaming", False),
                        waitForAnswer=step_feedback.get("step_waitforanswer"),
                        isError=step_feedback.get("step_iserror"),
                        metadata=(
                            step_feedback["step_metadata"]
                            if step_feedback.get("step_metadata") is not None
                            else {}
                        ),
                        tags=step_feedback.get("step_tags"),
                        input=(
                            step_feedback.get("step_input", "")
                            if step_feedback["step_showinput"] == "true"
                            else None
                        ),
                        output=step_feedback.get("step_output", ""),
                        createdAt=step_feedback.get("step_createdat"),
                        start=step_feedback.get("step_start"),
                        end=step_feedback.get("step_end"),
                        generation=step_feedback.get("step_generation"),
                        showInput=step_feedback.get("step_showinput"),
                        language=step_feedback.get("step_language"),
                        indent=step_feedback.get("step_indent"),
                        feedback=feedback,
                    )
                    # Append the step to the steps list of the corresponding ThreadDict
                    thread_dicts[thread_id]["steps"].append(step_dict)

        if isinstance(elements, list):
            for element in elements:
                thread_id = element["element_threadid"]
                if thread_id is not None:
                    element_dict = ElementDict(
                        id=element["element_id"],
                        threadId=thread_id,
                        type=element["element_type"],
                        chainlitKey=element.get("element_chainlitkey"),
                        url=element.get("element_url"),
                        objectKey=element.get("element_objectkey"),
                        name=element["element_name"],
                        display=element["element_display"],
                        size=element.get("element_size"),
                        language=element.get("element_language"),
                        autoPlay=element.get("element_autoPlay"),
                        playerConfig=element.get("element_playerconfig"),
                        page=element.get("element_page"),
                        forId=element.get("element_forid"),
                        mime=element.get("element_mime"),
                    )
                    thread_dicts[thread_id]["elements"].append(element_dict)  # type: ignore

        return list(thread_dicts.values())

在项目根目录下,创建一个app.py的文件,代码如下:

python 复制代码
from typing import List, Optional

import chainlit as cl
import chainlit.data as cl_data
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
from openai import AsyncOpenAI

from postgres_client import PostgresStorageClient

client = AsyncOpenAI()


thread_history = []  # type: List[cl_data.ThreadDict]
deleted_thread_ids = []  # type: List[str]

storage_client = PostgresStorageClient(host="postgres数据库IP", dbname="postgres数据库名称", port=5432, user="postgres数据库账户",
                                       password="postgres数据库密码")

cl_data._data_layer = SQLAlchemyDataLayer(
    conninfo="postgresql+asyncpg://username:password@ip:port/dbname",
    storage_provider=storage_client)


@cl.on_chat_start
async def main():
    content = "你好,我是泰山AI智能客服,有什么可以帮助您吗?"
    await cl.Message(content).send()


@cl.on_message
async def handle_message():
    # Wait for queue to be flushed
    await cl.sleep(1)
    msg = cl.Message(content="")
    await msg.send()

    stream = await client.chat.completions.create(
        model="qwen-turbo", messages=cl.chat_context.to_openai(), stream=True
    )

    async for part in stream:
        if token := part.choices[0].delta.content or "":
            await msg.stream_token(token)
    await msg.update()


@cl.password_auth_callback
def auth_callback(username: str, password: str) -> Optional[cl.User]:
    if (username, password) == ("admin", "admin"):
        return cl.User(identifier="admin")
    else:
        return None


@cl.on_chat_resume
async def on_chat_resume():
    pass
  • 将代码中关于postgres数据库连接信息,修改为自己的即可。

4. 执行命令创建 AUTH_SECRET 鉴权

bash 复制代码
chainlit create-secret 

复制最后一行代码到.env环境配置文件中

bash 复制代码
CHAINLIT_AUTH_SECRET="$b?/v0NeJlAU~I5As1WSCa,j8wJ3w%agTyIFlUt4408?mfC*,/wovlfA%3O/751U"
OPENAI_BASE_URL="https://dashscope.aliyuncs.com/compatible-mode/v1"
OPENAI_API_KEY=""

5. 执行服务启动命令

bash 复制代码
chainlit run app.py -w

6. 启动后效果展示

  • 现在聊天记录都被保存在服务的sqllite本地数据库中了,只要不重启服务,聊天记录就不会丢失了!

相关文章推荐

《使用 Xinference 部署本地模型》
《Fastgpt接入Whisper本地模型实现语音输入》
《Fastgpt部署和接入使用重排模型bge-reranker》
《Fastgpt部署接入 M3E和chatglm2-m3e文本向量模型》
《Fastgpt 无法启动或启动后无法正常使用的讨论(启动失败、用户未注册等问题这里)》
《vllm推理服务兼容openai服务API》
《vLLM模型推理引擎参数大全》
《解决vllm推理框架内在开启多显卡时报错问题》
《Ollama 在本地快速部署大型语言模型,可进行定制并创建属于您自己的模型》

相关推荐
QQ同步助手4 分钟前
如何正确使用人工智能:开启智慧学习与创新之旅
人工智能·学习·百度
AIGC大时代7 分钟前
如何使用ChatGPT辅助文献综述,以及如何进行优化?一篇说清楚
人工智能·深度学习·chatgpt·prompt·aigc
流浪的小新12 分钟前
【AI】人工智能、LLM学习资源汇总
人工智能·学习
martian6651 小时前
【人工智能数学基础篇】——深入详解多变量微积分:在机器学习模型中优化损失函数时应用
人工智能·机器学习·微积分·数学基础
人机与认知实验室2 小时前
人、机、环境中各有其神经网络系统
人工智能·深度学习·神经网络·机器学习
黑色叉腰丶大魔王2 小时前
基于 MATLAB 的图像增强技术分享
图像处理·人工智能·计算机视觉
迅易科技5 小时前
借助腾讯云质检平台的新范式,做工业制造企业质检的“AI慧眼”
人工智能·视觉检测·制造
古希腊掌管学习的神6 小时前
[机器学习]XGBoost(3)——确定树的结构
人工智能·机器学习
ZHOU_WUYI7 小时前
4.metagpt中的软件公司智能体 (ProjectManager 角色)
人工智能·metagpt
靴子学长7 小时前
基于字节大模型的论文翻译(含免费源码)
人工智能·深度学习·nlp