xDeepFM 学习日记

背景与动机

DeepFM 的局限

复制代码
DeepFM = FM (二阶) + DNN (隐式高阶)

问题:

  • FM 只显式建模二阶交互
  • DNN 虽然能学高阶,但隐式、不可解释
  • 某些显式的高阶交互难以有效学习

xDeepFM 的创新

核心思想: 显式建模高阶特征交互

复制代码
xDeepFM = CIN (显式高阶) + Linear (线性) + DNN (隐式高阶)

三大组件:

  1. CIN - 压缩交互网络,显式建模高阶
  2. Linear - 线性部分
  3. DNN - 隐式建模高阶

演进历史

复制代码
FM (2010)
   ↓ 只能建模二阶交互
FFM (2016)
   ↓ 引入域感知
DeepFM (2017)
   ↓ 添加 DNN 捕捉高阶
xDeepFM (2018) ⭐
   ↓ CIN 显式建模高阶交互

核心创新:CIN 网络

CIN 是什么?

CIN = Compressed Interaction Network (压缩交互网络)

CIN 能够以显式的方式建模高阶特征交互,而且参数量可控。

CIN 的工作原理

输入
复制代码
特征 0: [v₀₁, v₀₂, ...]  # k 维向量
特征 1: [v₁₁, v₁₂, ...]
特征 2: [v₂₁, v₂₂, ...]
...
CIN 层
复制代码
第 1 层:
  所有特征两两交互 → 压缩 → 新特征

第 2 层:
  特征和上层的输出两两交互 → 压缩 → 新特征

...
输出
复制代码
汇总各层输出 → 拼接 → 最终特征

CIN 的计算过程

数学公式

第 l 层 CIN:

复制代码
Xₖ = Σᵢⱼ (wₖ,ᵢⱼ ⊙ (xᵢ ⊙ xⱼ))  (l 层, k 输出通道)

其中:
- xᵢ, xⱼ: 特征向量
- ⊙: 逐元素相乘
- wₖ,ᵢⱼ: 可学习的权重矩阵
- Σᵢⱼ: 对所有特征对求和

逐元素相乘 (Hadamard Product):

复制代码
xᵢ = [a₁, a₂, a₃]
xⱼ = [b₁, b₂, b₃]

xᵢ ⊙ xⱼ = [a₁×b₁, a₂×b₂, a₃×b₃]
CIN 的优势
特性 说明
显式交互 每层的交互可以解释
高阶能力 多层 CIN 可以建模任意高阶
参数可控 共享权重,参数量适中
向量级 在向量级别交互,而非特征级别

CIN vs FM 的二阶交互

维度 FM 二阶 CIN (1层)
计算方式 vᵢ · vⱼ (内积) vᵢ ⊙ vⱼ (逐元素相乘)
输出维度 标量 (1) 向量 (k)
信息保留 压缩为 1 个数 保留 k 维信息

示例:

复制代码
FM:  vᵢ · vⱼ = 0.5  # 只得到一个数

CIN: vᵢ ⊙ vⱼ = [0.1, 0.3, 0.5, 0.2]  # 保留整个向量

xDeepFM vs DeepFM 对比

核心区别

维度 DeepFM xDeepFM
二阶交互 FM 隐向量内积 CIN 显式交互
高阶建模 DNN 隐式 CIN 显式 + DNN 隐式
可解释性 强 (CIN 部分可解释)
参数量 中等 稍多

模型结构对比

DeepFM:

复制代码
Input → Embedding → [FM + DNN] → Output

xDeepFM:

复制代码
Input → Embedding → [CIN + Linear + DNN] → Output
                          ↑
                    多层显式交互

模型架构

整体结构

复制代码
                    输入特征 (离散索引)
                            ↓
                    Embedding 层 (稀疏→稠密)
                            ↓
        ┌───────────────────┼───────────────────┐
        ↓                   ↓                                     ↓
    CIN 网络          Linear 部分              DNN 部分
 (显式高阶)         (线性影响)           (隐式高阶)
        ↓                   ↓                   ↓
    CIN 输出          Linear 输出            DNN 输出
        └───────────────────┴───────────────────┘
                            ↓
                        最终输出层

