【论文阅读】ARM: Adaptive Reasoning Model

ARM: Adaptive Reasoning Model

    • 方法
      • [第一阶段:SFT for Reasoning Formats Understanding](#第一阶段:SFT for Reasoning Formats Understanding)
      • [第二阶段:RL for Encouraging Efficient Format Selection](#第二阶段:RL for Encouraging Efficient Format Selection)

ARM: Adaptive Reasoning Model 这篇文章介绍了自适应推理模型(Adaptive Reasoning Model, ARM),该模型能够根据任务难度自适应地选择推理格式,从而在保持性能的同时提高计算效率。ARM支持四种推理格式:三种高效的格式------直接回答(Direct Answer)、短链思考(Short CoT)和代码(Code),以及一种详细的格式------长链思考(Long CoT)

本篇博客仅介绍方法上的创新, 项目地址:https://team-arm.github.io/arm/

方法

第一阶段:SFT for Reasoning Formats Understanding

在这一阶段,文章利用SFT作为冷启动,将模型引入可以用于解决问题的各种推理格式。文章使用特殊标记(例如,<Code></Code>)来包含思考逻辑:

  1. 直接回答(Direct Answer):这种格式直接给出答案,不包含推理过程。
  2. 短链推理(Short CoT):这种格式提供简短的推理过程,通常用于简单的任务。
  3. 代码(Code):这种格式使用代码来解决问题,适用于需要编程的场景。
  4. 长链推理(Long CoT):这种格式提供详细的推理过程,适用于复杂的任务。

为了确保生成的推理逻辑的质量,文章过滤掉那些导致错误答案的推理逻辑,最终生成的训练集包含3.0K个多项选择题和7.8K个开放形式问题,每个问题都有四种推理格式。文章在这一阶段使用AQuA-Rat 数据集因为它可以自然地转化为四种不同的推理形式。除了数据集中提供的Direct AnswerShort CoT 推理外,文章还利用GPT-4o 和DeepSeek-R1 分别补充了CodeLong CoT推理。

第二阶段:RL for Encouraging Efficient Format Selection

经过SFT后,模型学会了使用各种推理格式进行响应,但缺乏根据任务自适应切换格式的能力。为了解决这一问题,文章在第二阶段使用RL来鼓励模型选择更高效的推理格式,同时保持准确性。在这一阶段,文章使用了三个额外的数据集,这些数据集涵盖了从相对简单的常识推理任务到更复杂的数学问题的范围。这些数据集包括:

  • CSQA:常识推理任务
  • GSM8K:数学问题
  • MATH:数学问题

文章主要技术上的创新在强化学习训练的Reward设置上:

首先,作者定义了一组重塑后的奖励 r ′ = { r 1 ′ , r 2 ′ , ⋯   , r G ′ } r' = \{r'_1, r'_2, \cdots, r'_G\} r′={r1′,r2′,⋯,rG′},这些奖励用于评估模型生成的响应。具体来说,每个响应 o i o_i oi 的奖励 r i ′ r'_i ri′ 通过以下公式计算:

r i ′ = α i ( t ) ⋅ r i r'_i = \alpha_i(t) \cdot r_i ri′=αi(t)⋅ri

其中, r i r_i ri 是原始奖励, α i ( t ) \alpha_i(t) αi(t) 是一个格式多样性缩放因子,用于放大较少采样的推理格式的奖励,防止这些格式在训练过程中消失。格式多样性缩放因子 α i ( t ) \alpha_i(t) αi(t) 的计算公式如下:

α i ( t ) = G F ( o i ) ⋅ decay i ( t ) \alpha_i(t) = \frac{G}{F(o_i)} \cdot \text{decay}_i(t) αi(t)=F(oi)G⋅decayi(t)

其中, F ( o i ) F(o_i) F(oi) 表示在组 O O O 中与 o i o_i oi 对应的推理格式出现的次数, t t t 表示训练步数。衰减因子 decay i ( t ) \text{decay}_i(t) decayi(t) 的计算公式为:

decay i ( t ) = F ( o i ) G + 0.5 ⋅ ( 1 − F ( o i ) G ) ⋅ ( 1 + cos ⁡ ( π ⋅ t T ) ) \text{decay}_i(t) = \frac{F(o_i)}{G} + 0.5 \cdot \left(1 - \frac{F(o_i)}{G}\right) \cdot \left(1 + \cos\left(\frac{\pi \cdot t}{T}\right)\right) decayi(t)=GF(oi)+0.5⋅(1−GF(oi))⋅(1+cos(Tπ⋅t))

为了将GRPO扩展为Ada-GRPO,文章引入了格式多样性缩放因子 α i ( t ) \alpha_i(t) αi(t),使模型能够自适应地选择推理格式。具体来 α i ( t ) \alpha_i(t) αi(t) 由两个组件组成:

  1. Format Diversity Scaling Factor G F ( o i ) \frac{G}{F(o_i)} F(oi)G

    • 为了防止模型过早收敛到最高准确率的格式(即格式坍缩到长链推理 Long CoT),文章通过增加较少采样格式的奖励来鼓励探索。具体来说,如果某个推理格式出现次数较少,其奖励会被放大,从而促使模型更多地尝试这些格式。
  2. Decay Factor decay i ( t ) \text{decay}_i(t) decayi(t)

    • 为了避免因过度奖励稀有格式而导致的长期不一致,这一项逐渐减少多样性的影响。例如,Format Diversity Scaling Factor G F ( o i ) \frac{G}{F(o_i)} F(oi)G 可能会使模型在训练初期更倾向于选择低准确率的格式(如短链推理 Short CoT),仅仅因为这些格式出现次数较少,从而获得更高的奖励。虽然这种探索在训练初期是有益的,但后期可能会阻碍模型的收敛。衰减机制通过在训练初期促进多样性,然后随着训练的进行逐渐将重点转移到准确性上来,从而缓解这一问题。
相关推荐
肾透侧视攻城狮21 小时前
《PyTorch神经网络从开发到调试:实战技巧、可视化与兼容性问题解决方案》
神经网络·语言模型·二分类任务·实现前馈神经网络·可视化执行梯度下降算法·matplotlib版本兼容性·pytorch实现二分类任务
莽撞的大地瓜1 天前
连获国内多行业认可、入选全球AI全景图谱 彰显蜜度智能校对的硬核实力
人工智能·ai·语言模型·新媒体运营·知识图谱
人工智能培训1 天前
具身智能如何在保证安全的前提下高效探索学习?
语言模型·llm·数据采集·模型量化·多模态学习·具身智能·环境感知
阿杰学AI1 天前
AI核心知识82——大语言模型之AI Value Alignment(简洁且通俗易懂版)
人工智能·ai·语言模型·自然语言处理·aigc·机械学习·ai价值观对齐
学而要时习1 天前
深度神经网络到AI大语言模型:一场被“误认为突然发生”的技术演进
人工智能·语言模型·dnn
有Li1 天前
SafeRPlan: 用于椎弓根螺钉置入术中规划的安全深度强化学习/文献速递-基于人工智能的医学影像技术
论文阅读·人工智能·深度学习·文献·医学生
小明_GLC1 天前
Is Mamba Effective for Time Series Forecasting?论文阅读
论文阅读
阿杰学AI1 天前
AI核心知识81——大语言模型之MaaS(简洁且通俗易懂版)
人工智能·ai·语言模型·自然语言处理·aigc·maas·模型即服务
蓝海星梦1 天前
GRPO 算法演进——偏差修正/鲁棒优化/架构扩展篇
论文阅读·人工智能·深度学习·算法·自然语言处理·强化学习
xx_xxxxx_1 天前
多模态动态融合模型Predictive Dynamic Fusion论文阅读与代码分析2-对比模型与底层模型的基本结构
论文阅读·多模态