基于FastAPI实现本地大模型API封装调用

  • 关于FastAPI
    • FastAPI 是一个现代、快速(高性能)的 Python Web 框架,用于构建基于标准 Python 类型提示的 API。它以简洁、直观和高效的方式提供工具,特别适合开发现代 web 服务和后端应用程序。
  • 问题:_pad() got an unexpected keyword argument 'padding_side'
    • 解决:降级 transformers,pip install transformers==4.34.0,同时更改相关包版本以实现适配,pip install accelerate==0.25.0,pip install huggingface_hub==0.16.4
  • 问题:报错500
    • 服务器防火墙问题,只能在指定端口访问
    • post请求的参数通过request body传递,需要以 application/json 的方式 ,请求body
      • 以postman测试为例:Body中选择"raw",则对应的Headers中的"Content-Type""application/json",参数形式是{"content":"有什么推荐的咖啡吗"}
  • 代码实现
    • fastapi_demo.py(运行开启服务)
    • post.py(服务测试)
python 复制代码
# fastapi_demo.py(运行开启服务)
from fastapi import FastAPI, Request, HTTPException
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import uvicorn
import json
import datetime
import torch
import logging

print(f"CUDA 是否可用: {torch.cuda.is_available()}")
print(f"当前 CUDA 版本: {torch.version.cuda}")
print(f"当前可用 CUDA 设备数量: {torch.cuda.device_count()}")
 
# 设置设备参数
DEVICE = "cuda"  # 使用CUDA
DEVICE_ID = "0"  # CUDA设备ID,如果未设置则为空
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE  # 组合CUDA设备信息
 
# 清理GPU内存函数
def torch_gc():
    if torch.cuda.is_available():  # 检查是否可用CUDA
        with torch.cuda.device(CUDA_DEVICE):  # 指定CUDA设备
            torch.cuda.empty_cache()  # 清空CUDA缓存
            torch.cuda.ipc_collect()  # 收集CUDA内存碎片
 
# 构建 chat 模版
def bulid_input(prompt, history=[], system_message=None):
    system_format = 'system\n\n{content}\n'
    user_format = 'user\n\n{content}\n'
    assistant_format = 'assistant\n\n{content}\n'
 
    prompt_str = ''
 
    # 添加system消息
    if system_message:
        prompt_str += system_format.format(content=system_message)
 
    # 拼接历史对话
    for item in history:
        if item['role'] == 'user':
            prompt_str += user_format.format(content=item['content'])
        else:
            prompt_str += assistant_format.format(content=item['content'])
 
    # 添加当前用户输入
    prompt_str += user_format.format(content=prompt)
 
    return prompt_str
 
# 创建FastAPI应用
app = FastAPI()
 
# 添加GET请求处理
@app.get("/")
async def read_root():
    return {"message": "Welcome to the API. Please use POST method to interact with the model."}
 
@app.get('/favicon.ico')
async def favicon():
    return {'status': 'ok'}
 