各组件详解

1. Embedding 层

将离散特征索引映射到稠密向量

复制代码
用户 123 → [0.23, 0.15, ...]  # k 维
广告 45  → [0.67, 0.32, ...]
2. CIN 网络

多层 CIN 堆叠:

复制代码
CIN Layer 1:
  输入: 原始特征
  操作: 所有特征两两逐元素相乘 → 压缩
  输出: 第 1 层特征

CIN Layer 2:
  输入: 原始特征 + 第 1 层特征
  操作: 所有特征两两逐元素相乘 → 压缩
  输出: 第 2 层特征

...

汇总输出:

复制代码
CIN 输出 = [Layer1 输出, Layer2 输出, ...] 拼接
3. Linear 部分

学习特征的线性影响

复制代码
Linear 输出 = Σᵢ wᵢxᵢ
4. DNN 部分

隐式学习复杂的高阶非线性关系

复制代码
DNN 输出 = MLP(Embedding)

参数量计算

假设:

  • 特征数 n = 5
  • 隐向量维度 k = 8
  • CIN 层数 L = 3
  • CIN 隐维度 h = 64

CIN 参数量:

复制代码
每层 CIN: k × h = 8 × 64 = 512
总 CIN 参数: L × k × h = 3 × 512 = 1536

总参数量对比:

模型 参数量
DeepFM ~5K
xDeepFM ~10K

代码实现

CIN 层实现

python 复制代码
class CIN(nn.Module):
    """
    Compressed Interaction Network

    功能: 显式建模高阶特征交互
    """
    def __init__(self, num_features, hidden_dim, k=8):
        super().__init__()

        # 权重矩阵: (k, num_features, num_features)
        # k 是输出通道,可以学习不同角度的交互
        self.w = nn.Parameter(torch.randn(k, num_features, num_features) * 0.01)

        # 偏置
        self.b = nn.Parameter(torch.zeros(k, 1))

    def forward(self, x):
        """
        Args:
            x: (batch_size, num_features, k) 特征 embedding

        Returns:
            output: (batch_size, k) 压缩后的输出
        """
        batch_size = x.shape[0]
        num_features = x.shape[1]
        k = x.shape[2]

        # 逐元素相乘所有特征对
        # 结果: (batch_size, num_features, num_features, k)
        pairwise = torch.einsum('bik,bjk->bijk', x, x)

        # 加权求和
        # 结果: (batch_size, k, 1)
        output = torch.einsum('kij,bijk->bk', self.w, pairwise) + self.b

        return output.squeeze(1)  # (batch_size, k)

xDeepFM 实现

python 复制代码
import torch
import torch.nn as nn


