从MMoE到PLE:读懂多任务学习架构的渐进式演化

从MMoE到PLE:读懂多任务学习架构的渐进式演化

引言

在多任务学习(MTL)领域,MMoE(Multi-gate Mixture-of-Experts)无疑是一个里程碑式的模型,它通过巧妙的软参数共享机制,极大地提升了工业界推荐、广告等系统的多目标优化能力。然而,在面对任务间关系愈发复杂、甚至相互冲突的场景时,即便是 MMoE 也可能遭遇性能瓶颈,出现顾此失彼的"跷跷板现象"(Seesaw Phenomenon)。

为了攻克这一难题,研究人员在 MMoE 的思想基础上,提出了其演进版本------PLE(Progressive Layered Extraction)。PLE 通过一种更为精细和强大的网络结构,旨在从根本上解决任务冲突带来的负面影响。本文将深入探讨 PLE 模型如何通过结构创新来突破 MMoE 的局限,并对其核心原理、梯度更新机制及应用价值进行全面分析。

一、 MMoE 的局限性与"跷跷板现象"

1.1 MMoE 核心机制回顾

要理解 PLE 的创新,我们必须先回顾 MMoE 的工作机制。MMoE 的核心在于"共享专家网络"和"任务独立门控"。它让所有任务共享一组专家网络(Experts),同时为每个任务配备一个独立的门控(Gate),由门控来决定如何对这些共享专家的输出进行加权组合。

1.2 "跷跷板现象"的根源

"跷跷板现象"指的是在多任务学习中,提升一个任务性能的同时,却导致了其他任务性能下降的现象。

MMoE 之所以仍会面临这一问题,其根本原因在于它的所有专家网络本质上仍是全局共享的。当不同任务的目标存在内在冲突时(例如,在视频推荐中,"提升点击率"可能偏好标题党,而"提升完播率"则偏好高质量长视频),这些相互矛盾的优化目标会产生方向可能相反的梯度,并同时作用于所有共享专家。这使得专家网络的参数学习陷入"左右为难"的境地,难以找到一个能让所有任务都满意的优化方向。

二、 PLE 核心原理与架构:显式分离与渐进抽取

PLE 的设计正是为了解决上述问题,其解决方案可以概括为两大核心创新:显式专家分离和渐进式信息抽取。

2.1 核心思想概述

PLE 通过显式地区分共享专家与任务专属专家 ,并采用逐层递进的信息抽取网络,实现了任务间共性知识与个性特征的高效解耦与深度融合。

2.2 模型架构拆解

PLE 的架构在 MMoE 的基础上进行了精心的扩展,主要体现在以下几个方面:

  • 显式专家分离 (Explicit Expert Separation)

    这是 PLE 与 MMoE 最根本的区别。在 PLE 的每一个抽取层(Extraction Layer)中,专家网络被明确地划分为两类:

    • 共享专家 (Shared Experts):由所有任务共同使用,负责学习具有普适性的通用模式。
    • 任务专属专家 (Task-Specific Experts):每个任务都拥有一组自己独享的专家,负责学习该任务独特的、可能与其他任务不兼容的模式。
  • 渐进式分层抽取 (Progressive Layered Extraction)

    PLE 通常包含多层抽取网络。其"渐进式"体现在,更高层网络的输入,来自于其下一层所有专家(包括所有任务的专属专家和共享专家)的输出。这种设计使得模型能够逐层学习到更加抽象和高阶的特征表示,实现了信息的深度提炼。

  • 选择性路由机制 (Selective Routing)

    与 MMoE 类似,PLE 同样使用门控网络进行信息融合。但其关键区别在于,每个任务的输入,是由其自身的专属专家共享专家的信息通过门控网络加权组合而成。它不会直接使用其他任务的专属专家,从而在结构上避免了任务间的直接干扰。

三、 PLE 的参数更新机制:更精细的梯度路由

3.1 损失函数

