【蒸馏用损失】NCEloss介绍,大规模分类任务的损失函数

NCE Loss(Noise Contrastive Estimation Loss,噪声对比估计损失)是一种用于大规模分类任务的损失函数,尤其适用于类别数量极多(如词表预测、推荐系统召回等场景)时替代传统Softmax的方法。其核心思想是将多分类问题转化为二分类问题,通过对比正样本与负样本的差异来优化模型。

第一节、核心原理

  1. 问题转化

    • 传统Softmax需要计算所有类别的概率归一化项(分母),当类别数 N N N极大时计算成本高昂。NCE Loss通过采样少量负样本(如 k k k个),将问题简化为 k + 1 k+1 k+1个二分类任务(1个正样本 + k k k个负样本)。正负样本是对这些输入向量的目标类别(标签)的划分 。例如,在Word2Vec中,inputs是中心词的嵌入向量,正样本是其上下文词(标签)负样本是从噪声分布中采样的其他词
  2. 损失函数形式

    • 对于正样本 ( x , y ) (x, y) (x,y)和负样本 ( x , y j ) (x, y_j) (x,yj),NCE Loss定义为:
      L = − log ⁡ σ ( s ( x , y ) ) − ∑ j = 1 k log ⁡ σ ( − s ( x , y j ) ) L = -\log \sigma(s(x,y)) - \sum_{j=1}^k \log \sigma(-s(x,y_j)) L=−logσ(s(x,y))−j=1∑klogσ(−s(x,yj))
      其中 σ \sigma σ为Sigmoid函数, s ( x , y ) s(x,y) s(x,y)为模型对样本 ( x , y ) (x,y) (x,y)的得分(如点积)。

在NCE Loss(Noise Contrastive Estimation Loss)的上下文中, s ( x , y ) s(x,y) s(x,y) 通常表示模型对输入对 ( x , y ) (x,y) (x,y) 的匹配度得分,具体含义如下:

x x x 的典型含义

  • 在推荐系统中, x x x 通常代表用户特征(如用户ID、历史行为等)或用户向量(如双塔模型中的用户塔输出)。
  • 在自然语言处理中, x x x 可能是上下文词向量(如Word2Vec中的中心词或上下文窗口)。

y y y 的典型含义

  • 在推荐系统中, y y y 代表物品(Item),例如用户点击过的正样本物品或采样的负样本物品。
  • 在NLP中, y y y 可能是目标词(如语言模型中的下一个词)或噪声词(负采样词)。

s ( x , y ) s(x,y) s(x,y) 的计算方式

  • 一般为 x x x 和 y y y 的点积(如双塔模型中的用户向量与物品向量点积)。

  • 在NCE中, s ( x , y ) s(x,y) s(x,y) 会经过修正: g ( x , y ) = s ( x , y ) − log ⁡ Q ( y ∣ x ) g(x,y) = s(x,y) - \log Q(y|x) g(x,y)=s(x,y)−logQ(y∣x),其中 Q ( y ∣ x ) Q(y|x) Q(y∣x) 是负采样分布的概率。

  • x x x:输入特征(用户/上下文)

  • y y y:目标类别(物品/词)

  • s ( x , y ) s(x,y) s(x,y) :模型对 ( x , y ) (x,y) (x,y) 匹配度的原始得分,用于后续损失计算。

  1. 负采样策略 : (见第三节)
    • 负样本通常从噪声分布 Q ( y ) Q(y) Q(y)中采样,常见策略包括:
      • 均匀采样:所有类别等概率。
      • 频次采样:按类别出现频率采样,热门类别更易被选中。
      • 调节采样 :通过参数 b b b控制热门类别的采样概率(如 Q ( y ) ∝ f ( y ) b Q(y) \propto f(y)^b Q(y)∝f(y)b, f ( y ) f(y) f(y)为频率)。

优势与特性

  • 计算高效:仅需计算少量负样本的得分,避免全类别Softmax的开销。
  • 理论保证 :当负样本数 k k k足够大时,NCE Loss的梯度趋近于Softmax的梯度。
  • 灵活性:可通过调整负采样分布适应不同任务需求(如打压热门物品)。