class xDeepFM(nn.Module):
    """
    eXtreme Deep Factorization Machine

    核心创新:
        CIN (Compressed Interaction Network) 显式建模高阶交互

    模型结构:
        CIN + Linear + DNN
    """

    def __init__(self, feature_dims, embedding_dim=8,
                 cin_layers=[64, 64], hidden_dims=[64, 32]):
        super().__init__()

        self.feature_dims = feature_dims
        self.num_features = len(feature_dims)
        self.embedding_dim = embedding_dim

        # ==================== Embedding 层 ====================
        self.embeddings = nn.ModuleList([
            nn.Embedding(dim, embedding_dim) for dim in feature_dims
        ])

        # ==================== CIN 网络 ====================
        self.cin_layers = nn.ModuleList()

        # CIN 层数,每层输出不同的特征
        for hidden_dim in cin_layers:
            self.cin_layers.append(
                CIN(self.num_features, hidden_dim, embedding_dim)
            )

        # ==================== Linear 部分 ====================
        self.linear = nn.Linear(self.num_features, 1)

        # ==================== DNN 部分 ====================
        dnn_input_dim = self.num_features * embedding_dim
        dnn_layers = []
        for hidden_dim in hidden_dims:
            dnn_layers.append(nn.Linear(dnn_input_dim, hidden_dim))
            dnn_layers.append(nn.ReLU())
            dnn_layers.append(nn.BatchNorm1d(hidden_dim))
            dnn_input_dim = hidden_dim
        dnn_layers.append(nn.Linear(dnn_input_dim, 1))

        self.dnn = nn.Sequential(*dnn_layers)

    def forward(self, x):
        batch_size = x.shape[0]

        # ==================== Embedding ====================
        embedded_features = []
        for i, emb in enumerate(self.embeddings):
            emb_i = emb(x[:, i])
            embedded_features.append(emb_i)

        all_embeddings = torch.cat(embedded_features, dim=1)
        all_embeddings = all_embeddings.view(
            batch_size, self.num_features, self.embedding_dim
        )

        # ==================== CIN 部分 ====================
        cin_outputs = [all_embeddings]  # 保存每层输出
        for cin in self.cin_layers:
            # 下一层的输入 = 原始 embedding + 之前所有层的输出
            layer_input = torch.cat(cin_outputs, dim=1)

            # 通过 CIN 层
            layer_output = cin(layer_input)

            cin_outputs.append(layer_output)

        # 汇总 CIN 输出
        cin_output = torch.cat(cin_outputs[1:], dim=1)  # 跳过原始 embedding
        cin_output = cin_output.view(batch_size, -1)

        # ==================== Linear 部分 ====================
        linear_output = self.linear(x.float())

        # ==================== DNN 部分 ====================
        dnn_input = all_embeddings.view(batch_size, -1)
        dnn_output = self.dnn(dnn_input)

        # ==================== 合并输出 ====================
        output = cin_output + linear_output + dnn_output

        return output

面试常见问题

Q1: xDeepFM 相比 DeepFM 的核心优势是什么?

A:

  1. 显式高阶交互:CIN 明确建模高阶,比 DNN 更可解释
  2. 逐元素交互:保留更多信息(向量级而非标量级)
  3. 多层堆叠:可以建模任意阶数的交互
  4. 参数可控:通过压缩机制控制参数增长

Q2: CIN 和 FM 的交互有什么区别?

A:

对比维度 FM 二阶 CIN
计算方式 内积 vᵢ · vⱼ 逐元素相乘 vᵢ ⊙xⱼ
输出维度 标量 (1) 向量 (k)
信息保留 压缩为 1 个数 保留 k 维信息
学习能力 只能学二阶 多层可学高阶

Q3: 什么是逐元素相乘 (Hadamard Product)?

A:

python 复制代码
v₁ = [a₁, a₂, a₃]
v₂ = [b₁, b₂, b₃]

# 逐元素相乘
v₁ ⊙ v₂ = [a₁×b₁, a₂×b₂, a₃×b₃]

# 对比:内积
v₁ · v₂ = a₁×b₁ + a₂×b₂ + a₃×b₃

逐元素相乘保留了每个维度的信息,而内积合并为一个标量。

Q4: CIN 为什么叫"压缩"交互网络?

A:

因为 CIN 通过权重矩阵将高维的交互特征压缩到低维:

复制代码
原始交互: n × n × k  (所有特征对)
压缩后: k              (低维)

这样既建模了复杂交互,又控制了参数量。

Q5: xDeepFM 的三大组件各自的作用?

A:

组件 作用 贡献
CIN 显式高阶交互 可解释的特征交互
Linear 线性影响 特征的基础重要性
DNN 隐式高阶 难以显式建模的复杂关系

三者互补,提供更全面的表达能力。

Q6: 为什么需要 Linear 部分?CIN 已经建模交互了。

A:

CIN 只建模特征间的交互 ,Linear 建模特征本身的影响

类比:

  • CIN: 用户×广告交互
  • Linear: 用户本身、广告本身的影响

Q7: xDeepFM 的适用场景?

