python
复制代码
# fastapi_demo.py(运行开启服务)
from fastapi import FastAPI, Request, HTTPException
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import uvicorn
import json
import datetime
import torch
import logging
print(f"CUDA 是否可用: {torch.cuda.is_available()}")
print(f"当前 CUDA 版本: {torch.version.cuda}")
print(f"当前可用 CUDA 设备数量: {torch.cuda.device_count()}")
# 设置设备参数
DEVICE = "cuda" # 使用CUDA
DEVICE_ID = "0" # CUDA设备ID,如果未设置则为空
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE # 组合CUDA设备信息
# 清理GPU内存函数
def torch_gc():
if torch.cuda.is_available(): # 检查是否可用CUDA
with torch.cuda.device(CUDA_DEVICE): # 指定CUDA设备
torch.cuda.empty_cache() # 清空CUDA缓存
torch.cuda.ipc_collect() # 收集CUDA内存碎片
# 构建 chat 模版
def bulid_input(prompt, history=[], system_message=None):
system_format = 'system\n\n{content}\n'
user_format = 'user\n\n{content}\n'
assistant_format = 'assistant\n\n{content}\n'
prompt_str = ''
# 添加system消息
if system_message:
prompt_str += system_format.format(content=system_message)
# 拼接历史对话
for item in history:
if item['role'] == 'user':
prompt_str += user_format.format(content=item['content'])
else:
prompt_str += assistant_format.format(content=item['content'])
# 添加当前用户输入
prompt_str += user_format.format(content=prompt)
return prompt_str
# 创建FastAPI应用
app = FastAPI()
# 添加GET请求处理
@app.get("/")
async def read_root():
return {"message": "Welcome to the API. Please use POST method to interact with the model."}
@app.get('/favicon.ico')
async def favicon():
return {'status': 'ok'}
# 处理POST请求的端点
@app.post("/")
async def create_item(request: Request):
try:
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
if not prompt:
raise HTTPException(status_code=400, detail="提示词不能为空")
history = json_post_list.get('history', [])
system_message = json_post_list.get('system_message')
# 添加输入验证的日志
logging.info(f"收到请求: prompt={prompt}, history={history}, system_message={system_message}")
input_str = bulid_input(prompt=prompt, history=history, system_message=system_message)
try:
input_ids = process_input(input_str).to(CUDA_DEVICE)
except Exception as e:
logging.error(f"Tokenizer 错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"Tokenizer 处理失败: {str(e)}")
try:
generated_ids = model.generate(
input_ids=input_ids, max_new_tokens=1024, do_sample=True,
top_p=0.5, temperature=0.95, repetition_penalty=1.1
)
except Exception as e:
logging.error(f"模型生成错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"模型生成失败: {str(e)}")
outputs = generated_ids.tolist()[0][len(input_ids[0]):]
response = tokenizer.decode(outputs)
response = response.strip().replace('assistant\n\n', '').strip() # 解析 chat 模版
now = datetime.datetime.now() # 获取当前时间
time = now.strftime("%Y-%m-%d %H:%M:%S") # 格式化时间为字符串
# 构建响应JSON
answer = {
"response": response,
"status": 200,
"time": time
}
# 构建日志信息
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log) # 打印日志
torch_gc() # 执行GPU内存清理
return answer # 返回响应
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="无效的 JSON 格式")
except Exception as e:
logging.error(f"处理请求时发生错误: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# 主函数入口
if __name__ == '__main__':
# 首先检查可用的GPU数量
gpu_count = torch.cuda.device_count()
if int(DEVICE_ID) >= gpu_count:
raise ValueError(f"指定的DEVICE_ID ({DEVICE_ID}) 无效。系统只有 {gpu_count} 个GPU设备(0-{gpu_count-1})")
# 设置当前CUDA设备
torch.cuda.set_device(int(DEVICE_ID))
model_name_or_path = '/data/user23262833/MemoryStrategy/ChatGLM-Finetuning/chatglm3-6b(需要填写你的模型位置所在路径)'
# 修改 tokenizer 初始化
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
use_fast=False,
trust_remote_code=True,
padding_side='left' # 直接在初始化时设置
)
# 更简单的 process_input 实现
def process_input(text):
inputs = tokenizer.encode(text, return_tensors='pt')
return inputs if torch.is_tensor(inputs) else torch.tensor([inputs])
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map={"": int(DEVICE_ID)}, # 明确指定设备映射
torch_dtype=torch.float16
)
# 启动FastAPI应用
# 用6006端口可以将autodl的端口映射到本地,从而在本地使用api
uvicorn.run(app, host='需填写你的本地或者服务器ip', port=6006, workers=1) # 在指定端口和主机上启动应用