HLE测评LLM

下面是一版单文件、可直接跑的 Python 测评代码 。它兼容 OpenAI-compatible API,支持:

  • 拉取 cais/hle
  • 批量请求模型
  • 保存预测结果
  • 可选用 judge model 做判分
  • 输出 accuracy 和 calibration error

python 复制代码
# hle_eval_single.py
# -*- coding: utf-8 -*-

import os
import re
import json
import math
import time
import copy
import argparse
import asyncio
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from datasets import load_dataset
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio


SYSTEM_PROMPT = """Your response should be in the following format:
Explanation: {your explanation for your answer choice}
Answer: {your chosen answer}
Confidence: {your confidence score between 0% and 100% for your answer}
"""

JUDGE_PROMPT = """Judge whether the following [response] to [question] is correct or not based on the precise and unambiguous [correct_answer] below.

[question]:
{question}

[response]:
{response}

Your judgement must be a JSON object with the following keys:
- extracted_final_answer: The final exact answer extracted from [response]. Put "None" if there is no exact final answer.
- reasoning: Explain only whether the extracted_final_answer matches [correct_answer]. Do not solve the question again.
- correct: "yes" if extracted_final_answer matches [correct_answer], or is within a small margin of error for numerical problems. Otherwise "no".
- confidence: The extracted confidence score between 0 and 100 from [response]. Put 100 if there is no confidence score available.

[correct_answer]:
{correct_answer}
"""


def normalize_text(s: Optional[str]) -> str:
    if s is None:
        return ""
    s = str(s).strip().lower()
    s = s.replace("%", "%")
    s = re.sub(r"\s+", " ", s)
    return s


def normalize_answer_for_match(s: Optional[str]) -> str:
    s = normalize_text(s)
    s = s.strip(" .,:;!?'\"()[]{}")
    return s


def extract_answer_and_confidence(response_text: str) -> Tuple[str, int]:
    """
    从模型输出中提取:
    - Answer: ...
    - Confidence: ...
    """
    answer = ""
    confidence = 100

    # 提取 Answer
    m_answer = re.search(
        r"(?im)^\s*answer\s*:\s*(.+?)\s*$",
        response_text,
    )
    if m_answer:
        answer = m_answer.group(1).strip()
    else:
        # 兜底:取最后一行
        lines = [x.strip() for x in response_text.splitlines() if x.strip()]
        if lines:
            answer = lines[-1]

    # 提取 Confidence
    m_conf = re.search(
        r"(?im)^\s*confidence\s*:\s*([0-9]{1,3})(?:\s*%|\b)",
        response_text,
    )
    if m_conf:
        confidence = int(m_conf.group(1))
        confidence = max(0, min(100, confidence))
    else:
        # 再兜底搜全文
        m_conf2 = re.search(r"([0-9]{1,3})\s*%", response_text)
        if m_conf2:
            confidence = int(m_conf2.group(1))
            confidence = max(0, min(100, confidence))

    return answer, confidence


def try_mcq_match(pred_answer: str, gold_answer: str, sample: Dict[str, Any]) -> bool:
    """
    对多选题做一个较宽松的本地匹配:
    1) 直接和 gold answer 对比
    2) 如果 pred 是 A/B/C/D,而 gold 是选项全文,则尝试映射
    """
    pa = normalize_answer_for_match(pred_answer)
    ga = normalize_answer_for_match(gold_answer)

    if pa == ga:
        return True

    # 常见 letter 形式
    pa_letter = None
    m = re.match(r"^\(?([a-z])\)?(?:[.:)\- ]|$)", pa)
    if m:
        pa_letter = m.group(1).upper()
    elif len(pa) == 1 and pa.isalpha():
        pa_letter = pa.upper()

    # 尝试读取选项字段
    options = None
    for key in ["options", "choices", "answer_choices", "candidate_answers"]:
        if key in sample and sample[key]:
            options = sample[key]
            break

    if options and pa_letter:
        # 支持 list[str] 或 dict
        if isinstance(options, list):
            idx = ord(pa_letter) - ord("A")
            if 0 <= idx < len(options):
                opt_text = normalize_answer_for_match(str(options[idx]))
                if opt_text == ga:
                    return True
        elif isinstance(options, dict):
            # 例如 {"A": "...", "B": "..."}
            if pa_letter in options:
                opt_text = normalize_answer_for_match(str(options[pa_letter]))
                if opt_text == ga:
                    return True

    return False


