DeepFM 学习日记

背景与动机

问题场景

在推荐系统(CTR预测)中,我们需要解决两个核心问题:

  1. 稀疏特征处理:用户ID、商品ID等特征取值空间巨大
  2. 特征交互建模:不同特征组合会产生不同效果

演进历史

复制代码
LR (2012)
   ↓ 只能学习线性关系,需要人工特征工程
FM (2010)
   ↓ 自动学习二阶特征交互,但无法捕捉高阶
Wide&Deep (2016)
   ↓ 结合线性记忆能力和深度泛化能力
DeepFM (2017) ⭐
   ↓ FM显式建模二阶 + DNN隐式建模高阶
xDeepFM (2018)、DCN (2019) ...

DeepFM 的创新点

创新点 说明
无需人工特征工程 不同于 Wide&Deep,DeepFM 不需要人工设计交叉特征
端到端训练 FM 和 DNN 共享 Embedding,统一优化
兼顾高低阶交互 FM 负责二阶,DNN 负责高阶

模型架构

整体结构

复制代码
                              ┌─────────────┐
                              │  输入特征   │
                              │ (离散索引)  │
                              └──────┬──────┘
                                     │
                              ┌──────▼──────┐
                              │ Embedding   │
                              │   Layer     │
                              └──────┬──────┘
                                     │
                        ┌────────────┴────────────┐
                        │                         │
                 ┌──────▼──────┐           ┌──────▼──────┐
                 │  FM 部分    │           │  DNN 部分   │
                 │             │           │             │
                 │ • 一阶线性  │           │ • Flatten   │
                 │ • 二阶交互  │           │ • 全连接层  │
                 │             │           │ • 非线性激活│
                 └──────┬──────┘           └──────┬──────┘
                        │                         │
                        └────────────┬────────────┘
                                     │
                              ┌──────▼──────┐
                              │   输出层     │
                              │ (Sigmoid)   │
                              └─────────────┘

组件详解

1. Embedding 层详解

什么是 Embedding?

Embedding 是将高维稀疏特征映射到低维稠密向量的技术。在推荐系统中,它是最关键的第一步。


问题背景

原始输入(高维稀疏):

复制代码
用户ID=12345
广告ID=678
设备=iOS
位置=北京
时间=下午

One-Hot 编码后:

复制代码
用户特征: [0,0,0,...,1,0,...,0]  # 1000维,只有1个是1
广告特征: [0,0,...,1,0,...,0]     # 500维,只有1个是1

问题:

  • 维度爆炸(用户数可能上千万)
  • 极度稀疏(99.99% 都是 0)
  • 计算效率低
  • 难以捕捉相似性

Embedding 解决的三个核心问题
1. 降维:减少计算量
复制代码
One-Hot:   [0,0,...,1,0,...,0]  # 1000维
             ↓ Embedding
Embedding:  [0.23, -0.15, 0.87, ...]  # 8维
对比维度 One-Hot Embedding
维度 1,000,000 8-32
存储空间 极大
计算量 巨大
2. 稠密化:捕捉语义相似性

One-Hot 的局限:

复制代码
用户A ID=12345 → [0,...,1,...,0]  # 向量中1的位置不同
用户B ID=12346 → [0,...,1,...,0]  # 完全独立,无法比较

即使用户A和用户B行为相似,One-Hot 也无法体现。

Embedding 的优势:

复制代码
用户A → [0.23, -0.15, 0.87, ...]  # 向量空间中位置接近
用户B → [0.25, -0.13, 0.85, ...]

相似的用户/物品在 Embedding 空间中距离更近!

3. 隐式学习:发现潜在关联
复制代码
用户行为: 经常点击 iPhone 广告
    ↓ Embedding 学习
隐式关联: [iOS设备] + [科技兴趣] + [高消费能力]

无需人工设计,模型自动学习特征间的隐式关系。


Embedding 的本质

Embedding = 可学习的查找表

复制代码
ID → Embedding 表 → 向量
12345 → 查表第12345行 → [0.23, -0.15, 0.87, ...]

每个 ID 对应一个可学习的向量,训练过程中不断优化。


