【蒸馏用损失】NCEloss介绍,大规模分类任务的损失函数

NCE Loss(Noise Contrastive Estimation Loss,噪声对比估计损失)是一种用于大规模分类任务的损失函数,尤其适用于类别数量极多(如词表预测、推荐系统召回等场景)时替代传统Softmax的方法。其核心思想是将多分类问题转化为二分类问题,通过对比正样本与负样本的差异来优化模型。

第一节、核心原理

  1. 问题转化

    • 传统Softmax需要计算所有类别的概率归一化项(分母),当类别数 N N N极大时计算成本高昂。NCE Loss通过采样少量负样本(如 k k k个),将问题简化为 k + 1 k+1 k+1个二分类任务(1个正样本 + k k k个负样本)。正负样本是对这些输入向量的目标类别(标签)的划分 。例如,在Word2Vec中,inputs是中心词的嵌入向量,正样本是其上下文词(标签)负样本是从噪声分布中采样的其他词
  2. 损失函数形式

    • 对于正样本 ( x , y ) (x, y) (x,y)和负样本 ( x , y j ) (x, y_j) (x,yj),NCE Loss定义为:
      L = − log ⁡ σ ( s ( x , y ) ) − ∑ j = 1 k log ⁡ σ ( − s ( x , y j ) ) L = -\log \sigma(s(x,y)) - \sum_{j=1}^k \log \sigma(-s(x,y_j)) L=−logσ(s(x,y))−j=1∑klogσ(−s(x,yj))
      其中 σ \sigma σ为Sigmoid函数, s ( x , y ) s(x,y) s(x,y)为模型对样本 ( x , y ) (x,y) (x,y)的得分(如点积)。

在NCE Loss(Noise Contrastive Estimation Loss)的上下文中, s ( x , y ) s(x,y) s(x,y) 通常表示模型对输入对 ( x , y ) (x,y) (x,y) 的匹配度得分,具体含义如下:

x x x 的典型含义

  • 在推荐系统中, x x x 通常代表用户特征(如用户ID、历史行为等)或用户向量(如双塔模型中的用户塔输出)。
  • 在自然语言处理中, x x x 可能是上下文词向量(如Word2Vec中的中心词或上下文窗口)。

y y y 的典型含义

  • 在推荐系统中, y y y 代表物品(Item),例如用户点击过的正样本物品或采样的负样本物品。
  • 在NLP中, y y y 可能是目标词(如语言模型中的下一个词)或噪声词(负采样词)。

s ( x , y ) s(x,y) s(x,y) 的计算方式

  • 一般为 x x x 和 y y y 的点积(如双塔模型中的用户向量与物品向量点积)。

  • 在NCE中, s ( x , y ) s(x,y) s(x,y) 会经过修正: g ( x , y ) = s ( x , y ) − log ⁡ Q ( y ∣ x ) g(x,y) = s(x,y) - \log Q(y|x) g(x,y)=s(x,y)−logQ(y∣x),其中 Q ( y ∣ x ) Q(y|x) Q(y∣x) 是负采样分布的概率。

  • x x x:输入特征(用户/上下文)

  • y y y:目标类别(物品/词)

  • s ( x , y ) s(x,y) s(x,y) :模型对 ( x , y ) (x,y) (x,y) 匹配度的原始得分,用于后续损失计算。

  1. 负采样策略 : (见第三节)
    • 负样本通常从噪声分布 Q ( y ) Q(y) Q(y)中采样,常见策略包括:
      • 均匀采样:所有类别等概率。
      • 频次采样:按类别出现频率采样,热门类别更易被选中。
      • 调节采样 :通过参数 b b b控制热门类别的采样概率(如 Q ( y ) ∝ f ( y ) b Q(y) \propto f(y)^b Q(y)∝f(y)b, f ( y ) f(y) f(y)为频率)。

优势与特性

  • 计算高效:仅需计算少量负样本的得分,避免全类别Softmax的开销。
  • 理论保证 :当负样本数 k k k足够大时,NCE Loss的梯度趋近于Softmax的梯度。
  • 灵活性:可通过调整负采样分布适应不同任务需求(如打压热门物品)。

应用场景

  1. 自然语言处理
    • Word2Vec中采用NCE或负采样(NEG,NCE的简化版)训练词向量。
  2. 推荐系统
    • 召回阶段用NCE Loss区分用户正反馈物品与随机负样本。
  3. 对比学习
    • 如CLIP模型使用InfoNCE(NCE的变体)对齐图文特征。