A:

场景 是否推荐 原因
复杂特征交互 ✅ 推荐 CIN 显式建模
需要可解释性 ✅ 推荐 CIN 部分可解释
数据量大 ✅ 推荐 参数较多
数据量小 ❌ 不推荐 容易过拟合
延迟敏感 ❌ 不推荐 计算复杂
简单场景 ❌ 不推荐 DeepFM 足够

Q8: 如何调优 CIN 的层数?

A:

层数 效果 参数量
1 层 二阶交互
2 层 三阶交互
3+ 层 更高阶

建议:

  • 先用 2 层测试
  • 根据验证集效果增减
  • 层数过多容易过拟合

模型对比总结

复制代码
FM (2010)
   ↓ 特征交互,但只能二阶
FFM (2016)
   ↓ 引入域感知
DeepFM (2017)
   ↓ 加入 DNN 隐式高阶
xDeepFM (2018)
   ↓ CIN 显式高阶交互
AutoInt (xDeepFM后续)
   ↓ 自动确定交互阶数

快速检查清单

理解 xDeepFM,你应该能回答:

  • 解释 xDeepFM 和 DeepFM 的区别
  • 说明 CIN 的作用和原理
  • 解释逐元素相乘 vs 内积
  • 了解 xDeepFM 的三大组件
  • 计算 CIN 的参数量
  • 知道如何调优 CIN 层数
  • 理解为什么需要显式高阶交互

实现案例(CTR预测)

python 复制代码
import torch
import torch.nn as nn


class CIN(nn.Module):
    """
    Compressed Interaction Network (CIN)

    核心功能: 显式建模高阶特征交互

    简化版本: 只基于原始特征进行交互,不累积之前的层
    """

    def __init__(self, num_features, embedding_dim=8):
        """
        Args:
            num_features: 特征数量
            embedding_dim: embedding 向量维度
        """
        super().__init__()

        self.num_features = num_features
        self.embedding_dim = embedding_dim

        # ==================== CIN 权重矩阵 ====================
        # 学习每对特征交互的权重
        # 形状: (num_features, num_features, embedding_dim)
        self.w = nn.Parameter(torch.randn(num_features, num_features, embedding_dim) * 0.01)

    def forward(self, x):
        """
        Args:
            x: (batch_size, num_features, embedding_dim) 特征 embedding

        Returns:
            output: (batch_size, num_features, embedding_dim) CIN 输出
        """
        # ==================== 逐元素相乘 (Hadamard Product) ====================
        # 计算所有特征对的逐元素相乘
        # xᵢ ⊙ xⱼ: (batch_size, num_features, num_features, embedding_dim)
        pairwise = torch.einsum('bik,bjk->bijk', x, x)

        # ==================== 加权求和 ====================
        # 对每个特征 i,加权求和所有特征对 (i, j) 的交互
        # w: (num_features, num_features, embedding_dim)
        # pairwise: (batch_size, num_features, num_features, embedding_dim)
        # 结果: (batch_size, num_features, embedding_dim)
        output = torch.einsum('ijk,bijk->bik', self.w, pairwise)

        return output