与 MMoE 类似,PLE 的总损失函数通常是各个任务损失的加权和。

\[L_{total} = \sum_{k=1}^{K} w_k L_k(y_k^{true}, y_k^{pred}) \]

3.2 梯度流向分析

PLE 更精细的结构带来了更复杂的梯度路由机制,这也是其成功的关键:

  • 任务专属模块 :任务 A 的损失 L_A 在反向传播时,其梯度主要更新其专属的塔(Tower)、各层门控(Gate)和各层专属专家(Specific Experts)。这些模块的参数学习受到了很好的保护,不会被其他任务的冲突梯度直接影响。
  • 共享模块 :共享专家(Shared Experts)会接收来自所有任务的梯度,这一点与 MMoE 类似。
  • "渐进式"体现在梯度上 :这是 PLE 最巧妙的地方。在顶层网络,任务间的专属信息是高度分离的。但越往底层,梯度传播的路径就越交融。例如,任务 B 的损失 L_B,其梯度可以通过高层的共享专家 反向传播,并最终影响到低层中任务 A 的专属模块。这种机制允许模型在底层学习被所有任务共同塑造的、更鲁棒的通用知识,同时在顶层逐渐"提纯"出每个任务的独特决策逻辑。

四、 优缺点与适用场景

4.1 主要优点

  • 有效缓解"跷跷板现象":通过为冲突梯度提供"专属通道"(任务专属专家),PLE 从网络结构上极大地缓解了任务间的负迁移和性能冲突。
  • 更强的模型表达能力:多层结构和渐进式信息抽取机制,使其能够捕捉到任务间更为复杂和高阶的非线性关系。
  • 通用性强:无论任务是复杂相关、普通相关还是弱相关,PLE 都能通过其灵活的结构取得稳定的性能提升。

4.2 潜在局限性

  • 模型复杂度高:相比 MMoE,PLE 的网络结构更复杂,参数量和计算量也相应更大,对训练和推理资源的要求更高。
  • 超参数调优更复杂:引入了网络层数、每层共享/专属专家数量等更多需要细致调优的超参数,增加了模型落地的难度。

4.3 适用场景

  • 任务间存在已知冲突或竞争关系的场景:这是 PLE 最理想的应用场景。例如,在视频推荐中,需要同时优化"播放时长"(偏好长视频)和"点击率"(可能偏好短平快的封面标题),PLE 能很好地平衡这类竞争性目标。
  • 追求极致性能的大规模精排模型:在计算资源允许的情况下,PLE 是替代 MMoE 以获得更好、更稳定的多目标优化效果的理想选择。

