ESMM学习笔记:如何解决CVR预估中的样本选择偏差与数据稀疏难题

ESMM模型精解:如何解决CVR预估中的样本选择偏差与数据稀疏难题

引言

在现代推荐系统与计算广告中,对点击后转化率(Post-Click Conversion Rate, CVR)的精准预估是优化平台收益与用户体验的核心环节。然而,传统的 CVR 预估模型在工业实践中普遍面临两大技术瓶颈:样本选择偏差(Sample Selection Bias, SSB)与数据稀疏性(Data Sparsity, DS)。这两个问题严重制约了模型的泛化能力与学习效率。

为了应对这些挑战,阿里巴巴的研究团队提出了"全体空间多任务模型"(Entire Space Multi-Task Model, ESMM),通过一种巧妙的建模思路,从根本上解决了上述难题。本文将从模型背景、核心原理、梯度更新机制、优缺点及适用场景等角度,对 ESMM 模型进行系统性剖析。

一、 CVR 预估的核心挑战:问题的根源

在深入了解 ESMM 之前,我们必须先理解它所要解决的问题根源。

  • 1.1 样本选择偏差 (SSB)

    SSB 问题源于模型训练与预测阶段的数据分布不一致 。具体来说,CVR 模型的目标是预估用户在"点击"行为发生后产生"转化"的概率。因此,模型的训练样本天然地被限定在已发生点击行为的样本空间 内。然而,在实际的线上推理阶段,模型需要对所有曝光的物品进行 CVR 预估,以便与点击率(CTR)等指标结合进行排序。

    如下图所示,训练空间(Click Space)仅仅是推理空间(Impression Space)的一个很小的子集 。在这种有偏的样本上训练出的模型,直接应用于全体样本时,其泛化能力会受到严重损害。

  • 1.2 数据稀疏性 (DS)

    数据稀疏性是另一个严峻的挑战。在真实的业务场景中,用户从曝光到点击、再到转化的行为是一个发生率逐级递减的"漏斗"。用于 CVR 任务训练的"点击且转化"的样本,数量远远少于 CTR 任务的"点击"样本,更不用说海量的"曝光"样本了。例如,在淘宝的生产数据集中,用于 CVR 任务的样本量仅为 CTR 任务的 4% 左右。这种极端的正样本稀疏性,使得模型,尤其是需要海量数据进行学习的深度网络,难以充分学习到有效的特征表示,拟合过程非常困难。

二、 ESMM 模型核心原理:巧妙的全局建模

面对上述挑战,ESMM 并没有在模型结构上进行复杂的改造,而是从问题的定义本身出发,提出了一种全新的建模范式。

  • 2.1 核心思想概述

    ESMM 的核心思想是:不再直接建模有偏的 pCVR,而是通过利用用户行为的"曝光 → 点击 → 转化"这一序列依赖关系,在完整的、无偏的曝光样本空间上进行多任务联合建模

  • 2.2 关键公式与模型架构

    • 公式拆解

      ESMM 框架的理论基石是以下概率链式法则公式:

      \[p(y=1, z=1 | x) = p(y=1 | x) \times p(z=1 | y=1, x) \]

      其中,x 代表曝光的特征,y=1 代表点击事件,z=1 代表转化事件。这个公式可以被解读为:

      \[pCTCVR = pCTR \times pCVR \]

      • pCTR (Post-View Click-Through Rate): 曝光后点击率,即 p(y=1|x)
      • pCVR (Post-Click Conversion Rate): 点击后转化率,即 p(z=1|y=1,x)。这是我们最终想要求解的目标。
      • pCTCVR (Post-View Click-Through & Conversion Rate): 曝光后点击且转化率,即 p(y=1,z=1|x)
    • 架构解析

      基于上述公式,ESMM 设计了一个由两个子网络(任务塔)构成的多任务学习架构。

      • CTR Tower : 一个独立的子网络,用于根据输入特征 x 预测 pCTR
      • CVR Tower : 另一个独立的子网络,用于根据相同的输入特征 x 预测 pCVR
      • 共享 Embedding 层: 两个任务塔共享底层的特征嵌入层(Embedding Layer)。这是缓解数据稀疏性问题的关键。
      • *最终输出**: CTR 塔和 CVR 塔的输出 pCTRpCVR 相乘,得到 pCTCVR 的预测值。
  • 2.3 手段总结

    ESMM 解决两大难题的手段可以清晰地归纳为:

    • 针对 SSB :模型不直接使用有偏的点击样本来监督 CVR 塔。而是通过监督在全体曝光样本 上都具有明确标签的 pCTRpCTCVR 两个任务,让 CVR 塔作为中间变量被隐式地学习。由于 pCTRpCTCVR 都是在无偏的全空间上建模的,因此间接得到的 pCVR 也自然地适用于全空间,从根本上解决了样本选择偏差问题。
    • 针对 DS :通过让 CVR 任务塔与拥有海量训练样本的 CTR 任务塔共享底层的 Embedding 表示,实现了参数的迁移学习。CVR 任务可以从 CTR 任务中"借力",学习到更鲁棒的特征表示,从而有效缓解了自身训练数据稀疏的问题。

