大模型流式输出简谈

大模型流式输出是绕不开的一环,本文我将简单写一个示例,带你了解并简单上手

python代码准备

  • 本次需要用到的包

要用到StreamingResponse来处理流式输出

python 复制代码
import os
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI,Body
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from langchain.chat_models import init_chat_model
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel
  • 流式接口编写

    • 构建自定义请求体ChatRequest,接收用户输入的字符串

    • 构建异步方法event_generator()

    • event_generator()传入StreamingResponse中,并设置media_type参数为text/plain

python 复制代码
class ChatRequest(BaseModel):
    user_input: str

@app.post('/chat/stream_resp')
async def chat_stream_resp(request: ChatRequest):
    async def event_generator():
        async for chunk in chain.astream({'user_input': request.user_input}):
            if chunk:
                yield chunk.encode('utf-8')
    return StreamingResponse(event_generator(), media_type='text/plain')

前端代码准备

  • 主要讲解一下javascript中流式输出方法的实现
    • 构建POST请求与fastapi后端接口对接,注意fetch无法像axios那样自动解析JSON,所以要将POST请求的请求体做JSON转换
    • 为了做流式输出,要利用decoder,先构造reader指定其为response.body.getReader(),再构建decoder,新建TextDecoder('utf-8')并创建变量buffer用其来做后端响应的接收
    • 构建循环,进行解包赋值,拿到信号量和返回值,如果流式输出结束,信号量done将为true,否则则证明正在进行流式响应,利用decodervalue响应分片,指定streamtrue
    • 接收的流式响应拼接到buffer并对前端进行返回
html 复制代码
<script>
    async function startStream() {
        const output = document.getElementById('output');
        const user_input = document.getElementById('user_input').value;
        output.innerHTML = '';
        output.innerText = '';
        try {
            const response = await fetch('http://127.0.0.1:8000/chat/stream_resp', {
                    method: 'POST',
                    headers:{
                      'Content-Type':'application/json'
                    },
                    body: JSON.stringify({   //important:发送的是json格式字符串
                        'user_input': user_input
                    })
                }
            );
            if (!response.ok) throw new Error('网络响应失败!')
            const reader = response.body.getReader()
            const decoder = new TextDecoder('utf-8')
            let buffer = ''
            while (true) {
                const {done, value} = await reader.read()
                if (done) break;
                const chunk = decoder.decode(value, {stream: true})
                buffer += chunk;
                output.textContent = buffer
                output.scrollTop = output.scrollHeight;
            }
        } catch (err) {
            console.log('请求出错', err)
            output.textContent += '\n出错了!' + err.message;
        }
    }
</script>

整体代码

python

python 复制代码
import os

import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI,Body
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from langchain.chat_models import init_chat_model
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel

load_dotenv()

#定义程序
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

#初始化大模型
llm=init_chat_model(
    model='glm-4.7',
    model_provider='openai',
    api_key=os.getenv('zhipu_key'),
    base_url=os.getenv('zhipu_base_url')
)



#定义提示词
prompt_template = ChatPromptTemplate(
    messages=[
        ('system','你现在是一个小说家,你会讲小说'),
        ('human','{user_input}')
    ]
)


#定义lcel

chain=prompt_template|llm|StrOutputParser()


#important:方案2,构建自定义请求体,更规范
class ChatRequest(BaseModel):
    user_input: str

@app.post('/chat/stream_resp')
async def chat_stream_resp(request: ChatRequest):
    async def event_generator():
        async for chunk in chain.astream({'user_input': request.user_input}):
            if chunk:
                yield chunk.encode('utf-8')
    return StreamingResponse(event_generator(), media_type='text/plain')

if __name__ == '__main__':
    uvicorn.run(app, host='127.0.0.1', port=8000)

前端

html 复制代码
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>Title</title>
    <style>
        #output {
            width: 200px;
            height: 300px;
            background-color: #c3e6cb;
            margin-bottom: 30px;
        }
    </style>
</head>
<body>
<div id="output"></div>
用户输入:<input id="user_input">
<button onclick="startStream()">发送</button>
<script>
    async function startStream() {
        const output = document.getElementById('output');
        const user_input = document.getElementById('user_input').value;
        output.innerHTML = '';
        output.innerText = '';
        try {
            const response = await fetch('http://127.0.0.1:8000/chat/stream_resp', {
                    method: 'POST',
                    headers:{
                      'Content-Type':'application/json'
                    },
                    body: JSON.stringify({   //important:发送的是json格式字符串
                        'user_input': user_input
                    })
                }
            );
            if (!response.ok) throw new Error('网络响应失败!')
            const reader = response.body.getReader()
            const decoder = new TextDecoder('utf-8')
            let buffer = ''
            while (true) {
                const {done, value} = await reader.read()
                if (done) break;
                const chunk = decoder.decode(value, {stream: true})
                buffer += chunk;
                output.textContent = buffer
                output.scrollTop = output.scrollHeight;
            }
        } catch (err) {
            console.log('请求出错', err)
            output.textContent += '\n出错了!' + err.message;
        }
    }
</script>
</body>
</html>

希望对你有帮助