ESMM模型精解:如何解决CVR预估中的样本选择偏差与数据稀疏难题
引言
在现代推荐系统与计算广告中,对点击后转化率(Post-Click Conversion Rate, CVR)的精准预估是优化平台收益与用户体验的核心环节。然而,传统的 CVR 预估模型在工业实践中普遍面临两大技术瓶颈:样本选择偏差(Sample Selection Bias, SSB)与数据稀疏性(Data Sparsity, DS)。这两个问题严重制约了模型的泛化能力与学习效率。
为了应对这些挑战,阿里巴巴的研究团队提出了"全体空间多任务模型"(Entire Space Multi-Task Model, ESMM),通过一种巧妙的建模思路,从根本上解决了上述难题。本文将从模型背景、核心原理、梯度更新机制、优缺点及适用场景等角度,对 ESMM 模型进行系统性剖析。
一、 CVR 预估的核心挑战:问题的根源
在深入了解 ESMM 之前,我们必须先理解它所要解决的问题根源。
-
1.1 样本选择偏差 (SSB)
SSB 问题源于模型训练与预测阶段的数据分布不一致 。具体来说,CVR 模型的目标是预估用户在"点击"行为发生后产生"转化"的概率。因此,模型的训练样本天然地被限定在已发生点击行为的样本空间 内。然而,在实际的线上推理阶段,模型需要对所有曝光的物品进行 CVR 预估,以便与点击率(CTR)等指标结合进行排序。
如下图所示,训练空间(Click Space)仅仅是推理空间(Impression Space)的一个很小的子集 。在这种有偏的样本上训练出的模型,直接应用于全体样本时,其泛化能力会受到严重损害。
-
1.2 数据稀疏性 (DS)
数据稀疏性是另一个严峻的挑战。在真实的业务场景中,用户从曝光到点击、再到转化的行为是一个发生率逐级递减的"漏斗"。用于 CVR 任务训练的"点击且转化"的样本,数量远远少于 CTR 任务的"点击"样本,更不用说海量的"曝光"样本了。例如,在淘宝的生产数据集中,用于 CVR 任务的样本量仅为 CTR 任务的 4% 左右。这种极端的正样本稀疏性,使得模型,尤其是需要海量数据进行学习的深度网络,难以充分学习到有效的特征表示,拟合过程非常困难。
二、 ESMM 模型核心原理:巧妙的全局建模
面对上述挑战,ESMM 并没有在模型结构上进行复杂的改造,而是从问题的定义本身出发,提出了一种全新的建模范式。
-
2.1 核心思想概述
ESMM 的核心思想是:不再直接建模有偏的 pCVR,而是通过利用用户行为的"曝光 → 点击 → 转化"这一序列依赖关系,在完整的、无偏的曝光样本空间上进行多任务联合建模。
-
2.2 关键公式与模型架构
-
公式拆解
ESMM 框架的理论基石是以下概率链式法则公式:
\[p(y=1, z=1 | x) = p(y=1 | x) \times p(z=1 | y=1, x) \]
其中,
x
代表曝光的特征,y=1
代表点击事件,z=1
代表转化事件。这个公式可以被解读为:\[pCTCVR = pCTR \times pCVR \]
pCTR
(Post-View Click-Through Rate): 曝光后点击率,即p(y=1|x)
。pCVR
(Post-Click Conversion Rate): 点击后转化率,即p(z=1|y=1,x)
。这是我们最终想要求解的目标。pCTCVR
(Post-View Click-Through & Conversion Rate): 曝光后点击且转化率,即p(y=1,z=1|x)
。
-
架构解析
基于上述公式,ESMM 设计了一个由两个子网络(任务塔)构成的多任务学习架构。
- CTR Tower : 一个独立的子网络,用于根据输入特征
x
预测pCTR
。 - CVR Tower : 另一个独立的子网络,用于根据相同的输入特征
x
预测pCVR
。 - 共享 Embedding 层: 两个任务塔共享底层的特征嵌入层(Embedding Layer)。这是缓解数据稀疏性问题的关键。
- *最终输出**: CTR 塔和 CVR 塔的输出
pCTR
和pCVR
相乘,得到pCTCVR
的预测值。
- CTR Tower : 一个独立的子网络,用于根据输入特征
-
-
2.3 手段总结
ESMM 解决两大难题的手段可以清晰地归纳为:
- 针对 SSB :模型不直接使用有偏的点击样本来监督 CVR 塔。而是通过监督在全体曝光样本 上都具有明确标签的
pCTR
和pCTCVR
两个任务,让 CVR 塔作为中间变量被隐式地学习。由于pCTR
和pCTCVR
都是在无偏的全空间上建模的,因此间接得到的pCVR
也自然地适用于全空间,从根本上解决了样本选择偏差问题。 - 针对 DS :通过让 CVR 任务塔与拥有海量训练样本的 CTR 任务塔共享底层的 Embedding 表示,实现了参数的迁移学习。CVR 任务可以从 CTR 任务中"借力",学习到更鲁棒的特征表示,从而有效缓解了自身训练数据稀疏的问题。
- 针对 SSB :模型不直接使用有偏的点击样本来监督 CVR 塔。而是通过监督在全体曝光样本 上都具有明确标签的
三、 ESMM 的参数更新机制:梯度的流动
理解 ESMM 的参数更新机制,有助于我们深入其联合训练的本质。
-
3.1 联合损失函数
ESMM 的总损失函数由 CTR 任务和 CTCVR 任务的损失加权构成,不包含 CVR 任务的直接损失项。
\[L(\theta_{ctr}, \theta_{cvr}) = \sum_{i=1}^{N} L_{ctr}(y_i, f(x_i; \theta_{ctr})) + \sum_{i=1}^{N} L_{ctcvr}(y_i\&z_i, f(x_i; \theta_{ctr}) \times f(x_i; \theta_{cvr})) \]
其中
θ_ctr
和θ_cvr
分别是 CTR 和 CVR 网络的参数。 -
3.2 梯度流向分析
在反向传播过程中,梯度从两个损失源头出发:
- 来自
L_ctr
的梯度路径 :这部分梯度只与pCTR
相关。因此,它会反向传播更新 CTR 塔 的参数,并继续向下更新 共享 Embedding 层 的参数。 - 来自
L_ctcvr
的梯度路径 :这部分梯度是联合训练的关键。根据链式求导法则,对于乘法操作pCTCVR = pCTR * pCVR
,梯度会"兵分两路":- 一路流向
pCVR
,更新 CVR 塔 的参数,并继续向下更新 共享 Embedding 层。 - 另一路流向
pCTR
,更新 CTR 塔 的参数,并继续向下更新 共享 Embedding 层。
- 一路流向
- 来自
-
3.3 参数更新归属总结
综合两条路径,各个模块参数的最终梯度来源如下:
- CVR 塔参数 :仅由
L_ctcvr
更新。 - CTR 塔参数 :由
L_ctr
和L_ctcvr
共同更新。 - 共享 Embedding 参数 :由
L_ctr
和L_ctcvr
共同更新。
- CVR 塔参数 :仅由
四、 优缺点及适用场景分析
-
4.1 主要优点
- 根本解决 SSB:通过在全空间上建模,从理论上彻底消除了传统 CVR 预估的样本选择偏差。
- 有效缓解 DS:利用迁移学习的思想,共享 Embedding,极大地缓解了 CVR 任务数据稀疏的问题。
- 结构优雅 :模型设计巧妙,乘法形式不仅避免了除法可能带来的数值不稳定问题,还能保证
pCVR
的值在[0,1]
范围内。
-
4.2 潜在局限性
- 强依赖于行为序列:模型的设计强依赖于"曝光 → 点击 → 转化"这样清晰、固定的序列依赖关系。如果任务之间是平行的,或者不遵循这种漏斗模式,ESMM 则不适用。
- 任务关系固定 :模型假设了
pCTR
和pCVR
之间是简单的乘法关系,这可能无法捕捉现实世界中更复杂的任务关联。
-
4.3 适用场景
ESMM 极其适用于具有明显序列依赖关系的多阶段用户行为预测场景。
- 电商平台:经典的"曝光 → 点击 → 购买"转化漏斗。
- 信息流推荐:"曝光 → 点击 → 关注/收藏/分享"等多步转化行为。
- 在线广告:"广告展示 → 点击 → 表单提交/App下载"等转化链路。
代码实现
python
import torch
import torch.nn as nn
class ESMM(nn.Module):
"""
ESMM: Entire Space Multi-Task Model PyTorch Implementation.
This class implements the ESMM model as described in the paper:
"Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate"
核心思想 Core Ideas:
1. **全空间建模 (Entire Space Modeling)**: 通过 `pCTCVR = pCTR * pCVR` 的关系,在全体曝光样本上联合训练CTR和CTCVR任务,从而无偏地学习CVR
2. **特征表示共享 (Shared Feature Representation)**: CVR任务和CTR任务共享底层的Embedding层,有效缓解CVR任务因样本稀疏导致的学习不充分问题
Args:
feature_columns (dict): 一个描述输入稀疏特征的字典,key为特征名,value为该特征的词表大小 (vocabulary size)。
Example: {'user_id': 10000, 'item_id': 5000}
embedding_dim (int): Embedding向量的维度。
tower_dims (list): 一个定义两个任务塔 (Tower) 隐藏层维度和结构的列表。
Example: [256, 128]
"""
def __init__(self, feature_columns, embedding_dim, tower_dims):
super(ESMM, self).__init__()
self.feature_columns = feature_columns
self.embedding_dim = embedding_dim
# --- 模块一: 共享的 Embedding 层 (Shared Embedding Layer) ---
# 这是ESMM解决数据稀疏性(Data Sparsity)问题的关键。
# CVR和CTR任务共享同一套Embedding参数,使得CVR任务可以从更丰富的CTR样本中学习特征表示
self.embedding_layer = nn.ModuleDict({
feat: nn.Embedding(vocab_size, self.embedding_dim)
for feat, vocab_size in self.feature_columns.items()
})
# 计算两个任务塔(MLP)的输入维度
# 输入维度 = 特征数量 * 每个特征的Embedding维度
self.tower_input_dim = len(self.feature_columns) * self.embedding_dim
# --- 模块二: CTR 任务塔 (CTR Tower) ---
# 这个子网络负责预测点击率 pCTR。
self.ctr_tower = self._build_mlp_tower(self.tower_input_dim, tower_dims)
# --- 模块三: CVR 任务塔 (CVR Tower) ---
# 这个子网络负责预测点击后转化率 pCVR。
# 注意它的输入与CTR Tower完全相同,都是来自共享的Embedding层。
self.cvr_tower = self._build_mlp_tower(self.tower_input_dim, tower_dims)
def _build_mlp_tower(self, input_dim, tower_dims):
"""一个辅助函数,用于构建MLP网络(即任务塔)。"""
layers = []
for hidden_dim in tower_dims:
layers.append(nn.Linear(input_dim, hidden_dim))
layers.append(nn.ReLU())
input_dim = hidden_dim
# 输出层,得到一个logit
layers.append(nn.Linear(input_dim, 1))
return nn.Sequential(*layers)
def forward(self, x):
"""
ESMM的前向传播逻辑。
Args:
x (dict): 输入的特征字典,key为特征名,value为特征值的tensor。
Example: {'user_id': tensor([[1], [2]]), 'item_id': tensor([[10], [12]])}
Returns:
tuple: 包含三个预测值的元组 (pCTR, pCVR, pCTCVR)。
- pCTR: 预测的点击率 (Predicted CTR)
- pCVR: 预测的点击后转化率 (Predicted CVR)
- pCTCVR: 预测的点击且转化率 (Predicted CTCVR)
"""
# --- 流程 1: 特征 Embedding (共享) ---
# 将输入的稀疏特征ID通过共享的Embedding层转换为稠密向量
embedded_features = [
self.embedding_layer[feat](x[feat]) for feat in self.feature_columns
]
# 将所有特征的Embedding向量拼接成一个长向量,作为两个任务塔的共同输入
concatenated_embeddings = torch.flatten(torch.cat(embedded_features, dim=1), start_dim=1)
# --- 流程 2: CTR 任务预测 ---
# 将拼接后的向量输入CTR塔,得到CTR的logit
ctr_logit = self.ctr_tower(concatenated_embeddings)
# 通过Sigmoid激活函数得到预测的点击率 pCTR
pCTR = torch.sigmoid(ctr_logit)
# --- 流程 3: CVR 任务预测 ---
# 将【相同】的拼接向量输入CVR塔,得到CVR的logit
cvr_logit = self.cvr_tower(concatenated_embeddings)
# 通过Sigmoid激活函数得到预测的点击后转化率 pCVR
pCVR = torch.sigmoid(cvr_logit)
# --- 流程 4: 计算 pCTCVR ---
# 这是ESMM解决样本选择偏差(SSB)问题的核心步骤。
# 通过pCTR和pCVR的乘积,将CVR的预测从"点击空间"隐式地转换到"全曝光空间"
pCTCVR = pCTR * pCVR
return pCTR, pCVR, pCTCVR
总结
ESMM 模型为 CVR 预估领域提供了一个里程碑式的解决方案。它最大的贡献在于,没有将思路局限在设计更复杂的网络结构上,而是通过重构学习目标和建模空间,从根本上规避了 SSB 和 DS 这两个长期存在的业界难题。其优雅的设计、显著的效果以及在淘宝等大规模工业场景中的成功落地,使其成为推荐和广告领域从业者必学的经典模型之一。