第五章:计算机视觉-项目实战之推荐/广告系统
第二部分:粗排算法
第二节:理解粗排模型之离线部分:双塔模型结构精讲及实现
一、粗排在推荐系统中的位置:为什么它需要"双塔模型"?
回顾完整推荐系统流程:
召回 → 粗排 → 精排 → 重排 → 策略 → 展示
其中:
| 阶段 | 典型数量级 | 目标 | 典型模型 |
|---|---|---|---|
| 召回 | 1万~10万候选 | 保证召回覆盖面 | Faiss / Milvus / i2i / u2i |
| 粗排 | 几百~几千候选 | 快速筛掉明显不相关内容 | 双塔(Two-Tower / DSSM) |
| 精排 | 50~200候选 | 细粒度评分 | DNN / Wide&Deep / DIN / Transformer |
| 重排 | 20~50候选 | 多目标+多样性 | GBDT / Rank / 强化学习 |
粗排模型的核心目标只有一句话:
在极低延迟的前提下,让真正相关的内容尽可能排在前面。
因此粗排必须具备 3 个特点:
| 粗排需求 | 能力 |
|---|---|
| 高 QPS(百万级) | 必须轻量推理 |
| 向量化检索 | 需要可 ANN 检索 |
| 用户实时性 & 物料稳定性解耦 | 用户变、物料不变 |
而双塔模型刚好满足全部要求,因此成为工业界粗排事实标准方案。
二、双塔模型核心结构拆解(Two-Tower Architecture)
结构非常简单,可用一句话概括:
将用户和物料分别编码为向量,并在同一向量空间对齐,通过向量相似度衡量匹配分数。
如下图结构(示意图):
┌──────────────────────┐ ┌──────────────────────┐
│ User Tower │ │ Item Tower │
│ Embedding + DNN │ │ Embedding + DNN │
└───────┬──────────────┘ └─────────┬────────────┘
│ │
User Vector u Item Vector v
│ │
└─────────── Cosine / Dot ──────────┘
Matching Score
特点如下:
| 组件 | 作用 |
|---|---|
| User Tower | 建模用户兴趣(行为序列、性别、年龄、兴趣 Embedding) |
| Item Tower | 建模物料语义(标题、分类、作者、Embedding) |
| Matching Space | 将两塔向量映射到同一个语义空间 |
| Similarity | cos(u, v) / u·v 作为分数 |
粗排不追求强表达能力,而是追求快 & 稳定 & 易 ANN 检索,因此双塔非常适配。
三、训练样本构造与损失函数
双塔训练本质是对比学习(Contrastive Learning) ,最主流 Loss 为 InfoNCE,思想:
正样本相似度要高,负样本相似度要低。
训练样本格式:
| user | positive item | negative items |
|---|---|---|
| U1 | I_pos | I_neg1, I_neg2, I_neg3... |
Loss 公式(Batch 内共享负样本):
工业界中 90% 粗排都这么训练,原因是:
不需要额外负样本生成
Batch 内天然提供大量负样本
收敛快,鲁棒性高
四、可运行的 PyTorch 双塔粗排核心代码(可直接训练)
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class TwoTowerModel(nn.Module):
def __init__(self, user_num, item_num, emb_dim=64):
super().__init__()
self.user_emb = nn.Embedding(user_num, emb_dim)
self.item_emb = nn.Embedding(item_num, emb_dim)
def forward(self, user_ids, pos_item_ids):
u = self.user_emb(user_ids) # [B, D]
v = self.item_emb(pos_item_ids) # [B, D]
u = F.normalize(u, dim=-1)
v = F.normalize(v, dim=-1)
logits = torch.matmul(u, v.t()) # [B, B] 共享负样本
labels = torch.arange(len(user_ids)).to(user_ids.device)
loss = F.cross_entropy(logits, labels)
return loss, u, v
训练循环:
python
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for batch in train_loader:
loss, u, v = model(batch['user'], batch['item'])
optimizer.zero_grad()
loss.backward()
optimizer.step()
训练完成后产出:
python
user_emb.pt
item_emb.pt
这两份向量即可对接 ANN(Faiss / Milvus / HNSW)用于粗排检索。
五、粗排 ANN 召回衔接流程(离线 + 在线)
| 阶段 | 操作 |
|---|---|
| 离线 | 将所有 item embedding 建索引(Faiss/HNSW/Milvus) |
| 在线 | User embedding → TopK nearest item → 进入精排 |
伪代码:
python
import faiss
index = faiss.IndexFlatIP(emb_dim)
index.add(item_emb_np)
scores, ids = index.search(user_emb_np, topk)
粗排返回 100~500 item,进入精排 → CTR/DIN 计算更精的点击相关度
这是工业界最稳定的两阶段结构。
六、双塔 vs DSSM vs 精排模型对比总结
| 模型 | 用途 | 特点 |
|---|---|---|
| 双塔(粗排主力) | Matching | 快、可 ANN、结构简单 |
| DSSM | 召回 | 多用于语义匹配,结构更深 |
| 精排 DNN/DIN | Scoring | 单路模型,表达强但慢 |
一句话总结:
双塔不是最强的模型,但在粗排阶段它一定是最合适的模型。
七、本节总结
| 你现在已经理解了 | 状态 |
|---|---|
| 粗排为什么存在 | ✅ |
| 为什么双塔是粗排最佳方案 | ✅ |
| 双塔结构、Loss、训练逻辑 | ✅ |
| PyTorch实现 | ✅ |
| 如何与 ANN 连上,进入工业流水线 | ✅ |