三、 ESMM 的参数更新机制:梯度的流动

理解 ESMM 的参数更新机制,有助于我们深入其联合训练的本质。

  • 3.1 联合损失函数

    ESMM 的总损失函数由 CTR 任务和 CTCVR 任务的损失加权构成,不包含 CVR 任务的直接损失项。

    \[L(\theta_{ctr}, \theta_{cvr}) = \sum_{i=1}^{N} L_{ctr}(y_i, f(x_i; \theta_{ctr})) + \sum_{i=1}^{N} L_{ctcvr}(y_i\&z_i, f(x_i; \theta_{ctr}) \times f(x_i; \theta_{cvr})) \]

    其中 θ_ctrθ_cvr 分别是 CTR 和 CVR 网络的参数。

  • 3.2 梯度流向分析

    在反向传播过程中,梯度从两个损失源头出发:

    • 来自 L_ctr 的梯度路径 :这部分梯度只与 pCTR 相关。因此,它会反向传播更新 CTR 塔 的参数,并继续向下更新 共享 Embedding 层 的参数。
    • 来自 L_ctcvr 的梯度路径 :这部分梯度是联合训练的关键。根据链式求导法则,对于乘法操作 pCTCVR = pCTR * pCVR,梯度会"兵分两路":
      • 一路流向 pCVR,更新 CVR 塔 的参数,并继续向下更新 共享 Embedding 层
      • 另一路流向 pCTR,更新 CTR 塔 的参数,并继续向下更新 共享 Embedding 层
  • 3.3 参数更新归属总结

    综合两条路径,各个模块参数的最终梯度来源如下:

    • CVR 塔参数仅由 L_ctcvr 更新。
    • CTR 塔参数 :由 L_ctrL_ctcvr 共同更新。
    • 共享 Embedding 参数 :由 L_ctrL_ctcvr 共同更新。

四、 优缺点及适用场景分析

  • 4.1 主要优点

    • 根本解决 SSB:通过在全空间上建模,从理论上彻底消除了传统 CVR 预估的样本选择偏差。
    • 有效缓解 DS:利用迁移学习的思想,共享 Embedding,极大地缓解了 CVR 任务数据稀疏的问题。
    • 结构优雅 :模型设计巧妙,乘法形式不仅避免了除法可能带来的数值不稳定问题,还能保证 pCVR 的值在 [0,1] 范围内。
  • 4.2 潜在局限性

    • 强依赖于行为序列:模型的设计强依赖于"曝光 → 点击 → 转化"这样清晰、固定的序列依赖关系。如果任务之间是平行的,或者不遵循这种漏斗模式,ESMM 则不适用。
    • 任务关系固定 :模型假设了 pCTRpCVR 之间是简单的乘法关系,这可能无法捕捉现实世界中更复杂的任务关联。
  • 4.3 适用场景

    ESMM 极其适用于具有明显序列依赖关系的多阶段用户行为预测场景。

    • 电商平台:经典的"曝光 → 点击 → 购买"转化漏斗。
    • 信息流推荐:"曝光 → 点击 → 关注/收藏/分享"等多步转化行为。
    • 在线广告:"广告展示 → 点击 → 表单提交/App下载"等转化链路。

代码实现

python 复制代码
import torch
import torch.nn as nn