class xDeepFM(nn.Module):
    """
    eXtreme Deep Factorization Machine (xDeepFM)

    核心创新:
        CIN (Compressed Interaction Network) 显式建模高阶交互

    模型结构:
        CIN (显式高阶) + Linear (线性) + DNN (隐式高阶)
    """

    def __init__(self, feature_dims, embedding_dim_cin=8,
                 embedding_dim=8, cin_layers=[64, 64],
                 hidden_dims=[64, 32]):
        """
        Args:
            feature_dims: 每个特征的可能可能取值数列表
            embedding_dim_cin: CIN 使用的 embedding 维度
            embedding_dim: DNN 使用的 embedding 维度
            cin_layers: CIN 各层的输出维度列表
            hidden_dims: DNN 隐藏层维度列表
        """
        super().__init__()

        self.feature_dims = feature_dims
        self.num_features = len(feature_dims)
        self.embedding_dim = embedding_dim

        # ==================== Embedding 层 ====================
        # 为每个特征创建独立的 embedding 表
        self.embeddings = nn.ModuleList([
            nn.Embedding(dim, embedding_dim_cin) for dim in feature_dims
        ])

        # ==================== CIN 网络 ====================
        self.cin_layers = nn.ModuleList()

        # 创建多层 CIN
        for _ in cin_layers:
            self.cin_layers.append(
                CIN(self.num_features, embedding_dim_cin)
            )

        # ==================== Linear 部分 ====================
        # 学习特征的线性影响
        self.linear = nn.Linear(self.num_features, 1)

        # ==================== DNN 部分 ====================
        # DNN 的输入维度
        dnn_input_dim = self.num_features * embedding_dim

        # 构建 DNN 层
        dnn_layers = []
        for hidden_dim in hidden_dims:
            dnn_layers.append(nn.Linear(dnn_input_dim, hidden_dim))
            dnn_layers.append(nn.ReLU())
            dnn_layers.append(nn.BatchNorm1d(hidden_dim))
            dnn_input_dim = hidden_dim

        # DNN 输出层
        dnn_layers.append(nn.Linear(dnn_input_dim, 1))

        self.dnn = nn.Sequential(*dnn_layers)

    def forward(self, x):
        """
        Args:
            x: (batch_size, num_features) 离散特征索引

        Returns:
            logits: (batch_size, 1) 预测分数
        """
        batch_size = x.shape[0]

        # ==================== Embedding ====================
        # 将离散特征转为稠密向量
        embedded_features = []
        for i, emb in enumerate(self.embeddings):
            emb_i = emb(x[:, i])
            embedded_features.append(emb_i)

        # 拼接所有 embedding
        all_embeddings = torch.cat(embedded_features, dim=1)

        # 重塑为 (batch_size, num_features, embedding_dim)
        all_embeddings = all_embeddings.view(
            batch_size, self.num_features, -1
        )

        # ==================== CIN 部分 ====================
        # 使用第一个 CIN 层处理原始 embedding
        # CIN 输出: (batch_size, num_features, embedding_dim)
        # 需要聚合为 (batch_size, 1)
        cin_output = self.cin_layers[0](all_embeddings)
        # 先对特征维度求和: (batch_size, embedding_dim)
        cin_output = torch.sum(cin_output, dim=1)
        # 再对 embedding 维度求和: (batch_size, 1)
        cin_output = torch.sum(cin_output, dim=1, keepdim=True)

        # ==================== Linear 部分 ====================
        linear_output = self.linear(x.float())

        # ==================== DNN 部分 ====================
        # 展平 embedding 作为 DNN 输入
        dnn_input = all_embeddings.view(batch_size, -1)
        dnn_output = self.dnn(dnn_input)

        # ==================== 合并输出 ====================
        # xDeepFM = CIN + Linear + DNN
        output = cin_output + linear_output + dnn_output

        return output