类比理解
场景 高维稀疏 Embedding
文字 单词本身 词向量 (word2vec)
用户 用户ID 用户画像向量
商品 商品ID 商品特征向量
类比 名字 性格/特点

实际例子
python 复制代码
# 假设有 10000 个用户,Embedding 维度 8
user_embedding = nn.Embedding(10000, 8)

# 训练后,可能学到:
user_1 = [0.8, 0.2, 0.1, -0.3, 0.9, 0.5, 0.0, 0.1]  # 高消费 + 科技兴趣
user_2 = [0.1, 0.9, 0.8, -0.2, 0.3, 0.4, 0.0, 0.2]  # 低价 + 日常用品
user_3 = [0.7, 0.3, 0.2, -0.2, 0.8, 0.6, 0.0, 0.1]  # 和 user_1 相似!

# 计算相似度
cos_sim(user_1, user_3) = 0.96  # 很相似!
cos_sim(user_1, user_2) = 0.23  # 不相似

为什么这步对 DeepFM 很重要?
组件 为什么需要 Embedding
FM 需要向量计算 v i T v j v_i^T v_j viTvj(隐向量内积)
DNN 需要稠密输入进行矩阵运算
共享 FM 和 DNN 共享同一份 Embedding,参数复用,减少参数量

如果没有 Embedding:

  • FM 无法计算隐向量内积
  • DNN 输入维度太大,无法训练

总结

Embedding 的三个核心作用:

  1. 降维 - 从千万维降到 8-32 维
  2. 稠密化 - 从稀疏 0/1 到稠密实数
  3. 语义化 - 捕捉隐式相似性和关联

一句话: Embedding 把"名字"变成"特点",让模型理解特征的真实含义。

2. FM 部分

一阶:

复制代码
y₁ = w₀ + Σᵢ wᵢxᵢ

二阶(核心创新):

复制代码
y₂ = Σᵢⱼ (vᵢ · vⱼ) xᵢxⱼ   (i < j)

高效计算:

复制代码
# 原始复杂度: O(n²k)
# 优化后复杂度: O(nk)

# 优化公式:
Σᵢⱼ (vᵢ · vⱼ) xᵢxⱼ   (i < j)
= 1/2 · [(Σᵢ vᵢxᵢ)² - Σᵢ (vᵢxᵢ)²]
3. DNN 部分
复制代码
y₃ = MLP(Embedding)

多层全连接网络,捕捉高阶非线性关系

4. 最终输出
复制代码
ŷ = σ(y₁ + y₂ + y₃)

直接相加,线性组合


数学原理

FM 部分

一阶:

复制代码
y₁ = w₀ + Σᵢ wᵢxᵢ

二阶:

复制代码
y₂ = Σᵢⱼ (vᵢ · vⱼ) xᵢxⱼ   (i < j)

二阶优化计算(关键):

复制代码
Σᵢⱼ (vᵢ · vⱼ) xᵢxⱼ   (i < j)
= 1/2 · [(Σᵢ vᵢxᵢ)² - Σᵢ (vᵢxᵢ)²]

推导过程:

  1. 先求和再平方:(Σᵢ vᵢxᵢ)² = Σᵢ (vᵢxᵢ)² + 2 Σᵢⱼ (vᵢ · vⱼ) xᵢxⱼ (i < j)

  2. 移项得到:Σᵢⱼ (vᵢ · vⱼ) xᵢxⱼ = 1/2 · [(Σᵢ vᵢxᵢ)² - Σᵢ (vᵢxᵢ)²]

为什么这样优化?

方法 复杂度 说明
直接计算 O(n²k) 双重循环
优化计算 O(nk) 单次循环
加速倍数 ~n 倍 n 越大越明显

代码实现:

python 复制代码
sum_of_vectors = torch.sum(embeddings, dim=1)      # Σᵢ vᵢxᵢ
sum_square = torch.sum(embeddings ** 2, dim=1)    # Σᵢ (vᵢxᵢ)²
second_order = 0.5 * (sum_of_vectors ** 2 - sum_square)

DNN 部分

复制代码
y₃ = Wₗ · σ(Wₙ₋₁ · σ(... σ(W₁ · e + b₁) ...))

