大语言模型(LLMs)数学推理的经验技巧【思维链CoT的应用方法】

引言:2024年暑假期间,我参加了华为的一个比赛,关于提升数学推理能力的,赛后官方公布了优秀队伍的解题方案,我觉得蛮不错的,记录并学习一下。


✅ NLP 研 2 选手的学习笔记,2024年最后一篇

笔者简介:Wang Linyong,NPU,2023级,计算机技术

研究方向:文本生成、大语言模型

大赛链接:https://www.hiascend.com/zh/developer/contests/details/3e40e521c6304c4294317c0c1a70ff5e,2024 华为技术有限公司

项目链接:https://github.com/mindspore-courses/competition

赛事标题:昇思MindSpore模型开发挑战赛【模型微调赛题】


文章目录

  • [1 赛题简介](#1 赛题简介)
  • [2 获奖公示](#2 获奖公示)
  • [3 优秀方案学习 - 第1名队伍 debias_world1 的总体方案](#3 优秀方案学习 - 第1名队伍 debias_world1 的总体方案)
  • [4 优秀方案学习 - 第1名队伍 debias_world1 的详细步骤](#4 优秀方案学习 - 第1名队伍 debias_world1 的详细步骤)
  • [5 参考文献](#5 参考文献)

1 赛题简介

● 赛题要求基于开源中英文混合数学运算数据集,跑通 baseline,并对 MindFormers 中 LLama3-8b 模型进行微调(LoRA 或其他微调算法)。微调后的模型在原有能力不丢失的前提下(需保持在原能力的 90% 及以上),回答数学运算准确率相对 baseline 有提升,按照低参比例及准确率进行综合排名,评选出 20 个优秀团队,获得入围奖。

  1. 模型原有能力以其在 SQUAD 数据集上的阅读理解能力为准,评价标准为 F1 ScoreEm Score,要求微调后两项评价指标需要给定阈值以上方可算作有效作品,如何进行原有能力评估,以及 F1 ScoreEm Score 的参考阈值,请参考指导手册。

  2. 运算准确率评价标准:模型基于测试数据集(不公开,与训练数据集格式相同,为数道中英文数学运算题)进行推理,生成数学运算结果,如计算结果(数值)与正确答案相同,则视为本题正确,最终统计在测试数据集上回答正确的题目数量占比。
    >>> 运算准确率 = 正确运算题目数/测试集总题目数

  3. 低参比例:低参比例为微调参数量在总参数量的占比,选手在提交作品时需提供低参比例的计算结果,如何进行低参比例详见下方-低参比例运算。
    >>> 低参比例 = 参与微调的参数量/模型总参数量

  4. 低参比例和运算准确率综合排名:低参比例越低越好,运算准确率越高越好,按照如下加权进行运算。
    >>> 总分数:(100%-低参比例)×0.3+运算准确率×0.7

  5. 本题目共提供 80万 条中英文混合题目作为训练数据集(下载链接),选手可根据自己的实际情况调整数据集规模,建议综合在微调及推理时长、算力需求、维持模型原有能力、模型运算准备率提升等多方面因素进行训练数据集规模的评估。


2 获奖公示

● 我当时想着就试一试,只弄了三天,最后竟然入围了前 20,哈哈哈。


3 优秀方案学习 - 第1名队伍 debias_world1 的总体方案

● 这一队参考了 《Goat: Fine-tuned LLaMA Outperforms GPT-4 on Arithmetic Tasks》这篇论文。主要在于 思维链(CoT) 的运用,即把数学运算的每一个小步骤都写出来【数据集的预处理工作 】。然后把其他关于数学推理问题都转化为 乘法问题 或者 除法问题 ,应用了 归纳思想,比如:

python 复制代码
① 乘法运算:
141.67 * -3138.39 = -444615.7113

思维链为 ↓
141.67 * 3138.39 
= 141.67 * (3000 + 100 + 30 + 8 + 0.3 + 0.09) 
= 141.67 * 3000 + 141.67 * 100 + 141.67 * 30 + 141.67 * 8 + 141.67 * 0.3 + 141.67 * 0.09 
= 425010.00 + 14167.00 + 4250.10 + 1133.36 + 42.501 + 12.7503 
= 439177.00 + 4250.10 + 1133.36 + 42.501 + 12.7503 
= 443427.10 + 1133.36 + 42.501 + 12.7503 
= 444560.46 + 42.501 + 12.7503 
= 444602.961 + 12.7503 
= 444615.7113
final, 141.67 * -3138.39 = -444615.7113

---------------------------------------------------------------------------------------------
② 除法运算:
-9703.42 / 2833.26 = -3.424825113120574885467623868

思维链为 ↓
9703420000.00 - 2833.26 * 3000000 = 9703420000.00 - 8499780000.00 = 1203640000.00
1203640000.00 - 2833.26 * 400000 = 1203640000.00 - 1133304000.00 = 70336000.00
70336000.00 - 2833.26 * 20000 = 70336000.00 - 56665200.00 = 13670800.00
13670800.00 - 2833.26 * 4000 = 13670800.00 - 11333040.00 = 2337760.00
2337760.00 - 2833.26 * 800 = 2337760.00 - 2266608.00 = 71152.00
71152.00 - 2833.26 * 20 = 71152.00 - 56665.20 = 14486.80
14486.80 - 2833.26 * 5 = 14486.80 - 14166.30 = 320.50
Therefore, 9703420000.00 / 2833.26 = 3424825 R 320.50
final, -9703.42 / 2833.26 ~= -3.42483

---------------------------------------------------------------------------------------------
③ 面积计算
问题:一个长方形的长为 8 厘米,宽为 87 厘米,请计算其面积。
答案:面积为 696 平方厘米

思维链为 ↓
8 * 87 = 8 * (80 + 7) = 8 * 80 + 8 * 7 = 640 + 56 = 696
final, 8 * 87 = 696 平方厘米

---------------------------------------------------------------------------------------------
④ 平均值计算
问题:求以下数据的平均值:[91, 27, 4, 4, 9, 63, 67, 45, 7]
答案:平均值为 35.22222222222222

思维链为 ↓
91 + 27 + 4 + 4 + 9 + 63 + 67 + 45 + 7 
= 118 + 4 + 4 + 9 + 63 + 67 + 45 + 7 
= 122 + 4 + 9 + 63 + 67 + 45 + 7 
= 126 + 9 + 63 + 67 + 45 + 7 
= 135 + 63 + 67 + 45 + 7 
= 198 + 67 + 45 + 7 
= 265 + 45 + 7 
= 310 + 7 
= 317
计算 317 / 9
317000000 - 9 * 30000000 = 317000000 - 270000000 = 47000000
47000000 - 9 * 5000000 = 47000000 - 45000000 = 2000000
2000000 - 9 * 200000 = 2000000 - 1800000 = 200000
200000 - 9 * 20000 = 200000 - 180000 = 20000
20000 - 9 * 2000 = 20000 - 18000 = 2000
2000 - 9 * 200 = 2000 - 1800 = 200
200 - 9 * 20 = 200 - 180 = 20
20 - 9 * 2 = 20 - 18 = 2
Therefore, 317000000 / 9 = 35222222 R 2
final, 317 / 9 ~= 35.22222

● 而对于大语言模型的 LoRA 微调,其参数和 baseline 一致,知识把微调的 seq_len 由原来的 256 修改为 768。微调参数比例:54525952 / 8030000000 = 0.679%(微调大模型为 Llama3-8B)。

python 复制代码
pet_config:
      pet_type: lora
      # configuration of lora
      lora_rank: 64
      lora_alpha: 64
      lora_dropout: 0.05
      target_modules: '.*wq|.*wk|.*wv|.*wo'

● 微调样本量及训练时间:4.2万+ 样本,4batch_size = 16(单卡 batch_size = 4)配置下,微调 5epoch 花费 9小时+


4 优秀方案学习 - 第1名队伍 debias_world1 的详细步骤

1 步: 读取 train.json 文件的数据,并将其转换为 pandas.DataFrame 的格式。

python 复制代码
work_path = os.path.dirname(os.path.abspath(__file__))  # 工作路径
parser = argparse.ArgumentParser()  # 命令行参数解析器
parser.add_argument(  # 训练集路径
    '--data_path',
    default="../train.json",
    required=False,
    help='org data path')
parser.add_argument(  # 训练集相关参数
    '--train_len', default=40960, type=int, required=False,
    help='sample train data number, if train_len == -1 ,all data use for train')

parser.add_argument(  # 验证集相关参数
    '--valid_len', default=2000, type=int, required=False,
    help='valid data number, if valid_len == -1 ,all data except train use for valid')

parser.add_argument(  # 处理后的数据集保存路径
    '--out_dir',
    default="./",
    required=False,
    help='output dir for train/valid data')

args_, rest_args_ = parser.parse_known_args()  # 解析命令行参数

data = [json.loads(line) for line in open(args_.data_path, 'r', encoding='utf-8').readlines()]  # 读取数据
df = pd.DataFrame.from_records(data)  # 转换为 DataFrame 格式
print('原始数据数量:', df.shape)

运行结果展示:

2 步: 对数据集中每个问题进行 答案提取 。即对数据框 df 的每一行,基于其内容计算获取标准答案,并将其存入新的列 answer 中,便于后续操作。

python 复制代码
# 获取原始solution中数字结果(暂时不包括科学计数法数字)
def get_clean_answer(row):
    if re.search(r'计算\s?-?\d+\.?\d*\s?\+\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'计算\s?-?\d+\.?\d*\s?\-\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'计算\s?-?\d+\.?\d*\s?\*\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'计算\s?-?\d+\.?\d*\s?\/\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'解方程 -?\d+x \+ -?\d+ = 0', row['problem']) is not None \
        or re.search(r'一个长方形的长为 \d+ 厘米,宽为 \d+ 厘米,请计算其面积', row['problem']) is not None \
        or re.search(r'某物体的密度为 \d+ 克/立方厘米,体积为 \d+ 立方厘米,请计算该物体的质量', row['problem']) is not None \
        or re.search(r'解方程 -?\d+x \+ -?\d+ = 0', row['problem']) is not None \
        or re.search(r'计算\s?-?\d+.?\d*\s?的\s?\d+\s?次方?', row['problem']) is not None \
        or re.search(r'计算\s?\d+.?\d*\s?的平方根', row['problem']) is not None \
        or re.search(r'求以下数据的平均值:\[.+\]', row['problem']) is not None:
        return re.findall('-?\d+\.?\d*',row['solution'])[-1]
    elif re.search(r'将分数 \d+/\d+ 进行简化', row['problem']) is not None \
        or re.search(r'当 x = \d+.?\d* 时,求函数 y = \d*x\^\d+ 的值', row['problem']) is not None:
        return row['solution'].split(':')[-1]
    else: 
        return row['solution']

# ... 第1步的代码 ...
df['id'] = df.index
df['answer'] = df.apply(lambda x:get_clean_answer(x), axis = 1)  # axis=1 表示按行应用函数(每次传入 x 是数据框的一行)

运行结果展示:

3 步: 对数据集中每个问题进行 可学习样本标记 。即通过检查 problem 列中的内容,判断该问题是否符合某些特定的学习问题(比如计算四则运算、解方程、几何计算等),并基于此判断是否可以学习这个问题。最终,将判断结果存入新列 can_learn

python 复制代码
# 判断样本在目前的方案下模型是否可以学习
def can_learn(row):
    if re.search(r'计算\s?-?\d+\.?\d*\s?\+\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'计算\s?-?\d+\.?\d*\s?\-\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'计算\s?-?\d+\.?\d*\s?\*\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'计算\s?-?\d+\.?\d*\s?\/\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'解方程 -?\d+x \+ -?\d+ = 0', row['problem']) is not None \
        or re.search(r'一个长方形的长为 \d+ 厘米,宽为 \d+ 厘米,请计算其面积', row['problem']) is not None \
        or re.search(r'某物体的密度为 \d+ 克/立方厘米,体积为 \d+ 立方厘米,请计算该物体的质量', row['problem']) is not None \
        or re.search(r'商品原价为 \d+ 元,打折后的价格为 \d+ 元,请计算打折的折扣比例', row['problem']) is not None \
        or re.search(r'去年销售额为 \d+ 万元,今年销售额增加了 \d+%,请计算今年的销售额', row['problem']) is not None \
        or re.search(r'求以下数据的平均值:\[.+\]', row['problem']) is not None \
        or re.search(r'将分数 \d+/\d+ 进行简化', row['problem']) is not None:
        return True
    elif re.search(r'计算\s?-?\d+.?\d*\s?的\s?\d+\s?次方?', row['problem']) is not None:
        result_list =  re.findall('-?\d+\.?\d*',row['problem'])   # 如果问题包含次方计算,findall() 提取所有数字,将它们存储到 result_list 中。比如,对于 计算 -3.2 的 2 次方,result_list 将会是 ['-3.2', '2']。
        _, b = result_list
        if Decimal(b) <= Decimal(1):  # 如果指数(b)小于等于 1,则返回 True,表示该问题可以学习。比如,如果是 "2 的 1 次方" 或 "3 的 0 次方",这些可以认为是简单的计算问题。
            return True
    return False


# ... 第1步的代码 ...
# ... 第2步的代码 ...
df['can_learn'] = df.apply(lambda x:can_learn(x), axis = 1)

运行结果展示:

4 步: 对数据集中每个问题进行 思维链(CoT)生成 。通过分析问题类型,为每个问题生成详细的思维链解释(CoT),便于展示逐步解题的过程。输出的解释被存储在数据框的 output 列中,供后续使用。

python 复制代码
# 获取保留小数点后最多6位的字符串
def get_precision_str(answer):
    answer = Decimal(answer)
    answer_str = str(answer)
    splits = str(answer).split('.')
    if len(splits) == 2:
        if len(splits[1]) < 6:
            answer_str = splits[0] + '.' + splits[1][:5]
        else:  # 如果超过 6 位,小数部分四舍五入
            splits = str(answer + np.sign(answer) * Decimal('0.000005')).split('.')
            answer_str = splits[0] + '.' + splits[1][:5]

    return answer_str


#分解一个数字到非0的有效位
def get_number_split(input_num):
    splits =  str(input_num).split('.')
    if len(splits) == 2:
        integer_part, decimal_part = splits
    else:
        integer_part, decimal_part = splits[0], '0'

    integer_len = len(integer_part)
    integer_split = [Decimal(10**(integer_len-i-1))*Decimal(integer_part[i]) for i in range(integer_len)]
    decimal_split = [Decimal('0.'+'0'*i + decimal_part[i]) for i in range(len(decimal_part))]
    integer_split.extend(decimal_split)
    integer_split = list(filter(lambda x:x != 0, integer_split))
    return integer_split

# 乘法CoT
def get_mul_cot(num1, num2, answer):
	"""
	作用:生成乘法问题的思维链解释。
	逻辑:
	1. 如果 num2 由多部分构成(例如:123.45 可分解为 100 + 20 + 3 + 0.4 + 0.05),按分解展开逐步计算。
	2. 使用中间结果构造逐步累加的计算过程。
	3. 返回最终的步骤化解释。
	"""
    num1, num2, answer = np.abs(num1), np.abs(num2), np.abs(answer)
    question = f"{num1} * {num2}"
    num2_split = get_number_split(num2)
    
    if len(num2_split) == 1:
        cot = question + " = " + str(answer)
    else:
        split = f"""{num1} * ({" + ".join(str(x) for x in num2_split)})"""
        expansion = " + ".join([f"{num1} * {x}" for x in num2_split])
        summation_terms = [num1 * x for x in num2_split]
        summation = " + ".join(str(x) for x in summation_terms)
        step = ""
        while summation_terms:
            first = summation_terms.pop(0)
            if not summation_terms:
                output = first
                break
            summation_terms[0] = first + summation_terms[0]
            if len(summation_terms) == 1:
                summation_terms = [str(answer)]
            step = step + " + ".join([f"{x}" for x in summation_terms]) 
            if len(summation_terms)>=2:
                step = step + " = "
    
        cot = question + " = " + f"{split} = {expansion} = {summation} = " + step
    return cot

#除法CoT
def get_div_cot(num1, num2, answer = None):
	"""
	作用:生成除法问题的思维链解释。
	逻辑:
	1. 通过模拟逐步的长除法,计算商和余数。
	2. 构造逐步的减法过程,展示除法步骤。
	3. 返回最终的步骤化解释。
	"""
    base = Decimal('10')**6
    num1, num2 = np.abs(num1), np.abs(num2)
    num1 = num1*base
    quotient = num1 // num2
    remainder = num1 % num2
    if quotient == 0:
        cot = f"{num1} / {num2} = {quotient} R {remainder}"
    elif num1 == num2:
        cot = f"{num1} / {num2} = {quotient}"
    else:
        step = ""
        cot = ""
        left = num1
        i = 0
        computed_q = 0
        while left>=num2:
            if int(str(quotient)[i])!=0:
                intermediate = int(str(quotient)[i] + "0" * (len(str(quotient))-1-i))
                answer = num2 * intermediate
                new_left = left - answer
                step = f"{left} - {num2} * {intermediate} = {left} - {answer} = {new_left}\n"
                cot = cot + step
                left = new_left
                computed_q = computed_q + intermediate
            i = i+1
        assert(left == remainder)
        assert(computed_q == quotient)
    
        if remainder!=0:
            cot = cot + f"Therefore, {num1} / {num2} = {quotient} R {remainder}"
        else:
            cot = cot + f"Therefore, {num1} / {num2} = {quotient}"
    return cot

def apply_cot(row):
	"""
	作用: 这个函数对不同类型的问题进行分类处理,生成思维链解释。
	逻辑: 使用正则表达式 re.search() 检测 problem 中的问题类型,支持以下问题类型:
		1. 乘法 (计算 ... * ...)
		2. 除法 (计算 ... / ...)
		3. 解方程 (解方程 -ax + b = 0)
		4. 折扣比例计算 (商品原价为 ... 元,打折后的价格为 ... 元)
		5. 几何问题(矩形面积) (一个长方形的长为 ...,宽为 ...)
		6. 物理问题(质量计算) (某物体的密度为 ...,体积为 ...)
		7. 销售额增长 (去年销售额为 ...,今年增长了 ...%)
		8. 平均值计算 (求以下数据的平均值:[...])
	"""
    if re.search(r'计算\s?-?\d+\.?\d*\s?\*\s?-?\d+\.?\d*', row['problem']) is not None:
    	"""
    	匹配问题:计算 3 * 4
		提取数字:从 solution 中提取操作数和结果。
		使用 get_mul_cot() 生成逐步计算过程。
		最终添加总结:final, 3 * 4 = 12
    	"""
        result_list =  re.findall('-?\d+\.?\d*',row['solution'])
        result_list = [Decimal(x) for x in result_list]
        num1, num2, answer = result_list
        cot =  get_mul_cot(num1, num2, answer)
        cot = cot+ f"\nfinal, {num1} * {num2} = {answer}"
        return cot
    elif re.search(r'计算\s?-?\d+\.?\d*\s?\/\s?-?\d+\.?\d*', row['problem']) is not None:
	    """
	    匹配问题:计算 8 / 3
		提取数字:从 solution 中提取操作数和结果。
		使用 get_div_cot() 模拟长除法生成计算步骤。
		最终添加总结:final, 8 / 3 ~= 2.666667
	    """
        result_list =  re.findall('-?\d+\.?\d*',row['solution'])
        result_list = [Decimal(x) for x in result_list]
        num1, num2, answer = result_list
        splits = str(answer).split('.')
        answer_str = str(answer)
        if len(splits) == 2:
            answer_str = splits[0] + '.' + splits[1][:6]
        cot = get_div_cot(num1, num2, None)
        answer_str = get_precision_str(answer_str)
        cot = cot + f"\nfinal, {num1} / {num2} ~= {answer_str}"
        return cot
    elif re.search(r'解方程 -?\d+x \+ -?\d+ = 0', row['problem']) is not None:
	    """
	    匹配问题:解方程 -3x + 6 = 0
		提取系数和常数:通过正则解析出 a 和 b。
		手动计算解:x = -b / a
		使用 get_div_cot() 展示逐步计算。
	    """
        result_list =  re.findall('-?\d+\.?\d*',row['problem'])
        result_list = [Decimal(x) for x in result_list]
        a, b, c = result_list
        num1, num2 = -b + c, a
        answer = num1/num2
        splits = str(answer).split('.')
        answer_str = str(answer)
        if len(splits) == 2:
            answer_str = splits[0] + '.' + splits[1][:6]
        cot = get_div_cot(num1, num2, None)
        answer_str = get_precision_str(answer_str)
        cot = cot + f"\nfinal, {num1} / {num2} ~= {answer_str}"
        return cot
    elif re.search(r'商品原价为 \d+ 元,打折后的价格为 \d+ 元,请计算打折的折扣比例', row['problem']) is not None:
	    """
		匹配问题:商品原价为 100 元,打折后的价格为 80 元
		提取价格:通过正则提取原价和折后价。
		手动计算:(原价 - 折后价) / 原价
		使用 get_div_cot() 展示折扣比例的逐步计算过程。
	    """
        result_list =  re.findall('-?\d+\.?\d*',row['problem'])
        result_list = [Decimal(x) for x in result_list]
        a, b = result_list
        cot_head = f"({a} - {b}) / {a} = "
        num1, num2 = a - b, a
        answer = num1*100/num2
        splits = str(answer).split('.')
        answer_str = str(answer)
        if len(splits) == 2:
            answer_str = splits[0] + '.' + splits[1][:6]
        cot = cot_head + get_div_cot(num1, num2, None)
        answer_str = get_precision_str(answer_str)
        cot = cot + f"\nfinal, {num1} / {num2} ~= {answer_str} 折"
        return cot
    elif re.search(r'一个长方形的长为 \d+ 厘米,宽为 \d+ 厘米,请计算其面积', row['problem']) is not None:
	    """
	    匹配问题:一个长方形的长为 5 厘米,宽为 3 厘米
		提取长和宽。
		使用 get_mul_cot() 展示逐步计算过程。
	    """
        result_list =  re.findall('-?\d+\.?\d*',row['problem'])
        result_list = [Decimal(x) for x in result_list]
        num1, num2 = result_list
        answer = num1*num2
        cot =  get_mul_cot(num1, num2, answer)
        cot = cot+ f"\nfinal, {num1} * {num2} = {answer} 平方厘米"
        return cot
    elif re.search(r'某物体的密度为 \d+ 克/立方厘米,体积为 \d+ 立方厘米,请计算该物体的质量', row['problem']) is not None:
	    """
	    匹配问题:密度为 2 克/立方厘米,体积为 5 立方厘米
		提取密度和体积。
		使用 get_mul_cot() 展示逐步计算过程。
	    """
        result_list =  re.findall('-?\d+\.?\d*',row['problem'])
        result_list = [Decimal(x) for x in result_list]
        num1, num2 = result_list
        answer = num1*num2
        cot =  get_mul_cot(num1, num2, answer)
        cot = cot+ f"\nfinal, {num1} * {num2} = {answer} 克"
        return cot
    elif re.search(r'去年销售额为 \d+ 万元,今年销售额增加了 \d+%,请计算今年的销售额', row['problem']) is not None:
	    """
	    匹配问题:去年销售额为 100 万元,今年增长了 10%
		提取去年销售额和增长率。
		手动计算:去年销售额 × (1 + 增长率)
		使用 get_mul_cot() 展示逐步计算过程。
	    """
        result_list =  re.findall('-?\d+\.?\d*',row['problem'])
        result_list = [Decimal(x) for x in result_list]
        a, b = result_list
        num1, num2 = a, Decimal('1') + b/Decimal('100')
        answer = num1*num2
        cot =  get_mul_cot(num1, num2, answer)
        cot = cot+ f"\nfinal, {num1} * {num2} = {answer} 万元"
        return cot
    elif re.search(r'求以下数据的平均值:\[.+\]', row['problem']) is not None:
	    """
	    匹配问题:求以下数据的平均值:[1, 2, 3, 4]
		提取数据列表,逐步求和并计算平均值。
		使用 get_div_cot() 展示平均值计算过程。
	    """
        value_list = re.findall(r'\[.*\]' ,row['problem'])[0]
        value_list = eval(value_list)
        value_len = len(value_list)
        cum_list = list(np.cumsum(value_list))
        step = ""
        cot = ' + '.join([str(x) for x in value_list])
        if len(value_list) == 1:
            cot = str(value_list[0])
            return cot
        for idx in range(1, value_len):
            cot += ' = ' + ' + '.join([str(x) for x in ([cum_list[idx]] + value_list[idx+1:])])
        
        num1, num2 = Decimal(str(cum_list[-1])), Decimal(str(value_len))
        answer = num1/num2
        cot += f'\n计算 {cum_list[-1]} / {value_len}'
        cot += '\n' + get_div_cot(Decimal(str(cum_list[-1])), Decimal(str(value_len)), None)
        splits = str(answer).split('.')
        answer_str = str(answer)
        if len(splits) == 2:
            answer_str = splits[0] + '.' + splits[1][:6]
        answer_str = get_precision_str(answer_str)
        cot = cot + f"\nfinal, {num1} / {num2} ~= {answer_str}"
        return cot
    else:
	    """
	    如果问题未被识别,则直接返回原始答案。
	    """
        return row['answer']


# ... 第1步的代码 ...
# ... 第2步的代码 ...
# ... 第3步的代码 ...
df['output'] = df.apply(lambda row:apply_cot(row), axis = 1)

运行结果展示:

5 步: 对数据集进行去重和可学习样本提取操作。

python 复制代码
def is_number(s):
	"""
	作用:判断一个字符串是否为数值或可以表示为数值(包括分数形式,如 4/5)。
	逻辑:
	尝试将字符串转换为浮点数:
	如果转换成功,返回 True。
	处理分数形式:
	如果字符串中包含 /,去掉 / 后检查是否为纯数字。如果是纯数字,返回 True。
	否则返回 False。
	"""
    try:
        float(s)
        return True
    except ValueError:
        if s.replace('/', '').isnumeric():
            return True
        return False

# ... 第1步的代码 ...
# ... 第2步的代码 ...
# ... 第3步的代码 ...
# ... 第4步的代码 ...
org_df = df
df = df.drop_duplicates(['problem'])  # 对 problem 列进行去重,保留每个问题的第一条记录,删除重复的行。
print('去重后样本数量:', df.shape)
df['is_number'] = df['answer'].apply(lambda x:is_number(x))
df = df[df['is_number']]  # 仅保留 is_number 为 True 的行,即答案是数值的样本。
learn_df = df[df['can_learn']]  # 仅保留 can_learn 为 True 的行,即模型可以学习的样本。
print('去重后认为模型可以学习的样本数量:', learn_df.shape)

运行结果展示:

python 复制代码
原始数据数量: (809993, 2)
根据目前规则认为模型可以学习的样本数量: (713312, 5)
去重后样本数量: (548942, 6)
去重后认为模型可以学习的样本数量: (453521, 7)

6 步: 采样训练集和验证集,并保存。

python 复制代码
# ... 第1步的代码 ...
# ... 第2步的代码 ...
# ... 第3步的代码 ...
# ... 第4步的代码 ...
# ... 第5步的代码 ...
if args_.train_len > 0:
    train_df =  learn_df.sample(args_.train_len)  # 随机采样 args_.train_len 行
else:
    train_df = learn_df
if args_.valid_len > 0:
    valid_df = df[~df['id'].isin(train_df['id'])].sample(args_.valid_len)  # 从完整数据中(df)选择一定数量的样本作为验证集,且保证验证集样本不与训练集重叠。
else:
    valid_df = df[~df['id'].isin(train_df['id'])]
if not os.path.exists(args_.out_dir):  # 检查输出目录是否存在,如果不存在则创建
    os.makedirs(args_.out_dir, exist_ok = True) 
print('train_len:',train_df.shape)
print('valid_len:', valid_df.shape)
"""
目标:将训练集保存为 JSON 文件。
train_df[['id', 'problem', 'solution', 'answer', 'output']]:
1. id:样本的唯一标识。
2. problem:问题描述。
3. solution:答案。
4. answer:数值答案。
5. output:思维链(CoT)生成的解释。
"""
json.dump(train_df[['id', 'problem', 'solution', 'answer', 'output']].to_dict(orient='records'), open(os.path.join(args_.out_dir, 'train-data.json'), 'w'), indent=2)
json.dump(valid_df[['id', 'problem', 'solution', 'answer']].to_dict(orient='records'), open(os.path.join(args_.out_dir,'valid-data.json'), 'w'), indent=2)

valid_ms_data = valid_df.to_dict(orient='records')  # to_dict(orient='records'):将数据框转换为记录列表形式(每行数据为一个字典)
# 遍历 valid_ms_data,将每条记录写入文件 valid-data-list.json,每条记录占一行
with open(os.path.join(args_.out_dir,'valid-data-list.json'), 'w') as f:
    for line in valid_ms_data:
        json.dump(line, f)
        f.write('\n')

完整代码:

python 复制代码
import argparse
import os
import pandas as pd
import json
import re
import numpy as np
from decimal import Decimal


def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        if s.replace('/', '').isnumeric():
            return True
        return False


# 获取原始solution中数字结果(暂时不包括科学计数法数字)
def get_clean_answer(row):
    if re.search(r'计算\s?-?\d+\.?\d*\s?\+\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'计算\s?-?\d+\.?\d*\s?\-\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'计算\s?-?\d+\.?\d*\s?\*\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'计算\s?-?\d+\.?\d*\s?\/\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'解方程 -?\d+x \+ -?\d+ = 0', row['problem']) is not None \
        or re.search(r'一个长方形的长为 \d+ 厘米,宽为 \d+ 厘米,请计算其面积', row['problem']) is not None \
        or re.search(r'某物体的密度为 \d+ 克/立方厘米,体积为 \d+ 立方厘米,请计算该物体的质量', row['problem']) is not None \
        or re.search(r'解方程 -?\d+x \+ -?\d+ = 0', row['problem']) is not None \
        or re.search(r'计算\s?-?\d+.?\d*\s?的\s?\d+\s?次方?', row['problem']) is not None \
        or re.search(r'计算\s?\d+.?\d*\s?的平方根', row['problem']) is not None \
        or re.search(r'求以下数据的平均值:\[.+\]', row['problem']) is not None:
        return re.findall('-?\d+\.?\d*',row['solution'])[-1]
    elif re.search(r'将分数 \d+/\d+ 进行简化', row['problem']) is not None \
        or re.search(r'当 x = \d+.?\d* 时,求函数 y = \d*x\^\d+ 的值', row['problem']) is not None:
        return row['solution'].split(':')[-1]
    else: 
        return row['solution']

# 判断样本在目前的方案下模型是否可以学习
def can_learn(row):
    if re.search(r'计算\s?-?\d+\.?\d*\s?\+\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'计算\s?-?\d+\.?\d*\s?\-\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'计算\s?-?\d+\.?\d*\s?\*\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'计算\s?-?\d+\.?\d*\s?\/\s?-?\d+\.?\d*', row['problem']) is not None \
        or re.search(r'解方程 -?\d+x \+ -?\d+ = 0', row['problem']) is not None \
        or re.search(r'一个长方形的长为 \d+ 厘米,宽为 \d+ 厘米,请计算其面积', row['problem']) is not None \
        or re.search(r'某物体的密度为 \d+ 克/立方厘米,体积为 \d+ 立方厘米,请计算该物体的质量', row['problem']) is not None \
        or re.search(r'商品原价为 \d+ 元,打折后的价格为 \d+ 元,请计算打折的折扣比例', row['problem']) is not None \
        or re.search(r'去年销售额为 \d+ 万元,今年销售额增加了 \d+%,请计算今年的销售额', row['problem']) is not None \
        or re.search(r'求以下数据的平均值:\[.+\]', row['problem']) is not None \
        or re.search(r'将分数 \d+/\d+ 进行简化', row['problem']) is not None:
        return True
    elif re.search(r'计算\s?-?\d+.?\d*\s?的\s?\d+\s?次方?', row['problem']) is not None:
        result_list =  re.findall('-?\d+\.?\d*',row['problem'])
        _, b = result_list
        if Decimal(b) <= Decimal(1):  # 如果指数(b)小于等于 1,则返回 True,表示该问题可以学习。比如,如果是 "2 的 1 次方" 或 "3 的 0 次方",这些可以认为是简单的计算问题。
            return True
    return False


# 获取保留小数点后最多6位的字符串
def get_precision_str(answer):
    answer = Decimal(answer)
    answer_str = str(answer)
    splits = str(answer).split('.')
    if len(splits) == 2:
        if len(splits[1]) < 6:
            answer_str = splits[0] + '.' + splits[1][:5]
        else:
            splits = str(answer + np.sign(answer) * Decimal('0.000005')).split('.')
            answer_str = splits[0] + '.' + splits[1][:5]

    return answer_str


#分解一个数字到非0的有效位
def get_number_split(input_num):
    splits =  str(input_num).split('.')
    if len(splits) == 2:
        integer_part, decimal_part = splits
    else:
        integer_part, decimal_part = splits[0], '0'
    integer_len = len(integer_part)
    integer_split = [Decimal(10**(integer_len-i-1))*Decimal(integer_part[i]) for i in range(integer_len)]
    decimal_split = [Decimal('0.'+'0'*i + decimal_part[i]) for i in range(len(decimal_part))]
    integer_split.extend(decimal_split)
    integer_split = list(filter(lambda x:x != 0, integer_split))
    return integer_split

# 乘法CoT
def get_mul_cot(num1, num2, answer):
    num1, num2, answer = np.abs(num1), np.abs(num2), np.abs(answer)
    question = f"{num1} * {num2}"
    num2_split = get_number_split(num2)
    
    if len(num2_split) == 1:
        cot = question + " = " + str(answer)
    else:
        split = f"""{num1} * ({" + ".join(str(x) for x in num2_split)})"""
        expansion = " + ".join([f"{num1} * {x}" for x in num2_split])
        summation_terms = [num1 * x for x in num2_split]
        summation = " + ".join(str(x) for x in summation_terms)
        step = ""
        while summation_terms:
            first = summation_terms.pop(0)
            if not summation_terms:
                output = first
                break
            summation_terms[0] = first + summation_terms[0]
            if len(summation_terms) == 1:
                summation_terms = [str(answer)]
            step = step + " + ".join([f"{x}" for x in summation_terms]) 
            if len(summation_terms)>=2:
                step = step + " = "
        cot = question + " = " + f"{split} = {expansion} = {summation} = " + step
    return cot

# 除法CoT
def get_div_cot(num1, num2, answer = None):
    base = Decimal('10')**6
    num1, num2 = np.abs(num1), np.abs(num2)
    num1 = num1*base
    quotient = num1 // num2
    remainder = num1 % num2
    if quotient == 0:
        cot = f"{num1} / {num2} = {quotient} R {remainder}"
    elif num1 == num2:
        cot = f"{num1} / {num2} = {quotient}"
    else:
        step = ""
        cot = ""
        left = num1
        i = 0
        computed_q = 0
        while left>=num2:
            if int(str(quotient)[i])!=0:
                intermediate = int(str(quotient)[i] + "0" * (len(str(quotient))-1-i))
                answer = num2 * intermediate
                new_left = left - answer
                step = f"{left} - {num2} * {intermediate} = {left} - {answer} = {new_left}\n"
                cot = cot + step
                left = new_left
                computed_q = computed_q + intermediate
            i = i+1
        assert(left == remainder)
        assert(computed_q == quotient)
    
        if remainder!=0:
            cot = cot + f"Therefore, {num1} / {num2} = {quotient} R {remainder}"
        else:
            cot = cot + f"Therefore, {num1} / {num2} = {quotient}"
    return cot

def apply_cot(row):
    if re.search(r'计算\s?-?\d+\.?\d*\s?\*\s?-?\d+\.?\d*', row['problem']) is not None:
        result_list =  re.findall('-?\d+\.?\d*',row['solution'])
        result_list = [Decimal(x) for x in result_list]
        num1, num2, answer = result_list
        cot =  get_mul_cot(num1, num2, answer)
        cot = cot+ f"\nfinal, {num1} * {num2} = {answer}"
        return cot
    elif re.search(r'计算\s?-?\d+\.?\d*\s?\/\s?-?\d+\.?\d*', row['problem']) is not None:
        result_list =  re.findall('-?\d+\.?\d*',row['solution'])
        #print(result_list)
        result_list = [Decimal(x) for x in result_list]
        num1, num2, answer = result_list
        splits = str(answer).split('.')
        answer_str = str(answer)
        if len(splits) == 2:
            answer_str = splits[0] + '.' + splits[1][:6]
        cot = get_div_cot(num1, num2, None)
        answer_str = get_precision_str(answer_str)
        cot = cot + f"\nfinal, {num1} / {num2} ~= {answer_str}"
        return cot
    elif re.search(r'解方程 -?\d+x \+ -?\d+ = 0', row['problem']) is not None:
        result_list =  re.findall('-?\d+\.?\d*',row['problem'])
        result_list = [Decimal(x) for x in result_list]
        a, b, c = result_list
        num1, num2 = -b + c, a
        answer = num1/num2
        splits = str(answer).split('.')
        answer_str = str(answer)
        if len(splits) == 2:
            answer_str = splits[0] + '.' + splits[1][:6]
      
        cot = get_div_cot(num1, num2, None)
        answer_str = get_precision_str(answer_str)
        cot = cot + f"\nfinal, {num1} / {num2} ~= {answer_str}"
        return cot
    elif re.search(r'商品原价为 \d+ 元,打折后的价格为 \d+ 元,请计算打折的折扣比例', row['problem']) is not None:
        result_list =  re.findall('-?\d+\.?\d*',row['problem'])
        result_list = [Decimal(x) for x in result_list]
        a, b = result_list
        cot_head = f"({a} - {b}) / {a} = "
        num1, num2 = a - b, a
        answer = num1*100/num2
        splits = str(answer).split('.')
        answer_str = str(answer)
        if len(splits) == 2:
            answer_str = splits[0] + '.' + splits[1][:6]
        cot = cot_head + get_div_cot(num1, num2, None)
        answer_str = get_precision_str(answer_str)
        cot = cot + f"\nfinal, {num1} / {num2} ~= {answer_str} 折"
        return cot
    elif re.search(r'一个长方形的长为 \d+ 厘米,宽为 \d+ 厘米,请计算其面积', row['problem']) is not None:
        result_list =  re.findall('-?\d+\.?\d*',row['problem'])
        result_list = [Decimal(x) for x in result_list]
        num1, num2 = result_list
        answer = num1*num2
        cot =  get_mul_cot(num1, num2, answer)
        cot = cot+ f"\nfinal, {num1} * {num2} = {answer} 平方厘米"
        return cot
    elif re.search(r'某物体的密度为 \d+ 克/立方厘米,体积为 \d+ 立方厘米,请计算该物体的质量', row['problem']) is not None:
        result_list =  re.findall('-?\d+\.?\d*',row['problem'])
        result_list = [Decimal(x) for x in result_list]
        num1, num2 = result_list
        answer = num1*num2
        cot =  get_mul_cot(num1, num2, answer)
        cot = cot+ f"\nfinal, {num1} * {num2} = {answer} 克"
        return cot
    elif re.search(r'去年销售额为 \d+ 万元,今年销售额增加了 \d+%,请计算今年的销售额', row['problem']) is not None:
        result_list =  re.findall('-?\d+\.?\d*',row['problem'])
        result_list = [Decimal(x) for x in result_list]
        a, b = result_list
        num1, num2 = a, Decimal('1') + b/Decimal('100')
        answer = num1*num2
        cot =  get_mul_cot(num1, num2, answer)
        cot = cot+ f"\nfinal, {num1} * {num2} = {answer} 万元"
        return cot
    elif re.search(r'求以下数据的平均值:\[.+\]', row['problem']) is not None:
        value_list = re.findall(r'\[.*\]' ,row['problem'])[0]
        value_list = eval(value_list)
        value_len = len(value_list)
        cum_list = list(np.cumsum(value_list))
        step = ""
        cot = ' + '.join([str(x) for x in value_list])
        if len(value_list) == 1:
            cot = str(value_list[0])
            return cot
        for idx in range(1, value_len):
            cot += ' = ' + ' + '.join([str(x) for x in ([cum_list[idx]] + value_list[idx+1:])])
        
        num1, num2 = Decimal(str(cum_list[-1])), Decimal(str(value_len))
        answer = num1/num2
        cot += f'\n计算 {cum_list[-1]} / {value_len}'
        cot += '\n' + get_div_cot(Decimal(str(cum_list[-1])), Decimal(str(value_len)), None)
        splits = str(answer).split('.')
        answer_str = str(answer)
        if len(splits) == 2:
            answer_str = splits[0] + '.' + splits[1][:6]
        answer_str = get_precision_str(answer_str)
        cot = cot + f"\nfinal, {num1} / {num2} ~= {answer_str}"
        return cot
    else:
        return row['answer']

if __name__ == "__main__":
    # python data_process.py --data_path '../train.json' --out_dir './' --train_len 40960  --valid_len 2000
    work_path = os.path.dirname(os.path.abspath(__file__))  # 工作路径
    parser = argparse.ArgumentParser()  # 命令行参数解析器
    parser.add_argument(  # 输入数据路径
        '--data_path',
        default="../train.json",
        required=False,
        help='org data path')
    parser.add_argument(
        '--train_len', default=40960, type=int,required=False,
        help='sample train data number, if train_len == -1 ,all data use for train')

    parser.add_argument(
        '--valid_len', default=2000, type=int,required=False,
        help='valid data number, if valid_len == -1 ,all data except train use for valid')
    
    parser.add_argument(
        '--out_dir',
        default="./",
        required=False,
        help='output dir for train/valid data')
    args_, rest_args_ = parser.parse_known_args()   # 解析命令行参数
    
    data = [json.loads(line) for line in open(args_.data_path, 'r', encoding='utf-8').readlines()]  # 读取数据
    df = pd.DataFrame.from_records(data)  # 转换为DataFrame
    print('原始数据数量:', df.shape)
    df['id'] = df.index
    df['answer'] = df.apply(lambda x:get_clean_answer(x), axis = 1)
    df['can_learn'] = df.apply(lambda x:can_learn(x), axis = 1)
    print('根据目前规则认为模型可以学习的样本数量:', df[df['can_learn']].shape)
    
    df['output'] = df.apply(lambda row:apply_cot(row), axis = 1)
    org_df = df
    df = df.drop_duplicates(['problem'])
    print('去重后样本数量:', df.shape)
    df['is_number'] = df['answer'].apply(lambda x:is_number(x))
    df = df[df['is_number']]
    learn_df = df[df['can_learn']]
    print('去重后认为模型可以学习的样本数量:', learn_df.shape)

    if args_.train_len > 0:
        train_df =  learn_df.sample(args_.train_len)
    else:
        train_df = learn_df
    if args_.valid_len > 0:
        valid_df = df[~df['id'].isin(train_df['id'])].sample(args_.valid_len)
    else:
        valid_df = df[~df['id'].isin(train_df['id'])]
    if not os.path.exists(args_.out_dir):
        os.makedirs(args_.out_dir, exist_ok = True) 
    print('train_len:',train_df.shape)
    print('valid_len:', valid_df.shape)
    json.dump(train_df[['id', 'problem', 'solution', 'answer', 'output']].to_dict(orient='records'), open(os.path.join(args_.out_dir, 'train-data.json'), 'w'), indent=2)
    json.dump(valid_df[['id', 'problem', 'solution', 'answer']].to_dict(orient='records'), open(os.path.join(args_.out_dir,'valid-data.json'), 'w'), indent=2)

    valid_ms_data = valid_df.to_dict(orient='records')
    with open(os.path.join(args_.out_dir,'valid-data-list.json'), 'w') as f:
        for line in valid_ms_data:
            json.dump(line, f)
            f.write('\n')

5 参考文献

1 篇


⭐️ ⭐️ 写于2024年12月31日 15:34 教研室工位

相关推荐
ningaiiii19 分钟前
NSGA-II(非支配排序遗传算法II)详解与实现
人工智能·深度学习·神经网络·数据挖掘
JINGWHALE143 分钟前
设计模式 结构型 装饰器模式(Decorator Pattern)与 常见技术框架应用 解析
前端·人工智能·后端·设计模式·性能优化·系统架构·装饰器模式
AI34561 小时前
壁纸样机神器,它支持复杂的图层操作吗?
人工智能
幻风_huanfeng1 小时前
线性变换在机器学习中的应用实例
人工智能·机器学习
沉木渡香1 小时前
【pytorch-lightning】架构一览
人工智能·pytorch·python
EnochChen_1 小时前
PyTorch快速入门教程【小土堆】之完整模型训练套路
人工智能·pytorch·python
赛逸展张胜2 小时前
5G+工业互联网”迎来新机遇,CES Asia 2025见证产业腾飞
大数据·人工智能·科技·5g·智慧城市
合方圆~小文2 小时前
高清监控视频的管理与展示:从摄像头到平台的联接过程
linux·网络·人工智能·云计算·智能家居
AidLux2 小时前
2024 高通边缘智能创新应用大赛智能边缘计算赛道冠军方案解读
人工智能·边缘计算
凡人的AI工具箱2 小时前
每天40分玩转Django:Django即时聊天应用实战
数据库·人工智能·后端·python·django·sqlite