模型上下文长度测试工具
本文介绍一个用于探测大语言模型上下文长度的测试脚本。该脚本通过二分法发送递增长度的提示词,自动找出模型能处理的最大上下文长度。
1. 功能概述
本脚本的核心功能包括:
- 二分探测:在指定范围内通过二分查找,自动定位模型的最大上下文长度
- 双单位支持:支持以 token 近似填充或严格字符数两种方式构造测试提示词
- 兼容性强:适配任何 OpenAI 兼容接口的 API
- 智能错误判断:自动识别上下文长度超限错误,区分其他类型的请求失败
- 结果输出:以 JSON 格式输出每次测试结果,最终汇总最大成功长度
2. 核心函数说明
2.1 Token 估算函数
estimate_tokens 函数用于估算文本的 token 数量。它优先使用 tiktoken 库进行精确估算,若未安装则返回 None,脚本仍可继续运行。
2.2 填充内容生成
make_filler 函数根据指定的单位生成填充文本:
tokens模式:使用" x"重复填充,对多数 tokenizer 接近 1 tokenchars模式:严格按字符数生成
2.3 消息构造
build_messages 函数构造测试用的消息列表,包含 system 和 user 两条消息,其中 user 消息内嵌填充内容用于测试。
2.4 API 调用
call_chat_completions 函数负责发送 HTTP 请求到 OpenAI 兼容接口,并记录请求延迟。
2.5 错误判断
is_context_length_error 函数通过关键词匹配判断失败是否由上下文过长导致,支持多种服务商的错误文本格式。
2.6 测试执行
test_size 函数封装单次测试流程,包括消息构建、API 调用、结果记录和错误判断。
3. 命令行参数
脚本支持丰富的命令行参数,方便灵活配置:
| 参数 | 说明 |
|---|---|
--model |
模型名(必填) |
--base-url |
API 地址,默认读取环境变量 BASE_URL 或 https://api.openai.com/v1 |
--api-key |
API 密钥,默认读取 API_KEY 或 OPENAI_API_KEY |
--unit |
测试单位:tokens(近似 token)或 chars(严格字符数) |
--low |
二分下界,默认 1024 |
--high |
二分上界,默认 200000 |
--output-tokens |
输出 token 数,默认 1 |
--timeout |
单次请求超时秒数,默认 120 |
--max-token-param |
输出长度参数名:max_tokens 或 max_completion_tokens |
--ignore-non-context-errors |
将非上下文错误也视为失败继续二分 |
4. 使用示例
基本用法:
bash
python context_probe.py --model gpt-4 --api-key sk-xxx
指定自定义 API 地址和范围:
bash
python context_probe.py --model claude-3 --base-url https://api.anthropic.com/v1 --low 1000 --high 100000
使用字符数测试并忽略非上下文错误:
bash
python context_probe.py --model gemini-pro --unit chars --ignore-non-context-errors
5. 完整源代码
python
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
import time
import urllib.error
import urllib.request
def estimate_tokens(text: str, model: str | None = None) -> int | None:
"""
尽量估算 token 数。
如果没装 tiktoken,就返回 None,脚本仍然可以按长度测试。
"""
try:
import tiktoken
try:
enc = tiktoken.encoding_for_model(model or "")
except Exception:
enc = tiktoken.get_encoding("cl100k_base")
return len(enc.encode(text))
except Exception:
return None
def make_filler(size: int, unit: str) -> str:
"""
unit=tokens 时,使用 ' x' 重复。对很多 tokenizer 来说大致接近 1 token。
unit=chars 时,严格按字符数生成。
"""
if unit == "chars":
return "x" * size
# 近似 token 填充,不保证所有模型 tokenizer 都是 1:1
return " x" * size
def build_messages(size: int, unit: str):
filler = make_filler(size, unit)
return [
{
"role": "system",
"content": "You are a context length test model. Reply exactly: OK",
},
{
"role": "user",
"content": (
"这是一个纯上下文长度测试。不需要理解内容,不需要检索信息。"
"请只回复 OK。\n\n"
"<BEGIN_LENGTH_TEST>\n"
f"{filler}\n"
"<END_LENGTH_TEST>"
),
},
]
def call_chat_completions(
base_url: str,
api_key: str,
model: str,
messages: list[dict],
output_tokens: int,
timeout: int,
max_token_param: str,
):
url = base_url.rstrip("/") + "/chat/completions"
payload = {
"model": model,
"messages": messages,
"temperature": 0,
max_token_param: output_tokens,
}
data = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(
url=url,
data=data,
method="POST",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
},
)
start = time.time()
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
body = resp.read().decode("utf-8", errors="replace")
latency = time.time() - start
return True, resp.status, body, latency
except urllib.error.HTTPError as e:
latency = time.time() - start
body = e.read().decode("utf-8", errors="replace")
return False, e.code, body, latency
except Exception as e:
latency = time.time() - start
return False, -1, repr(e), latency
def is_context_length_error(status: int, body: str) -> bool:
"""
判断失败是不是因为上下文太长。
不同服务商报错文本不一样,所以这里做宽松匹配。
"""
text = body.lower()
keywords = [
"context length",
"maximum context",
"max context",
"too many tokens",
"token limit",
"tokens exceed",
"exceeds the limit",
"request too large",
"payload too large",
"413",
]
if status == 413:
return True
return any(k in text for k in keywords)
def test_size(args, size: int):
messages = build_messages(size, args.unit)
full_prompt = "\n".join(m["content"] for m in messages)
est = estimate_tokens(full_prompt, args.model)
ok, status, body, latency = call_chat_completions(
base_url=args.base_url,
api_key=args.api_key,
model=args.model,
messages=messages,
output_tokens=args.output_tokens,
timeout=args.timeout,
max_token_param=args.max_token_param,
)
result = {
"size": size,
"unit": args.unit,
"chars": len(full_prompt),
"estimated_tokens": est,
"ok": ok,
"status": status,
"latency_sec": round(latency, 2),
}
if ok:
return True, result
context_err = is_context_length_error(status, body)
result["context_length_error"] = context_err
result["error_preview"] = body[:500].replace("\n", " ")
if not context_err and not args.ignore_non_context_errors:
print(json.dumps(result, ensure_ascii=False, indent=2))
raise RuntimeError(
"请求失败,但不像是上下文长度错误。"
"如果你想把所有失败都当作长度失败处理,加 --ignore-non-context-errors。"
)
return False, result
def main():
parser = argparse.ArgumentParser(
description="Probe model context length by sending increasingly long prompts. No needle test."
)
parser.add_argument("--model", required=True, help="模型名")
parser.add_argument(
"--base-url",
default=os.getenv("BASE_URL", "https://api.openai.com/v1"),
help="OpenAI-compatible API base URL,默认读取 BASE_URL 或 https://api.openai.com/v1",
)
parser.add_argument(
"--api-key",
default=os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY"),
help="API key,默认读取 API_KEY 或 OPENAI_API_KEY",
)
parser.add_argument(
"--unit",
choices=["tokens", "chars"],
default="tokens",
help="测试单位。tokens 是近似 token 填充;chars 是严格字符数。",
)
parser.add_argument("--low", type=int, default=1024, help="二分下界")
parser.add_argument("--high", type=int, default=200000, help="二分上界")
parser.add_argument("--output-tokens", type=int, default=1, help="输出 token 数,默认 1")
parser.add_argument("--timeout", type=int, default=120, help="单次请求超时秒数")
parser.add_argument(
"--max-token-param",
default="max_tokens",
choices=["max_tokens", "max_completion_tokens"],
help="不同接口使用的输出长度参数名",
)
parser.add_argument(
"--ignore-non-context-errors",
action="store_true",
help="把非上下文错误也当作失败继续二分",
)
args = parser.parse_args()
if not args.api_key:
raise SystemExit("缺少 API key。请设置 API_KEY / OPENAI_API_KEY,或传 --api-key。")
print("配置:")
print(f" model: {args.model}")
print(f" base_url: {args.base_url}")
print(f" unit: {args.unit}")
print(f" low: {args.low}")
print(f" high: {args.high}")
print()
best = None
left, right = args.low, args.high
while left <= right:
mid = (left + right) // 2
ok, result = test_size(args, mid)
status = "PASS" if ok else "FAIL"
print(json.dumps({"status": status, **result}, ensure_ascii=False))
if ok:
best = result
left = mid + 1
else:
right = mid - 1
print("\n最终结果:")
if best is None:
print("没有找到任何成功长度。请检查模型名、base-url、API key 或把 --low 调小。")
else:
print(json.dumps(best, ensure_ascii=False, indent=2))
print(
f"\n最大成功长度约为:{best['size']} {best['unit']},"
f"字符数 {best['chars']},"
f"估算 token 数 {best['estimated_tokens']}。"
)
if __name__ == "__main__":
main()