def calib_err(confidence: np.ndarray, correct: np.ndarray, p: str = "2", beta: int = 100) -> float:
    """
    与官方 judge 脚本同风格的 calibration error 计算。
    """
    if len(confidence) == 0:
        return 0.0
    idxs = np.argsort(confidence)
    confidence = confidence[idxs]
    correct = correct[idxs]

    if len(confidence) < beta:
        beta = max(1, len(confidence))

    bins = [[i * beta, (i + 1) * beta] for i in range(max(1, len(confidence) // beta))]
    bins[-1] = [bins[-1][0], len(confidence)]

    cerr = 0.0
    total_examples = len(confidence)

    for i in range(len(bins)):
        start, end = bins[i]
        bin_confidence = confidence[start:end]
        bin_correct = correct[start:end]
        num_examples_in_bin = len(bin_confidence)

        if num_examples_in_bin > 0:
            difference = abs(np.nanmean(bin_confidence) - np.nanmean(bin_correct))
            if p == "2":
                cerr += num_examples_in_bin / total_examples * (difference ** 2)
            elif p == "1":
                cerr += num_examples_in_bin / total_examples * difference
            elif p in ("infty", "infinity", "max"):
                cerr = max(cerr, difference)
            else:
                raise ValueError("p must be '1', '2', or 'infty'")

    if p == "2":
        cerr = math.sqrt(cerr)
    return float(cerr)


class HLEEvaluator:
    def __init__(self, args: argparse.Namespace):
        self.args = args
        self.client = AsyncOpenAI(
            api_key=args.api_key or os.getenv("OPENAI_API_KEY"),
            base_url=args.base_url or os.getenv("OPENAI_BASE_URL"),
            timeout=args.timeout,
            max_retries=1,
        )
        self.judge_client = AsyncOpenAI(
            api_key=args.judge_api_key or args.api_key or os.getenv("OPENAI_API_KEY"),
            base_url=args.judge_base_url or args.base_url or os.getenv("OPENAI_BASE_URL"),
            timeout=args.timeout,
            max_retries=1,
        )

    def load_questions(self) -> List[Dict[str, Any]]:
        ds = load_dataset(
            self.args.dataset,
            split=self.args.split,
            token=self.args.hf_token or os.getenv("HF_TOKEN"),
        ).to_dict()
        questions = [dict(zip(ds.keys(), values)) for values in zip(*ds.values())]
        if self.args.max_samples:
            questions = questions[: self.args.max_samples]
        return questions

    def build_messages(self, q: Dict[str, Any]) -> List[Dict[str, Any]]:
        question_text = q["question"]

        text_content = {"type": "text", "text": question_text}
        content = [text_content]

        # 官方数据里 image 字段可能为空字符串
        image_url = q.get("image")
        if image_url:
            content.append({"type": "image_url", "image_url": {"url": image_url}})

        system_role = "user" if "o1" in self.args.model else "system"

        return [
            {"role": system_role, "content": SYSTEM_PROMPT},
            {"role": "user", "content": content},
        ]

    async def call_model_once(self, q: Dict[str, Any]) -> Optional[Tuple[str, Dict[str, Any]]]:
        try:
            response = await self.client.chat.completions.create(
                model=self.args.model,
                messages=self.build_messages(q),
                max_completion_tokens=self.args.max_completion_tokens,
                temperature=self.args.temperature if "o1" not in self.args.model else None,
                stream=False,
            )
            content = response.choices[0].message.content or ""
            usage = {}
            if getattr(response, "usage", None) is not None:
                try:
                    usage = json.loads(response.usage.json())
                except Exception:
                    usage = {
                        "prompt_tokens": getattr(response.usage, "prompt_tokens", None),
                        "completion_tokens": getattr(response.usage, "completion_tokens", None),
                        "total_tokens": getattr(response.usage, "total_tokens", None),
                    }

            pred_answer, pred_conf = extract_answer_and_confidence(content)

            result = {
                "id": q["id"],
                "model": self.args.model,
                "response": content,
                "pred_answer": pred_answer,
                "pred_confidence": pred_conf,
                "usage": usage,
            }
            return q["id"], result
        except Exception as e:
            print(f"[ERROR] model call failed for {q.get('id')}: {e}")
            return None

    async def predict_all(self, questions: List[Dict[str, Any]]) -> Dict[str, Any]:
        save_path = self.args.predictions_out
        if os.path.exists(save_path):
            with open(save_path, "r", encoding="utf-8") as f:
                predictions = json.load(f)
        else:
            predictions = {}

        pending = [q for q in questions if q["id"] not in predictions]

        sem = asyncio.Semaphore(self.args.num_workers)

        async def bound_call(q: Dict[str, Any]):
            async with sem:
                return await self.call_model_once(q)

        tasks = [bound_call(q) for q in pending]
        results = await tqdm_asyncio.gather(*tasks)

        for item in results:
            if item is None:
                continue
            qid, result = item
            predictions[qid] = result

        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(predictions, f, ensure_ascii=False, indent=2)

        return predictions

    async def judge_one_with_llm(self, q: Dict[str, Any], pred: Dict[str, Any]) -> Dict[str, Any]:
        prompt = JUDGE_PROMPT.format(
            question=q["question"],
            response=pred["response"],
            correct_answer=q["answer"],
        )

        try:
            response = await self.judge_client.chat.completions.create(
                model=self.args.judge_model,
                messages=[{"role": "user", "content": prompt}],
                max_completion_tokens=2048,
                temperature=0,
                stream=False,
            )
            text = response.choices[0].message.content or ""

            # 尝试从回复里抓 JSON
            json_text = text.strip()
            m = re.search(r"\{.*\}", text, flags=re.S)
            if m:
                json_text = m.group(0)

            obj = json.loads(json_text)

            correct = str(obj.get("correct", "no")).strip().lower()
            confidence = obj.get("confidence", pred.get("pred_confidence", 100))
            try:
                confidence = int(confidence)
            except Exception:
                confidence = pred.get("pred_confidence", 100)
            confidence = max(0, min(100, confidence))

            return {
                "correct_answer": q["answer"],
                "model_answer": obj.get("extracted_final_answer", pred.get("pred_answer", "")),
                "reasoning": obj.get("reasoning", ""),
                "correct": "yes" if correct == "yes" else "no",
                "confidence": confidence,
            }
        except Exception as e:
            print(f"[ERROR] judge failed for {q.get('id')}: {e}")
            # judge 失败时退回本地 exact match
            local_correct = self.local_judge(q, pred)
            return {
                "correct_answer": q["answer"],
                "model_answer": pred.get("pred_answer", ""),
                "reasoning": "fallback local judge",
                "correct": "yes" if local_correct else "no",
                "confidence": pred.get("pred_confidence", 100),
            }

    def local_judge(self, q: Dict[str, Any], pred: Dict[str, Any]) -> bool:
        gold = q["answer"]
        pa = pred.get("pred_answer", "")

        # 先试 MCQ 规则
        if try_mcq_match(pa, gold, q):
            return True

        # 再试普通 exact/normalized exact match
        return normalize_answer_for_match(pa) == normalize_answer_for_match(gold)

    async def judge_all(self, questions: List[Dict[str, Any]], predictions: Dict[str, Any]) -> Dict[str, Any]:
        save_path = self.args.judged_out
        if os.path.exists(save_path):
            with open(save_path, "r", encoding="utf-8") as f:
                judged = json.load(f)
        else:
            judged = {}

        qmap = {q["id"]: q for q in questions}
        pending_ids = [qid for qid in predictions if qid in qmap and qid not in judged]

        if not self.args.judge_model:
            for qid in pending_ids:
                q = qmap[qid]
                pred = predictions[qid]
                correct = self.local_judge(q, pred)
                judged[qid] = copy.deepcopy(pred)
                judged[qid]["judge_response"] = {
                    "correct_answer": q["answer"],
                    "model_answer": pred.get("pred_answer", ""),
                    "reasoning": "local exact/MCQ match",
                    "correct": "yes" if correct else "no",
                    "confidence": pred.get("pred_confidence", 100),
                }

            with open(save_path, "w", encoding="utf-8") as f:
                json.dump(judged, f, ensure_ascii=False, indent=2)
            return judged

        sem = asyncio.Semaphore(self.args.num_workers)

        async def bound_judge(qid: str):
            async with sem:
                q = qmap[qid]
                pred = predictions[qid]
                jr = await self.judge_one_with_llm(q, pred)
                out = copy.deepcopy(pred)
                out["judge_response"] = jr
                return qid, out

        tasks = [bound_judge(qid) for qid in pending_ids]
        results = await tqdm_asyncio.gather(*tasks)

        for qid, item in results:
            judged[qid] = item

        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(judged, f, ensure_ascii=False, indent=2)

        return judged

    @staticmethod
    def dump_metrics(judged: Dict[str, Any], n_total: int) -> None:
        correct = []
        confidence = []

        for _, item in judged.items():
            jr = item.get("judge_response", {})
            correct.append(1 if jr.get("correct") == "yes" else 0)
            confidence.append(jr.get("confidence", 100) / 100.0)

        correct = np.array(correct, dtype=np.float32)
        confidence = np.array(confidence, dtype=np.float32)

        n_pred = len(correct)
        if n_pred == 0:
            print("No judged predictions found.")
            return

        accuracy = round(100.0 * float(correct.sum()) / float(n_total), 2)
        half_width = round(1.96 * math.sqrt(accuracy * (100 - accuracy) / max(1, n_total)), 2)
        cal_error = round(100.0 * calib_err(confidence, correct, p="2", beta=min(100, max(1, n_pred))), 2)

        print("\n*** Metrics ***")
        print(f"Available judged predictions: {n_pred} / total questions: {n_total}")
        print(f"Accuracy: {accuracy}% +/- {half_width}% | n = {n_total}")
        print(f"Calibration Error: {cal_error}")

    async def run(self) -> None:
        t0 = time.time()

        questions = self.load_questions()
        print(f"Loaded {len(questions)} questions from {self.args.dataset}:{self.args.split}")

        predictions = await self.predict_all(questions)
        print(f"Saved predictions -> {self.args.predictions_out}")

        judged = await self.judge_all(questions, predictions)
        print(f"Saved judged results -> {self.args.judged_out}")

        self.dump_metrics(judged, n_total=len(questions))

        print(f"\nDone in {time.time() - t0:.1f}s")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Single-file HLE evaluator")

    # Dataset
    parser.add_argument("--dataset", type=str, default="cais/hle")
    parser.add_argument("--split", type=str, default="test")
    parser.add_argument("--hf_token", type=str, default=None)
    parser.add_argument("--max_samples", type=int, default=None)

    # Model
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--api_key", type=str, default=None)
    parser.add_argument("--base_url", type=str, default=None)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--max_completion_tokens", type=int, default=8192)
    parser.add_argument("--timeout", type=float, default=600.0)
    parser.add_argument("--num_workers", type=int, default=16)

    # Judge
    parser.add_argument("--judge_model", type=str, default=None,
                        help="例如 o3-mini-2025-01-31;不填则用本地 exact/MCQ match")
    parser.add_argument("--judge_api_key", type=str, default=None)
    parser.add_argument("--judge_base_url", type=str, default=None)

    # Outputs
    parser.add_argument("--predictions_out", type=str, default="hle_predictions.json")
    parser.add_argument("--judged_out", type=str, default="hle_judged.json")

    return parser.parse_args()


def main():
    args = parse_args()
    evaluator = HLEEvaluator(args)
    asyncio.run(evaluator.run())


if __name__ == "__main__":
    main()

安装

bash 复制代码
pip install datasets openai tqdm numpy

运行示例

1)只做预测 + 本地简单判分
bash 复制代码
export OPENAI_API_KEY=your_key
python hle_eval_single.py \
  --model gpt-4o-2024-11-20 \
  --max_samples 50 \
  --num_workers 8
2)预测 + judge model 判分
bash 复制代码
export OPENAI_API_KEY=your_key
python hle_eval_single.py \
  --model gpt-4o-2024-11-20 \
  --judge_model o3-mini-2025-01-31 \
  --max_samples 50 \
  --num_workers 8
3)接 vLLM / OpenAI-compatible 本地服务
bash 复制代码
python hle_eval_single.py \
  --model your-model-name \
  --base_url http://127.0.0.1:8000/v1 \
  --api_key EMPTY \
  --judge_model your-judge-model-name \
  --judge_base_url http://127.0.0.1:8000/v1 \
  --judge_api_key EMPTY \
  --max_samples 50