其中:

  • e:扁平化的 embedding 向量
  • σ:激活函数(ReLU)
  • W, b:权重和偏置

代码实现

核心代码流程

python 复制代码
# Step 1: Embedding
embedded = [emb(x[:, i]) for i, emb in enumerate(self.embeddings)]
all_emb = torch.cat(embedded, dim=1)

# Step 2: FM 一阶
fm_linear = self.fm_linear(x)

# Step 3: FM 二阶
sum_emb = torch.sum(all_emb, dim=1)
sum_square = torch.sum(all_emb ** 2, dim=1)
fm_interaction = 0.5 * torch.sum(sum_emb ** 2 - sum_square, dim=1)

# Step 4: DNN
dnn_output = self.dnn(all_emb.view(batch_size, -1))

# Step 5: 合并
output = fm_linear + fm_interaction + dnn_output

完整实现见 DeepFM.py


与主流模型对比

Wide&Deep vs DeepFM

对比维度 Wide&Deep DeepFM
特征工程 需要人工设计交叉特征 无需人工设计
Embedding Wide部分不用 FM和DNN共享
二阶交互 需要指定 自动学习
训练效率 中等
适用场景 有领域知识 通用推荐

模型选择建议

复制代码
数据量小 → FM
数据量大 + 简单场景 → Wide&Deep
数据量大 + 复杂交互 → DeepFM
超大规模 → xDeepFM / DCN

实战技巧

1. 超参数调优

参数 推荐值 影响
embedding_dim 8-32 过大过拟合,过小欠拟合
hidden_dims [64, 32] 或 [128, 64, 32] 根据数据量调整
learning_rate 0.001 (Adam) 影响收敛速度
dropout 0.1-0.3 防止过拟合

2. 特征处理

python 复制代码
# 离散特征:直接使用索引
user_id = 12345
# 连续特征:分桶后离散化
age_bucket = discretize(age, bins=10)

3. 训练技巧

python 复制代码
# 学习率衰减
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)

# 早停
if val_loss > best_val_loss:
    patience -= 1
    if patience == 0:
        break

# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

4. 常见问题排查

问题 可能原因 解决方案
Loss 不下降 学习率太小 增大 learning rate
过拟合 模型太复杂 减少 hidden_dims 或加 dropout
欠拟合 模型太简单 增加 hidden_dims 或 embedding_dim
训练太慢 batch size 太小 增大 batch size

面试常见问题

Q1: DeepFM 的核心优势是什么?

A:

  1. 自动特征交互:FM 自动学习二阶交互,无需人工设计
  2. 高低阶兼顾:FM 显式建模二阶,DNN 隐式建模高阶
  3. 共享 Embedding:减少参数,加速收敛
  4. 端到端训练:所有组件联合优化

Q2: 为什么需要 FM 部分,DNN 不够吗?

A:

  • DNN 隐式学习特征交互,解释性差
  • FM 显式建模二阶,可解释性强
  • 某些强二阶交互,DNN 难以有效学习
  • FM 能提供更稳定的梯度

Q3: DeepFM 二阶交互的计算优化是如何实现的?

A:

利用恒等式:

复制代码
Σᵢⱼ (vᵢ · vⱼ) xᵢxⱼ   (i < j)
= 1/2 · [(Σᵢ vᵢxᵢ)² - Σᵢ (vᵢxᵢ)²]

将 O(n²) 降为 O(n)

Q4: Wide&Deep 和 DeepFM 的本质区别?

A:

  • 特征工程:Wide&Deep 需要人工设计交叉,DeepFM 不需要
  • Embedding:Wide&Deep 的 Wide 部分不用 Embedding
  • 灵活性:DeepFM 更通用,Wide&Deep 在有领域知识时更高效

Q5: DeepFM 如何处理稀疏特征?

A:

  • 使用 Embedding 将稀疏离散特征映射到稠密低维空间
  • 不同特征独立 Embedding,捕捉各自语义
  • FM 部分显式建模特征间的交互

Q6: 模型上线后如何评估?

A:

  • 离线指标:AUC, LogLoss, NDCG
  • 在线指标:CTR, GMV, 用户停留时长
  • A/B测试:对比新旧版本效果
  • 长期指标:用户留存、活跃度

