背景与动机
问题场景
在推荐系统(CTR预测)中,我们需要解决两个核心问题:
- 稀疏特征处理:用户ID、商品ID等特征取值空间巨大
- 特征交互建模:不同特征组合会产生不同效果
演进历史
LR (2012)
↓ 只能学习线性关系,需要人工特征工程
FM (2010)
↓ 自动学习二阶特征交互,但无法捕捉高阶
Wide&Deep (2016)
↓ 结合线性记忆能力和深度泛化能力
DeepFM (2017) ⭐
↓ FM显式建模二阶 + DNN隐式建模高阶
xDeepFM (2018)、DCN (2019) ...
DeepFM 的创新点
| 创新点 | 说明 |
|---|---|
| 无需人工特征工程 | 不同于 Wide&Deep,DeepFM 不需要人工设计交叉特征 |
| 端到端训练 | FM 和 DNN 共享 Embedding,统一优化 |
| 兼顾高低阶交互 | FM 负责二阶,DNN 负责高阶 |
模型架构
整体结构
┌─────────────┐
│ 输入特征 │
│ (离散索引) │
└──────┬──────┘
│
┌──────▼──────┐
│ Embedding │
│ Layer │
└──────┬──────┘
│
┌────────────┴────────────┐
│ │
┌──────▼──────┐ ┌──────▼──────┐
│ FM 部分 │ │ DNN 部分 │
│ │ │ │
│ • 一阶线性 │ │ • Flatten │
│ • 二阶交互 │ │ • 全连接层 │
│ │ │ • 非线性激活│
└──────┬──────┘ └──────┬──────┘
│ │
└────────────┬────────────┘
│
┌──────▼──────┐
│ 输出层 │
│ (Sigmoid) │
└─────────────┘
组件详解
1. Embedding 层详解
什么是 Embedding?
Embedding 是将高维稀疏特征映射到低维稠密向量的技术。在推荐系统中,它是最关键的第一步。
问题背景
原始输入(高维稀疏):
用户ID=12345
广告ID=678
设备=iOS
位置=北京
时间=下午
One-Hot 编码后:
用户特征: [0,0,0,...,1,0,...,0] # 1000维,只有1个是1
广告特征: [0,0,...,1,0,...,0] # 500维,只有1个是1
问题:
- 维度爆炸(用户数可能上千万)
- 极度稀疏(99.99% 都是 0)
- 计算效率低
- 难以捕捉相似性
Embedding 解决的三个核心问题
1. 降维:减少计算量
One-Hot: [0,0,...,1,0,...,0] # 1000维
↓ Embedding
Embedding: [0.23, -0.15, 0.87, ...] # 8维
| 对比维度 | One-Hot | Embedding |
|---|---|---|
| 维度 | 1,000,000 | 8-32 |
| 存储空间 | 极大 | 小 |
| 计算量 | 巨大 | 小 |
2. 稠密化:捕捉语义相似性
One-Hot 的局限:
用户A ID=12345 → [0,...,1,...,0] # 向量中1的位置不同
用户B ID=12346 → [0,...,1,...,0] # 完全独立,无法比较
即使用户A和用户B行为相似,One-Hot 也无法体现。
Embedding 的优势:
用户A → [0.23, -0.15, 0.87, ...] # 向量空间中位置接近
用户B → [0.25, -0.13, 0.85, ...]
相似的用户/物品在 Embedding 空间中距离更近!
3. 隐式学习:发现潜在关联
用户行为: 经常点击 iPhone 广告
↓ Embedding 学习
隐式关联: [iOS设备] + [科技兴趣] + [高消费能力]
无需人工设计,模型自动学习特征间的隐式关系。
Embedding 的本质
Embedding = 可学习的查找表
ID → Embedding 表 → 向量
12345 → 查表第12345行 → [0.23, -0.15, 0.87, ...]
每个 ID 对应一个可学习的向量,训练过程中不断优化。
类比理解
| 场景 | 高维稀疏 | Embedding |
|---|---|---|
| 文字 | 单词本身 | 词向量 (word2vec) |
| 用户 | 用户ID | 用户画像向量 |
| 商品 | 商品ID | 商品特征向量 |
| 类比 | 名字 | 性格/特点 |
实际例子
python
# 假设有 10000 个用户,Embedding 维度 8
user_embedding = nn.Embedding(10000, 8)
# 训练后,可能学到:
user_1 = [0.8, 0.2, 0.1, -0.3, 0.9, 0.5, 0.0, 0.1] # 高消费 + 科技兴趣
user_2 = [0.1, 0.9, 0.8, -0.2, 0.3, 0.4, 0.0, 0.2] # 低价 + 日常用品
user_3 = [0.7, 0.3, 0.2, -0.2, 0.8, 0.6, 0.0, 0.1] # 和 user_1 相似!
# 计算相似度
cos_sim(user_1, user_3) = 0.96 # 很相似!
cos_sim(user_1, user_2) = 0.23 # 不相似
为什么这步对 DeepFM 很重要?
| 组件 | 为什么需要 Embedding |
|---|---|
| FM | 需要向量计算 v i T v j v_i^T v_j viTvj(隐向量内积) |
| DNN | 需要稠密输入进行矩阵运算 |
| 共享 | FM 和 DNN 共享同一份 Embedding,参数复用,减少参数量 |
如果没有 Embedding:
- FM 无法计算隐向量内积
- DNN 输入维度太大,无法训练
总结
Embedding 的三个核心作用:
- 降维 - 从千万维降到 8-32 维
- 稠密化 - 从稀疏 0/1 到稠密实数
- 语义化 - 捕捉隐式相似性和关联
一句话: Embedding 把"名字"变成"特点",让模型理解特征的真实含义。
2. FM 部分
一阶:
y₁ = w₀ + Σᵢ wᵢxᵢ
二阶(核心创新):
y₂ = Σᵢⱼ (vᵢ · vⱼ) xᵢxⱼ (i < j)
高效计算:
# 原始复杂度: O(n²k)
# 优化后复杂度: O(nk)
# 优化公式:
Σᵢⱼ (vᵢ · vⱼ) xᵢxⱼ (i < j)
= 1/2 · [(Σᵢ vᵢxᵢ)² - Σᵢ (vᵢxᵢ)²]
3. DNN 部分
y₃ = MLP(Embedding)
多层全连接网络,捕捉高阶非线性关系
4. 最终输出
ŷ = σ(y₁ + y₂ + y₃)
直接相加,线性组合
数学原理
FM 部分
一阶:
y₁ = w₀ + Σᵢ wᵢxᵢ
二阶:
y₂ = Σᵢⱼ (vᵢ · vⱼ) xᵢxⱼ (i < j)
二阶优化计算(关键):
Σᵢⱼ (vᵢ · vⱼ) xᵢxⱼ (i < j)
= 1/2 · [(Σᵢ vᵢxᵢ)² - Σᵢ (vᵢxᵢ)²]
推导过程:
-
先求和再平方:(Σᵢ vᵢxᵢ)² = Σᵢ (vᵢxᵢ)² + 2 Σᵢⱼ (vᵢ · vⱼ) xᵢxⱼ (i < j)
-
移项得到:Σᵢⱼ (vᵢ · vⱼ) xᵢxⱼ = 1/2 · [(Σᵢ vᵢxᵢ)² - Σᵢ (vᵢxᵢ)²]
为什么这样优化?
| 方法 | 复杂度 | 说明 |
|---|---|---|
| 直接计算 | O(n²k) | 双重循环 |
| 优化计算 | O(nk) | 单次循环 |
| 加速倍数 | ~n 倍 | n 越大越明显 |
代码实现:
python
sum_of_vectors = torch.sum(embeddings, dim=1) # Σᵢ vᵢxᵢ
sum_square = torch.sum(embeddings ** 2, dim=1) # Σᵢ (vᵢxᵢ)²
second_order = 0.5 * (sum_of_vectors ** 2 - sum_square)
DNN 部分
y₃ = Wₗ · σ(Wₙ₋₁ · σ(... σ(W₁ · e + b₁) ...))
其中:
e:扁平化的 embedding 向量σ:激活函数(ReLU)W, b:权重和偏置
代码实现
核心代码流程
python
# Step 1: Embedding
embedded = [emb(x[:, i]) for i, emb in enumerate(self.embeddings)]
all_emb = torch.cat(embedded, dim=1)
# Step 2: FM 一阶
fm_linear = self.fm_linear(x)
# Step 3: FM 二阶
sum_emb = torch.sum(all_emb, dim=1)
sum_square = torch.sum(all_emb ** 2, dim=1)
fm_interaction = 0.5 * torch.sum(sum_emb ** 2 - sum_square, dim=1)
# Step 4: DNN
dnn_output = self.dnn(all_emb.view(batch_size, -1))
# Step 5: 合并
output = fm_linear + fm_interaction + dnn_output
完整实现见 DeepFM.py
与主流模型对比
Wide&Deep vs DeepFM
| 对比维度 | Wide&Deep | DeepFM |
|---|---|---|
| 特征工程 | 需要人工设计交叉特征 | 无需人工设计 |
| Embedding | Wide部分不用 | FM和DNN共享 |
| 二阶交互 | 需要指定 | 自动学习 |
| 训练效率 | 高 | 中等 |
| 适用场景 | 有领域知识 | 通用推荐 |
模型选择建议
数据量小 → FM
数据量大 + 简单场景 → Wide&Deep
数据量大 + 复杂交互 → DeepFM
超大规模 → xDeepFM / DCN
实战技巧
1. 超参数调优
| 参数 | 推荐值 | 影响 |
|---|---|---|
| embedding_dim | 8-32 | 过大过拟合,过小欠拟合 |
| hidden_dims | [64, 32] 或 [128, 64, 32] | 根据数据量调整 |
| learning_rate | 0.001 (Adam) | 影响收敛速度 |
| dropout | 0.1-0.3 | 防止过拟合 |
2. 特征处理
python
# 离散特征:直接使用索引
user_id = 12345
# 连续特征:分桶后离散化
age_bucket = discretize(age, bins=10)
3. 训练技巧
python
# 学习率衰减
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
# 早停
if val_loss > best_val_loss:
patience -= 1
if patience == 0:
break
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
4. 常见问题排查
| 问题 | 可能原因 | 解决方案 |
|---|---|---|
| Loss 不下降 | 学习率太小 | 增大 learning rate |
| 过拟合 | 模型太复杂 | 减少 hidden_dims 或加 dropout |
| 欠拟合 | 模型太简单 | 增加 hidden_dims 或 embedding_dim |
| 训练太慢 | batch size 太小 | 增大 batch size |
面试常见问题
Q1: DeepFM 的核心优势是什么?
A:
- 自动特征交互:FM 自动学习二阶交互,无需人工设计
- 高低阶兼顾:FM 显式建模二阶,DNN 隐式建模高阶
- 共享 Embedding:减少参数,加速收敛
- 端到端训练:所有组件联合优化
Q2: 为什么需要 FM 部分,DNN 不够吗?
A:
- DNN 隐式学习特征交互,解释性差
- FM 显式建模二阶,可解释性强
- 某些强二阶交互,DNN 难以有效学习
- FM 能提供更稳定的梯度
Q3: DeepFM 二阶交互的计算优化是如何实现的?
A:
利用恒等式:
Σᵢⱼ (vᵢ · vⱼ) xᵢxⱼ (i < j)
= 1/2 · [(Σᵢ vᵢxᵢ)² - Σᵢ (vᵢxᵢ)²]
将 O(n²) 降为 O(n)
Q4: Wide&Deep 和 DeepFM 的本质区别?
A:
- 特征工程:Wide&Deep 需要人工设计交叉,DeepFM 不需要
- Embedding:Wide&Deep 的 Wide 部分不用 Embedding
- 灵活性:DeepFM 更通用,Wide&Deep 在有领域知识时更高效
Q5: DeepFM 如何处理稀疏特征?
A:
- 使用 Embedding 将稀疏离散特征映射到稠密低维空间
- 不同特征独立 Embedding,捕捉各自语义
- FM 部分显式建模特征间的交互
Q6: 模型上线后如何评估?
A:
- 离线指标:AUC, LogLoss, NDCG
- 在线指标:CTR, GMV, 用户停留时长
- A/B测试:对比新旧版本效果
- 长期指标:用户留存、活跃度
Q7: DeepFM 的不足是什么?
A:
- 参数量较大,训练和推理成本高
- 对显存要求较高
- 某些场景下,Wide&Deep 更简单高效
- 超高阶交互仍依赖 DNN,可能不如专门模型
Q8: 实际应用中如何改进 DeepFM?
A:
- 特征工程:增加时间特征、序列特征等
- 模型结构:加入注意力机制、多层 FM
- 训练优化:混合精度训练、分布式训练
- 推理优化:模型蒸馏、量化加速
扩展阅读
相关论文
- DeepFM (2017) - 原始论文
- Wide&Deep (2016) - 奠基之作
- xDeepFM (2018) - CIN 网络
- DCN (2019) - 交叉网络
- DIN (2018) - 兴趣网络
- DIEN (2019) - 兴趣演化网络
代码资源
学习路径
基础 → FM → Wide&Deep → DeepFM → xDeepFM/DCN → DIEN/DIEN
↓
实战项目
快速检查清单
理解 DeepFM,你应该能回答:
- 解释 DeepFM 的整体架构
- 说明 FM 部分和 DNN 部分的作用
- 推导二阶交互的优化计算
- 对比 Wide&Deep 和 DeepFM
- 解释为什么共享 Embedding
- 知道如何调整超参数
- 了解常见问题和解决方案
- 能从零实现 DeepFM
实现案例(CTR预测):
python
import torch
import torch.nn as nn
class DeepFM(nn.Module):
"""
DeepFM: 结合 FM 的二阶交互能力和 DNN 的高阶非线性拟合能力
核心思想:
FM 部分: 显式建模二阶特征交互 (Σᵢⱼ vᵢ·vⱼ xᵢxⱼ)
DNN 部分: 隐式学习高阶非线性关系
架构:
Input (离散特征索引)
↓
Embedding 层 (稀疏→稠密)
↓
┌──────┴──────┐
↓ ↓
FM 部分 DNN 部分
(二阶交互) (高阶非线性)
└──────┬──────┘
↓
输出层 (相加)
"""
def __init__(self, feature_dims, embedding_dim=8, hidden_dims=[64, 32]):
"""
Args:
feature_dims: 每个特征的可能取值数列表
例如: [1000, 500, 5, 4, 10]
表示: 1000个用户, 500个广告, 5种设备, 4个时间段, 10个位置
embedding_dim: embedding 向量的维度 (默认 8)
hidden_dims: DNN 隐藏层的维度列表 (默认 [64, 32])
"""
super().__init__()
# 保存配置
self.feature_dims = feature_dims
self.num_features = len(feature_dims) # 特征总数
self.embedding_dim = embedding_dim
# ==================== Embedding 层 ====================
# 为每个离散特征创建独立的 embedding 表
# 例如: 用户特征有自己的表, 广告特征有自己的表
# 这样每个特征可以学习到不同的语义表示
self.embeddings = nn.ModuleList([
nn.Embedding(dim, embedding_dim) for dim in feature_dims
])
# ==================== FM 部分 ====================
# FM 一阶权重: 学习每个特征的线性影响
# 例如: w[0] 表示用户特征的权重, w[1] 表示广告特征的权重
self.fm_linear = nn.Linear(self.num_features, 1)
# ==================== DNN 部分 ====================
# DNN 的输入维度 = 所有特征的 embedding 拼接后的总维度
# 例如: 5个特征 × 8维 = 40维
dnn_input_dim = self.num_features * embedding_dim
# 构建 DNN 层: Linear → ReLU → BatchNorm
dnn_layers = []
for hidden_dim in hidden_dims:
dnn_layers.append(nn.Linear(dnn_input_dim, hidden_dim))
dnn_layers.append(nn.ReLU()) # 非线性激活
dnn_layers.append(nn.BatchNorm1d(hidden_dim)) # 加速收敛
dnn_input_dim = hidden_dim # 下一层的输入
# DNN 的输出层: 输出一个标量
dnn_layers.append(nn.Linear(dnn_input_dim, 1))
# 将所有层串联成一个模块
self.dnn = nn.Sequential(*dnn_layers)
def forward(self, x):
"""
前向传播
Args:
x: (batch_size, num_features) 离散特征索引
例如: [[123, 45, 2, 1, 3], ...]
表示: 用户123, 广告45, iOS, 上午, 位置3
Returns:
logits: (batch_size, 1) 预测分数 (未经过 sigmoid)
"""
batch_size = x.shape[0] # 获取 batch size
# ==================== Step 1: Embedding ====================
# 将离散索引转换为稠密向量
embedded_features = []
for i, emb in enumerate(self.embeddings):
# 取出第 i 个特征的索引, 通过 embedding 查找表得到向量
# 输入: x[:, i] 形状 (batch_size,)
# 输出: emb_i 形状 (batch_size, embedding_dim)
emb_i = emb(x[:, i])
embedded_features.append(emb_i)
# 将所有特征的 embedding 在特征维度上拼接
# 结果形状: (batch_size, num_features * embedding_dim)
# 例如: (32, 5*8) = (32, 40)
all_embeddings = torch.cat(embedded_features, dim=1)
# 重塑为 3D 张量, 方便后续计算
# 形状: (batch_size, num_features, embedding_dim)
# 例如: (32, 5, 8)
all_embeddings = all_embeddings.view(batch_size, self.num_features, self.embedding_dim)
# ==================== Step 2: FM 部分 ====================
# --- FM 一阶项 ---
# 学习特征本身的线性影响
# 输入需要转为 float, 因为 x 是 long 类型 (索引)
fm_linear_part = self.fm_linear(x.float()) # 形状: (batch_size, 1)
# --- FM 二阶项 (核心优化) ---
# 计算所有特征对的内积: Σᵢⱼ (vᵢ·vⱼ) xᵢxⱼ
# 优化公式: 1/2 * [(Σᵢ vᵢxᵢ)² - Σᵢ (vᵢxᵢ)²]
# 第1部分: 先求和再平方
# 形状: (batch_size, embedding_dim)
sum_of_vectors = torch.sum(all_embeddings, dim=1)
# 第2部分: 先平方再求和
# 形状: (batch_size, embedding_dim)
sum_of_square = torch.sum(all_embeddings ** 2, dim=1)
# 应用优化公式
# 形状: (batch_size, 1)
fm_interaction_part = 0.5 * torch.sum(sum_of_vectors ** 2 - sum_of_square, dim=1, keepdim=True)
# FM 总输出 = 一阶项 + 二阶项
fm_output = fm_linear_part + fm_interaction_part
# ==================== Step 3: DNN 部分 ====================
# 展平所有 embeddings 作为 DNN 的输入
# 形状: (batch_size, num_features * embedding_dim)
dnn_input = all_embeddings.view(batch_size, -1)
# 通过多层全连接网络学习高阶非线性
dnn_output = self.dnn(dnn_input) # 形状: (batch_size, 1)
# ==================== Step 4: 合并输出 ====================
# FM 和 DNN 的输出直接相加
# FM 负责低阶交互, DNN 负责高阶非线性
output = fm_output + dnn_output
return output
# ==================== 使用示例 ====================
if __name__ == '__main__':
# 模拟 CTR (点击率预测) 数据
# 特征定义: [用户ID, 广告ID, 设备类型, 时间段, 位置]
# 数字表示每个特征的可能取值数量
feature_dims = [1000, 500, 5, 4, 10]
# 创建 DeepFM 模型
model = DeepFM(
feature_dims=feature_dims, # 特征维度
embedding_dim=8, # embedding 向量维度
hidden_dims=[64, 32] # DNN 隐藏层配置
)
print('=== DeepFM 模型结构 ===')
print(model)
# ==================== 生成模拟训练数据 ====================
batch_size = 32
# 生成随机训练样本
# 每个样本是一个特征向量: [用户ID, 广告ID, 设备, 时间, 位置]
x = torch.tensor([
[torch.randint(0, dims, size=(1,)).item() for dims in feature_dims]
for _ in range(batch_size)
])
# 生成随机标签 (0: 不点击, 1: 点击)
y = torch.randint(0, 2, (batch_size, 1), dtype=torch.float32)
# ==================== 训练配置 ====================
# 使用 BCEWithLogitsLoss, 内部自动应用 sigmoid
criterion = nn.BCEWithLogitsLoss()
# Adam 优化器, 自适应学习率
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
print('\n=== 开始训练 ===')
print(f'训练样本数: {batch_size}')
print(f'特征维度: {feature_dims}')
# ==================== 训练循环 ====================
for epoch in range(1000):
# 前向传播: 模型预测
pred = model(x) # 输出 logits
# 计算损失
loss = criterion(pred, y)
# 反向传播: 计算梯度
optimizer.zero_grad() # 清空上一次的梯度
loss.backward() # 计算当前梯度
optimizer.step() # 更新参数
# 每 20 个 epoch 打印一次 loss
if (epoch + 1) % 20 == 0:
print(f'Epoch {epoch + 1:4d}, Loss: {loss.item():.6f}')
# ==================== 预测示例 ====================
# 切换到评估模式 (关闭 dropout 等)
model.eval()
# 不需要计算梯度, 节省内存和加速
with torch.no_grad():
# 生成一个随机测试样本
test_x = torch.tensor([[
torch.randint(0, dim, size=(1,)).item() for dim in feature_dims
]])
# 获取模型输出 (logits)
logits = model(test_x)
# 应用 sigmoid 得到概率
click_prob = torch.sigmoid(logits)
print(f'\n=== 预测结果 ===')
print(f'模型输出 (logits): {logits.item():.4f}')
print(f'点击概率 (sigmoid): {click_prob.item():.4f}')
# ==================== 模型信息 ====================
# 计算总参数量
total_params = sum(p.numel() for p in model.parameters())
print(f'\n模型参数量: {total_params:,}')
# 计算各部分参数量
embedding_params = sum(p.numel() for p in model.embeddings.parameters())
fm_params = sum(p.numel() for p in model.fm_linear.parameters())
dnn_params = sum(p.numel() for p in model.dnn.parameters())
print(f' - Embedding: {embedding_params:,}')
print(f' - FM 部分: {fm_params:,}')
print(f' - DNN 部分: {dnn_params:,}')
参考资料
- DeepFM 论文: https://arxiv.org/abs/1703.04247
- Wide&Deep 论文: https://arxiv.org/abs/1606.07792
- 推荐系统综述: https://arxiv.org/abs/1906.02966