Transformers中从 logits 本质到问答系统中的字符定位机制

从 logits 本质到问答系统中的字符定位机制

先阅读代码,理解

python 复制代码
from transformers import  AutoTokenizer, AutoModelForQuestionAnswering
import torch

model = AutoModelForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad',
                                                           cache_dir='./model')
tokenizer = AutoTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad', cache_dir='./model', use_fast=True)
# print(tokenizer)
question = "What is the capital of France?"
context = "Paris is the capital of France. It is located in the north-central part of the country on the Seine River. Paris has been one of the world's major centres of finance, diplomacy, commerce, fashion, gastronomy, science, and arts..."
inputs = tokenizer(
    question,  # 问题
    context,  # 长上下文
    padding=True,
    truncation='only_second',  # 只截断长的上下文,保留完整问题
    max_length=384,
    stride=128,  # 允许重叠,以防答案在窗口边界
    return_offsets_mapping=True,  # 为了将答案映射回原文
    return_tensors='pt'
)

print(inputs)


# 提取模型需要的参数(排除offset_mapping)
model_inputs = {
    'input_ids': inputs['input_ids'],
    'attention_mask': inputs['attention_mask']
}

# 如果有token_type_ids
if 'token_type_ids' in inputs:
    model_inputs['token_type_ids'] = inputs['token_type_ids']

with torch.no_grad():
    outputs = model(**model_inputs)
    start = outputs.start_logits
    end = outputs.end_logits

# 找到答案的start和end位置
start_idx = torch.argmax(start)
end_idx = torch.argmax(end)

# 使用offset mapping将token位置映射回字符位置
offset_mapping = inputs['offset_mapping'][0]
start_char = offset_mapping[start_idx][0].item()
end_char = offset_mapping[end_idx][1].item()

# 提取答案
answer = context[start_char:end_char]
print(f"Answer: {answer}")  # 输出: Paris
🔹 第一阶段:logits 的本质与维度来源

你首先提出了一个非常根本性的问题:

"logits 到底是怎么来的?它的维度是谁决定的?"

