1.2 Kaggle大白话:Eedi竞赛Transformer框架解决方案02-GPT_4o生成训练集缺失数据

目录

    • [0. 本栏目竞赛汇总表](#0. 本栏目竞赛汇总表)
    • [1. 本文主旨](#1. 本文主旨)
    • [2. AI工程架构](#2. AI工程架构)
    • [3. 数据预处理模块](#3. 数据预处理模块)
      • [3.1 配置数据路径和处理参数](#3.1 配置数据路径和处理参数)
      • [3.2 配置API参数](#3.2 配置API参数)
      • [3.3 配置输出路径](#3.3 配置输出路径)
    • [4. AI并行处理模块](#4. AI并行处理模块)
      • [4.1 定义LLM客户端类](#4.1 定义LLM客户端类)
      • [4.2 定义数据处理函数](#4.2 定义数据处理函数)
      • [4.3 定义JSON保存函数](#4.3 定义JSON保存函数)
      • [4.4 定义数据分片函数](#4.4 定义数据分片函数)
      • [4.5 定义分片处理函数](#4.5 定义分片处理函数)
      • [4.5 定义文件名排序函数](#4.5 定义文件名排序函数)
    • [5. 数据整合模块](#5. 数据整合模块)
      • [5.1 加载数据并生成分片](#5.1 加载数据并生成分片)
      • [5.2 初始化LLM客户端并测试](#5.2 初始化LLM客户端并测试)
      • [5.3 并行处理数据生成](#5.3 并行处理数据生成)
      • [5.4 合并处理结果](#5.4 合并处理结果)
      • [5.5 保存最终结果](#5.5 保存最终结果)

0. 本栏目竞赛汇总表

Kaggle竞赛汇总

1. 本文主旨

  • 大白话:由于在上一篇文章的数据探索中,我们发现了部分训练数据的错误解释存在缺失,因此直接使用GPT_4o+人设提示词工程,对训练集数据存在的错误解释缺失问题的处理。
  • 通过本文可收获技能:API调用AI接口、人设提示词工程案例、复杂的数据处理与缓存处理。
  • 上文回顾Eedi大模型蒸馏方案01-竞赛信息解读与数据理解

2. AI工程架构

数据整合模块 初始化客户端 加载数据 并行处理生成 合并结果 保存CSV AI并行处理模块 定义数据处理函数 定义LLM客户端 定义JSON保存函数 定义分片函数 定义排序函数 数据预处理模块 配置路径和参数 导入依赖库 配置API和输出

3. 数据预处理模块

3.1 配置数据路径和处理参数

python 复制代码
data_path = "~/work/eedi_synthetic_data/MalAlgoQA_format.csv"
index_start = 0
index_end = len(df)
step = 100
max_workers = 2

3.2 配置API参数

python 复制代码
model_config = dict(
    openai_api_base = "https://testshellapi.kimi.asia/v1", 
    api_key = "****",
    model = "gpt-4o",
    default_system_prompt = """
        ##Task
        You are a Mathematics teacher. Your task is to reason and identify the ConstructName and SubjectName and then the misconception behind the user input Incorrect Answers with the Question.
        ConstructName is Most granular level of knowledge related to question, appears to describe the specific mathematical method or procedure used to solve the question. It explains the technique or approach needed to reach the answer.
        SubjectName is More general context than the construct, represents the broader mathematical topic or category that the question belongs to.
        Misconceptions are a mistake in conceptual understanding and they have relations with all the applications of those concepts. For example, a single misconception on the connections among proportional relationships (part/whole, part/part, whole/part) can cause problems in identifying those patterns in drawings and can be the cause of failing to realize all parts must be of equal size, therefore associating the denominator of the fraction with the total number of parts regardless their size.
        Answer concisely what misconception it is to lead to getting the incorrect answer.
        Do not use "The misconception is" to start your answers.
        Do not mention the concrete details of the question or answers. 

        ##User input
        Question: The question text
        A: multiple choice answer A text
        B: multiple choice answer B text
        C: multiple choice answer C text
        D: multiple choice answer D text
        Correct Answer: The correct answer text

        ##You should answer in the following JSON format
        {
            "ConstructName": "here writes the constructName",
            "SubjectName": "here writes the SubjectName"
            "MisconceptionAName": "here writes the answer A's misconception.",
            "MisconceptionBName": "here writes the answer B's misconception.",
            "MisconceptionCName": "here writes the answer C's misconception.",
            "MisconceptionDName": "here writes the answer D's misconception.",
        }
        """, # system prompt,
    default_temperature = 0.5,
    max_tokens = 256,
)

3.3 配置输出路径

python 复制代码
cache_folder = f"./cache_{model_config['model']}_model_misconceptions_result"
if not os.path.exists(cache_folder):
    os.makedirs(cache_folder)
output_data_path = f"misconception_data_{os.path.splitext(os.path.basename(data_path))[0]}_{model_config['model']}.csv"

4. AI并行处理模块

4.1 定义LLM客户端类

python 复制代码
class LLMChat:
    def __init__(self, openai_api_base, api_key, model, default_temperature, default_system_prompt, max_tokens=512):
        self.client = OpenAI(
            api_key = api_key,
            base_url=openai_api_base,
        )
        self.model = model
        self.default_temperature = default_temperature
        self.default_system_prompt = default_system_prompt
        self.max_tokens = max_tokens
    
    def chat(self, user_prompt, system_prompt=None, temperature=None):
        if not system_prompt:
            system_prompt = self.default_system_prompt
            
        if not temperature:
            temperature = self.default_temperature

        chat_response = self.client.chat.completions.create(
            model=self.model,
            temperature=temperature,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
            max_tokens=self.max_tokens,
            response_format={"type": "json_object"}
        )
        return chat_response.choices[0].message.content

4.2 定义数据处理函数

python 复制代码
def process_row(args, debug=False):
    user_prompt = """
    Question: {question}
    A: {answer_a}
    B: {answer_b}
    C: {answer_c}
    D: {answer_d}
    Correct Answer: {correct_answer}
    """
    index, row = args
    ca = row["CorrectAnswer"]
    correctanswer = row[f"Answer{ca}Text"]
    input_user_prompt = user_prompt.format(
        question=row['QuestionText'],
        answer_a=row['AnswerAText'],
        answer_b=row['AnswerBText'],
        answer_c=row['AnswerCText'],
        answer_d=row['AnswerDText'],
        correct_answer=correctanswer,
    )
    ret_data = {}
    try:
        ret_data = vc.chat(input_user_prompt)
        if debug:
            print(ret_data+'\n')
    except Exception as e:
        print(f'An exception occur {str(e)}')
        ret_data['error'] = str(e)
        pass
    if debug:
        print('system: ', model_config['default_system_prompt'])
        print('>'* 50)
        print('user_input: ', input_user_prompt)
        print('>'* 50)
        print('assistant: ', ret_data)
    return ret_data

4.3 定义JSON保存函数

python 复制代码
def save_json(fn, obj):
    with open(fn, 'w') as f:
        json.dump(obj, f, ensure_ascii=False, indent=4)
    print(f"save file to {fn}")

4.4 定义数据分片函数

python 复制代码
def slice_range(start, end, step):
    if step <= 0:
        raise ValueError("步长必须大于0")
    
    result = []
    while start <= end:
        result.append(start)
        start += step
    if result[-1] < end:
        result.append(end)
    return result

4.5 定义分片处理函数

python 复制代码
def process_pairs(sliced_range):
    slices = []
    for first, second in zip(sliced_range, sliced_range[1:]):
        slices.append([first, second])
    return slices

4.5 定义文件名排序函数

python 复制代码
def natural_sort_key(filename):
    parts = re.findall(r'\d+', filename)
    return tuple(map(int, parts))

5. 数据整合模块

5.1 加载数据并生成分片

python 复制代码
df = pd.read_csv(data_path)
df.head()
sliced_range = process_pairs(slice_range(index_start, index_end, step))

df数据检查:

5.2 初始化LLM客户端并测试

python 复制代码
vc = LLMChat(**model_config)
r = process_row((7, df.iloc[7]), debug=True)

5.3 并行处理数据生成

python 复制代码
for slices in tqdm(sliced_range, total=len(sliced_range)):
    output_filepath = f'{cache_folder}/cache_res_{slices[0]}.json'
    if os.path.exists(output_filepath):
        print(f'cache file exists, skip {output_filepath}')
        continue
    df_tasks = df.iloc[slices[0]:slices[1]]
    results = []
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        results = list(tqdm(executor.map(process_row, df_tasks.iterrows()), total=len(df_tasks)))
    save_json(output_filepath, results)

5.4 合并处理结果

python 复制代码
f_names = glob.glob(f'{cache_folder}/*.json')
sorted_filenames = sorted(f_names, key=natural_sort_key)
f_names = sorted_filenames

results = []
for fn in f_names:
    with open(fn, 'r') as f:
        batch_results = json.load(f)
    results.extend(batch_results)

l = len(results)
results = [json.loads(r) for r in results]

5.5 保存最终结果

python 复制代码
df = df.iloc[:l]
gen_df = pd.DataFrame(results)
df = pd.concat([df, gen_df], axis=1)
df.to_csv(output_data_path, index=False)

(To be continued)

相关推荐
ZKNOW甄知科技2 分钟前
数智同行:甄知科技2026年Q1季度回顾
运维·服务器·人工智能·科技·程序人生·安全·自动化
呆呆敲代码的小Y2 分钟前
【Unity工具篇】| 游戏完整资源热更新流程,YooAsset官方示例项目
人工智能·游戏·unity·游戏引擎·热更新·yooasset·免费游戏
jikemaoshiyanshi3 分钟前
B2B企业GEO服务商哪家好?深度解析径硕科技(JINGdigital)及其JINGEO产品为何是首选
大数据·运维·人工智能·科技
Lab_AI3 分钟前
浩天药业携手创腾科技,开启研发数字化新篇章!电子实验记录本(ELN)落地浩天药业
人工智能
m0_738120724 分钟前
网络安全编程——Python编写基于UDP的主机发现工具(解码IP header)
python·网络协议·tcp/ip·安全·web安全·udp
supericeice5 分钟前
大模型建筑隐患管理方案怎么做?创邻科技用知识图谱、图数据库和企业AI大脑打通隐患问答、整改与推荐
人工智能·科技·知识图谱
北冥有羽Victoria8 分钟前
OpenCLI 操作网页 从0到1完整实操指南
vscode·爬虫·python·github·api·ai编程·opencli
蕤葳-9 分钟前
非编程背景学习AI的方法
人工智能
handsomestWei10 分钟前
scikit-learn数据预处理模块
python·机器学习·scikit-learn
北京耐用通信12 分钟前
不换设备、不重写程序:耐达讯自动化网关如何实现CC-Link IE转Modbus TCP的高效互通?
人工智能·科技·物联网·网络协议·自动化·信息与通信