Nash-MTL:在多任务梯度组合中引入纳什谈判解

Navon, Aviv, Aviv Shamsian, Idan Achituve, Haggai Maron, Kenji Kawaguchi, Gal Chechik, and Ethan Fetaya. Multi-task learning as a bargaining game. ICML 2022. arXiv preprint arXiv:2202.01017 (2022).

多任务学习的目标诱人:一个模型,同时搞定多个任务,节省计算成本并提升数据效率。然而,在现实中,同时训练多个任务,效果常常还不如为每个任务单独训练一个模型。

其核心矛盾在于:不同任务的梯度(指导模型更新的方向)经常"打架"。有的梯度幅值大,有的方向完全相反。简单地将梯度加起来更新(线性标量化),模型就会被大梯度或某个特定任务"带偏",导致其他任务学不好。这就像一个团队里声音最大的人总是主导决策,最终结果未必对整体最优。

现有的改进方法,如 PCGrad(梯度手术)、CAGrad(冲突规避梯度下降)等,可以被视为一系列精巧的"启发式"策略:它们通过投影、加权等方式来调和梯度冲突,也确实提升了性能。但它们大多基于直觉设计,回答的是"如何"调和冲突,却未从根本上回答:什么是一个对所有任务都"好"的更新方向。

发表于 ICML 2022 的论文《Multi-Task Learning as a Bargaining Game》提出一个优雅的框架,将梯度组合问题重塑为一个合作谈判游戏,并借用博弈论中成熟的纳什谈判解,为"公平有效的更新"提供了一个公理化的判断机制。由此,这篇论文提出了 Nash-MTL 算法,不仅在多个基准上达到最优,并且为我们理解多任务优化提供了一个新颖的理论视角。


文章目录


一、多任务学习中的梯度冲突

在多任务学习(MTL)中,假设我们有 K 个任务,对应 K 个损失函数 ℓ 1 , ⋯   , ℓ K \ell_1, \cdots , \ell_K ℓ1,⋯,ℓK。在每个训练步骤,我们会计算每个任务相对于共享参数的梯度 g 1 , ⋯   , g K g_1, \cdots , g_K g1,⋯,gK。关键的一步是:如何将这些梯度组合成一个单一的更新方向 Δθ

传统线性标量化(LS)直接求和( Δ θ = ∑ g i \Delta \theta = \sum g_i Δθ=∑gi),这相当于默认所有任务"投票权"相同。但在梯度幅值差异巨大时(例如任务 A 的损失值在 1000 量级,任务 B 在 0.01 量级),更新完全由大梯度任务主宰。

先前的研究通过设计各种算法来修改梯度(如投影冲突部分、动态调整权重)来缓解此问题。然而,这些方法缺乏一个更高层次的原则性定义,来说明什么样的更新方向才是真正"公平"且"有效"的。

二、核心思想:梯度组合是一场纳什谈判

论文的核心思想是将梯度组合步骤建模为一个合作谈判博弈

  • 玩家:每个任务就是一个玩家。
  • 谈判协议(A):大家共同选择一个参数更新向量 Δθ(在一个限定范围球内)。
  • 破裂点(D):如果谈判破裂,大家不更新,即 Δθ = 0,所有任务维持原状。
  • 效用( u i u_i ui) :每个玩家从协议 Δθ 中获得的"收益",定义为自身梯度在该方向上的投影,即 u i ( Δ θ ) = g i ⊤ Δ θ u_i(\Delta \theta) = g_i^\top \Delta \theta ui(Δθ)=gi⊤Δθ。这个值越大,意味着沿该方向更新,该任务的损失下降越快。

现在,问题转化为:如何找到一个让所有任务(玩家)都愿意接受的 Δθ?

纳什在 1953 年证明,如果一个谈判问题的解满足以下四条看似自然而基本的公理,那么这个解存在且唯一 ,这就是纳什谈判解

三、理解纳什四大公理:以"分蛋糕"为例

让我们用一个经典的"两人分蛋糕"例子来直观理解这些公理,并对应到 MTL 场景。

  1. 帕累托最优:分完蛋糕后,不应该还剩一些在盘子里。因为剩下的部分可以继续分,让至少一人更多而不损害他人。在 MTL 中,这意味着我们的更新方向应该是"高效"的,不应存在另一个方向能显著提升某个任务而不损害其他任务。

  2. 对称性:如果两个分蛋糕者身份、贡献完全一样,那么他们应分得一样多。在MTL中,如果两个任务完全等同(梯度性质相同),那它们在更新方向中的权重应该相等。

  3. 无关选项独立性(IIA):如果我们已经同意按"你 7 我 3"分蛋糕,此时突然有人提议"也可以按 8:2 分",只要我们原先的 7:3 方案仍然可行,我们就应坚持原方案。 IIA 保证了方案的稳定性。在 MTL 中,它意味着求解的更新方向不会因为考虑了一些无关的、次优的更新方向而改变。

  4. 仿射不变性(最关键):衡量蛋糕的单位(克、盎司、块数)不应影响最终分配的比例。在 MTL 中,这意味着对任意任务的损失函数进行线性缩放(乘以正数)或平移(加常数),所得的更新方向 Δθ 应保持不变。这直接免疫了不同任务间因损失量纲差异巨大,而导致的优化偏差。