说明

这版脚本是按官方思路整理的"单文件版":

  • 预测 prompt 沿用了官方格式

  • 支持 question / image / answer / id

  • judge 逻辑优先走 LLM judge;不配 judge 时,就退化为本地 exact match / 多选匹配

  • 会输出两份文件:

    • hle_predictions.json
    • hle_judged.json

更稳的正式复现实验,还是建议直接跑官方仓库里的两步脚本;官方 README 里给了标准命令。 ([GitHub][1])

参考链接:

1\]: https://github.com/centerforaisafety/hle "GitHub - centerforaisafety/hle: Humanity's Last Exam · GitHub"

相关推荐
冰西瓜6002 小时前
深度学习的数学原理(三十四)—— Transformer 解码器完整实现
人工智能·深度学习·transformer
Forrit2 小时前
使用 Self-Instruct 构建医学问答数据集
网络·transformer
数智工坊2 小时前
【经典RL算法】Q-Learning:强化学习的里程碑——从理论到收敛证明的完整解析
论文阅读·人工智能·深度学习·算法·transformer
西西弗Sisyphus2 小时前
从零实现 Transformer:第 0 部分 - 基础( Foundations)view 重塑形状 和 transpose 交换维度顺序
transformer·embedding·view·transpose·multi-head
数智工坊4 小时前
【深度学习RL】A3C:异步强化学习的革命——用CPU打败GPU的深度RL算法
论文阅读·人工智能·深度学习·算法·transformer
数智工坊4 小时前
【深度学习RL】DQN:深度强化学习的里程碑——让AI从像素中学会玩Atari游戏
论文阅读·人工智能·深度学习·游戏·transformer
数智工坊5 小时前
【RL理论奠基】时序差分学习的奠基之作:从预测问题到TD(λ)家族的完整理论
论文阅读·人工智能·深度学习·学习·transformer·迁移学习
倔强的胖蚂蚁1 天前
Transformer 大模型原理 完整入门指南
人工智能·深度学习·云原生·transformer
机器学习之心1 天前
DBO-Transformer模型回归+SHAP分析+新数据预测+多输出!深度学习可解释分析(附MATLAB代码)
深度学习·回归·transformer·shap分析
低调小一2 天前
Midscene.js 原理拆解:它不是“自然语言点按钮”,而是一套会看屏幕的 UI 自动化运行时
人工智能·rnn·架构·大模型·transformer·tdd·midscene