# 处理POST请求的端点
@app.post("/")
async def create_item(request: Request):
    try:
        json_post_raw = await request.json()
        json_post = json.dumps(json_post_raw)
        json_post_list = json.loads(json_post)
        prompt = json_post_list.get('prompt')
        
        if not prompt:
            raise HTTPException(status_code=400, detail="提示词不能为空")

        history = json_post_list.get('history', [])
        system_message = json_post_list.get('system_message')

        # 添加输入验证的日志
        logging.info(f"收到请求: prompt={prompt}, history={history}, system_message={system_message}")

        input_str = bulid_input(prompt=prompt, history=history, system_message=system_message)
        try:
            input_ids = process_input(input_str).to(CUDA_DEVICE)
        except Exception as e:
            logging.error(f"Tokenizer 错误: {str(e)}")
            raise HTTPException(status_code=500, detail=f"Tokenizer 处理失败: {str(e)}")

        try:
            generated_ids = model.generate(
                input_ids=input_ids, max_new_tokens=1024, do_sample=True,
                top_p=0.5, temperature=0.95, repetition_penalty=1.1
            )
        except Exception as e:
            logging.error(f"模型生成错误: {str(e)}")
            raise HTTPException(status_code=500, detail=f"模型生成失败: {str(e)}")

        outputs = generated_ids.tolist()[0][len(input_ids[0]):]
        response = tokenizer.decode(outputs)
        response = response.strip().replace('assistant\n\n', '').strip()  # 解析 chat 模版
 
        now = datetime.datetime.now()  # 获取当前时间
        time = now.strftime("%Y-%m-%d %H:%M:%S")  # 格式化时间为字符串
        # 构建响应JSON
        answer = {
            "response": response,
            "status": 200,
            "time": time
        }
        # 构建日志信息
        log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
        print(log)  # 打印日志
        torch_gc()  # 执行GPU内存清理
        return answer  # 返回响应

    except json.JSONDecodeError:
        raise HTTPException(status_code=400, detail="无效的 JSON 格式")
    except Exception as e:
        logging.error(f"处理请求时发生错误: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))
 
# 主函数入口
if __name__ == '__main__':
    # 首先检查可用的GPU数量
    gpu_count = torch.cuda.device_count()
    if int(DEVICE_ID) >= gpu_count:
        raise ValueError(f"指定的DEVICE_ID ({DEVICE_ID}) 无效。系统只有 {gpu_count} 个GPU设备(0-{gpu_count-1})")
    
    # 设置当前CUDA设备
    torch.cuda.set_device(int(DEVICE_ID))
    
    model_name_or_path = '/data/user23262833/MemoryStrategy/ChatGLM-Finetuning/chatglm3-6b(需要填写你的模型位置所在路径)'
    
    # 修改 tokenizer 初始化
    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path,
        use_fast=False,
        trust_remote_code=True,
        padding_side='left'  # 直接在初始化时设置
    )
    
    # 更简单的 process_input 实现
    def process_input(text):
        inputs = tokenizer.encode(text, return_tensors='pt')
        return inputs if torch.is_tensor(inputs) else torch.tensor([inputs])
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path, 
        device_map={"": int(DEVICE_ID)},  # 明确指定设备映射
        torch_dtype=torch.float16
    )
 
    # 启动FastAPI应用
    # 用6006端口可以将autodl的端口映射到本地,从而在本地使用api
    uvicorn.run(app, host='需填写你的本地或者服务器ip', port=6006, workers=1)  # 在指定端口和主机上启动应用
python 复制代码
# post.py
import requests
import json
 
def get_completion(prompt):
    try:
        headers = {'Content-Type': 'application/json'}
        data = {"prompt": prompt}
        response = requests.post(url='需填写你的本地或者服务器ip:6006', headers=headers, data=json.dumps(data))
        
        # 检查响应状态码
        response.raise_for_status()
        
        # 添加响应内容的打印,用于调试
        print("Response content:", response.text)
        
        return response.json()['response']
    except requests.exceptions.RequestException as e:
        print(f"请求错误: {e}")
        return None
    except json.JSONDecodeError as e:
        print(f"JSON解析错误: {e}")
        return None
    except KeyError as e:
        print(f"响应中缺少 'response' 键: {e}")
        return None
 
# 测试代码
response = get_completion('你好')
if response is not None:
    print(response)
相关推荐
_extraordinary_20 分钟前
选择排序+快速排序递归版(二)
算法
sp_fyf_202422 分钟前
计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-11-04
人工智能·神经网络·算法·机器学习·语言模型·数据挖掘
Wils0nEdwards32 分钟前
Leetcode 回文数
算法·leetcode·职场和发展
可可可可可人33 分钟前
leetCode——二进制手表
算法·leetcode
兔兔爱学习兔兔爱学习34 分钟前
leetcode219. Contains Duplicate II
javascript·数据结构·算法
Lemon_man_40 分钟前
算法——两两交换链表中的节点(leetcode24)
数据结构·算法·链表
XD7429716361 小时前
sglang 部署Qwen2VL7B,大模型部署,速度测试,深度学习
人工智能·深度学习
hai405871 小时前
ELMo模型介绍:深度理解语言模型的嵌入艺术
人工智能·语言模型·自然语言处理
快乐点吧1 小时前
【深度学习】模型参数冻结:原理、应用与实践
人工智能·深度学习
行码棋1 小时前
【机器学习】SVM原理详解
人工智能·机器学习·支持向量机