应用场景

  1. 自然语言处理
    • Word2Vec中采用NCE或负采样(NEG,NCE的简化版)训练词向量。
  2. 推荐系统
    • 召回阶段用NCE Loss区分用户正反馈物品与随机负样本。
  3. 对比学习
    • 如CLIP模型使用InfoNCE(NCE的变体)对齐图文特征。

变体与改进

  • NEG(Negative Sampling)
    忽略NCE中的分布修正项,进一步简化计算,但无理论收敛保证。
  • InfoNCE
    引入温度系数 τ \tau τ调节分布尖锐程度,常用于对比学习。
  • 硬负样本挖掘
    选择模型当前易混淆的负样本(Hard Negative)提升训练效果。

第二节、代码实现示例

python 复制代码
loss = nce_loss(
    weights=embedding_matrix,  # 类别嵌入矩阵
    biases=biases,             # 类别偏置
    inputs=input_embeddings,   # 输入特征
    labels=positive_labels,    # 正样本标签
    num_sampled=num_negatives, # 负样本数
    num_classes=vocab_size     # 总类别数
)

在NCE Loss的实现中,参数inputslabels分别对应数学公式中的 x x x和 y y y,具体含义如下:

x x x(输入特征)

  • 对应代码中的inputs=input_embeddings,表示模型输入的向量(如用户特征、上下文词向量等)。
  • 例如,在推荐系统中, x x x可能是用户嵌入向量;在NLP中, x x x可能是中心词或上下文窗口的向量表示。

y y y(目标类别)

  • 对应代码中的labels=positive_labels,表示正样本的类别标签(如物品ID、目标词等)。
  • 例如,推荐系统中 y y y是用户点击的物品,NLP中 y y y是目标词或噪声词。

补充说明:

  • s ( x , y ) s(x,y) s(x,y)的计算
    通常为 x x x和 y y y的嵌入向量点积(如matmul(inputs, weights)),即模型对 ( x , y ) (x,y) (x,y)匹配度的原始得分。
  • 负样本处理
    NCE Loss通过num_sampled参数采样负样本(噪声分布 Q ( y ) Q(y) Q(y)中的 y j y_j yj),与正样本 y y y共同构成二分类任务。

第三节、负样本处理

在NCE Loss中,通过num_sampled参数采样负样本的过程如下:

1. 输入与输出

  • 输入

    • 正样本标签labels(形状为(batch_size, num_true)),表示当前输入x对应的真实类别y
    • 噪声分布Q(y),即负样本的采样分布(如均匀分布、频次采样等)。
    • 负样本数量num_sampled,指定每个正样本采样的负样本数k
  • 输出

    • 采样的负样本 :从噪声分布Q(y)中采样k个类别y_jj=1,2,...,k),与正样本y共同构成k+1个二分类任务。

2. 负样本采样机制

NCE Loss通过以下步骤采样负样本:

  1. 噪声分布选择
    • 均匀采样 :所有类别等概率被选中(Q(y) = 1/NN为总类别数)。
    • 频次采样 :类别y的采样概率与其在数据集中出现的频率成正比(Q(y) ∝ f(y))。
    • 调节采样 :通过调节因子b控制热门类别的采样概率(Q(y) ∝ f(y)^bb=0时为均匀采样,b=1时为频次采样)。
    • log-uniform采样 : ( P ( k ) P(k) P(k)是 Q ( y ) Q(y) Q(y)的一种实现方式:) P ( k ) = log ⁡ ( k + 2 ) − log ⁡ ( k + 1 ) log ⁡ ( range_max + 1 ) P(k) = \frac{\log(k+2) - \log(k+1)}{\log(\text{range\_max} + 1)} P(k)=log(range_max+1)log(k+2)−log(k+1)
      其中k为类别ID,range_max=num_classes。该分布使得高频类别(大k)被采样的概率更低。
采样类型 概率分布 特点 适用场景
均匀采样 P ( k ) = 1 N P(k) = \frac{1}{N} P(k)=N1 所有类别等概率 类别分布未知或均匀时
频次采样 P ( k ) ∝ f ( k ) P(k) \propto f(k) P(k)∝f(k) 直接按类别出现频率 f ( k ) f(k) f(k)采样 需精确匹配数据分布时
log-uniform采样 P ( k ) ∝ 1 k P(k) \propto \frac{1}{k} P(k)∝k1 对高频类别降权,平衡采样效率与覆盖 词频、推荐物品ID等长尾数据

2. 数学形式