Q7: DeepFM 的不足是什么?

A:

  1. 参数量较大,训练和推理成本高
  2. 对显存要求较高
  3. 某些场景下,Wide&Deep 更简单高效
  4. 超高阶交互仍依赖 DNN,可能不如专门模型

Q8: 实际应用中如何改进 DeepFM?

A:

  1. 特征工程:增加时间特征、序列特征等
  2. 模型结构:加入注意力机制、多层 FM
  3. 训练优化:混合精度训练、分布式训练
  4. 推理优化:模型蒸馏、量化加速

扩展阅读

相关论文

  1. DeepFM (2017) - 原始论文
  2. Wide&Deep (2016) - 奠基之作
  3. xDeepFM (2018) - CIN 网络
  4. DCN (2019) - 交叉网络
  5. DIN (2018) - 兴趣网络
  6. DIEN (2019) - 兴趣演化网络

代码资源

学习路径

复制代码
基础 → FM → Wide&Deep → DeepFM → xDeepFM/DCN → DIEN/DIEN
                      ↓
                 实战项目

快速检查清单

理解 DeepFM,你应该能回答:

  • 解释 DeepFM 的整体架构
  • 说明 FM 部分和 DNN 部分的作用
  • 推导二阶交互的优化计算
  • 对比 Wide&Deep 和 DeepFM
  • 解释为什么共享 Embedding
  • 知道如何调整超参数
  • 了解常见问题和解决方案
  • 能从零实现 DeepFM

实现案例(CTR预测):

python 复制代码
import torch
import torch.nn as nn


