汇总开源大模型的本地API启动方式

文章目录

CodeGeex2

py 复制代码
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime
import torch
import argparse
try:
    import chatglm_cpp
    enable_chatglm_cpp = True
except:
    print("[WARN] chatglm-cpp not found. Install it by `pip install chatglm-cpp` for better performance. "
          "Check out https://github.com/li-plus/chatglm.cpp for more details.")
    enable_chatglm_cpp = False

LANGUAGE_TAG = {
    "Abap"         : "* language: Abap",
    "ActionScript" : "// language: ActionScript",
    "Ada"          : "-- language: Ada",
    "Agda"         : "-- language: Agda",
    "ANTLR"        : "// language: ANTLR",
    "AppleScript"  : "-- language: AppleScript",
    "Assembly"     : "; language: Assembly",
    "Augeas"       : "// language: Augeas",
    "AWK"          : "// language: AWK",
    "Basic"        : "' language: Basic",
    "C"            : "// language: C",
    "C#"           : "// language: C#",
    "C++"          : "// language: C++",
    "CMake"        : "# language: CMake",
    "Cobol"        : "// language: Cobol",
    "CSS"          : "/* language: CSS */",
    "CUDA"         : "// language: Cuda",
    "Dart"         : "// language: Dart",
    "Delphi"       : "{language: Delphi}",
    "Dockerfile"   : "# language: Dockerfile",
    "Elixir"       : "# language: Elixir",
    "Erlang"       : f"% language: Erlang",
    "Excel"        : "' language: Excel",
    "F#"           : "// language: F#",
    "Fortran"      : "!language: Fortran",
    "GDScript"     : "# language: GDScript",
    "GLSL"         : "// language: GLSL",
    "Go"           : "// language: Go",
    "Groovy"       : "// language: Groovy",
    "Haskell"      : "-- language: Haskell",
    "HTML"         : "<!--language: HTML-->",
    "Isabelle"     : "(*language: Isabelle*)",
    "Java"         : "// language: Java",
    "JavaScript"   : "// language: JavaScript",
    "Julia"        : "# language: Julia",
    "Kotlin"       : "// language: Kotlin",
    "Lean"         : "-- language: Lean",
    "Lisp"         : "; language: Lisp",
    "Lua"          : "// language: Lua",
    "Markdown"     : "<!--language: Markdown-->",
    "Matlab"       : f"% language: Matlab",
    "Objective-C"  : "// language: Objective-C",
    "Objective-C++": "// language: Objective-C++",
    "Pascal"       : "// language: Pascal",
    "Perl"         : "# language: Perl",
    "PHP"          : "// language: PHP",
    "PowerShell"   : "# language: PowerShell",
    "Prolog"       : f"% language: Prolog",
    "Python"       : "# language: Python",
    "R"            : "# language: R",
    "Racket"       : "; language: Racket",
    "RMarkdown"    : "# language: RMarkdown",
    "Ruby"         : "# language: Ruby",
    "Rust"         : "// language: Rust",
    "Scala"        : "// language: Scala",
    "Scheme"       : "; language: Scheme",
    "Shell"        : "# language: Shell",
    "Solidity"     : "// language: Solidity",
    "SPARQL"       : "# language: SPARQL",
    "SQL"          : "-- language: SQL",
    "Swift"        : "// language: swift",
    "TeX"          : f"% language: TeX",
    "Thrift"       : "/* language: Thrift */",
    "TypeScript"   : "// language: TypeScript",
    "Vue"          : "<!--language: Vue-->",
    "Verilog"      : "// language: Verilog",
    "Visual Basic" : "' language: Visual Basic",
}

app = FastAPI()
def device(config, model_path):
    if enable_chatglm_cpp and config.use_chatglm_cpp:
        print("Using chatglm-cpp to improve performance")
        dtype = "f16" if config.half else "f32"
        if config.quantize in [4, 5, 8]:
            dtype = f"q{config.quantize}_0"
        model = chatglm_cpp.Pipeline(model_path, dtype=dtype)
        return model

    print("chatglm-cpp not enabled, falling back to transformers")
    if config.device != "cpu":
        if not config.half:
            model = AutoModel.from_pretrained(model_path, trust_remote_code=True).cuda(int(config.device))
        else:
            model = AutoModel.from_pretrained(model_path, trust_remote_code=True).cuda(int(config.device)).half()
        if config.quantize in [4, 8]:
            print(f"Model is quantized to INT{config.quantize} format.")
            model = model.half().quantize(config.quantize)
    else:
        model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
    return model.eval()

@app.post("/")
async def create_item(request: Request):
    global model, tokenizer
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    lang = json_post_list.get('lang')
    prompt = json_post_list.get('prompt')
    max_length = json_post_list.get('max_length', 128)
    top_p = json_post_list.get('top_p', 0.95)
    temperature = json_post_list.get('temperature', 0.2)
    top_k = json_post_list.get('top_k', 0)
    if lang != "None":
        prompt = LANGUAGE_TAG[lang] + "\n" + prompt
    if enable_chatglm_cpp and use_chatglm_cpp:
        response = model.generate(prompt,
                                  max_length=max_length,
                                  do_sample=temperature > 0,
                                  top_p=top_p,
                                  top_k=top_k,
                                  temperature=temperature)
    else:
        response = model.chat(tokenizer,
                              prompt,
                              max_length=max_length,
                              top_p=top_p,
                              top_k=top_k,
                              temperature=temperature)
    now = datetime.datetime.now()
    time = now.strftime("%Y-%m-%d %H:%M:%S")
    answer = {
        "response": response,
        "lang": lang,
        "status": 200,
        "time": time
    }

    return answer


