基于xiaothink对Wanyv-50M模型进行c-eval评估

使用pypi安装xiaothink:

bash 复制代码
pip install xiaothink==1.0.2

下载模型:
万语-50M


开始评估(修改模型路径后即可直接开始运行,结果保存在output文件夹里):

python 复制代码
import os
import json
import pandas as pd
import re
from tqdm import tqdm
import random
import time
import requests
from xiaothink.llm.inference.test_formal import *
model=QianyanModel(MT=40.231,
                   ckpt_dir=r'path\to\wanyv\model\ckpt_test_40_2_3_1_formal_open')

def chat_x(inp,temp=0.3):
    return model.chat_SingleTurn(inp,temp=temp,loop=True,stop='。')#




from collections import Counter

def pre(question: str, options_str: str) -> str:
    question = question.replace('答案:', '')
    options_str = options_str.replace('答案:', '')

        
    
    if not 'A' in question:#你只需要直接-让我们首先一步步思考,最后在回答末尾
        prompt_template = '''题目:{question}\n{options_str}\n让我们首先一步步思考,最后在回答末尾给出一个字母作为你的答案(A或B或C或D)'''
        prompt_template2 = '''题目:{question}\n选项:{options_str}\n给出答案'''
        prompt_template3 = '''{question}\n{options_str}\n'''
        prompt_template4 = '''{question}\n{options_str}\n给出你的选择'''
        prompt_template5 = '''题目:{question}\n{options_str}\n答案:'''
    else:
        prompt_template = '''题目:{question}\n让我们首先一步步思考,最后在回答末尾给出一个字母作为你的答案(A或B或C或D)'''
        prompt_template2 = '''题目:{question}\n给出答案'''
        prompt_template3 = '''{question}\n'''
        prompt_template4 = '''{question}\n给出你的选择'''
        prompt_template5 = '''题目:{question}\n答案:'''

    ansd={}
    # Run the chat_core function 5 times and collect answers
    answers = []
    for _ in range(1):
        response = chat_x(prompt_template.format(question=question, options_str=options_str))
        #print(response)
        # Extract answer from response
        for option in 'ABCD':
            if option in response:
                answers.append(option)
                ansd[option]=response
                break
        else:
            print('AI选项检查:', repr(response))
            answers.append('A')  # Default to 'A' if no option found
            ansd['A']=''
    # Count occurrences of each answer
    answer_counts = Counter(answers)

    # Find the most common answer(s)
    most_common_answers = answer_counts.most_common()
    highest_frequency = most_common_answers[0][1]
    most_frequent_answers = [answer for answer, count in most_common_answers if count == highest_frequency]

    # Choose one of the most frequent answers (if there's a tie, choose the first alphabetically)
    final_answer = min(most_frequent_answers)

    with open('ceval_text_sklm.txt','a',encoding='utf-8') as f:
        f.write(
'{"instruction": "{prompt_template}", "input": "", "output": "{final_answer}"}\n'.replace('{prompt_template}',prompt_template.format(question=question, options_str=options_str).replace('\n','\\n')).replace('{final_answer}',ansd[final_answer]),)

    with open('ceval_text_sklm.txt','a',encoding='utf-8') as f:
        f.write(
'{"instruction": "{prompt_template}", "input": "", "output": "{final_answer}"}\n'.replace('{prompt_template}',prompt_template2.format(question=question, options_str=options_str).replace('\n','\\n')).replace('{final_answer}',ansd[final_answer]),)

    with open('ceval_text_sklm.txt','a',encoding='utf-8') as f:
        f.write(
'{"instruction": "{prompt_template}", "input": "", "output": "{final_answer}"}\n'.replace('{prompt_template}',prompt_template3.format(question=question, options_str=options_str).replace('\n','\\n')).replace('{final_answer}',ansd[final_answer]),)

    with open('ceval_text_sklm.txt','a',encoding='utf-8') as f:
        f.write(
'{"instruction": "{prompt_template}", "input": "", "output": "{final_answer}"}\n'.replace('{prompt_template}',prompt_template4.format(question=question, options_str=options_str).replace('\n','\\n')).replace('{final_answer}',ansd[final_answer]),)

    with open('ceval_text_sklm.txt','a',encoding='utf-8') as f:
        f.write(
'{"instruction": "{prompt_template}", "input": "", "output": "{final_answer}"}\n'.replace('{prompt_template}',prompt_template5.format(question=question, options_str=options_str).replace('\n','\\n')).replace('{final_answer}',ansd[final_answer]),)
        
    return final_answer

