关于CLS与mean_pooling的一些笔记

在自然语言处理领域,使用 [CLS]Token 还是 mean_pooling 的选择完全取决于模型的架构和具体训练目标。

1. [CLS] Token (Pooler Output ,output[1])

在许多Transformer模型中,每个序列开头都会插入一个特殊的 [CLS] Token(分类)。

  • 逻辑: 该模型被训练为将整个句子的语义摘要集中到这个单一向量中。
  • 数学: 最终表示通常是最后一个隐藏状态(output[0])的第一个向量,通常通过线性层和 激活函数生成"output[1]"。
  • 最佳使用场景: 它在意图分类或命令匹配等需要高度区分的任务中非常有效。
  • 可辨别性: 它专门设计用来区分听起来相似但语义不同的输入。

2. Mean Pooling (平均池化)

Mean Pooling是一种对Token Emnedding 序列(Output [0])进行的手动操作。

  • 逻辑: 它计算序列中所有符号向量的数学平均值(不包括填充符号)。
  • 数学: 对于长度为 的序列,嵌入 计算为
  • 最佳使用场景: 它是专门训练语义文本相似性(STS)模型中的标准,如Sentence-BERT(SBERT)。
  • 冗余: 它之所以有效,是因为注意力机制确保每个标记都包含句子其他部分的一些上下文。

3. 当前偏好的比较

特点 [CLS] 代币策略 即池策略
主要资料来源 Output[1](或 output[0][0] output[0](平均值)
培训要求 需要基于CLS的损耗函数 需要基于池化的损耗函数
结果向量 高度"总结" 数学上的"平均值"
无关标记影响 低(忽略无关紧要的标记) 中等(平均所有代币)

实施摘要

如果模型配置明确启用 pooling_mode_cls_token 和禁用 pooling_mode_mean_tokens,使用mean pooling 在数学上与模型内部权重不一致。在这种情况下,[CLS]Token(Output[1])将提供更优越的准确性和可辨别性。

ONNX模型使用不同Embedding的代码:

复制代码
import numpy as np
import torch
from transformers import XLMRobertaModel, AutoModel, AutoTokenizer,XLMRobertaTokenizer, AlbertForMaskedLM
from sentence_transformers import SentenceTransformer
import onnx
import onnxruntime as ort

def l2_normalize(x, axis=-1, eps=1e-12):
    norm = np.linalg.norm(x, axis=axis, keepdims=True)
    return x / np.maximum(norm, eps)


def mean_pooling(token_embeddings, attention_mask):
    """
    token_embeddings: np.ndarray [1, T, H]
    attention_mask:   np.ndarray [1, T] (0/1)
    """

    # Expand mask to [1, T, 1]
    mask = attention_mask[:, :, None].astype(np.float32)

    # Masked sum
    summed = np.sum(token_embeddings * mask, axis=1)

    # Count of valid tokens
    counts = np.sum(mask, axis=1)

    # Avoid division by zero
    pooled = summed / np.maximum(counts, 1e-9)

class OnnxModel:

    def __init__(self, onnx_model_path, tokenizer_path):
        print('model', onnx_model_path)
        print('tokenizer', tokenizer_path)
        self.onnx_model = onnx.load(onnx_model_path)
        onnx.checker.check_model(self.onnx_model)
        self.ort_session = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

        self.input_name = self.ort_session.get_inputs()[0].name
        self.attention_mask_name = self.ort_session.get_inputs()[1].name

        self.encoder_session = self.ort_session
        self.enc_input_ids = self.encoder_session.get_inputs()[0].name
        self.enc_attention_mask = self.encoder_session.get_inputs()[1].name

    def encode(self, text: str, dim: int = 384) -> np.ndarray:
        # --- 1. Tokenize (NO padding) ---
        tokens = self.tokenizer(
            text,
            return_tensors="np",
            padding=False,
            truncation=False
        )

        input_ids = tokens["input_ids"].astype(np.int64)          # [1, T]
        attention_mask = tokens["attention_mask"].astype(np.int64)  # [1, T]

        # --- 2. Run encoder ---
        encoder_outputs = self.encoder_session.run(
            None,
            {
                self.enc_input_ids: input_ids,
                self.enc_attention_mask: attention_mask,
            }
        )

        
        # 3.1 sentence_embedding 
        sentence_embedding = encoder_outputs[1][0]
        sentence_embedding_norm = l2_normalize(sentence_embedding)

        # 3.2 cls_embedding # last_hidden_state: [1, T, H]
        token_embeddings = encoder_outputs[0]
        cls_embedding = token_embeddings[:, 0]
        cls_embedding_norm = l2_normalize(cls_embedding)

        # 3.3 mean_pooling_embedding 
        token_embeddings = encoder_outputs[0]
        # --- Pool to create 'sentence' embedding
        mean_pooling_embedding = mean_pooling(token_embeddings, attention_mask)
        mean_pooling_embedding_norm = l2_normalize(pooled)[0]

        # --- 4. Truncate to dim if needed ---
        # if embedding.shape[0] > dim:
        #     embedding = embedding[:dim]

        return sentence_embedding_norm # 根据需要返回不同的Embedding
相关推荐
会飞的老朱1 小时前
医药集团数智化转型,智能综合管理平台激活集团管理新效能
大数据·人工智能·oa协同办公
聆风吟º3 小时前
CANN runtime 实战指南:异构计算场景中运行时组件的部署、调优与扩展技巧
人工智能·神经网络·cann·异构计算
Codebee5 小时前
能力中心 (Agent SkillCenter):开启AI技能管理新时代
人工智能
聆风吟º6 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
uesowys6 小时前
Apache Spark算法开发指导-One-vs-Rest classifier
人工智能·算法·spark
AI_56786 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子6 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
wdfk_prog6 小时前
[Linux]学习笔记系列 -- [drivers][input]input
linux·笔记·学习
ouliten6 小时前
cuda编程笔记(36)-- 应用Tensor Core加速矩阵乘法
笔记·cuda
智驱力人工智能6 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算