对于每个正样本(x, y),NCE Loss构造的损失函数为:
L = − log ⁡ σ ( s ( x , y ) − log ⁡ Q ( y ) ) − ∑ j = 1 k log ⁡ σ ( − s ( x , y j ) + log ⁡ Q ( y j ) ) L = -\log \sigma(s(x,y) - \log Q(y)) - \sum_{j=1}^k \log \sigma(-s(x,y_j) + \log Q(y_j)) L=−logσ(s(x,y)−logQ(y))−j=1∑klogσ(−s(x,yj)+logQ(yj))

其中:

  • s(x,y)是模型对(x,y)的匹配得分(如点积)。
  • σ为Sigmoid函数。
  • Q(y)是噪声分布的概率。

正样本 y = 1 y=1 y=1 和负样本 y j = 1 y_j=1 yj=1

  • 若正样本 y = 1 y=1 y=1 和负样本 y j = 1 y_j=1 yj=1 理论上 的噪声概率相同(即 Q ( y = 1 ) = Q ( y j = 1 ) Q(y=1) = Q(y_j=1) Q(y=1)=Q(yj=1)),但实际采样时 y j y_j yj 不会与 y y y 重复。
  • 若任务中允许 y j = y y_j=y yj=y(罕见情况),则两者的噪声概率确实相同,但这种情况会破坏对比学习的有效性。

假设类别总数 N = 100 N=100 N=100,采用均匀分布:

  • 正样本 y = 1 y=1 y=1 的 Q ( y = 1 ) = 1 100 Q(y=1) = \frac{1}{100} Q(y=1)=1001。
  • 负样本 y j y_j yj 的 Q ( y j ) Q(y_j) Q(yj) 也是 1 100 \frac{1}{100} 1001,但实际采样时 y j ∈ { 2 , 3 , . . . , 100 } y_j \in \{2, 3, ..., 100\} yj∈{2,3,...,100},因此 Q ( y j = 1 ) Q(y_j=1) Q(yj=1) 不会被计算。

若采用 Log-Uniform 分布, Q ( y = 1 ) Q(y=1) Q(y=1) 和 Q ( y j = 1 ) Q(y_j=1) Q(yj=1) 的公式值相同,但负样本中 y j = 1 y_j=1 yj=1 会被排除。

3. 实际应用示例

在推荐系统中:

  • 输入 :用户特征x(如用户ID的嵌入向量),正样本y(用户点击的物品ID)。
  • 输出 :从物品库中采样k个未点击的物品作为负样本y_j,计算对比损失。

总结

NCE Loss通过num_sampled参数控制负样本数量,从噪声分布Q(y)中采样y_j,与正样本y共同训练模型区分正负样本的能力。其核心优势是避免了全类别Softmax的计算开销。

相关推荐
看到我,请让我去学习3 分钟前
OpenCV开发-初始概念
人工智能·opencv·计算机视觉
汀沿河3 分钟前
8.1 prefix Tunning与Prompt Tunning模型微调方法
linux·运维·服务器·人工智能
陈敬雷-充电了么-CEO兼CTO13 分钟前
大模型技术原理 - 基于Transformer的预训练语言模型
人工智能·深度学习·语言模型·自然语言处理·chatgpt·aigc·transformer
学术 学术 Fun19 分钟前
✨ OpenAudio S1:影视级文本转语音与语音克隆Mac整合包
人工智能·语音识别
用户Taobaoapi20141 小时前
母婴用品社媒种草效果量化:淘宝详情API+私域转化追踪案例
大数据·数据挖掘·数据分析
用户Taobaoapi20141 小时前
Taobao agent USA丨美国淘宝代购1688代采集运系统搭建指南
数据挖掘·php
风铃喵游1 小时前
让大模型调用MCP服务变得超级简单
前端·人工智能
booooooty1 小时前
基于Spring AI Alibaba的多智能体RAG应用
java·人工智能·spring·多智能体·rag·spring ai·ai alibaba
PyAIExplorer2 小时前
基于 OpenCV 的图像 ROI 切割实现
人工智能·opencv·计算机视觉
风口猪炒股指标2 小时前
技术分析、超短线打板模式与情绪周期理论,在市场共识的形成、分歧、瓦解过程中缘起性空的理解
人工智能·博弈论·群体博弈·人生哲学·自我引导觉醒