模型上下文长度测试工具

模型上下文长度测试工具

本文介绍一个用于探测大语言模型上下文长度的测试脚本。该脚本通过二分法发送递增长度的提示词,自动找出模型能处理的最大上下文长度。

1. 功能概述

本脚本的核心功能包括:

  • 二分探测:在指定范围内通过二分查找,自动定位模型的最大上下文长度
  • 双单位支持:支持以 token 近似填充或严格字符数两种方式构造测试提示词
  • 兼容性强:适配任何 OpenAI 兼容接口的 API
  • 智能错误判断:自动识别上下文长度超限错误,区分其他类型的请求失败
  • 结果输出:以 JSON 格式输出每次测试结果,最终汇总最大成功长度

2. 核心函数说明

2.1 Token 估算函数

estimate_tokens 函数用于估算文本的 token 数量。它优先使用 tiktoken 库进行精确估算,若未安装则返回 None,脚本仍可继续运行。

2.2 填充内容生成

make_filler 函数根据指定的单位生成填充文本:

  • tokens 模式:使用 " x" 重复填充,对多数 tokenizer 接近 1 token
  • chars 模式:严格按字符数生成

2.3 消息构造

build_messages 函数构造测试用的消息列表,包含 system 和 user 两条消息,其中 user 消息内嵌填充内容用于测试。

2.4 API 调用

call_chat_completions 函数负责发送 HTTP 请求到 OpenAI 兼容接口,并记录请求延迟。

2.5 错误判断

is_context_length_error 函数通过关键词匹配判断失败是否由上下文过长导致,支持多种服务商的错误文本格式。

2.6 测试执行

test_size 函数封装单次测试流程,包括消息构建、API 调用、结果记录和错误判断。

3. 命令行参数

脚本支持丰富的命令行参数,方便灵活配置:

参数 说明
--model 模型名(必填)
--base-url API 地址,默认读取环境变量 BASE_URLhttps://api.openai.com/v1
--api-key API 密钥,默认读取 API_KEYOPENAI_API_KEY
--unit 测试单位:tokens(近似 token)或 chars(严格字符数)
--low 二分下界,默认 1024
--high 二分上界,默认 200000
--output-tokens 输出 token 数,默认 1
--timeout 单次请求超时秒数,默认 120
--max-token-param 输出长度参数名:max_tokensmax_completion_tokens
--ignore-non-context-errors 将非上下文错误也视为失败继续二分

4. 使用示例

基本用法:

bash 复制代码
python context_probe.py --model gpt-4 --api-key sk-xxx

指定自定义 API 地址和范围:

bash 复制代码
python context_probe.py --model claude-3 --base-url https://api.anthropic.com/v1 --low 1000 --high 100000

使用字符数测试并忽略非上下文错误:

bash 复制代码
python context_probe.py --model gemini-pro --unit chars --ignore-non-context-errors

5. 完整源代码

python 复制代码
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import json
import os
import time
import urllib.error
import urllib.request


def estimate_tokens(text: str, model: str | None = None) -> int | None:
    """
    尽量估算 token 数。
    如果没装 tiktoken,就返回 None,脚本仍然可以按长度测试。
    """
    try:
        import tiktoken

        try:
            enc = tiktoken.encoding_for_model(model or "")
        except Exception:
            enc = tiktoken.get_encoding("cl100k_base")

        return len(enc.encode(text))
    except Exception:
        return None


def make_filler(size: int, unit: str) -> str:
    """
    unit=tokens 时,使用 ' x' 重复。对很多 tokenizer 来说大致接近 1 token。
    unit=chars 时,严格按字符数生成。
    """
    if unit == "chars":
        return "x" * size

    # 近似 token 填充,不保证所有模型 tokenizer 都是 1:1
    return " x" * size


def build_messages(size: int, unit: str):
    filler = make_filler(size, unit)

    return [
        {
            "role": "system",
            "content": "You are a context length test model. Reply exactly: OK",
        },
        {
            "role": "user",
            "content": (
                "这是一个纯上下文长度测试。不需要理解内容,不需要检索信息。"
                "请只回复 OK。\n\n"
                "<BEGIN_LENGTH_TEST>\n"
                f"{filler}\n"
                "<END_LENGTH_TEST>"
            ),
        },
    ]


def call_chat_completions(
    base_url: str,
    api_key: str,
    model: str,
    messages: list[dict],
    output_tokens: int,
    timeout: int,
    max_token_param: str,
):
    url = base_url.rstrip("/") + "/chat/completions"

    payload = {
        "model": model,
        "messages": messages,
        "temperature": 0,
        max_token_param: output_tokens,
    }

    data = json.dumps(payload).encode("utf-8")

    req = urllib.request.Request(
        url=url,
        data=data,
        method="POST",
        headers={
            "Content-Type": "application/json",
            "Authorization": f"Bearer {api_key}",
        },
    )

    start = time.time()

    try:
        with urllib.request.urlopen(req, timeout=timeout) as resp:
            body = resp.read().decode("utf-8", errors="replace")
            latency = time.time() - start
            return True, resp.status, body, latency
    except urllib.error.HTTPError as e:
        latency = time.time() - start
        body = e.read().decode("utf-8", errors="replace")
        return False, e.code, body, latency
    except Exception as e:
        latency = time.time() - start
        return False, -1, repr(e), latency