# ==================== 使用示例 ====================
if __name__ == '__main__':
    # 特征定义
    # 特征: [用户ID, 广告ID, 设备类型, 时间段, 位置]
    feature_dims = [1000, 500, 5, 4, 10]

    # 创建 xDeepFM 模型
    model = xDeepFM(
        feature_dims=feature_dims,
        embedding_dim_cin=8,
        embedding_dim=8,
        cin_layers=[64, 64],    # 2 层 CIN
        hidden_dims=[64, 32]     # DNN 隐藏层
    )

    print('=== xDeepFM 模型结构 ===')
    print(model)

    # ==================== 参数量分析 ====================
    total_params = sum(p.numel() for p in model.parameters())

    # 计算各部分参数量
    embedding_params = sum(p.numel() for p in model.embeddings.parameters())
    cin_params = sum(p.numel() for p in model.cin_layers.parameters())
    linear_params = sum(p.numel() for p in model.linear.parameters())
    dnn_params = sum(p.numel() for p in model.dnn.parameters())

    print(f'\n参数量分析:')
    print(f'  总参数量: {total_params:,}')
    print(f'  Embedding: {embedding_params:,}')
    print(f'  CIN 网络: {cin_params:,}')
    print(f'  Linear 部分: {linear_params:,}')
    print(f'  DNN 部分: {dnn_params:,}')

    # ==================== 生成训练数据 ====================
    batch_size = 32

    # 生成随机训练样本
    # 每个样本: [用户ID, 广告ID, 设备, 时间,时间]
    x = torch.tensor([
        [torch.randint(0, dim, size=(1,)).item() for dim in feature_dims]
        for _ in range(batch_size)
    ])

    # 生成随机标签
    y = torch.randint(0, 2, (batch_size, 1), dtype=torch.float32)

    # ==================== 训练配置 ====================
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    print(f'\n=== 开始训练 ===')
    print(f'batch_size: {batch_size}')
    print(f'特征数: {model.num_features}')
    print(f'CIN 层数: {len(model.cin_layers)}')

    # ==================== 训练循环 ====================
    for epoch in range(2000):
        # 前向传播
        pred = model(x)

        # 计算损失
        loss = criterion(pred, y)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 20 == 0:
            print(f'Epoch {epoch + 1:3d}, Loss: {loss.item():.6f}')

    # ==================== 预测示例 ====================
    model.eval()

    with torch.no_grad():
        # 生成测试样本
        test_x = torch.tensor([[
            torch.randint(0, dim, size=(1,)).item() for dim in feature_dims
        ]])

        # 预测
        logits = model(test_x)
        click_prob = torch.sigmoid(logits)

        print(f'\n=== 预测结果 ===')
        print(f'模型输出 (logits): {logits.item():.4f}')
        print(f'点击概率 (sigmoid): {click_prob.item():.4f}')

    # ==================== CIN 工作原理示例 ====================
    print(f'\n=== CIN 工作原理示例 ===')
    print('假设两个特征的逐元素相乘:')
    print('v₁ = [0.1, 0.2, 0.3]')
    print('v₂ = [0.4, 0.5, 0.6]')
    print('v₁ ⊙ v₂ = [0.04, 0.10, 0.18]  (逐元素相乘)')
    print('\n对比 FM 的内积:')
    print('v₁ · v₂ = 0.1×0.4 + 0.2×0.5 + 0.3×0.6 = 0.32  (合并为一个数)')
    print('\nCIN 优势: 保留完整向量信息,可以学习更复杂的交互')

参考资料


相关推荐
啦啦啦_99991 小时前
1. 逻辑回归
算法·机器学习·逻辑回归
南宫萧幕2 小时前
Python与Simulink联合仿真:基于DQN的HEV能量管理策略建模与全链路排雷实战
开发语言·人工智能·python·算法·机器学习·matlab·控制
小糖学代码2 小时前
LLM系列:2.pytorch入门:9.神经网络的学习
人工智能·python·深度学习·神经网络·学习·机器学习
liuyunshengsir2 小时前
手写最基础的大模型推理并使用Profile监控GPU性能消耗情况
人工智能·深度学习·机器学习
硅谷秋水4 小时前
《自动驾驶系统开发》英文版《Autonomous Driving Hanbook》推荐
人工智能·深度学习·机器学习·计算机视觉·语言模型·自动驾驶
啦啦啦_99994 小时前
案例之 逻辑回归_癌症预测
算法·机器学习·逻辑回归
惊鸿一博4 小时前
自动驾驶_一段式端到端_三条技术路线_UniAD_SparseDrive_概述
人工智能·机器学习·自动驾驶
我是大聪明.5 小时前
大模型Tokenizer原理:BPE、WordPiece与子词编码的核心机制深度解析
人工智能·线性代数·算法·机器学习·矩阵
威尔逊·柏斯科·希伯理5 小时前
机器学习-特征工程
人工智能·机器学习