class Llama_Evaluator:
    def __init__(self, choices, k):
        self.choices = choices
        self.k = k

    def eval_subject(self, subject_name,
                     test_df,
                     dev_df=None,
                     few_shot=False,
                     cot=False,
                     save_result_dir=None,
                     with_prompt=False,
                     constrained_decoding=False,
                     do_test=False):
        all_answers = {}
        correct_num = 0
        if save_result_dir:
            result = []
            score = []
        if few_shot:
            history = self.generate_few_shot_prompt(subject_name, dev_df, cot=cot)
        else:
            history = ''
        answers = ['NA'] * len(test_df) if do_test is True else list(test_df['answer'])
        for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)):
            question = self.format_example(row, include_answer=False, cot=cot, with_prompt=with_prompt)
            options_str = self.format_options(row)
            instruction = history + question + "\n选项:" + options_str
            ans = pre(instruction, options_str)

            if ans == answers[row_index]:
                correct_num += 1
                correct = 1
            else:
                correct = 0
            print(f"\n=======begin {str(row_index)}=======")
            print("question: ", question)
            print("options: ", options_str)
            print("ans: ", ans)
            print("ground truth: ", answers[row_index], "\n")
            if save_result_dir:
                result.append(ans)
                score.append(correct)
            print(f"=======end {str(row_index)}=======")

            all_answers[str(row_index)] = ans

        correct_ratio = 100 * correct_num / len(answers)

        if save_result_dir:
            test_df['model_output'] = result
            test_df['correctness'] = score
            test_df.to_csv(os.path.join(save_result_dir, f'{subject_name}_test.csv'))

        return correct_ratio, all_answers

    def format_example(self, line, include_answer=True, cot=False, with_prompt=False):
        example = line['question']
        for choice in self.choices:
            example += f'\n{choice}. {line[f"{choice}"]}'
        if include_answer:
            if cot:
                example += "\n答案:让我们一步一步思考,\n" + \
                           line["explanation"] + f"\n所以答案是{line['answer']}。\n\n"
            else:
                example += '\n答案:' + line["answer"] + '\n\n'
        else:
            if with_prompt is False:
                if cot:
                    example += "\n答案:让我们一步一步思考,\n1."
                else:
                    example += '\n答案:'
            else:
                if cot:
                    example += "\n答案是什么?让我们一步一步思考,\n1."
                else:
                    example += '\n答案是什么? '
        return example

    def generate_few_shot_prompt(self, subject, dev_df, cot=False):
        prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n"
        k = self.k
        if self.k == -1:
            k = dev_df.shape[0]
        for i in range(k):
            prompt += self.format_example(
                dev_df.iloc[i, :],
                include_answer=True,
                cot=cot
            )
        return prompt

    def format_options(self, line):
        options_str = ""
        for choice in self.choices:
            options_str += f"{choice}: {line[f'{choice}']} "
        return options_str

