一、Why MindSpore Serving
大模型时代,作为一个开发人员更多的是关注一个大模型如何训练好、如何调整模型参数、如何才能得到一个更高的模型精度。而作为一个整体项目,只有项目落地才能有其真正的价值。那么如何才能够使得大模型实现落地?如何才能使大模型项目中的文件以app的形式呈现给用户?
解决这个问题的一个组件就是Serving(服务),它主要解决的问题有:
- 模型如何提交给服务;
- 服务如何部署;
- 服务如何呈现给用户;
- 如何应用各种复杂场景等待
MindSpore Serving就是为了实现将大模型部署到生产环境而产生的。
MindSpore Serving是一个轻量级、高性能的服务模块,旨在帮助MindSpore开发者在生产环境中高效部署在线推理服务。当用户使用MindSpore完成模型训练后,导出MindIR,即可使用MindSpore Serving创建该大模型的推理服务。
MindSpore Serving实现的是一个模型服务化的部署,也就是说模型以线上的形式部署在服务器和云上,客户通过浏览器或者客户端去访问这个服务,将需要进行推理的输入内容发送给服务器,然后服务器将推理的结果返回给用户。
二、Component
MindSpore Serving由三部分组成,分别是客户端(Client)、Master和Worker。
-
客户端是用户节点,提供了gRPC和RESTful的访问。
-
Master是一个管理节点,管理所有Worker的信息,包括Worker有哪些模型的信息;Master也是一个分化节点,接收到了客户端的请求之后,会根据请求的内容,结合当前管理的Worker节点的信息进行分发,将请求分发给不同的Worker执行。
-
Worker是一个执行节点,会执行加载、模型的更新,在接收到Master转发的请求之后,会将请求进行组装和拆分,然后做前处理、推理和后处理,执行完之后将结果返回给Master,Master再将结果返回给客户端。
三、Features
1.简单易用:
对客户端提供了gRPC和RESTful的服务,同时又提供了服务的拉起、服务的部署和客户端的访问,提供了简单的python接口,通过python接口,用户可以很方便的定制和访问部署服务,只需要一行命令就能够完成一件事。
2.提供定制化的服务:
对于模型来说输入和输出一般是固定的,而对于用户来说输入和输出可能是多变的,这就需要一个预处理模块,将模型的输入转为一个模型可以识别的输入。同时还需要一个后处理模块,给用户提供定制化的服务,针对模型可以定制方法classifly_top,用户根据需要去写前处理和后处理的操作。对于客户端来说只要指定模型名和方法名就能实现推理的结果。
3.支持批处理:
主要是针对具有batchsize维度的文本来说。batchsize实现了文本的并行,在硬件资源足够的情况下,batchsize可以很大地提高性能。对于MindSpore Serving来说,用户一次性发送的请求是不确定的,因此Serving分割和组合一个或者多个请求以匹配用户模型的batchsize。例如batchsize=2,但是有三个请求发过来,这时候就会将两个请求合并处理,到后面再拆分,这样就实现了三个请求的并行,提高了效率。
- 高性能扩展:
MindSpore Serving所使用的算子引擎框架是MindSpore框架,具有自动融合和自动并行的高性能,再加上MindSpore Serving本身具有一个高性能的底层通信能力,客户端可以进行多实例组装,模型支持批处理,多模型之间支持并发,预处理和后处理支持多线程的处理。客户端和Worker可以实现扩展的,因此它也实现了一个高扩展性。
四、Demo
基于昇腾910B3
start_agent.py
python
from agent.agent_multi_post_method import *
from multiprocessing import Queue
from config.serving_config import AgentConfig, ModelName
if __name__ == "__main__":
startup_queue = Queue(1024)
startup_agents(AgentConfig.ctx_setting,
AgentConfig.inc_setting,
AgentConfig.post_model_setting,
len(AgentConfig.AgentPorts),
AgentConfig.prefill_model,
AgentConfig.decode_model,
AgentConfig.argmax_model,
AgentConfig.topk_model,
startup_queue)
started_agents = 0
while True:
value = startup_queue.get()
print("agent : %f started" % value)
started_agents = started_agents + 1
if started_agents >= len(AgentConfig.AgentPorts):
print("all agents started")
break
# server_app_post.init_server_app()
# server_app_post.warmup_model(ModelName)
# server_app_post.run_server_app()
client/server_app_post.py
python
import asyncio
import json
import logging
import signal
import sys
import uuid
from multiprocessing import Process
import uvicorn
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
from client.client_utils import ClientRequest, Parameters
from config.serving_config import SERVER_APP_HOST, SERVER_APP_PORT
from server.llm_server_post import LLMServer
logging.basicConfig(level=logging.DEBUG,
filename='./output/server_app.log',
filemode='w',
format=
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s')
app = FastAPI()
llm_server = None
async def get_full_res(request, results):
all_texts = ''
async for result in results:
prompt_ = result.prompt
answer_texts = [output.text for output in result.outputs]
text = answer_texts[0]
if text is None:
text = ""
all_texts += text
ret = {
"generated_text": all_texts,
}
yield (json.dumps(ret, ensure_ascii=False) + '\n').encode("utf-8")
async def get_full_res_sse(request, results):
all_texts = ''
async for result in results:
answer_texts = [output.text for output in result.outputs]
text = answer_texts[0]
if text is None:
text = ""
all_texts += text
ret = {"event": "message", "retry": 30000, "generated_text": all_texts}
yield json.dumps(ret, ensure_ascii=False)
async def get_stream_res(request, results):
all_texts = ''
index = 0
async for result in results:
prompt_ = result.prompt
answer_texts = [output.text for output in result.outputs]
text = answer_texts[0]
if text is None:
text = ""
else:
index += 1
all_texts += text
ret = {
"token": {
"text": text,
"index": index
},
}
print(ret, index)
yield ("data:" + json.dumps(ret, ensure_ascii=False) + '\n').encode("utf-8")
print(all_texts)
return_full_text = request.parameters.return_full_text
if return_full_text:
ret = {
"generated_text": all_texts,
}
yield ("data:" + json.dumps(ret, ensure_ascii=False) + '\n').encode("utf-8")
async def get_stream_res_sse(request, results):
all_texts = ""
index = 0
async for result in results:
answer_texts = [output.text for output in result.outputs]
text = answer_texts[0]
if text is None:
text = ""
else:
index += 1
all_texts += text
ret = {"event": "message", "retry": 30000, "data": text}
yield json.dumps(ret, ensure_ascii=False)
print(all_texts)
if request.parameters.return_full_text:
ret = {"event": "message", "retry": 30000, "data": all_texts}
yield json.dumps(ret, ensure_ascii=False)
def send_request(request: ClientRequest):
print('request: ', request)
request_id = str(uuid.uuid1())
if request.parameters is None:
request.parameters = Parameters()
if request.parameters.do_sample is None:
request.parameters.do_sample = False
if request.parameters.top_k is None:
request.parameters.top_k = 3
if request.parameters.top_p is None:
request.parameters.top_p = 1.0
if request.parameters.temperature is None:
request.parameters.temperature = 1.0
if request.parameters.repetition_penalty is None:
request.parameters.repetition_penalty = 1.0
if request.parameters.max_new_tokens is None:
request.parameters.max_new_tokens = 300
if request.parameters.return_protocol is None:
request.parameters.return_protocol = "sse"
if request.parameters.top_k < 0:
request.parameters.top_k = 0
if request.parameters.top_p < 0.01:
request.parameters.top_p = 0.01
if request.parameters.top_p > 1.0:
request.parameters.top_p = 1.0
params = {
"prompt": request.inputs,
"do_sample": request.parameters.do_sample,
"top_k": request.parameters.top_k,
"top_p": request.parameters.top_p,
"temperature": request.parameters.temperature,
"repetition_penalty": request.parameters.repetition_penalty,
"max_token_len": request.parameters.max_new_tokens
}
print('generate_answer...')
global llm_server
results = llm_server.generate_answer(request_id, **params)
return results
@app.post("/models/llama2")
async def async_generator(request: ClientRequest):
results = send_request(request)
if request.stream:
if request.parameters.return_protocol == "sse":
print('get_stream_res_sse...')
return EventSourceResponse(get_stream_res_sse(request, results),
media_type="text/event-stream",
ping_message_factory=lambda: ServerSentEvent(
**{"comment": "You can't see this ping"}),
ping=600)
else:
print('get_stream_res...')
return StreamingResponse(get_stream_res(request, results))
else:
print('get_full_res...')
return StreamingResponse(get_full_res(request, results))
@app.post("/models/llama2/generate")
async def async_full_generator(request: ClientRequest):
results = send_request(request)
print('get_full_res...')
return StreamingResponse(get_full_res(request, results))
@app.post("/models/llama2/generate_stream")
async def async_stream_generator(request: ClientRequest):
results = send_request(request)
if request.parameters.return_protocol == "sse":
print('get_stream_res_sse...')
return EventSourceResponse(get_stream_res_sse(request, results),
media_type="text/event-stream",
ping_message_factory=lambda: ServerSentEvent(
**{"comment": "You can't see this ping"}),
ping=600)
else:
print('get_stream_res...')
return StreamingResponse(get_stream_res(request, results))
def update_internlm_request(request: ClientRequest):
if request.inputs:
request.inputs = "<s><|User|>:{}<eoh>\n<|Bot|>:".format(request.inputs)
@app.post("/models/internlm")
async def async_internlm_generator(request: ClientRequest):
# update_internlm_request(request)
return await async_generator(request)
@app.post("/models/internlm/generate")
async def async_internlm_full_generator(request: ClientRequest):
# update_internlm_request(request)
return await async_full_generator(request)
@app.post("/models/internlm/generate_stream")
async def async_internlm_stream_generator(request: ClientRequest):
# update_internlm_request(request)
return await async_stream_generator(request)
def init_server_app():
global llm_server
llm_server = LLMServer()
print('init server app finish')
async def warmup(request: ClientRequest):
request.parameters = Parameters(max_new_tokens=3)
results = send_request(request)
print('warmup get_stream_res...')
async for item in get_stream_res(request, results):
print(item)
def warmup_llama2():
request = ClientRequest(inputs="test")
asyncio.run(warmup(request))
print('warmup llama2 finish')
def warmup_internlm():
request = ClientRequest(inputs="test")
update_internlm_request(request)
asyncio.run(warmup(request))
print('warmup internlm finish')
def run_server_app():
print('server port is: ', SERVER_APP_PORT)
uvicorn.run(app, host=SERVER_APP_HOST, port=SERVER_APP_PORT)
WARMUP_MODEL_MAP = {
"llama": warmup_llama2,
"internlm": warmup_internlm,
}
def warmup_model(model_name):
model_prefix = model_name.split('_')[0]
if model_prefix in WARMUP_MODEL_MAP.keys():
func = WARMUP_MODEL_MAP[model_prefix]
warmup_process = Process(target=func)
warmup_process.start()
warmup_process.join()
print("mindspore serving is started.")
else:
print("model not support warmup : ", model_name)
async def _get_batch_size():
global llm_server
batch_size = llm_server.get_bs_current()
ret = {'event': "message", "retry": 30000, "data": batch_size}
yield json.dumps(ret, ensure_ascii=False)
async def _get_request_numbers():
global llm_server
queue_size = llm_server.get_queue_current()
ret = {'event': "message", "retry": 30000, "data": queue_size}
yield json.dumps(ret, ensure_ascii=False)
@app.get("/serving/get_bs")
async def get_batch_size():
return EventSourceResponse(_get_batch_size(),
media_type="text/event-stream",
ping_message_factory=lambda: ServerSentEvent(**{"comment": "You can't see this ping"}),
ping=600)
@app.get("/serving/get_request_numbers")
async def get_request_numbers():
return EventSourceResponse(_get_request_numbers(),
media_type="text/event-stream",
ping_message_factory=lambda: ServerSentEvent(**{"comment": "You can't see this ping"}),
ping=600)
def sig_term_handler(signal, frame):
print("catch SIGTERM")
global llm_server
llm_server.stop()
print("----serving exit----")
sys.exit(0)
if __name__ == "__main__":
signal.signal(signal.SIGTERM, sig_term_handler)
init_server_app()
# warmup_model(ModelName)
run_server_app()