纳什谈判解,就是最大化所有玩家"收益增量"乘积 的解。在分蛋糕中,是最大化 (你的份额 - 你的破裂点) * (我的份额 - 我的破裂点)。在我们的 MTL 谈判中,破裂点即为不进行梯度更新,效用为 0,因此就是:

最大化 ∑ log ⁡ ( g i ⊤ Δ θ ) \sum \log(g_i^\top \Delta \theta) ∑log(gi⊤Δθ)

这个目标函数不是人为设计的,而是满足上述四大公平公理的必然数学结果。这为 Nash-MTL 方法提供了优雅的原则性动机。

四、Nash-MTL 方法

将上述理论应用到 MTL,就得到了 Nash-MTL 算法。其目标是找到更新方向 Δ θ ∗ \Delta \theta^* Δθ∗,它由各任务梯度的加权和组成: Δ θ ∗ = ∑ α i g i \Delta \theta^* = \sum \alpha_i g_i Δθ∗=∑αigi,其中权重 α i > 0 \alpha_i > 0 αi>0。

核心推导与解

通过求解纳什谈判解,论文推导出最优权重 α 必须满足一个简洁的方程:

G ⊤ G α = 1 / α G^\top G \alpha = 1 / \alpha G⊤Gα=1/α

(其中 G 是以梯度 g i g_i gi 为列的矩阵,1/α 表示对 α 各元素取倒数)

这个方程有清晰的几何直观:

  • 当所有梯度相互正交时,解是 α i = 1 / ∥ g i ∥ α_i = 1 / \|g_i\| αi=1/∥gi∥,即 Δ θ ∗ = ∑ ( g i / ∥ g i ∥ ) \Delta \theta^* = \sum (g_i / \|g_i\|) Δθ∗=∑(gi/∥gi∥),相当于将每个梯度归一化后相加,这是一个直观的尺度不变解。
  • 当梯度不正交时,方程 (2) α i ∥ g i ∥ 2 + ∑ j ≠ i α j g j ⊤ g i = 1 / α i α_i \|g_i\|^2 + \sum_{j≠i} α_j g_j^\top g_i = 1 / α_i αi∥gi∥2+∑j=iαjgj⊤gi=1/αi 表明,权重 α i α_i αi 会动态调整:如果任务 i 的梯度与其他任务梯度冲突(内积为负),则 α i α_i αi 会自动增大以作补偿;如果协同(内积为正),则 α i α_i αi 会减小。这实现了基于梯度间交互的自动、内生权衡。

高效求解算法

直接求解 G ⊤ G α = 1 / α G^\top G \alpha = 1 / \alpha G⊤Gα=1/α 是一个非凸问题。论文设计了一个高效的凹凸过程(CCP) 迭代算法来逼近解。实践中,作者发现即使仅进行 1 次 CCP 迭代,算法性能也接近最优,这大大降低了计算开销。完整的 Nash-MTL 单步流程如下:

  1. 前向传播,计算所有任务损失。
  2. 反向传播,计算每个任务的梯度 g i g_i gi。
  3. (谈判环节)将梯度矩阵 G 输入求解器,快速解出满足 G ⊤ G α ≈ 1 / α G^\top G \alpha \approx 1 / \alpha G⊤Gα≈1/α 的权重 α。
  4. 组合更新方向: Δ θ = ∑ α i g i \Delta \theta = \sum α_i g_i Δθ=∑αigi。
  5. 使用 Δθ 更新模型参数。

应对计算开销的巧思

与许多先进 MTL 方法一样,Nash-MTL 每步都需要所有任务的梯度,这在任务数 K 很大时计算负担较重。论文探索了一个简单有效的加速策略:不每步都更新权重 α,而是每 T 步(例如 T=10, 50, 100)更新一次,中间步骤复用旧的 α。实验表明,这能在性能下降极小的情况下,带来数倍至近十倍的训练加速,使其实用性大大增强。

五、实验验证:全面领先的性能

