from transformers import Cache
class OptimizedInference:
"""KV缓存优化推理"""
def __init__(self, model):
self.model = model
self.cache = Cache()
def generate_with_cache(self, input_ids, max_new_tokens=100):
"""使用KV缓存生成"""
past_key_values = None
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = self.model(
input_ids,
past_key_values=past_key_values,
use_cache=True
)
# 更新缓存
past_key_values = outputs.past_key_values
# 获取下一个token
next_token = outputs.logits[:, -1, :].argmax(dim=-1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
return input_ids
2. 结果缓存
python复制代码
import hashlib
from functools import lru_cache
import redis
import json
class LLMCache:
"""LLM结果缓存"""
def __init__(self, redis_host="localhost"):
self.redis_client = redis.Redis(host=redis_host)
self.local_cache = {}
def get_cache_key(self, prompt: str, params: dict) -> str:
"""生成缓存键"""
content = f"{prompt}{json.dumps(params, sort_keys=True)}"
return hashlib.md5(content.encode()).hexdigest()
def cached_generate(self, prompt: str, **params):
"""带缓存的生成"""
cache_key = self.get_cache_key(prompt, params)
# 1. 本地缓存
if cache_key in self.local_cache:
return self.local_cache[cache_key]
# 2. Redis缓存
cached = self.redis_client.get(cache_key)
if cached:
result = json.loads(cached)
self.local_cache[cache_key] = result
return result
# 3. 生成新结果
result = self.llm.generate(prompt, **params)
# 4. 保存缓存
self.redis_client.setex(
cache_key,
3600, # TTL: 1小时
json.dumps(result)
)
self.local_cache[cache_key] = result
return result
3. Embedding缓存
python复制代码
class EmbeddingCache:
"""向量缓存"""
def __init__(self):
self.cache = {}
@lru_cache(maxsize=10000)
def get_embedding(self, text: str):
"""缓存文本向量"""
if text not in self.cache:
self.cache[text] = self.model.encode(text)
return self.cache[text]
def batch_encode(self, texts: list):
"""批量编码优化"""
uncached = []
results = []
for text in texts:
if text in self.cache:
results.append(self.cache[text])
else:
uncached.append(text)
results.append(None)
# 批量计算未缓存的
if uncached:
new_embeddings = self.model.encode(uncached)
j = 0
for i, result in enumerate(results):
if result is None:
results[i] = new_embeddings[j]
self.cache[texts[i]] = new_embeddings[j]
j += 1
return results
并发优化
1. 异步处理
python复制代码
import asyncio
from typing import List
import aiohttp
class AsyncLLM:
"""异步LLM客户端"""
def __init__(self, api_key: str, max_concurrent=10):
self.api_key = api_key
self.semaphore = asyncio.Semaphore(max_concurrent)
async def generate_single(self, session, prompt):
"""单个异步请求"""
async with self.semaphore:
headers = {"Authorization": f"Bearer {self.api_key}"}
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": prompt}]
}
async with session.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=data
) as response:
result = await response.json()
return result["choices"][0]["message"]["content"]
async def batch_generate(self, prompts: List[str]):
"""批量异步生成"""
async with aiohttp.ClientSession() as session:
tasks = [
self.generate_single(session, prompt)
for prompt in prompts
]
return await asyncio.gather(*tasks)
# 使用
async def main():
llm = AsyncLLM(api_key="sk-xxx")
prompts = ["Hello", "What is AI?", "How to code?"]
results = await llm.batch_generate(prompts)
# 3个请求并发,总时间 = max(单个请求时间)
2. 批处理优化
python复制代码
class BatchProcessor:
"""批处理优化器"""
def __init__(self, model, batch_size=8, timeout=0.1):
self.model = model
self.batch_size = batch_size
self.timeout = timeout
self.queue = []
self.results = {}
async def add_request(self, request_id: str, prompt: str):
"""添加请求到队列"""
future = asyncio.Future()
self.queue.append({
"id": request_id,
"prompt": prompt,
"future": future
})
# 触发批处理
if len(self.queue) >= self.batch_size:
await self._process_batch()
return await future
async def _process_batch(self):
"""处理批次"""
if not self.queue:
return
batch = self.queue[:self.batch_size]
self.queue = self.queue[self.batch_size:]
# 批量推理
prompts = [item["prompt"] for item in batch]
results = self.model.generate(prompts, max_length=100)
# 分发结果
for item, result in zip(batch, results):
item["future"].set_result(result)
流式优化
Server-Sent Events (SSE)
python复制代码
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import json
app = FastAPI()
@app.get("/stream")
async def stream_llm(prompt: str):
"""流式API端点"""
async def generate():
# 流式生成
for chunk in llm.generate_stream(prompt):
# SSE格式
data = json.dumps({"text": chunk})
yield f"data: {data}\n\n"
# 结束标记
yield "data: [DONE]\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream"
)
# 客户端
async def stream_client():
async with aiohttp.ClientSession() as session:
async with session.get(
"http://localhost:8000/stream",
params={"prompt": "Tell me a story"}
) as response:
async for line in response.content:
if line.startswith(b"data: "):
data = json.loads(line[6:])
print(data["text"], end="")
部署优化
1. 模型预热
python复制代码
class ModelWarmer:
"""模型预热器"""
def warmup(self, model, num_iterations=3):
"""预热模型"""
dummy_input = torch.randint(0, 1000, (1, 10))
for _ in range(num_iterations):
with torch.no_grad():
_ = model(dummy_input)
torch.cuda.synchronize() # 等待GPU完成
print("模型预热完成")
2. 负载均衡
python复制代码
class LoadBalancer:
"""简单负载均衡器"""
def __init__(self, servers: List[str]):
self.servers = servers
self.current = 0
def get_server(self) -> str:
"""轮询选择服务器"""
server = self.servers[self.current]
self.current = (self.current + 1) % len(self.servers)
return server
async def request(self, prompt: str):
"""负载均衡请求"""
server = self.get_server()
async with aiohttp.ClientSession() as session:
async with session.post(
f"{server}/generate",
json={"prompt": prompt}
) as response:
return await response.json()