关于CLS与mean_pooling的一些笔记

在自然语言处理领域,使用 [CLS]Token 还是 mean_pooling 的选择完全取决于模型的架构和具体训练目标。

1. [CLS] Token (Pooler Output ,output[1])

在许多Transformer模型中,每个序列开头都会插入一个特殊的 [CLS] Token(分类)。

  • 逻辑: 该模型被训练为将整个句子的语义摘要集中到这个单一向量中。
  • 数学: 最终表示通常是最后一个隐藏状态(output[0])的第一个向量,通常通过线性层和 激活函数生成"output[1]"。
  • 最佳使用场景: 它在意图分类或命令匹配等需要高度区分的任务中非常有效。
  • 可辨别性: 它专门设计用来区分听起来相似但语义不同的输入。

2. Mean Pooling (平均池化)

Mean Pooling是一种对Token Emnedding 序列(Output [0])进行的手动操作。

  • 逻辑: 它计算序列中所有符号向量的数学平均值(不包括填充符号)。
  • 数学: 对于长度为 的序列,嵌入 计算为
  • 最佳使用场景: 它是专门训练语义文本相似性(STS)模型中的标准,如Sentence-BERT(SBERT)。
  • 冗余: 它之所以有效,是因为注意力机制确保每个标记都包含句子其他部分的一些上下文。

3. 当前偏好的比较

特点 [CLS] 代币策略 即池策略
主要资料来源 Output[1](或 output[0][0] output[0](平均值)
培训要求 需要基于CLS的损耗函数 需要基于池化的损耗函数
结果向量 高度"总结" 数学上的"平均值"
无关标记影响 低(忽略无关紧要的标记) 中等(平均所有代币)

实施摘要

如果模型配置明确启用 pooling_mode_cls_token 和禁用 pooling_mode_mean_tokens,使用mean pooling 在数学上与模型内部权重不一致。在这种情况下,[CLS]Token(Output[1])将提供更优越的准确性和可辨别性。

ONNX模型使用不同Embedding的代码:

复制代码
import numpy as np
import torch
from transformers import XLMRobertaModel, AutoModel, AutoTokenizer,XLMRobertaTokenizer, AlbertForMaskedLM
from sentence_transformers import SentenceTransformer
import onnx
import onnxruntime as ort

def l2_normalize(x, axis=-1, eps=1e-12):
    norm = np.linalg.norm(x, axis=axis, keepdims=True)
    return x / np.maximum(norm, eps)


def mean_pooling(token_embeddings, attention_mask):
    """
    token_embeddings: np.ndarray [1, T, H]
    attention_mask:   np.ndarray [1, T] (0/1)
    """

    # Expand mask to [1, T, 1]
    mask = attention_mask[:, :, None].astype(np.float32)

    # Masked sum
    summed = np.sum(token_embeddings * mask, axis=1)

    # Count of valid tokens
    counts = np.sum(mask, axis=1)

    # Avoid division by zero
    pooled = summed / np.maximum(counts, 1e-9)

class OnnxModel:

    def __init__(self, onnx_model_path, tokenizer_path):
        print('model', onnx_model_path)
        print('tokenizer', tokenizer_path)
        self.onnx_model = onnx.load(onnx_model_path)
        onnx.checker.check_model(self.onnx_model)
        self.ort_session = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

        self.input_name = self.ort_session.get_inputs()[0].name
        self.attention_mask_name = self.ort_session.get_inputs()[1].name

        self.encoder_session = self.ort_session
        self.enc_input_ids = self.encoder_session.get_inputs()[0].name
        self.enc_attention_mask = self.encoder_session.get_inputs()[1].name

    def encode(self, text: str, dim: int = 384) -> np.ndarray:
        # --- 1. Tokenize (NO padding) ---
        tokens = self.tokenizer(
            text,
            return_tensors="np",
            padding=False,
            truncation=False
        )

        input_ids = tokens["input_ids"].astype(np.int64)          # [1, T]
        attention_mask = tokens["attention_mask"].astype(np.int64)  # [1, T]

        # --- 2. Run encoder ---
        encoder_outputs = self.encoder_session.run(
            None,
            {
                self.enc_input_ids: input_ids,
                self.enc_attention_mask: attention_mask,
            }
        )

        
        # 3.1 sentence_embedding 
        sentence_embedding = encoder_outputs[1][0]
        sentence_embedding_norm = l2_normalize(sentence_embedding)

        # 3.2 cls_embedding # last_hidden_state: [1, T, H]
        token_embeddings = encoder_outputs[0]
        cls_embedding = token_embeddings[:, 0]
        cls_embedding_norm = l2_normalize(cls_embedding)

        # 3.3 mean_pooling_embedding 
        token_embeddings = encoder_outputs[0]
        # --- Pool to create 'sentence' embedding
        mean_pooling_embedding = mean_pooling(token_embeddings, attention_mask)
        mean_pooling_embedding_norm = l2_normalize(pooled)[0]

        # --- 4. Truncate to dim if needed ---
        # if embedding.shape[0] > dim:
        #     embedding = embedding[:dim]

        return sentence_embedding_norm # 根据需要返回不同的Embedding
相关推荐
badhope2 小时前
Mobile-Skills:移动端技能可视化的创新实践
开发语言·人工智能·git·智能手机·github
吴佳浩4 小时前
GPU 编号进阶:CUDA\_VISIBLE\_DEVICES、多进程与容器化陷阱
人工智能·pytorch·python
吴佳浩4 小时前
GPU 编号错乱踩坑指南:PyTorch cuda 编号与 nvidia-smi 不一致
人工智能·pytorch·nvidia
小饕4 小时前
苏格拉底式提问对抗315 AI投毒:实操指南
网络·人工智能
卧蚕土豆4 小时前
【有啥问啥】OpenClaw 安装与使用教程
人工智能·深度学习
GoCodingInMyWay4 小时前
开源好物 26/03
人工智能·开源
AI科技星4 小时前
全尺度角速度统一:基于 v ≡ c 的纯推导与验证
c语言·开发语言·人工智能·opencv·算法·机器学习·数据挖掘
zhangfeng11335 小时前
Windows 的 Git Bash 中使用 md5sum 命令非常简单 md5做文件完整性检测 WinRAR 可以计算文件的 MD5 值
人工智能·windows·git·bash
hjxu20165 小时前
【OpenClaw 龙虾养成笔记一】在远程服务器,使用Docker安装OpenClaw
服务器·笔记·docker
monsion5 小时前
OpenCode 学习指南
人工智能·vscode·架构