def is_context_length_error(status: int, body: str) -> bool:
    """
    判断失败是不是因为上下文太长。
    不同服务商报错文本不一样,所以这里做宽松匹配。
    """
    text = body.lower()

    keywords = [
        "context length",
        "maximum context",
        "max context",
        "too many tokens",
        "token limit",
        "tokens exceed",
        "exceeds the limit",
        "request too large",
        "payload too large",
        "413",
    ]

    if status == 413:
        return True

    return any(k in text for k in keywords)


def test_size(args, size: int):
    messages = build_messages(size, args.unit)
    full_prompt = "\n".join(m["content"] for m in messages)
    est = estimate_tokens(full_prompt, args.model)

    ok, status, body, latency = call_chat_completions(
        base_url=args.base_url,
        api_key=args.api_key,
        model=args.model,
        messages=messages,
        output_tokens=args.output_tokens,
        timeout=args.timeout,
        max_token_param=args.max_token_param,
    )

    result = {
        "size": size,
        "unit": args.unit,
        "chars": len(full_prompt),
        "estimated_tokens": est,
        "ok": ok,
        "status": status,
        "latency_sec": round(latency, 2),
    }

    if ok:
        return True, result

    context_err = is_context_length_error(status, body)
    result["context_length_error"] = context_err
    result["error_preview"] = body[:500].replace("\n", " ")

    if not context_err and not args.ignore_non_context_errors:
        print(json.dumps(result, ensure_ascii=False, indent=2))
        raise RuntimeError(
            "请求失败,但不像是上下文长度错误。"
            "如果你想把所有失败都当作长度失败处理,加 --ignore-non-context-errors。"
        )

    return False, result


def main():
    parser = argparse.ArgumentParser(
        description="Probe model context length by sending increasingly long prompts. No needle test."
    )

    parser.add_argument("--model", required=True, help="模型名")
    parser.add_argument(
        "--base-url",
        default=os.getenv("BASE_URL", "https://api.openai.com/v1"),
        help="OpenAI-compatible API base URL,默认读取 BASE_URL 或 https://api.openai.com/v1",
    )
    parser.add_argument(
        "--api-key",
        default=os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY"),
        help="API key,默认读取 API_KEY 或 OPENAI_API_KEY",
    )
    parser.add_argument(
        "--unit",
        choices=["tokens", "chars"],
        default="tokens",
        help="测试单位。tokens 是近似 token 填充;chars 是严格字符数。",
    )
    parser.add_argument("--low", type=int, default=1024, help="二分下界")
    parser.add_argument("--high", type=int, default=200000, help="二分上界")
    parser.add_argument("--output-tokens", type=int, default=1, help="输出 token 数,默认 1")
    parser.add_argument("--timeout", type=int, default=120, help="单次请求超时秒数")
    parser.add_argument(
        "--max-token-param",
        default="max_tokens",
        choices=["max_tokens", "max_completion_tokens"],
        help="不同接口使用的输出长度参数名",
    )
    parser.add_argument(
        "--ignore-non-context-errors",
        action="store_true",
        help="把非上下文错误也当作失败继续二分",
    )

    args = parser.parse_args()

    if not args.api_key:
        raise SystemExit("缺少 API key。请设置 API_KEY / OPENAI_API_KEY,或传 --api-key。")

    print("配置:")
    print(f"  model: {args.model}")
    print(f"  base_url: {args.base_url}")
    print(f"  unit: {args.unit}")
    print(f"  low: {args.low}")
    print(f"  high: {args.high}")
    print()

    best = None
    left, right = args.low, args.high

    while left <= right:
        mid = (left + right) // 2

        ok, result = test_size(args, mid)

        status = "PASS" if ok else "FAIL"
        print(json.dumps({"status": status, **result}, ensure_ascii=False))

        if ok:
            best = result
            left = mid + 1
        else:
            right = mid - 1

    print("\n最终结果:")
    if best is None:
        print("没有找到任何成功长度。请检查模型名、base-url、API key 或把 --low 调小。")
    else:
        print(json.dumps(best, ensure_ascii=False, indent=2))
        print(
            f"\n最大成功长度约为:{best['size']} {best['unit']},"
            f"字符数 {best['chars']},"
            f"估算 token 数 {best['estimated_tokens']}。"
        )


if __name__ == "__main__":
    main()
相关推荐
ZzT2 小时前
各大 AI 的系统提示词被扒光了,我从里面学到了写指令的功夫
ai编程·claude
武子康3 小时前
调查研究-212 智谱 ZCode Harness for GLM-5.2:国产 Coding Agent 从“模型能力“走向“工程执行环境“
大数据·人工智能·深度学习·llm·claude·glm·智谱
AlfredZhao5 小时前
AI 编程变更记录:知识加工模块与博客工厂模块的状态重新定义
codex·ai_coding
老程序猿5 小时前
一个撇号里,藏得下 3 个 bit——system prompt 隐写手法拆解
ai编程·claude
L3S6 小时前
你的 Agent 为什么总失忆?—— Memory 设计从入门到 Claude Code
agent·claude
Awu12271 天前
⚡从零开发 Agent CLI(五)实现一个可治理、可扩展的工具系统
前端·人工智能·claude
wok1571 天前
Claude Code 自动更新权限问题解决
claude
Java陈序员1 天前
一站式本地监控!一款开源的 Token 用量监控分析工具!
ai编程·claude·cursor
小碗细面1 天前
让 AI Agent 真正读懂你的资料:我开源了 source-skill-pipeline
aigc·ai编程·claude
uccs1 天前
AI Agent 系统的容错设计实践
agent·ai编程·claude