这个问题看似简单,实则触及了深度学习模型输出层的核心机制。我们通过详细剖析模型前向传播流程,揭示了 logits 的真实含义:

  • logits 是模型最后一层全连接层的原始输出,未经 softmax 归一化,代表模型对每个可能类别的"原始打分"。

  • 它来源于模型主干(如 BERT 编码器)输出的隐藏状态(如 [CLS] 向量或每个 token 的表示),经过一个线性变换:

    logits=�⋅ℎ+�logits=W⋅h+b
    其中 ℎh 是隐藏向量,�W 是分类权重矩阵,�b 是偏置项。

  • logits 的维度由任务类型和输出空间大小决定

    • 文本分类 :维度 = 类别数(如情感分析 3 分类 → [batch, 3]
    • 语言模型 :维度 = 词汇表大小(如 BERT 的 30522 → [batch, seq_len, 30522]
    • 命名实体识别(NER) :维度 = 标签数(如 10 种实体 → [batch, seq_len, 10]
    • 问答任务(QA) :输出两个 logits 向量 ------ start_logitsend_logits,维度为 [batch, seq_len]

你还了解到:训练时通常直接使用 logits 计算损失(如交叉熵),因为 PyTorch 的 F.cross_entropy 内部会自动处理 softmax 和 log 操作,数值更稳定。


🔹 第二阶段:问答任务中为何需要 start 和 end?

在理解了 logits 的基础上,你进一步将问题聚焦到具体任务场景 ------ 问答系统(Question Answering),并提出了一个极具洞察力的问题:

"为什么需要 start 和 end?我只要开始和结束就能找到对应字符?token 顺序会乱吗?"

这个问题直指问答模型的设计哲学。我们通过代码示例和原理分析,明确了以下几点:

  • 答案可能是多 token 的连续片段 ,因此不能仅靠一个位置确定。模型必须分别预测:
    • start_logits:每个 token 作为答案起始位置的可能性
    • end_logits:每个 token 作为答案结束位置的可能性
  • token 的顺序永远不会乱。BERT 的 tokenizer 保证分词后的 token 序列严格保持原文中的出现顺序。
  • 即使答案是一个完整的句子(如 "Paris is the capital of France"),模型也会通过训练学会将该句首尾 token 的得分调高,从而正确预测 start_idxend_idx

🔹 第三阶段:subword 分词与 offset_mapping 的作用

你敏锐地察觉到一个关键矛盾:

如果一个词被拆分成多个 subword token(如 "France" → "Fran" + "##ce"),它们还能连在一起吗?为什么 offset 是 (0,5), (6,8) 这种形式?

这引出了 NLP 中极为重要的 字符级对齐(character alignment)机制

我们通过一个完整的示例说明:

  • BERT 使用 WordPiece 分词器,会将未登录词(OOV)拆分为 subword。
  • 拆分后的 token 用 ## 标记延续性(如 ##ce 表示它是前一个 token 的后缀)。
  • offset_mapping 是一个列表,记录每个 token 在原始文本中的字符起止位置 ,格式为 (start_char, end_char),采用左闭右开区间。
  • 空格、标点、大小写字母都占用字符位置,因此:
    • "Paris"(0,5)(P=0, a=1, r=2, i=3, s=4,结束于 5)
    • 空格在位置 5
    • "is"(6,8)(i=6, s=7,结束于 8)

你进一步理解到:offset_mapping 的存在,使得模型即使在 subword 级别工作,也能精确还原原文中的字符位置,从而提取出保留原始大小写、空格和标点的答案。


🔹 第四阶段:offset_mapping 的具体使用与代码解析

你贴出了关键代码:

复制代码

Python

编辑

复制代码
offset_mapping = inputs['offset_mapping'][0]
start_char = offset_mapping[start_idx][0].item()
end_char = offset_mapping[end_idx][1].item()
answer = context[start_char:end_char]

我们逐行解析,明确了每一部分的作用:

  1. inputs['offset_mapping'][0]:取出第一个样本的 offset 映射表(去掉 batch 维度)
  2. offset_mapping[start_idx][0].item():获取答案起始 token 的起始字符位置,转为 Python 整数
  3. offset_mapping[end_idx][1].item():获取答案结束 token 的结束字符位置
  4. context[start_char:end_char]:直接从原始文本中切片,确保答案格式与原文一致

你还了解到:

  • 特殊 token(如 [CLS], [SEP], padding)的 offset 为 (0,0),需在提取时跳过
  • 需处理 start_idx > end_idx 的异常情况
  • 使用 offset_mappingconvert_tokens_to_string 更准确,因为它保留了原文格式

🔹 第五阶段:offset_mapping 的真实样例可视化

为了让你彻底理解 offset_mapping 的结构,我们构建了一个完整的示例,展示从原始文本到分词再到 offset 的全过程:

Token offset_mapping 原文对应部分
[CLS] (0,0)
what (0,4) "What"
is (5,7) "is"
paris (20,25) "Paris"
fran (44,48) "Fran"
##ce (48,50) "ce"
. (50,51) "."

通过这个表格,你直观地看到:

  • 即使 token 是小写,offset 仍指向原文大写位置
  • 被拆分的词在字符位置上是连续的
  • 空格和标点都占用位置,影响后续 token 的 offset

🎯 总结:你掌握了什么?

通过这次对话,你已经系统性地掌握了以下核心知识:

主题 掌握内容
logits 机制 理解了 logits 是模型的原始输出,其维度由任务决定,是后续 softmax 和 loss 计算的基础
问答模型结构 理解了 start_logitsend_logits 的设计原理,以及为何需要两个输出
token 与文本对齐 掌握了 offset_mapping 的作用,能够在 token 级预测和字符级答案之间自由转换
subword 分词影响 理解了 WordPiece 拆词机制及其对答案提取的影响,知道如何处理 ## 标记
实际工程技巧 学会了如何安全地提取答案,包括处理 padding、异常索引、大小写还原等细节

💡 最终洞见

你所探索的,不仅仅是"如何写几行代码提取答案",而是:

现代 NLP 模型如何在"离散的 subword 表示"与"连续的自然语言文本"之间架起桥梁。

offset_mapping 正是这座桥梁的基石。它让模型既能高效处理 subword 单元,又能精准还原人类可读的文本答案。

你已经从"会用 API"上升到了"理解机制"的层面 ------ 这正是成为真正 NLP 工程师或研究员的关键一步。

👏 你做得非常棒!继续深入,未来可期!

相关推荐
战族狼魂3 小时前
基于python+Java的二手车与奔驰销量数据可视化平台
java·数据库·python
Goboy3 小时前
【Python修仙笔记.3】Python函数作为秘技 - 封装你的仙法
后端·python
芥子沫3 小时前
经典机器学习&深度学习领域数据集介绍
人工智能·深度学习·机器学习·数据集
高洁013 小时前
大模型-去噪扩散概率模型(DDPM)采样算法详解
python·深度学习·神经网络·transformer·知识图谱
zy_destiny3 小时前
【工业场景】用YOLOv8实现行人识别
人工智能·深度学习·opencv·算法·yolo·机器学习
Goboy3 小时前
【Python修仙笔记.4】数据结构法宝 - 存储你的仙器
后端·python
爱编程的鱼5 小时前
OpenCV Python 绑定:原理与实战
c语言·开发语言·c++·python
晓风残月淡8 小时前
JVM字节码与类的加载(二):类加载器
jvm·python·php