关于CLS与mean_pooling的一些笔记

在自然语言处理领域,使用 [CLS]Token 还是 mean_pooling 的选择完全取决于模型的架构和具体训练目标。

1. [CLS] Token (Pooler Output ,output[1])

在许多Transformer模型中,每个序列开头都会插入一个特殊的 [CLS] Token(分类)。

  • 逻辑: 该模型被训练为将整个句子的语义摘要集中到这个单一向量中。
  • 数学: 最终表示通常是最后一个隐藏状态(output[0])的第一个向量,通常通过线性层和 激活函数生成"output[1]"。
  • 最佳使用场景: 它在意图分类或命令匹配等需要高度区分的任务中非常有效。
  • 可辨别性: 它专门设计用来区分听起来相似但语义不同的输入。

2. Mean Pooling (平均池化)

Mean Pooling是一种对Token Emnedding 序列(Output [0])进行的手动操作。

  • 逻辑: 它计算序列中所有符号向量的数学平均值(不包括填充符号)。
  • 数学: 对于长度为 的序列,嵌入 计算为
  • 最佳使用场景: 它是专门训练语义文本相似性(STS)模型中的标准,如Sentence-BERT(SBERT)。
  • 冗余: 它之所以有效,是因为注意力机制确保每个标记都包含句子其他部分的一些上下文。

3. 当前偏好的比较

特点 [CLS] 代币策略 即池策略
主要资料来源 Output[1](或 output[0][0] output[0](平均值)
培训要求 需要基于CLS的损耗函数 需要基于池化的损耗函数
结果向量 高度"总结" 数学上的"平均值"
无关标记影响 低(忽略无关紧要的标记) 中等(平均所有代币)

实施摘要

如果模型配置明确启用 pooling_mode_cls_token 和禁用 pooling_mode_mean_tokens,使用mean pooling 在数学上与模型内部权重不一致。在这种情况下,[CLS]Token(Output[1])将提供更优越的准确性和可辨别性。

ONNX模型使用不同Embedding的代码:

复制代码
import numpy as np
import torch
from transformers import XLMRobertaModel, AutoModel, AutoTokenizer,XLMRobertaTokenizer, AlbertForMaskedLM
from sentence_transformers import SentenceTransformer
import onnx
import onnxruntime as ort

def l2_normalize(x, axis=-1, eps=1e-12):
    norm = np.linalg.norm(x, axis=axis, keepdims=True)
    return x / np.maximum(norm, eps)


def mean_pooling(token_embeddings, attention_mask):
    """
    token_embeddings: np.ndarray [1, T, H]
    attention_mask:   np.ndarray [1, T] (0/1)
    """

    # Expand mask to [1, T, 1]
    mask = attention_mask[:, :, None].astype(np.float32)

    # Masked sum
    summed = np.sum(token_embeddings * mask, axis=1)

    # Count of valid tokens
    counts = np.sum(mask, axis=1)

    # Avoid division by zero
    pooled = summed / np.maximum(counts, 1e-9)

class OnnxModel:

    def __init__(self, onnx_model_path, tokenizer_path):
        print('model', onnx_model_path)
        print('tokenizer', tokenizer_path)
        self.onnx_model = onnx.load(onnx_model_path)
        onnx.checker.check_model(self.onnx_model)
        self.ort_session = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

        self.input_name = self.ort_session.get_inputs()[0].name
        self.attention_mask_name = self.ort_session.get_inputs()[1].name

        self.encoder_session = self.ort_session
        self.enc_input_ids = self.encoder_session.get_inputs()[0].name
        self.enc_attention_mask = self.encoder_session.get_inputs()[1].name

    def encode(self, text: str, dim: int = 384) -> np.ndarray:
        # --- 1. Tokenize (NO padding) ---
        tokens = self.tokenizer(
            text,
            return_tensors="np",
            padding=False,
            truncation=False
        )

        input_ids = tokens["input_ids"].astype(np.int64)          # [1, T]
        attention_mask = tokens["attention_mask"].astype(np.int64)  # [1, T]

        # --- 2. Run encoder ---
        encoder_outputs = self.encoder_session.run(
            None,
            {
                self.enc_input_ids: input_ids,
                self.enc_attention_mask: attention_mask,
            }
        )

        
        # 3.1 sentence_embedding 
        sentence_embedding = encoder_outputs[1][0]
        sentence_embedding_norm = l2_normalize(sentence_embedding)

        # 3.2 cls_embedding # last_hidden_state: [1, T, H]
        token_embeddings = encoder_outputs[0]
        cls_embedding = token_embeddings[:, 0]
        cls_embedding_norm = l2_normalize(cls_embedding)

        # 3.3 mean_pooling_embedding 
        token_embeddings = encoder_outputs[0]
        # --- Pool to create 'sentence' embedding
        mean_pooling_embedding = mean_pooling(token_embeddings, attention_mask)
        mean_pooling_embedding_norm = l2_normalize(pooled)[0]

        # --- 4. Truncate to dim if needed ---
        # if embedding.shape[0] > dim:
        #     embedding = embedding[:dim]

        return sentence_embedding_norm # 根据需要返回不同的Embedding
相关推荐
副露のmagic15 小时前
深度学习基础复健
人工智能·深度学习
番茄大王sc15 小时前
2026年科研AI工具深度测评(一):文献调研与综述生成领域,维普科创助手领跑学术严谨性
人工智能·深度学习·考研·学习方法·论文笔记
傻小胖15 小时前
13.BTC-思考-北大肖臻老师客堂笔记
笔记·区块链
代码丰15 小时前
SpringAI+RAG向量库+知识图谱+多模型路由+Docker打造SmartHR智能招聘助手
人工智能·spring·知识图谱
独处东汉16 小时前
freertos开发空气检测仪之输入子系统结构体设计
数据结构·人工智能·stm32·单片机·嵌入式硬件·算法
乐迪信息16 小时前
乐迪信息:AI防爆摄像机在船舶监控的应用
大数据·网络·人工智能·算法·无人机
風清掦16 小时前
【江科大STM32学习笔记-04】0.96寸OLED显示屏
笔记·stm32·学习
风栖柳白杨16 小时前
【语音识别】soundfile使用方法
人工智能·语音识别
胡西风_foxww16 小时前
ObsidianAI_学习一个陌生知识领域_建立学习路径和知识库框架_写一本书
人工智能·笔记·学习·知识库·obsidian·notebooklm·写一本书
Hernon16 小时前
AI智能体 - 探索与发现 Clawdbot >> Moltbot
大数据·人工智能·ai智能体·ai开发框架