关于CLS与mean_pooling的一些笔记

在自然语言处理领域,使用 [CLS]Token 还是 mean_pooling 的选择完全取决于模型的架构和具体训练目标。

1. [CLS] Token (Pooler Output ,output[1])

在许多Transformer模型中,每个序列开头都会插入一个特殊的 [CLS] Token(分类)。

  • 逻辑: 该模型被训练为将整个句子的语义摘要集中到这个单一向量中。
  • 数学: 最终表示通常是最后一个隐藏状态(output[0])的第一个向量,通常通过线性层和 激活函数生成"output[1]"。
  • 最佳使用场景: 它在意图分类或命令匹配等需要高度区分的任务中非常有效。
  • 可辨别性: 它专门设计用来区分听起来相似但语义不同的输入。

2. Mean Pooling (平均池化)

Mean Pooling是一种对Token Emnedding 序列(Output [0])进行的手动操作。

  • 逻辑: 它计算序列中所有符号向量的数学平均值(不包括填充符号)。
  • 数学: 对于长度为 的序列,嵌入 计算为
  • 最佳使用场景: 它是专门训练语义文本相似性(STS)模型中的标准,如Sentence-BERT(SBERT)。
  • 冗余: 它之所以有效,是因为注意力机制确保每个标记都包含句子其他部分的一些上下文。

3. 当前偏好的比较

特点 [CLS] 代币策略 即池策略
主要资料来源 Output[1](或 output[0][0] output[0](平均值)
培训要求 需要基于CLS的损耗函数 需要基于池化的损耗函数
结果向量 高度"总结" 数学上的"平均值"
无关标记影响 低(忽略无关紧要的标记) 中等(平均所有代币)

实施摘要

如果模型配置明确启用 pooling_mode_cls_token 和禁用 pooling_mode_mean_tokens,使用mean pooling 在数学上与模型内部权重不一致。在这种情况下,[CLS]Token(Output[1])将提供更优越的准确性和可辨别性。

ONNX模型使用不同Embedding的代码:

复制代码
import numpy as np
import torch
from transformers import XLMRobertaModel, AutoModel, AutoTokenizer,XLMRobertaTokenizer, AlbertForMaskedLM
from sentence_transformers import SentenceTransformer
import onnx
import onnxruntime as ort

def l2_normalize(x, axis=-1, eps=1e-12):
    norm = np.linalg.norm(x, axis=axis, keepdims=True)
    return x / np.maximum(norm, eps)


def mean_pooling(token_embeddings, attention_mask):
    """
    token_embeddings: np.ndarray [1, T, H]
    attention_mask:   np.ndarray [1, T] (0/1)
    """

    # Expand mask to [1, T, 1]
    mask = attention_mask[:, :, None].astype(np.float32)

    # Masked sum
    summed = np.sum(token_embeddings * mask, axis=1)

    # Count of valid tokens
    counts = np.sum(mask, axis=1)

    # Avoid division by zero
    pooled = summed / np.maximum(counts, 1e-9)

class OnnxModel:

    def __init__(self, onnx_model_path, tokenizer_path):
        print('model', onnx_model_path)
        print('tokenizer', tokenizer_path)
        self.onnx_model = onnx.load(onnx_model_path)
        onnx.checker.check_model(self.onnx_model)
        self.ort_session = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

        self.input_name = self.ort_session.get_inputs()[0].name
        self.attention_mask_name = self.ort_session.get_inputs()[1].name

        self.encoder_session = self.ort_session
        self.enc_input_ids = self.encoder_session.get_inputs()[0].name
        self.enc_attention_mask = self.encoder_session.get_inputs()[1].name

    def encode(self, text: str, dim: int = 384) -> np.ndarray:
        # --- 1. Tokenize (NO padding) ---
        tokens = self.tokenizer(
            text,
            return_tensors="np",
            padding=False,
            truncation=False
        )

        input_ids = tokens["input_ids"].astype(np.int64)          # [1, T]
        attention_mask = tokens["attention_mask"].astype(np.int64)  # [1, T]

        # --- 2. Run encoder ---
        encoder_outputs = self.encoder_session.run(
            None,
            {
                self.enc_input_ids: input_ids,
                self.enc_attention_mask: attention_mask,
            }
        )

        
        # 3.1 sentence_embedding 
        sentence_embedding = encoder_outputs[1][0]
        sentence_embedding_norm = l2_normalize(sentence_embedding)

        # 3.2 cls_embedding # last_hidden_state: [1, T, H]
        token_embeddings = encoder_outputs[0]
        cls_embedding = token_embeddings[:, 0]
        cls_embedding_norm = l2_normalize(cls_embedding)

        # 3.3 mean_pooling_embedding 
        token_embeddings = encoder_outputs[0]
        # --- Pool to create 'sentence' embedding
        mean_pooling_embedding = mean_pooling(token_embeddings, attention_mask)
        mean_pooling_embedding_norm = l2_normalize(pooled)[0]

        # --- 4. Truncate to dim if needed ---
        # if embedding.shape[0] > dim:
        #     embedding = embedding[:dim]

        return sentence_embedding_norm # 根据需要返回不同的Embedding
相关推荐
星越华夏20 小时前
计算机视觉:YOLOv12安装环境
人工智能·yolo·计算机视觉
二哈赛车手21 小时前
新人笔记---ApiFox的一些常见使用出错
java·笔记·spring
Yolanda9421 小时前
【人工智能】《从零搭建AI问答助手项目(九):Prompt优化》
人工智能·prompt
wj30558537821 小时前
课程 9:模型测试记录与 Prompt 策略
linux·人工智能·python·comfyui
小和尚同志21 小时前
深入使用 skill-creator:结合真实生产级实践
人工智能·aigc
DevSecOps选型指南21 小时前
安全419专访悬镜安全 | 穿越周期在 AI 浪潮中定义数字供应链安全新范式
人工智能
沪漂阿龙1 天前
面试题详解:GraphRAG 全面解析——知识图谱增强 RAG、Local Search、Global Search、社区摘要、工程落地与评估指标一次讲透
人工智能·知识图谱
WangN21 天前
Unitree RL Lab 学习笔记【通识】
人工智能·机器学习
haina20191 天前
海纳AI亮相《科创中国》,解码招聘“智”变之路
人工智能·ai面试·ai招聘
阿星AI工作室1 天前
刘润年中大课笔记:一句话说清AI落地之战的本质
大数据·人工智能·创业创新·商业