BertForTokenClassification类
BertForTokenclassification类是Hugging Face transformers库中专门为基于BERT的序列标注任务(如命名实体识别NER、词性标注POS)设计的模型类。它在BERT的基础上添加了一个线性分类层,用于对每个token进行分类。
1、特点
任务类型:专为Token-level分类设计,即对输入序列中的每一个token预测一个标签。典型应用有命名实体识别(NER)、词性标注(POS)、语义角色标注(SRL)
2、模型架构
BERT Base Model (bert-base-uncased等)
↓
[CLS] Token 1 Token 2 ... Token N [SEP] (输出隐藏状态)
↓
Dropout Layer (可选)
↓
Linear Classifier (hidden_size → num_labels)
↓
Softmax (输出每个 token 的标签概率)
3、关键组件
BERT编辑器:提取上下文相关的token表示(支持所有BERT变体)
分类头:将每个token的隐藏状态映射到标签空间(hidden_size→num_labels)
CRF层(可选):可通过扩展添加条件随机场层,提升标签间依赖建模(需自定义实现)
4、使用方法
(1)加载预训练模型
python
import torch
from transformers import BertForTokenClassification, BertTokenizerFast
model = BertForTokenClassification.from_pretrained(
'chinese-bert-wwm',
num_labels=10, # 标签数量
id2label={0: 'O', 1: 'B-质量差', 2: 'I-质量差', ......} # 标签映射
)
tokenizer = BertTokenizerFast.from_pretrained('chinese-bert-wwm')
(2)数据预处理
python
text = '容易碎裂。质量太差,不值这个价。'
input = tokenizer(
text,
return_tensor='pt',
trucation=True,
padding=True,
return_offsets_mapping=True
)
# 假设0=O,1=B-质量差,2=I-质量差,3=B-易碎裂,4=I-易碎裂
labels = [3, 4, 4, 4, 4, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0]
inputs["labels"] = torch.tensor([labels])
(3)模型推理
python
outputs = model(**inputs)
logits = outputs.logits # 形状:(batch_size, seq_len, num_labels)
# 获取预测标签
predictions = torch.argmax(logits, dim=-1)[0].tolist()
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# 打印结果
for token, pred in zip(tokens, predictions):
print(f"{token:15}→{model.config.id2label.get(pred, 'UNK')}")
输出示例
python
[CLS] →O
容 →B-易碎裂
易 →I-易碎裂
碎 →I-易碎裂
裂 →I-易碎裂
。 →O
质 →B-质量差
量 →I-质量差
太 →I-质量差
差 →I-质量差
, →O
不 →O
值 →O
这 →O
个 →O
价 →O
。 →O
[SEP] →O