class DeepFM(nn.Module):
    """
    DeepFM: 结合 FM 的二阶交互能力和 DNN 的高阶非线性拟合能力

    核心思想:
        FM 部分: 显式建模二阶特征交互 (Σᵢⱼ vᵢ·vⱼ xᵢxⱼ)
        DNN 部分: 隐式学习高阶非线性关系

    架构:
        Input (离散特征索引)
            ↓
        Embedding 层 (稀疏→稠密)
            ↓
        ┌──────┴──────┐
        ↓             ↓
    FM 部分        DNN 部分
    (二阶交互)    (高阶非线性)
        └──────┬──────┘
               ↓
          输出层 (相加)
    """

    def __init__(self, feature_dims, embedding_dim=8, hidden_dims=[64, 32]):
        """
        Args:
            feature_dims: 每个特征的可能取值数列表
                        例如: [1000, 500, 5, 4, 10]
                        表示: 1000个用户, 500个广告, 5种设备, 4个时间段, 10个位置
            embedding_dim: embedding 向量的维度 (默认 8)
            hidden_dims: DNN 隐藏层的维度列表 (默认 [64, 32])
        """
        super().__init__()

        # 保存配置
        self.feature_dims = feature_dims
        self.num_features = len(feature_dims)  # 特征总数
        self.embedding_dim = embedding_dim

        # ==================== Embedding 层 ====================
        # 为每个离散特征创建独立的 embedding 表
        # 例如: 用户特征有自己的表, 广告特征有自己的表
        # 这样每个特征可以学习到不同的语义表示
        self.embeddings = nn.ModuleList([
            nn.Embedding(dim, embedding_dim) for dim in feature_dims
        ])

        # ==================== FM 部分 ====================
        # FM 一阶权重: 学习每个特征的线性影响
        # 例如: w[0] 表示用户特征的权重, w[1] 表示广告特征的权重
        self.fm_linear = nn.Linear(self.num_features, 1)

        # ==================== DNN 部分 ====================
        # DNN 的输入维度 = 所有特征的 embedding 拼接后的总维度
        # 例如: 5个特征 × 8维 = 40维
        dnn_input_dim = self.num_features * embedding_dim

        # 构建 DNN 层: Linear → ReLU → BatchNorm
        dnn_layers = []
        for hidden_dim in hidden_dims:
            dnn_layers.append(nn.Linear(dnn_input_dim, hidden_dim))
            dnn_layers.append(nn.ReLU())                    # 非线性激活
            dnn_layers.append(nn.BatchNorm1d(hidden_dim))     # 加速收敛
            dnn_input_dim = hidden_dim  # 下一层的输入

        # DNN 的输出层: 输出一个标量
        dnn_layers.append(nn.Linear(dnn_input_dim, 1))

        # 将所有层串联成一个模块
        self.dnn = nn.Sequential(*dnn_layers)

    def forward(self, x):
        """
        前向传播

        Args:
            x: (batch_size, num_features) 离散特征索引
               例如: [[123, 45, 2, 1, 3], ...]
               表示: 用户123, 广告45, iOS, 上午, 位置3

        Returns:
            logits: (batch_size, 1) 预测分数 (未经过 sigmoid)
        """
        batch_size = x.shape[0]  # 获取 batch size

        # ==================== Step 1: Embedding ====================
        # 将离散索引转换为稠密向量
        embedded_features = []
        for i, emb in enumerate(self.embeddings):
            # 取出第 i 个特征的索引, 通过 embedding 查找表得到向量
            # 输入: x[:, i] 形状 (batch_size,)
            # 输出: emb_i 形状 (batch_size, embedding_dim)
            emb_i = emb(x[:, i])
            embedded_features.append(emb_i)

        # 将所有特征的 embedding 在特征维度上拼接
        # 结果形状: (batch_size, num_features * embedding_dim)
        # 例如: (32, 5*8) = (32, 40)
        all_embeddings = torch.cat(embedded_features, dim=1)

        # 重塑为 3D 张量, 方便后续计算
        # 形状: (batch_size, num_features, embedding_dim)
        # 例如: (32, 5, 8)
        all_embeddings = all_embeddings.view(batch_size, self.num_features, self.embedding_dim)

        # ==================== Step 2: FM 部分 ====================

        # --- FM 一阶项 ---
        # 学习特征本身的线性影响
        # 输入需要转为 float, 因为 x 是 long 类型 (索引)
        fm_linear_part = self.fm_linear(x.float())  # 形状: (batch_size, 1)

        # --- FM 二阶项 (核心优化) ---
        # 计算所有特征对的内积: Σᵢⱼ (vᵢ·vⱼ) xᵢxⱼ

        # 优化公式: 1/2 * [(Σᵢ vᵢxᵢ)² - Σᵢ (vᵢxᵢ)²]

        # 第1部分: 先求和再平方
        # 形状: (batch_size, embedding_dim)
        sum_of_vectors = torch.sum(all_embeddings, dim=1)

        # 第2部分: 先平方再求和
        # 形状: (batch_size, embedding_dim)
        sum_of_square = torch.sum(all_embeddings ** 2, dim=1)

        # 应用优化公式
        # 形状: (batch_size, 1)
        fm_interaction_part = 0.5 * torch.sum(sum_of_vectors ** 2 - sum_of_square, dim=1, keepdim=True)

        # FM 总输出 = 一阶项 + 二阶项
        fm_output = fm_linear_part + fm_interaction_part

        # ==================== Step 3: DNN 部分 ====================
        # 展平所有 embeddings 作为 DNN 的输入
        # 形状: (batch_size, num_features * embedding_dim)
        dnn_input = all_embeddings.view(batch_size, -1)

        # 通过多层全连接网络学习高阶非线性
        dnn_output = self.dnn(dnn_input)  # 形状: (batch_size, 1)

        # ==================== Step 4: 合并输出 ====================
        # FM 和 DNN 的输出直接相加
        # FM 负责低阶交互, DNN 负责高阶非线性
        output = fm_output + dnn_output

        return output