变体与改进

  • NEG(Negative Sampling)
    忽略NCE中的分布修正项,进一步简化计算,但无理论收敛保证。
  • InfoNCE
    引入温度系数 τ \tau τ调节分布尖锐程度,常用于对比学习。
  • 硬负样本挖掘
    选择模型当前易混淆的负样本(Hard Negative)提升训练效果。

第二节、代码实现示例

python 复制代码
loss = nce_loss(
    weights=embedding_matrix,  # 类别嵌入矩阵
    biases=biases,             # 类别偏置
    inputs=input_embeddings,   # 输入特征
    labels=positive_labels,    # 正样本标签
    num_sampled=num_negatives, # 负样本数
    num_classes=vocab_size     # 总类别数
)

在NCE Loss的实现中,参数inputslabels分别对应数学公式中的 x x x和 y y y,具体含义如下:

x x x(输入特征)

  • 对应代码中的inputs=input_embeddings,表示模型输入的向量(如用户特征、上下文词向量等)。
  • 例如,在推荐系统中, x x x可能是用户嵌入向量;在NLP中, x x x可能是中心词或上下文窗口的向量表示。

y y y(目标类别)

  • 对应代码中的labels=positive_labels,表示正样本的类别标签(如物品ID、目标词等)。
  • 例如,推荐系统中 y y y是用户点击的物品,NLP中 y y y是目标词或噪声词。

补充说明:

  • s ( x , y ) s(x,y) s(x,y)的计算
    通常为 x x x和 y y y的嵌入向量点积(如matmul(inputs, weights)),即模型对 ( x , y ) (x,y) (x,y)匹配度的原始得分。
  • 负样本处理
    NCE Loss通过num_sampled参数采样负样本(噪声分布 Q ( y ) Q(y) Q(y)中的 y j y_j yj),与正样本 y y y共同构成二分类任务。

第三节、负样本处理

在NCE Loss中,通过num_sampled参数采样负样本的过程如下:

1. 输入与输出

  • 输入

    • 正样本标签labels(形状为(batch_size, num_true)),表示当前输入x对应的真实类别y
    • 噪声分布Q(y),即负样本的采样分布(如均匀分布、频次采样等)。
    • 负样本数量num_sampled,指定每个正样本采样的负样本数k
  • 输出

    • 采样的负样本 :从噪声分布Q(y)中采样k个类别y_jj=1,2,...,k),与正样本y共同构成k+1个二分类任务。

2. 负样本采样机制

NCE Loss通过以下步骤采样负样本:

  1. 噪声分布选择
    • 均匀采样 :所有类别等概率被选中(Q(y) = 1/NN为总类别数)。
    • 频次采样 :类别y的采样概率与其在数据集中出现的频率成正比(Q(y) ∝ f(y))。
    • 调节采样 :通过调节因子b控制热门类别的采样概率(Q(y) ∝ f(y)^bb=0时为均匀采样,b=1时为频次采样)。
    • log-uniform采样 : ( P ( k ) P(k) P(k)是 Q ( y ) Q(y) Q(y)的一种实现方式:) P ( k ) = log ⁡ ( k + 2 ) − log ⁡ ( k + 1 ) log ⁡ ( range_max + 1 ) P(k) = \frac{\log(k+2) - \log(k+1)}{\log(\text{range\_max} + 1)} P(k)=log(range_max+1)log(k+2)−log(k+1)
      其中k为类别ID,range_max=num_classes。该分布使得高频类别(大k)被采样的概率更低。
采样类型 概率分布 特点 适用场景
均匀采样 P ( k ) = 1 N P(k) = \frac{1}{N} P(k)=N1 所有类别等概率 类别分布未知或均匀时
频次采样 P ( k ) ∝ f ( k ) P(k) \propto f(k) P(k)∝f(k) 直接按类别出现频率 f ( k ) f(k) f(k)采样 需精确匹配数据分布时
log-uniform采样 P ( k ) ∝ 1 k P(k) \propto \frac{1}{k} P(k)∝k1 对高频类别降权,平衡采样效率与覆盖 词频、推荐物品ID等长尾数据

2. 数学形式