class ESMM(nn.Module):
    """
    ESMM: Entire Space Multi-Task Model PyTorch Implementation.

    This class implements the ESMM model as described in the paper:
    "Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate"

    核心思想 Core Ideas:
    1.  **全空间建模 (Entire Space Modeling)**: 通过 `pCTCVR = pCTR * pCVR` 的关系,在全体曝光样本上联合训练CTR和CTCVR任务,从而无偏地学习CVR
    2.  **特征表示共享 (Shared Feature Representation)**: CVR任务和CTR任务共享底层的Embedding层,有效缓解CVR任务因样本稀疏导致的学习不充分问题

    Args:
        feature_columns (dict): 一个描述输入稀疏特征的字典,key为特征名,value为该特征的词表大小 (vocabulary size)。
                                Example: {'user_id': 10000, 'item_id': 5000}
        embedding_dim (int): Embedding向量的维度。
        tower_dims (list): 一个定义两个任务塔 (Tower) 隐藏层维度和结构的列表。
                           Example: [256, 128]
    """
    def __init__(self, feature_columns, embedding_dim, tower_dims):
        super(ESMM, self).__init__()
        self.feature_columns = feature_columns
        self.embedding_dim = embedding_dim

        # --- 模块一: 共享的 Embedding 层 (Shared Embedding Layer) ---
        # 这是ESMM解决数据稀疏性(Data Sparsity)问题的关键。
        # CVR和CTR任务共享同一套Embedding参数,使得CVR任务可以从更丰富的CTR样本中学习特征表示
        self.embedding_layer = nn.ModuleDict({
            feat: nn.Embedding(vocab_size, self.embedding_dim)
            for feat, vocab_size in self.feature_columns.items()
        })

        # 计算两个任务塔(MLP)的输入维度
        # 输入维度 = 特征数量 * 每个特征的Embedding维度
        self.tower_input_dim = len(self.feature_columns) * self.embedding_dim

        # --- 模块二: CTR 任务塔 (CTR Tower) ---
        # 这个子网络负责预测点击率 pCTR。
        self.ctr_tower = self._build_mlp_tower(self.tower_input_dim, tower_dims)

        # --- 模块三: CVR 任务塔 (CVR Tower) ---
        # 这个子网络负责预测点击后转化率 pCVR。
        # 注意它的输入与CTR Tower完全相同,都是来自共享的Embedding层。
        self.cvr_tower = self._build_mlp_tower(self.tower_input_dim, tower_dims)

    def _build_mlp_tower(self, input_dim, tower_dims):
        """一个辅助函数,用于构建MLP网络(即任务塔)。"""
        layers = []
        for hidden_dim in tower_dims:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            input_dim = hidden_dim
        # 输出层,得到一个logit
        layers.append(nn.Linear(input_dim, 1))
        return nn.Sequential(*layers)

    def forward(self, x):
        """
        ESMM的前向传播逻辑。

        Args:
            x (dict): 输入的特征字典,key为特征名,value为特征值的tensor。
                      Example: {'user_id': tensor([[1], [2]]), 'item_id': tensor([[10], [12]])}

        Returns:
            tuple: 包含三个预测值的元组 (pCTR, pCVR, pCTCVR)。
                   - pCTR: 预测的点击率 (Predicted CTR)
                   - pCVR: 预测的点击后转化率 (Predicted CVR)
                   - pCTCVR: 预测的点击且转化率 (Predicted CTCVR)
        """
        # --- 流程 1: 特征 Embedding (共享) ---
        # 将输入的稀疏特征ID通过共享的Embedding层转换为稠密向量
        embedded_features = [
            self.embedding_layer[feat](x[feat]) for feat in self.feature_columns
        ]
        # 将所有特征的Embedding向量拼接成一个长向量,作为两个任务塔的共同输入
        concatenated_embeddings = torch.flatten(torch.cat(embedded_features, dim=1), start_dim=1)

        # --- 流程 2: CTR 任务预测 ---
        # 将拼接后的向量输入CTR塔,得到CTR的logit
        ctr_logit = self.ctr_tower(concatenated_embeddings)
        # 通过Sigmoid激活函数得到预测的点击率 pCTR
        pCTR = torch.sigmoid(ctr_logit)

        # --- 流程 3: CVR 任务预测 ---
        # 将【相同】的拼接向量输入CVR塔,得到CVR的logit
        cvr_logit = self.cvr_tower(concatenated_embeddings)
        # 通过Sigmoid激活函数得到预测的点击后转化率 pCVR
        pCVR = torch.sigmoid(cvr_logit)

        # --- 流程 4: 计算 pCTCVR ---
        # 这是ESMM解决样本选择偏差(SSB)问题的核心步骤。
        # 通过pCTR和pCVR的乘积,将CVR的预测从"点击空间"隐式地转换到"全曝光空间" 
        pCTCVR = pCTR * pCVR

        return pCTR, pCVR, pCTCVR

总结

ESMM 模型为 CVR 预估领域提供了一个里程碑式的解决方案。它最大的贡献在于,没有将思路局限在设计更复杂的网络结构上,而是通过重构学习目标和建模空间,从根本上规避了 SSB 和 DS 这两个长期存在的业界难题。其优雅的设计、显著的效果以及在淘宝等大规模工业场景中的成功落地,使其成为推荐和广告领域从业者必学的经典模型之一。

相关推荐
GRITJW1 天前
深度剖析RQ-VAE:从向量量化到生成式推荐的语义ID技术
推荐算法
IT学长编程2 天前
计算机毕业设计 基于Hadoop的健康饮食推荐系统的设计与实现 Java 大数据毕业设计 Hadoop毕业设计选题【附源码+文档报告+安装调试】
java·大数据·hadoop·毕业设计·课程设计·推荐算法·毕业论文
科兴第一吴彦祖3 天前
在线会议系统是一个基于Vue3 + Spring Boot的现代化在线会议管理平台,集成了视频会议、实时聊天、AI智能助手等多项先进技术。
java·vue.js·人工智能·spring boot·推荐算法
GRITJW3 天前
推荐系统中负采样策略及采样偏差的校正方法
推荐算法
lifallen4 天前
淘宝RecGPT:通过LLM增强推荐
人工智能·深度学习·ai·推荐算法
麦麦大数据4 天前
J002 Vue+SpringBoot电影推荐可视化系统|双协同过滤推荐算法评论情感分析spark数据分析|配套文档1.34万字
vue.js·spring boot·数据分析·spark·可视化·推荐算法
一只鱼^_9 天前
牛客周赛 Round 108
数据结构·c++·算法·动态规划·图论·广度优先·推荐算法
moonsheeper14 天前
推荐算法发展历史
算法·机器学习·推荐算法
乐迪信息14 天前
乐迪信息:智慧煤矿视觉检测平台:从皮带、人员到矿车
大数据·人工智能·算法·安全·视觉检测·推荐算法