关于CLS与mean_pooling的一些笔记

在自然语言处理领域,使用 CLSToken 还是 mean_pooling 的选择完全取决于模型的架构和具体训练目标。

1. CLS Token (Pooler Output ,output1)

在许多Transformer模型中,每个序列开头都会插入一个特殊的 [CLS] Token(分类)。

  • 逻辑: 该模型被训练为将整个句子的语义摘要集中到这个单一向量中。
  • 数学: 最终表示通常是最后一个隐藏状态(output0)的第一个向量,通常通过线性层和 激活函数生成"output1"。
  • 最佳使用场景: 它在意图分类或命令匹配等需要高度区分的任务中非常有效。
  • 可辨别性: 它专门设计用来区分听起来相似但语义不同的输入。

2. Mean Pooling (平均池化)

Mean Pooling是一种对Token Emnedding 序列(Output 0)进行的手动操作。

  • 逻辑: 它计算序列中所有符号向量的数学平均值(不包括填充符号)。
  • 数学: 对于长度为 的序列,嵌入 计算为
  • 最佳使用场景: 它是专门训练语义文本相似性(STS)模型中的标准,如Sentence-BERT(SBERT)。
  • 冗余: 它之所以有效,是因为注意力机制确保每个标记都包含句子其他部分的一些上下文。

3. 当前偏好的比较

特点 CLS 代币策略 即池策略
主要资料来源 Output1(或 output[0][0] output0(平均值)
培训要求 需要基于CLS的损耗函数 需要基于池化的损耗函数
结果向量 高度"总结" 数学上的"平均值"
无关标记影响 低(忽略无关紧要的标记) 中等(平均所有代币)

实施摘要

如果模型配置明确启用 pooling_mode_cls_token 和禁用 pooling_mode_mean_tokens,使用mean pooling 在数学上与模型内部权重不一致。在这种情况下,CLSToken(Output1)将提供更优越的准确性和可辨别性。

ONNX模型使用不同Embedding的代码:

复制代码
import numpy as np
import torch
from transformers import XLMRobertaModel, AutoModel, AutoTokenizer,XLMRobertaTokenizer, AlbertForMaskedLM
from sentence_transformers import SentenceTransformer
import onnx
import onnxruntime as ort

def l2_normalize(x, axis=-1, eps=1e-12):
    norm = np.linalg.norm(x, axis=axis, keepdims=True)
    return x / np.maximum(norm, eps)


def mean_pooling(token_embeddings, attention_mask):
    """
    token_embeddings: np.ndarray [1, T, H]
    attention_mask:   np.ndarray [1, T] (0/1)
    """

    # Expand mask to [1, T, 1]
    mask = attention_mask[:, :, None].astype(np.float32)

    # Masked sum
    summed = np.sum(token_embeddings * mask, axis=1)

    # Count of valid tokens
    counts = np.sum(mask, axis=1)

    # Avoid division by zero
    pooled = summed / np.maximum(counts, 1e-9)

class OnnxModel:

    def __init__(self, onnx_model_path, tokenizer_path):
        print('model', onnx_model_path)
        print('tokenizer', tokenizer_path)
        self.onnx_model = onnx.load(onnx_model_path)
        onnx.checker.check_model(self.onnx_model)
        self.ort_session = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

        self.input_name = self.ort_session.get_inputs()[0].name
        self.attention_mask_name = self.ort_session.get_inputs()[1].name

        self.encoder_session = self.ort_session
        self.enc_input_ids = self.encoder_session.get_inputs()[0].name
        self.enc_attention_mask = self.encoder_session.get_inputs()[1].name

    def encode(self, text: str, dim: int = 384) -> np.ndarray:
        # --- 1. Tokenize (NO padding) ---
        tokens = self.tokenizer(
            text,
            return_tensors="np",
            padding=False,
            truncation=False
        )

        input_ids = tokens["input_ids"].astype(np.int64)          # [1, T]
        attention_mask = tokens["attention_mask"].astype(np.int64)  # [1, T]

        # --- 2. Run encoder ---
        encoder_outputs = self.encoder_session.run(
            None,
            {
                self.enc_input_ids: input_ids,
                self.enc_attention_mask: attention_mask,
            }
        )

        
        # 3.1 sentence_embedding 
        sentence_embedding = encoder_outputs[1][0]
        sentence_embedding_norm = l2_normalize(sentence_embedding)

        # 3.2 cls_embedding # last_hidden_state: [1, T, H]
        token_embeddings = encoder_outputs[0]
        cls_embedding = token_embeddings[:, 0]
        cls_embedding_norm = l2_normalize(cls_embedding)

        # 3.3 mean_pooling_embedding 
        token_embeddings = encoder_outputs[0]
        # --- Pool to create 'sentence' embedding
        mean_pooling_embedding = mean_pooling(token_embeddings, attention_mask)
        mean_pooling_embedding_norm = l2_normalize(pooled)[0]

        # --- 4. Truncate to dim if needed ---
        # if embedding.shape[0] > dim:
        #     embedding = embedding[:dim]

        return sentence_embedding_norm # 根据需要返回不同的Embedding
相关推荐
小饕14 小时前
RAG学习之【向量数据库】Milvus 从入门到精通:索引、检索、混合搜索一篇打通(RAG 必备)
数据库·人工智能·学习·milvus
华奥系科技14 小时前
汛期城市内涝治理:智慧水务如何重塑防汛“安全感”?
大数据·运维·人工智能
aneasystone本尊14 小时前
给小龙虾配齐工具箱:OpenClaw 的工具体系
人工智能
m0_7186774914 小时前
EaseChart:免费的流程图编辑器和付费的AI流程图Agent
人工智能
不羁的木木14 小时前
HarmonyOS AI开发提效工具:DevEco Code & DevEco CLI - 跨设备调试与AI应用部署
人工智能·华为·harmonyos·鸿蒙
我的世界洛天依14 小时前
胡桃讲编程:麻宫雅典娜 97 RVCv2 第一代(R1)开源发布文档 | 经典复古分支
人工智能
zhangfeng113314 小时前
JupyterLab 里,JSON文件纯文本格式编辑 / 查看
人工智能·json
Bode_200214 小时前
智能协同与绿色数字孪生舱主要功能与关键技术
大数据·人工智能·制造·碳中和
daly52014 小时前
人工智能专业有哪些?2026高考报考指南(专业分类 + 课程 + 就业全解析)
人工智能·分类·高考
暗夜猎手-大魔王14 小时前
转载--AgentScope 生产最佳实践
人工智能