五、 代码实现

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class PLE(nn.Module):
    def __init__(self, input_dim, num_tasks, num_levels, num_shared_experts, num_specific_experts, 
                 expert_hidden_dims, expert_output_dim, tower_hidden_dims):
        super(PLE, self).__init__()

        self.input_dim = input_dim
        self.num_tasks = num_tasks
        self.num_levels = num_levels
        self.num_shared_experts = num_shared_experts
        self.num_specific_experts = num_specific_experts

        # --- 模块一: 专家网络 (Experts) ---
        # 专属专家
        self.specific_experts = nn.ModuleList([
            nn.ModuleList([
                nn.ModuleList([self._build_mlp(input_dim if level == 0 else expert_output_dim, 
                                               expert_hidden_dims) 
                               for _ in range(num_specific_experts)])
                for _ in range(num_tasks)
            ]) for level in range(num_levels)
        ])
        
        # 共享专家 
        self.shared_experts = nn.ModuleList([
            nn.ModuleList([
                self._build_mlp(
                    input_dim * num_tasks if level == 0 else expert_output_dim * num_tasks, 
                    expert_hidden_dims, 
                ) for _ in range(num_shared_experts)])
            for level in range(num_levels)
        ])

        # --- 模块二: 门控网络 (Gates)
        self.gates = nn.ModuleList([
            nn.ModuleList([
                nn.Linear(
                    input_dim * num_tasks if level == 0 else expert_output_dim * num_tasks, 
                    num_specific_experts + num_shared_experts
                ) for _ in range(num_tasks)
            ]) for level in range(num_levels)
        ])

        # --- 模块三: 任务塔 (Towers) ---
        self.towers = nn.ModuleList([
            self._build_mlp(expert_output_dim, tower_hidden_dims) for _ in range(num_tasks)
        ])

    def _build_mlp(self, input_dim, hidden_dims):
        layers = []
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            input_dim = hidden_dim
        return nn.Sequential(*layers)

    def forward(self, x):
        level_specific_inputs = [x] * self.num_tasks
        
        for level in range(self.num_levels):
            # 专属专家输出
            specific_expert_outputs = [
                [expert(level_specific_inputs[i]) for expert in self.specific_experts[level][i]]
                for i in range(self.num_tasks)
            ]
            
            # 共享专家输入
            shared_expert_input = torch.cat(level_specific_inputs, dim=1)
            shared_expert_outputs = [expert(shared_expert_input) for expert in self.shared_experts[level]]
            
            # 门控计算与加权求和
            next_level_specific_inputs = []
            for i in range(self.num_tasks):
                current_experts_outputs = specific_expert_outputs[i] + shared_expert_outputs
                current_experts_stacked = torch.stack(current_experts_outputs, dim=1)
                
                # 门控网络的输入与共享专家一致
                gate_input = shared_expert_input 
                gate_weights = F.softmax(self.gates[level][i](gate_input), dim=1)
                
                weighted_sum = torch.sum(current_experts_stacked * gate_weights.unsqueeze(-1), dim=1)
                next_level_specific_inputs.append(weighted_sum)
            
            level_specific_inputs = next_level_specific_inputs

        # 任务塔计算
        task_outputs = [self.towers[i](level_specific_inputs[i]) for i in range(self.num_tasks)]
        
        return task_outputs

总结

PLE 模型可以被视为 MMoE 的一个强大且成熟的演进版本。它通过"显式分离专家"和"渐进式抽取"两大核心创新,为解决复杂多任务学习问题中的"跷跷板现象"提供了强有力的架构支持。尽管模型复杂度有所提升,但其在处理相互竞争的业务目标时所展现出的卓越性能和稳定性,使其成为当今工业界推荐系统多目标优化领域一个不可或缺的重要工具。

相关推荐
GRITJW2 天前
ESMM学习笔记:如何解决CVR预估中的样本选择偏差与数据稀疏难题
推荐算法
GRITJW3 天前
深度剖析RQ-VAE:从向量量化到生成式推荐的语义ID技术
推荐算法
IT学长编程4 天前
计算机毕业设计 基于Hadoop的健康饮食推荐系统的设计与实现 Java 大数据毕业设计 Hadoop毕业设计选题【附源码+文档报告+安装调试】
java·大数据·hadoop·毕业设计·课程设计·推荐算法·毕业论文
科兴第一吴彦祖5 天前
在线会议系统是一个基于Vue3 + Spring Boot的现代化在线会议管理平台,集成了视频会议、实时聊天、AI智能助手等多项先进技术。
java·vue.js·人工智能·spring boot·推荐算法
GRITJW5 天前
推荐系统中负采样策略及采样偏差的校正方法
推荐算法
lifallen6 天前
淘宝RecGPT:通过LLM增强推荐
人工智能·深度学习·ai·推荐算法
麦麦大数据6 天前
J002 Vue+SpringBoot电影推荐可视化系统|双协同过滤推荐算法评论情感分析spark数据分析|配套文档1.34万字
vue.js·spring boot·数据分析·spark·可视化·推荐算法
一只鱼^_11 天前
牛客周赛 Round 108
数据结构·c++·算法·动态规划·图论·广度优先·推荐算法
moonsheeper15 天前
推荐算法发展历史
算法·机器学习·推荐算法