def main(model_path, output_dir, take, few_shot=False, cot=False, with_prompt=False, constrained_decoding=False, do_test=False, n_times=1, do_save_csv=False):
    assert os.path.exists("subject_mapping.json"), "subject_mapping.json not found!"
    with open("subject_mapping.json") as f:
        subject_mapping = json.load(f)
    filenames = os.listdir("data/val")
    subject_list = [val_file.replace("_val.csv", "") for val_file in filenames]
    accuracy, summary = {}, {}

    run_date = time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime(time.time()))
    save_result_dir = os.path.join(output_dir, f"take{take}")
    if not os.path.exists(save_result_dir):
        os.makedirs(save_result_dir, exist_ok=True)

    evaluator = Llama_Evaluator(choices=choices, k=n_times)

    all_answers = {}
    for index, subject_name in tqdm(list(enumerate(subject_list)),desc='主进度'):
        print(f"{index / len(subject_list)} Inference starts at {run_date} on {model_path} with subject of {subject_name}!")
        val_file_path = os.path.join('data/val', f'{subject_name}_val.csv')
        dev_file_path = os.path.join('data/dev', f'{subject_name}_dev.csv')
        test_file_path = os.path.join('data/test', f'{subject_name}_test.csv')

        val_df = pd.read_csv(val_file_path) if not do_test else pd.read_csv(test_file_path)
        dev_df = pd.read_csv(dev_file_path) if few_shot else None

        correct_ratio, answers = evaluator.eval_subject(subject_name, val_df, dev_df,
                                                       save_result_dir=save_result_dir if do_save_csv else None,
                                                       few_shot=few_shot,
                                                       cot=cot,
                                                       with_prompt=with_prompt,
                                                       constrained_decoding=constrained_decoding,
                                                       do_test=do_test)
        print(f"Subject: {subject_name}")
        print(f"Acc: {correct_ratio}")
        accuracy[subject_name] = correct_ratio
        summary[subject_name] = {"score": correct_ratio,
                                 "num": len(val_df),
                                 "correct": correct_ratio * len(val_df) / 100}
        all_answers[subject_name] = answers

    json.dump(all_answers, open(save_result_dir + '/submission.json', 'w'), ensure_ascii=False, indent=4)
    print("Accuracy:")
    for k, v in accuracy.items():
        print(k, ": ", v)

    total_num = 0
    total_correct = 0
    summary['grouped'] = {
        "STEM": {"correct": 0.0, "num": 0},
        "Social Science": {"correct": 0.0, "num": 0},
        "Humanities": {"correct": 0.0, "num": 0},
        "Other": {"correct": 0.0, "num": 0}
    }
    for subj, info in subject_mapping.items():
        group = info[2]
        summary['grouped'][group]["num"] += summary[subj]['num']
        summary['grouped'][group]["correct"] += summary[subj]['correct']
    for group, info in summary['grouped'].items():
        info['score'] = info["correct"] / info["num"]
        total_num += info["num"]
        total_correct += info["correct"]
    summary['All'] = {"score": total_correct / total_num, "num": total_num, "correct": total_correct}

    json.dump(summary, open(save_result_dir + '/summary.json', 'w'), ensure_ascii=False, indent=2)

# Example usage
if __name__ == "__main__":
    model_path = "path/to/model"
    output_dir = "output"
    take = 0
    few_shot = False
    cot = False
    with_prompt = False
    constrained_decoding = False
    do_test = True#False
    n_times = 1
    do_save_csv = False

    main(model_path, output_dir, take, few_shot, cot, with_prompt, constrained_decoding, do_test, n_times, do_save_csv)
相关推荐
OceanBase数据库官方博客23 分钟前
向量检索+大语言模型,免费搭建基于专属知识库的 RAG 智能助手
人工智能·oceanbase·分布式数据库·向量数据库·rag
测试者家园23 分钟前
ChatGPT助力数据可视化与数据分析效率的提升(一)
软件测试·人工智能·信息可视化·chatgpt·数据挖掘·数据分析·用chatgpt做软件测试
西猫雷婶26 分钟前
python学opencv|读取图像(十九)使用cv2.rectangle()绘制矩形
开发语言·python·opencv
海绵波波1071 小时前
flask后端开发(10):问答平台项目结构搭建
后端·python·flask
赵谨言1 小时前
基于python网络爬虫的搜索引擎设计
爬虫·python·搜索引擎
code04号1 小时前
python脚本:批量提取excel数据
开发语言·python·excel
hakesashou2 小时前
python如何打乱list
开发语言·python
Loving_enjoy2 小时前
ChatGPT详解
人工智能·自然语言处理
人类群星闪耀时2 小时前
深度学习在灾难恢复中的作用:智能运维的新时代
运维·人工智能·深度学习