论文在三个截然不同的领域进行了全面测试,对比了包括 LS、MGDA、PCGrad、CAGrad、GradDrop 等 12 种主流方法。

任务 1:量子化学回归(QM9)

  • 任务:预测 13 万个小分子的 11 种物理化学性质(回归任务)。
  • 挑战:各性质数值尺度差异极大。
  • 结果:Nash-MTL 在综合指标 Δm%(相对单任务模型的平均性能变化)和 平均排名(MR) 上均位列第一。许多方法甚至不如简单的尺度不变(SI)基线,这突显了仿射不变性在解决尺度差异问题上的根本性优势。

任务 2:计算机视觉(NYUv2 & Cityscapes)

  • 任务:在 NYUv2 上同时进行语义分割、深度估计、表面法线预测;在 Cityscapes 上进行语义分割和深度估计。
  • 结果
    • NYUv2:Nash-MTL取得最优的 MR 和 Δm%(-4.04%),是唯一一个整体性能显著超越单任务模型的 MTL 方法。
    • Cityscapes:Nash-MTL 取得最优的 MR,Δm% 排名第二。论文指出,对于两任务情况,Nash-MTL 的解简化为"归一化梯度后等权相加",这种简单形式依然表现出色。

任务 3:多任务强化学习(Meta-World MT10)

  • 任务:让一个机器人学会10种不同的操作技能。
  • 结果:Nash-MTL 取得了最高的平均成功率(0.91),并且是所有 MTL 方法中唯一达到与单任务学习基线(STL SAC)同等性能的方法,证明了其在序列决策问题中的强大协调能力。

综合来看,Nash-MTL 不仅在传统监督学习领域表现出色,在复杂的 RL 领域同样证明了其优越性。其尺度不变性和基于梯度交互的动态权重机制,使其在面对差异巨大的多任务时,能找到一个更平衡、更公平的优化路径。

六、实现细节与复现要点

  • 模型架构:
    • QM9:使用标准的消息传递图神经网络(GNN)。
    • 视觉任务:使用基于 SegNet 的多任务注意力网络(MTAN)。
    • MT10:使用 Soft Actor-Critic(SAC) 作为基础RL算法。
  • 训练设置:
    • 优化器:Adam。
    • 学习率:视觉任务常设为 1e-4;QM9需网格搜索(1e-3, 5e-4, 1e-4)。
    • 关键超参数:CAGrad 的冲突控制参数 c=0.4;DWA的温度参数 T=2
    • 数据增强:视觉任务均使用数据增强。
    • 评估:报告最后多个 epoch 的平均测试性能,RL任务则报告整个训练周期中的最佳平均成功率。
  • Nash-MTL 特定设置:
    • CCP 迭代步数:默认 20,但 1步也足够,对性能影响很小。
    • 权重更新频率 T:可根据任务数和计算资源调整,T 越大速度越快,性能略有折衷。

七、总结与启示

Nash-MTL 的杰出之处,在于它为多任务学习的梯度调和问题引入了一个优雅的理论框架。它不再局限于设计更复杂的启发式梯度操作,而是从博弈论的经典成果中汲取智慧,将优化方向的选择定义为一个寻求公平合作解的过程。

该方法也存在局限性。例如,其假设任务梯度在非帕累托平稳点处线性独立,这在某些极端情况下可能不成立;每步计算所有梯度的要求在大规模任务场景下仍是挑战,尽管可通过稀疏更新缓解。

相关推荐
CServer_0117 小时前
汽车零部件生产:从“管理软件”到“数据驱动”的智能中枢
人工智能·汽车
说私域17 小时前
小程序电商运营中“开源AI智能名片链动2+1模式S2B2C商城小程序”对培养“老铁”用户的重要性研究
人工智能·小程序·开源
小烤箱17 小时前
Autoware Universe 感知模块详解 | 第十节:工程角度的自动驾驶检测管线方法论
人工智能·机器学习·自动驾驶·autoware·感知算法
玖日大大17 小时前
Wan2.1视频生成模型本地部署完整指南
人工智能·音视频
葡萄城技术团队17 小时前
生成式人工智能(AI):智能技术,能够创造而不仅仅是计算
人工智能
摸鱼仙人~17 小时前
BERT分类的上下文限制及解决方案
人工智能·分类·bert
神一样的老师17 小时前
微型机器学习(TinyML):研究趋势与未来应用机遇
人工智能·机器学习
木头程序员17 小时前
机器学习概述:核心范式、关键技术与应用展望
大数据·人工智能·机器学习·回归·聚类
悟道心17 小时前
5. 自然语言处理NLP - Transformer
人工智能·自然语言处理·transformer