对于每个正样本(x, y),NCE Loss构造的损失函数为:
L = − log ⁡ σ ( s ( x , y ) − log ⁡ Q ( y ) ) − ∑ j = 1 k log ⁡ σ ( − s ( x , y j ) + log ⁡ Q ( y j ) ) L = -\log \sigma(s(x,y) - \log Q(y)) - \sum_{j=1}^k \log \sigma(-s(x,y_j) + \log Q(y_j)) L=−logσ(s(x,y)−logQ(y))−j=1∑klogσ(−s(x,yj)+logQ(yj))

其中:

  • s(x,y)是模型对(x,y)的匹配得分(如点积)。
  • σ为Sigmoid函数。
  • Q(y)是噪声分布的概率。

正样本 y = 1 y=1 y=1 和负样本 y j = 1 y_j=1 yj=1

  • 若正样本 y = 1 y=1 y=1 和负样本 y j = 1 y_j=1 yj=1 理论上 的噪声概率相同(即 Q ( y = 1 ) = Q ( y j = 1 ) Q(y=1) = Q(y_j=1) Q(y=1)=Q(yj=1)),但实际采样时 y j y_j yj 不会与 y y y 重复。
  • 若任务中允许 y j = y y_j=y yj=y(罕见情况),则两者的噪声概率确实相同,但这种情况会破坏对比学习的有效性。

假设类别总数 N = 100 N=100 N=100,采用均匀分布:

  • 正样本 y = 1 y=1 y=1 的 Q ( y = 1 ) = 1 100 Q(y=1) = \frac{1}{100} Q(y=1)=1001。
  • 负样本 y j y_j yj 的 Q ( y j ) Q(y_j) Q(yj) 也是 1 100 \frac{1}{100} 1001,但实际采样时 y j ∈ { 2 , 3 , . . . , 100 } y_j \in \{2, 3, ..., 100\} yj∈{2,3,...,100},因此 Q ( y j = 1 ) Q(y_j=1) Q(yj=1) 不会被计算。

若采用 Log-Uniform 分布, Q ( y = 1 ) Q(y=1) Q(y=1) 和 Q ( y j = 1 ) Q(y_j=1) Q(yj=1) 的公式值相同,但负样本中 y j = 1 y_j=1 yj=1 会被排除。

3. 实际应用示例

在推荐系统中:

  • 输入 :用户特征x(如用户ID的嵌入向量),正样本y(用户点击的物品ID)。
  • 输出 :从物品库中采样k个未点击的物品作为负样本y_j,计算对比损失。

总结

NCE Loss通过num_sampled参数控制负样本数量,从噪声分布Q(y)中采样y_j,与正样本y共同训练模型区分正负样本的能力。其核心优势是避免了全类别Softmax的计算开销。

相关推荐
nuise_几秒前
李宏毅机器学习笔记06 | 鱼和熊掌可以兼得的机器学习 - 内容接宝可梦
人工智能·笔记·机器学习
声网14 分钟前
MiniMax 发布新 TTS 模型 Speech-02,轻松制作长篇有声内容;Meta 高端眼镜年底推出:售价上千美元丨日报
人工智能
HeteroCat18 分钟前
OpenAI 官方学院 -- 提示词课程要点
人工智能·chatgpt
每天做一点改变20 分钟前
AI Agent成为行业竞争新焦点:技术革新与商业重构的双重浪潮
人工智能·重构
大美B端工场-B端系统美颜师23 分钟前
定制化管理系统与通用管理系统,谁更胜一筹?
人工智能·信息可视化·数据挖掘·数据分析
生信小鹏23 分钟前
Nature旗下 | npj Digital Medicine | 图像+转录组+临床变量三合一,多模态AI预测化疗反应,值得复现学习的完整框架
人工智能·学习·免疫治疗·scrna-seq·scrna
OpenLoong 开源社区37 分钟前
技术视界 | 从哲学到技术:人形机器人感知导航的探索(下篇)
人工智能·机器人·开源社区·人形机器人·openloong
csssnxy1 小时前
叁仟数智指路机器人的主要功能有哪些?
人工智能
蝎蟹居1 小时前
GB/T 4706.1-2024 家用和类似用途电器的安全 第1部分:通用要求 与2005版差异(1)
人工智能·单片机·嵌入式硬件·物联网·安全
浊酒南街1 小时前
TensorFlow实现逻辑回归
人工智能·tensorflow·逻辑回归