# ==================== 使用示例 ====================
if __name__ == '__main__':
    # 模拟 CTR (点击率预测) 数据
    # 特征定义: [用户ID, 广告ID, 设备类型, 时间段, 位置]
    # 数字表示每个特征的可能取值数量
    feature_dims = [1000, 500, 5, 4, 10]

    # 创建 DeepFM 模型
    model = DeepFM(
        feature_dims=feature_dims,  # 特征维度
        embedding_dim=8,            # embedding 向量维度
        hidden_dims=[64, 32]       # DNN 隐藏层配置
    )

    print('=== DeepFM 模型结构 ===')
    print(model)

    # ==================== 生成模拟训练数据 ====================
    batch_size = 32

    # 生成随机训练样本
    # 每个样本是一个特征向量: [用户ID, 广告ID, 设备, 时间, 位置]
    x = torch.tensor([
        [torch.randint(0, dims, size=(1,)).item() for dims in feature_dims]
        for _ in range(batch_size)
    ])

    # 生成随机标签 (0: 不点击, 1: 点击)
    y = torch.randint(0, 2, (batch_size, 1), dtype=torch.float32)

    # ==================== 训练配置 ====================
    # 使用 BCEWithLogitsLoss, 内部自动应用 sigmoid
    criterion = nn.BCEWithLogitsLoss()

    # Adam 优化器, 自适应学习率
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    print('\n=== 开始训练 ===')
    print(f'训练样本数: {batch_size}')
    print(f'特征维度: {feature_dims}')

    # ==================== 训练循环 ====================
    for epoch in range(1000):
        # 前向传播: 模型预测
        pred = model(x)  # 输出 logits

        # 计算损失
        loss = criterion(pred, y)

        # 反向传播: 计算梯度
        optimizer.zero_grad()  # 清空上一次的梯度
        loss.backward()       # 计算当前梯度
        optimizer.step()      # 更新参数

        # 每 20 个 epoch 打印一次 loss
        if (epoch + 1) % 20 == 0:
            print(f'Epoch {epoch + 1:4d}, Loss: {loss.item():.6f}')

    # ==================== 预测示例 ====================
    # 切换到评估模式 (关闭 dropout 等)
    model.eval()

    # 不需要计算梯度, 节省内存和加速
    with torch.no_grad():
        # 生成一个随机测试样本
        test_x = torch.tensor([[
            torch.randint(0, dim, size=(1,)).item() for dim in feature_dims
        ]])

        # 获取模型输出 (logits)
        logits = model(test_x)

        # 应用 sigmoid 得到概率
        click_prob = torch.sigmoid(logits)

        print(f'\n=== 预测结果 ===')
        print(f'模型输出 (logits): {logits.item():.4f}')
        print(f'点击概率 (sigmoid): {click_prob.item():.4f}')

    # ==================== 模型信息 ====================
    # 计算总参数量
    total_params = sum(p.numel() for p in model.parameters())
    print(f'\n模型参数量: {total_params:,}')

    # 计算各部分参数量
    embedding_params = sum(p.numel() for p in model.embeddings.parameters())
    fm_params = sum(p.numel() for p in model.fm_linear.parameters())
    dnn_params = sum(p.numel() for p in model.dnn.parameters())

    print(f'  - Embedding: {embedding_params:,}')
    print(f'  - FM 部分: {fm_params:,}')
    print(f'  - DNN 部分: {dnn_params:,}')

参考资料


相关推荐
Narrastory1 小时前
Note:强化学习(六)
人工智能·深度学习·强化学习
数据智能老司机2 小时前
学习 AutoML——理解 AutoML 流水线
机器学习
Luca_kill2 小时前
GPT Image 2 深度评测:当 AI 图像生成跨越“图灵测试”,它如何重塑开发者工作流?
人工智能·深度学习·openai·ai图像生成·gpt image 2
小糖学代码2 小时前
LLM系列:1.python入门:16.正则表达式与文本处理 (re)
人工智能·pytorch·python·深度学习·神经网络·正则表达式
Ai173163915793 小时前
10大算力芯片某某XXU全解析:CPU/GPU/TPU/NPU/LPU/FPGA/RPU/BPU/DPU/GPGPU
大数据·图像处理·人工智能·深度学习·计算机视觉·自动驾驶·知识图谱
我是大聪明.3 小时前
大模型Tokenizer原理:深入理解BPE与WordPiece子词编码技术
人工智能·深度学习·机器学习
人工智能培训3 小时前
工程科研中的AI应用:结构力学分析技巧
人工智能·深度学习·机器学习·docker·容器
Mr数据杨3 小时前
飞船乘客状态预测与金融风控建模启发
大数据·机器学习·数据分析·kaggle
AGV算法笔记3 小时前
CVPR 2024顶级SLAM论文精读:SplaTAM如何用3D高斯实现稠密RGB-D SLAM?
深度学习·3d·机器人视觉·slam·三维重建