def api_start(config):
    global use_chatglm_cpp
    use_chatglm_cpp = config.use_chatglm_cpp
    model_path = "CodeModels/CodeGeex2"
    global tokenizer
    global model
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = device(config, model_path)
    uvicorn.run(app, host="0.0.0.0", port=7861, workers=1)

ChatGLM2_6B

py 复制代码
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime
import torch

def torch_gc(mydevice):
    if torch.cuda.is_available():
        with torch.cuda.device(mydevice):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

app = FastAPI()
def device(config, model_path):
    if config.device != "cpu":
        if not config.half:
            model = AutoModel.from_pretrained(model_path, trust_remote_code=True).cuda(int(config.device))
        else:
            model = AutoModel.from_pretrained(model_path, trust_remote_code=True).cuda(int(config.device)).half()
        if config.quantize in [4, 8]:
            print(f"Model is quantized to INT{config.quantize} format.")
            model = model.half().quantize(config.quantize)
    else:
        model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
    return model.eval()

@app.post("/")
async def create_item(request: Request):
    global model, tokenizer
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    prompt = json_post_list.get('prompt')
    history = json_post_list.get('history', [])
    max_length = json_post_list.get('max_length', 2048)
    top_p = json_post_list.get('top_p', 0.7)
    temperature = json_post_list.get('temperature', 0.95)
    top_k = json_post_list.get('top_k', 0)
    response, history = model.chat(tokenizer,
                                   prompt,
                                   history=history,
                                   max_length=max_length,
                                   top_p=top_p,
                                   temperature=temperature)
    now = datetime.datetime.now()
    time = now.strftime("%Y-%m-%d %H:%M:%S")
    answer = {
        "response": response,
        "history": history,
        "status": 200,
        "time": time
    }
    torch_gc(model.device)
    return answer


def api_start(config):
    model_path = "LanguageModels/ChatGLM2_6B/"
    global tokenizer
    global model
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = device(config, model_path)
    uvicorn.run(app, host="0.0.0.0", port=7862, workers=1)

Baichuan2_13B

py 复制代码
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.utils import GenerationConfig
import uvicorn, json, datetime
import torch

def torch_gc(mydevice):
    if torch.cuda.is_available():
        with torch.cuda.device(mydevice):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

app = FastAPI()
def device(config, model_path):
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
    model.generation_config = GenerationConfig.from_pretrained(model_path)
    return model.eval()

@app.post("/")
async def create_item(request: Request):
    global model, tokenizer
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    prompt = json_post_list.get('prompt')
    messages = []
    messages.append({"role": "user", "content": prompt})
    response = model.chat(tokenizer, messages)
    now = datetime.datetime.now()
    time = now.strftime("%Y-%m-%d %H:%M:%S")
    answer = {
        "response": response,
        "status": 200,
        "time": time
    }
    torch_gc(model.device)
    return answer


def api_start(config):
    model_path = "LanguageModels/Baichuan2_13B_Chat/"
    global tokenizer
    global model
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
    model = device(config, model_path)
    uvicorn.run(app, host="0.0.0.0", port=7863, workers=1)

sqlcoder

py 复制代码
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.utils import GenerationConfig
import uvicorn, json, datetime
import torch

def torch_gc(mydevice):
    if torch.cuda.is_available():
        with torch.cuda.device(mydevice):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

app = FastAPI()
def device(config, model_path):
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", load_in_8bit=True, use_cache=True, trust_remote_code=True)
    return model.eval()

@app.post("/")
async def create_item(request: Request):
    global model, tokenizer
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    prompt = json_post_list.get('prompt')
    eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=eos_token_id,
        pad_token_id=eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=5
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    response = outputs[0].split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";"
    now = datetime.datetime.now()
    time = now.strftime("%Y-%m-%d %H:%M:%S")
    answer = {
        "response": response,
        "status": 200,
        "time": time
    }
    torch_gc(model.device)
    return answer


def api_start(config):
    model_path = "CodeModels/sqlcoder/"
    global tokenizer
    global model
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = device(config, model_path)
    uvicorn.run(app, host="0.0.0.0", port=7864, workers=1)

开启后测试

py 复制代码
curl -X POST "http://127.0.0.1:7864 -H 'Content-Type: application/json' -d '{"prompt": "你的名字是"}'
相关推荐
dundunmm23 分钟前
机器学习之scikit-learn(简称 sklearn)
python·算法·机器学习·scikit-learn·sklearn·分类算法
古希腊掌管学习的神23 分钟前
[机器学习]sklearn入门指南(1)
人工智能·python·算法·机器学习·sklearn
一道微光37 分钟前
Mac的M2芯片运行lightgbm报错,其他python包可用,x86_x64架构运行
开发语言·python·macos
四口鲸鱼爱吃盐1 小时前
Pytorch | 利用AI-FGTM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
是娜个二叉树!1 小时前
图像处理基础 | 格式转换.rgb转.jpg 灰度图 python
开发语言·python
互联网杂货铺1 小时前
Postman接口测试:全局变量/接口关联/加密/解密
自动化测试·软件测试·python·测试工具·职场和发展·测试用例·postman
waicsdn_haha1 小时前
Postman最新详细安装及使用教程【附安装包】
测试工具·api·压力测试·postman·策略模式·get·delete
南七澄江3 小时前
各种网站(学习资源及其他)
开发语言·网络·python·深度学习·学习·机器学习·ai
无泡汽水4 小时前
漏洞检测工具:Swagger UI敏感信息泄露
python·web安全