基于 BERT 模型实现命名实体识别(NER)任务

基于BERT的中文命名实体识别(NER)

📋 项目概述

本项目实现了一个基于BERT(Bidirectional Encoder Representations from Transformers)的中文命名实体识别系统。该系统能够从中文文本中自动识别并提取四类命名实体:

  • PERSON(人名)
  • LOCATION(地点)
  • TIME(时间)
  • ORGANIZATION(组织机构)

🎯 任务介绍

什么是命名实体识别(NER)?

命名实体识别(Named Entity Recognition, NER)是自然语言处理中的一项基础任务,属于序列标注问题。其目标是从非结构化的文本中识别出具有特定意义的实体,并将它们分类到预定义的类别中。

示例:

复制代码
输入:我在北京工作,明天要去见张三。
输出:
  - LOCATION: ["北京"]
  - TIME: ["明天"]
  - PERSON: ["张三"]

任务背景

在信息爆炸的时代,从海量文本中快速准确地提取关键信息变得越来越重要。命名实体识别技术广泛应用于:

  • 信息抽取:从新闻、文档中提取关键信息
  • 知识图谱构建:自动构建实体关系网络
  • 搜索引擎优化:提升搜索结果的准确性
  • 智能问答系统:理解用户问题中的关键实体
  • 文本分析:舆情分析、内容推荐等

传统的中文NER任务面临以下挑战:

  1. 中文分词歧义:中文没有明显的词边界
  2. 实体边界识别:需要准确判断实体的起始和结束位置
  3. 上下文理解:同一词汇在不同语境下可能属于不同实体类型
  4. 标注体系:需要选择合适的标注方案(如BIO、BIOES等)

💡 解决方案

本项目采用BERT + 序列标注的架构来解决中文NER任务,具体方案如下:

技术架构

复制代码
输入文本 → BERT编码器 → 线性分类层 → CRF层(可选)→ 标签序列

核心思路

  1. BERT作为特征提取器

    • 利用BERT预训练模型强大的上下文理解能力
    • BERT能够捕捉词汇的双向上下文信息,解决歧义问题
    • 使用中文BERT(bert-base-chinese)适配中文任务
  2. 序列标注方法

    • 采用BIO标注体系(Begin-Inside-Outside)
    • 将NER任务转化为对每个字符/词进行分类的问题
    • 每个位置预测一个标签(B-实体类型、I-实体类型、O)
  3. 可选CRF层

    • 通过CRF(条件随机场)层约束标签之间的转移
    • 确保标签序列的合理性(如I不能出现在B之前)
    • 提升模型的整体性能

CRF(条件随机场)详解

什么是CRF?

CRF(Conditional Random Field,条件随机场) 是一种用于序列标注的概率图模型。它能够建模序列中相邻标签之间的依赖关系,通过考虑标签转移概率来生成更合理的标签序列。

CRF的定义

CRF是一种无向图模型,用于对序列数据进行标注。在NER任务中,CRF层位于BERT编码器和最终标签预测之间,它:

  1. 学习标签转移概率:学习从一个标签转移到另一个标签的概率

    • 例如:从B-PERSON转移到I-PERSON的概率应该很高
    • O转移到I-PERSON的概率应该很低(因为I不能独立出现)
  2. 全局最优解码:使用Viterbi算法找到全局最优的标签序列,而不是简单地逐位置选择最大概率的标签

CRF的作用
  1. 约束标签序列的合法性

    • 防止出现不合法的标签组合(如I-PERSON出现在B-PERSON之前)
    • 确保BIO标注体系的基本规则得到遵守
    • 例如:O → I-PERSON这种转移是不允许的
  2. 考虑上下文依赖

    • 不仅考虑当前位置的特征,还考虑相邻位置的标签
    • 例如:如果前一个位置是B-LOCATION,当前位置更可能是I-LOCATION而不是O
  3. 提升模型性能

    • 通过全局优化,减少局部最优导致的错误
    • 在序列标注任务中,CRF通常能提升1-3%的F1分数
CRF在NER中的工作原理

训练阶段

  • CRF层学习一个转移矩阵 (Transition Matrix),大小为[class_num, class_num]
  • 矩阵中的每个元素T[i][j]表示从标签i转移到标签j的得分
  • 训练时,CRF损失函数会惩罚不合法的标签转移,奖励合法的转移

预测阶段

  • 使用Viterbi算法进行解码
  • Viterbi算法是一种动态规划算法,能够找到全局最优的标签序列
  • 相比简单的argmax方法,Viterbi考虑了整个序列的标签转移约束
CRF vs 不使用CRF的对比
对比维度 不使用CRF 使用CRF
预测方式 逐位置独立预测(argmax) 全局最优解码(Viterbi)
标签约束 无约束,可能出现非法序列 有约束,确保序列合法性
性能 基线性能 通常提升1-3%
计算复杂度 低(O(n)) 较高(O(n×m²),n为序列长度,m为标签数)
适用场景 简单任务,标签独立性较强 序列标注任务,标签依赖性强
本项目中的CRF实现

在本项目中,CRF层是可选的(通过config["use_crf"]控制):

  • 使用CRF时

    • 训练:计算CRF损失(负对数似然)
    • 预测:使用crf_layer.decode()进行Viterbi解码
  • 不使用CRF时

    • 训练:使用交叉熵损失
    • 预测:使用torch.argmax()逐位置选择最大概率标签

代码示例 (来自model.py):

python 复制代码
if self.use_crf:
    mask = target.gt(-1)  # 创建mask,忽略padding位置
    return - self.crf_layer(predict, target, mask, reduction="mean")  # CRF损失
else:
    return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))  # 交叉熵损失
何时使用CRF?

建议使用CRF的情况

  • 标签之间有明显的依赖关系(如BIO体系)
  • 需要确保标签序列的合法性
  • 追求更高的模型性能
  • 计算资源充足

可以不使用CRF的情况

  • 标签独立性较强
  • 需要快速推理速度
  • 计算资源有限
  • 模型已经表现很好,CRF提升不明显

标注体系(BIO)

本项目使用BIO标注体系,共9个标签:

BIO(Begin-Inside-Outside)标注用于序列标注任务,定义如下:

  • B-XXX(Begin) :表示某类实体的起始位置

  • I-XXX(Inside) :表示实体内部的后续位置

  • O(Outside):表示不属于任何实体的普通词/字

标注形式示例(句子"我在北京工作"):

  • 句子:我 在 北 京 工 作

  • 标签:O O B-LOCATION I-LOCATION O O

说明 :实体必须以 B- 开头,后续连续的实体位置用 I- 标注,不属于实体的部分使用 O

标签 含义 示例
B-PERSON 人名开始 "张"(在"张三"中)
I-PERSON 人名内部 "三"(在"张三"中)
B-LOCATION 地点开始 "北"(在"北京"中)
I-LOCATION 地点内部 "京"(在"北京"中)
B-TIME 时间开始 "明"(在"明天"中)
I-TIME 时间内部 "天"(在"明天"中)
B-ORGANIZATION 组织开始 "清"(在"清华大学"中)
I-ORGANIZATION 组织内部 "华"、"大"、"学"(在"清华大学"中)
O 非实体 其他所有字符

📁 项目结构

复制代码
week9-序列标注/
├── loader.py          # 数据加载模块
├── model.py           # BERT模型定义
├── main.py            # 训练主程序
├── evaluate.py        # 模型评估模块
├── config.py          # 配置文件
├── ner_data/          # 数据目录
│   ├── train          # 训练集
│   ├── dev            # 验证集
│   ├── test           # 测试集
│   └── schema.json    # 标签映射文件
└── README.md          # 项目文档

🔧 代码详解

1. loader.py - 数据加载模块

功能:负责从磁盘读取NER数据,并将其转换为模型可用的张量格式。

核心类:DataGenerator

主要方法:
  1. load() - 数据加载主函数

    • 输入格式 :CoNLL格式的文本文件

      复制代码
      我 O
      爱 O
      北 B-LOCATION
      京 I-LOCATION
      
      张 B-PERSON
      三 I-PERSON
    • 处理流程

      • 按空行(\n\n)切分句子
      • 逐行解析:token<空格>label
      • 将标签字符串映射为数字ID(通过schema.json
      • 将字符序列编码为BERT词表ID
      • 对序列进行padding/truncate到固定长度(max_length
    • 输出 :每条样本包含:

      • input_ids: [max_length] 的token ID序列
      • labels: [max_length] 的标签ID序列(padding用-1)
  2. encode_sentence() - 字符编码

    • 将字符列表转换为BERT词表ID
    • 查不到的词用[UNK]替代
    • 自动padding到max_length
  3. padding() - 序列对齐

    • 截断超长序列
    • 补齐短序列
    • 确保batch内序列长度一致

关键设计

  • 使用-1作为labels的padding值,在计算loss时会被忽略(ignore_index=-1
  • 保存原始句子字符串(self.sentences),用于评估时对齐预测结果
完整代码:
python 复制代码
# -*- coding: utf-8 -*-

import json
import re
import os
import torch
import random
import jieba
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from transformers import BertTokenizer
"""
数据加载
"""


class DataGenerator:
    def __init__(self, data_path, config):
        self.config = config
        self.path = data_path
        self.tokenizer = load_vocab(config["bert_path"])
        self.sentences = []
        self.schema = self.load_schema(config["schema_path"])
        self.load()

    def load(self):
        """
        从磁盘读取序列标注数据,并转换为可训练的张量样本。

        期望的数据文件格式(类似 CoNLL / NER 常用格式):
        - 一行一个 token(通常是"字")及其标签:  token<空格>label
          例如:
            我 O
            爱 O
            北 B-LOCATION
            京 I-LOCATION
        - 句子与句子之间用"空行"分隔(即连续两个换行符 \n\n)

        最终每条样本的数据形态:
        - input_ids: LongTensor,[max_length],token id 序列(截断/补齐到固定长度)
        - labels:    LongTensor,[max_length],标签 id 序列(截断/补齐到固定长度;padding 用 -1)
        """
        self.data = []
        with open(self.path, encoding="utf8") as f:
            # 1) 先按空行切分出"句子块"。每个 segment 对应一句(或一段)
            #
            # 例:假设你的原始文件内容长这样(两句话,中间空一行):
            #   我 O
            #   爱 O
            #   北 B-LOCATION
            #   京 I-LOCATION
            #
            #   张 B-PERSON
            #   三 I-PERSON
            #   来 O
            #   了 O
            #
            # 那么 f.read().split("\n\n") 得到:
            #   segments = [
            #     "我 O\n爱 O\n北 B-LOCATION\n京 I-LOCATION",
            #     "张 B-PERSON\n三 I-PERSON\n来 O\n了 O"
            #   ]
            segments = f.read().split("\n\n")
            for segment in segments:
                # sentenece: 当前句子的 token 列表(这里按"字"处理)
                sentenece = []
                # labels: 与 token 对齐的标签 id 列表(长度应与 sentenece 一致)
                labels = []
                for line in segment.split("\n"):
                    if line.strip() == "":
                        continue
                    # 每行必须恰好两列:token 和 label
                    char, label = line.split()
                    sentenece.append(char)
                    # 将字符串标签(如 "B-PERSON")映射为数字 id
                    labels.append(self.schema[label])

                # sentence: 仅用于评估/对齐的原始句子字符串(例如 "我爱北京")
                #
                # 例:对第一句话:
                #   sentenece = ["我","爱","北","京"]
                #   labels(未padding前) = [8, 8, 0, 4]   # 假设 schema: O->8, B-LOCATION->0, I-LOCATION->4
                #   sentence = "我爱北京"
                sentence = "".join(sentenece)
                self.sentences.append(sentence)

                # 2) token -> vocab id,并 padding 到 max_length
                #
                # 例:假设 max_length=6,且 tokenizer.vocab 映射示意:
                #   "我"->1001, "爱"->1002, "北"->1003, "京"->1004, "[PAD]"->0
                # 则 encode_sentence(sentenece) 可能得到(实际 id 以你的 bert vocab 为准):
                #   input_ids(已padding后) = [1001, 1002, 1003, 1004, 0, 0]   # 长度固定为 6
                input_ids = self.encode_sentence(sentenece)
                # 3) labels 同样 padding 到 max_length;pad 值用 -1(通常用于 loss ignore)
                #
                # 例:labels padding 后:
                #   labels(已padding后) = [8, 8, 0, 4, -1, -1]   # 长度固定为 max_length
                labels = self.padding(labels, -1)

                # 4) 存成 torch 张量,DataLoader 取出后 batch 维度会变成 [batch_size, max_length]
                self.data.append([torch.LongTensor(input_ids), torch.LongTensor(labels)])
        return

    def encode_sentence(self, text, padding=True):
        """
        将 token 列表编码为词表 id 列表。

        - text: List[str],例如 ["我","爱","北","京"]
        - 返回: List[int],例如 [101, 4263, ...](具体 id 取决于 tokenizer.vocab)

        注意:这里是"逐 token 查 vocab",不是 BERT 的 WordPiece 自动分词流程。
        """
        # 查不到的 token 使用 [UNK]
        input_ids = [self.tokenizer.vocab.get(char, self.tokenizer.vocab["[UNK]"]) for char in text]
        if padding:
            input_ids = self.padding(input_ids, self.tokenizer.vocab["[PAD]"])
        return input_ids
    

    #补齐或截断输入的序列,使其可以在一个batch内运算
    def padding(self, input_id, pad_token=0):
        """
        将序列截断或补齐到固定长度 config["max_length"]。

        - input_id: List[int],长度不定
        - pad_token: 补齐用的值
          - 对 input_ids 通常用 [PAD] 的 id
          - 对 labels 这里用 -1(常用于 loss 的 ignore_index)
        """
        input_id = input_id[:self.config["max_length"]]
        input_id += [pad_token] * (self.config["max_length"] - len(input_id))
        return input_id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

    def load_schema(self, path):
        with open(path, encoding="utf8") as f:
            return json.load(f)

def load_vocab(vocab_path):
    return BertTokenizer.from_pretrained(vocab_path)


#用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):
    dg = DataGenerator(data_path, config)
    dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)
    return dl



if __name__ == "__main__":
    from config import Config
    dg = DataGenerator("ner_data/train", Config)
    dl = DataLoader(dg, batch_size=32)  
    for x,y in dl:
        print(x.shape, y.shape)
        print(x[1], y[1])
        input()

2. model.py - 模型定义

功能:定义基于BERT的NER模型架构。

核心类:TorchModel

模型结构:
python 复制代码
输入: [batch_size, max_length]  # token IDs
  ↓
BERT编码器: BertModel.from_pretrained()
  ↓
输出: [batch_size, max_length, hidden_size]  # BERT隐藏层表示
  ↓
线性分类层: nn.Linear(hidden_size, class_num)
  ↓
输出: [batch_size, max_length, class_num]  # 每个位置的标签概率分布
  ↓
CRF层(可选): CRF层约束标签转移
  ↓
最终标签序列
关键组件:
  1. BERT编码器self.bert

    • 使用预训练的中文BERT模型(bert-base-chinese
    • 输出每个位置的上下文表示向量(768维)
  2. 分类层self.classify

    • 全连接层:hidden_size → class_num(9类)
    • 将BERT输出映射到标签空间
  3. CRF层self.crf_layer,可选)

    • 使用torchcrf库实现
    • 学习标签之间的转移概率
    • 确保标签序列的合法性
  4. 损失函数self.loss

    • CrossEntropyLoss(ignore_index=-1)
    • 忽略padding位置的损失
forward()方法:
  • 训练模式target is not None):

    • 如果使用CRF:计算CRF损失(考虑标签转移)
    • 否则:计算交叉熵损失
  • 预测模式target is None):

    • 如果使用CRF:使用Viterbi算法解码最优标签序列
    • 否则:返回每个位置的标签概率分布

优化器选择

  • choose_optimizer()函数支持Adam和SGD两种优化器
完整代码:
python 复制代码
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torchcrf import CRF
from transformers import BertModel
"""
建立网络模型结构
"""

class TorchModel(nn.Module):
    def __init__(self, config):
        super(TorchModel, self).__init__()
        max_length = config["max_length"]
        class_num = config["class_num"]
        # self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
        # self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)
        self.bert = BertModel.from_pretrained(config["bert_path"], return_dict=False)
        self.classify = nn.Linear(self.bert.config.hidden_size, class_num)
        self.crf_layer = CRF(class_num, batch_first=True)
        self.use_crf = config["use_crf"]
        self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1)  #loss采用交叉熵损失

    #当输入真实标签,返回loss值;无真实标签,返回预测值
    def forward(self, x, target=None):
        # x = self.embedding(x)  #input shape:(batch_size, sen_len)
        # x, _ = self.layer(x)      #input shape:(batch_size, sen_len, input_dim)

        x, _ = self.bert(x)
        predict = self.classify(x) #ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)

        if target is not None:
            if self.use_crf:
                mask = target.gt(-1)
                return - self.crf_layer(predict, target, mask, reduction="mean")
            else:
                #(number, class_num), (number)
                return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))
        else:
            if self.use_crf:
                return self.crf_layer.decode(predict)
            else:
                return predict


def choose_optimizer(config, model):
    optimizer = config["optimizer"]
    learning_rate = config["learning_rate"]
    if optimizer == "adam":
        return Adam(model.parameters(), lr=learning_rate)
    elif optimizer == "sgd":
        return SGD(model.parameters(), lr=learning_rate)


if __name__ == "__main__":
    from config import Config
    model = TorchModel(Config)

3. main.py - 训练主程序

功能:模型训练的完整流程控制。

训练流程:
  1. 初始化阶段

    • 创建模型保存目录
    • 加载训练数据(load_data()
    • 初始化BERT模型(TorchModel
    • 检测GPU可用性,自动迁移到GPU
    • 选择优化器(Adam/SGD)
    • 初始化评估器(Evaluator
  2. 训练循环(每个epoch):

    复制代码
    for epoch in range(config["epoch"]):
        model.train()  # 设置为训练模式
        for batch_data in train_data:
            optimizer.zero_grad()  # 清零梯度
            input_id, labels = batch_data
            loss = model(input_id, labels)  # 前向传播,计算损失
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数
        evaluator.eval(epoch)  # 在验证集上评估
  3. 关键特性

    • 自动GPU检测和迁移
    • 每个epoch结束后在验证集上评估
    • 记录训练loss,便于监控训练过程
    • 支持模型保存(代码中已注释,可按需启用)
完整代码:
python 复制代码
# -*- coding: utf-8 -*-

import torch
import os
import random
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data

logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

"""
模型训练主程序
"""

def main(config):
    #创建保存模型的目录
    if not os.path.isdir(config["model_path"]):
        os.mkdir(config["model_path"])
    #加载训练数据
    train_data = load_data(config["train_data_path"], config)
    #加载模型
    model = TorchModel(config)
    # 标识是否使用gpu
    cuda_flag = torch.cuda.is_available()
    if cuda_flag:
        logger.info("gpu可以使用,迁移模型至gpu")
        model = model.cuda()
    #加载优化器
    optimizer = choose_optimizer(config, model)
    #加载效果测试类
    evaluator = Evaluator(config, model, logger)
    #训练
    for epoch in range(config["epoch"]):
        epoch += 1
        model.train()
        logger.info("epoch %d begin" % epoch)
        train_loss = []
        for index, batch_data in enumerate(train_data):
            optimizer.zero_grad()
            if cuda_flag:
                batch_data = [d.cuda() for d in batch_data]
            input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况
            loss = model(input_id, labels)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            if index % int(len(train_data) / 2) == 0:
                logger.info("batch loss %f" % loss)
        logger.info("epoch average loss: %f" % np.mean(train_loss))
        evaluator.eval(epoch)
    model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)
    # torch.save(model.state_dict(), model_path)
    return model, train_data

if __name__ == "__main__":
    model, train_data = main(Config)

4. evaluate.py - 模型评估模块

功能:评估模型在验证集上的表现,计算准确率、召回率和F1分数。

核心类:Evaluator

主要方法:
  1. eval() - 评估主函数

    • 遍历验证集的每个batch
    • 使用模型进行预测(model.eval() + torch.no_grad()
    • 调用write_stats()统计预测结果
    • 调用show_stats()显示评估指标
  2. write_stats() - 统计预测结果

    • 将标签序列解码为实体字典(decode()
    • 统计每个实体类型的:
      • 正确识别数:预测实体中与真实实体匹配的数量
      • 样本实体数:真实标签中的实体总数
      • 识别出实体数:模型预测出的实体总数
  3. decode() - 标签序列解码

    • 输入:原始句子 + 标签ID序列
    • 处理
      • 将标签ID序列转换为字符串(如[8,8,0,4]"8844"
      • 使用正则表达式匹配BIO模式:
        • "(04+)" → LOCATION(0=B-LOCATION, 4=I-LOCATION)
        • "(15+)" → ORGANIZATION
        • "(26+)" → PERSON
        • "(37+)" → TIME
    • 输出 :实体字典,如{"LOCATION": ["北京"], "PERSON": ["张三"], ...}
  4. show_stats() - 计算并显示评估指标

    • Precision(准确率) = 正确识别数 / 识别出实体数
    • Recall(召回率) = 正确识别数 / 样本实体数
    • F1分数 = 2 × (Precision × Recall) / (Precision + Recall)
    • Macro-F1:对所有类别F1分数求平均
    • Micro-F1:将所有类别统计合并后计算F1

评估指标说明

  • 准确率:模型预测的实体中有多少是正确的
  • 召回率:真实实体中有多少被模型识别出来
  • F1分数:准确率和召回率的调和平均,综合评估指标
完整代码:
python 复制代码
# -*- coding: utf-8 -*-
import torch
import re
import numpy as np
from collections import defaultdict
from loader import load_data

"""
模型效果测试
"""

class Evaluator:
    """
    模型评估器类,用于评估命名实体识别模型的效果
    计算准确率、召回率和F1分数
    """
    def __init__(self, config, model, logger):
        """
        初始化评估器
        :param config: 配置字典,包含模型和数据相关配置
        :param model: 训练好的模型
        :param logger: 日志记录器
        """
        self.config = config
        self.model = model
        self.logger = logger
        # 加载验证数据,shuffle=False确保数据顺序固定,便于索引对应
        self.valid_data = load_data(config["valid_data_path"], config, shuffle=False)


    def eval(self, epoch):
        """
        评估模型在验证集上的表现
        :param epoch: 当前训练轮次
        """
        self.logger.info("开始测试第%d轮模型效果:" % epoch)
        # 初始化统计字典,用于记录各类实体的识别情况
        self.stats_dict = {"LOCATION": defaultdict(int),
                           "TIME": defaultdict(int),
                           "PERSON": defaultdict(int),
                           "ORGANIZATION": defaultdict(int)}
        # 将模型设置为评估模式(关闭dropout等)
        self.model.eval()
        # 遍历验证数据的每个批次
        for index, batch_data in enumerate(self.valid_data):
            # 获取当前批次对应的原始句子
            # 注意:这里假设数据未打乱(shuffle=False),索引才能正确对应
            sentences = self.valid_data.dataset.sentences[index * self.config["batch_size"]: (index+1) * self.config["batch_size"]]
            # 如果GPU可用,将数据转移到GPU
            if torch.cuda.is_available():
                batch_data = [d.cuda() for d in batch_data]
            # 解包批次数据:input_id是输入序列,labels是真实标签
            # 注意:输入变化时这里需要修改,比如多输入、多输出的情况
            input_id, labels = batch_data
            # 禁用梯度计算,节省内存和加速推理
            with torch.no_grad():
                # 不输入labels,使用模型当前参数进行预测
                pred_results = self.model(input_id)
            # 统计当前批次的预测结果
            self.write_stats(labels, pred_results, sentences)
        # 显示最终的评估统计结果
        self.show_stats()
        return

    def write_stats(self, labels, pred_results, sentences):
        """
        统计预测结果,计算各类实体的识别情况
        :param labels: 真实标签(ground truth)
        :param pred_results: 模型预测结果
        :param sentences: 原始句子列表
        """
        # 确保三个列表长度一致
        assert len(labels) == len(pred_results) == len(sentences)
        # 如果未使用CRF层,需要从概率分布中取最大值的索引作为预测标签
        if not self.config["use_crf"]:
            pred_results = torch.argmax(pred_results, dim=-1)
        # 逐句处理:真实标签、预测标签、原始句子
        for true_label, pred_label, sentence in zip(labels, pred_results, sentences):
            # 如果未使用CRF,将tensor转换为列表
            if not self.config["use_crf"]:
                pred_label = pred_label.cpu().detach().tolist()
            # 将真实标签tensor转换为列表
            true_label = true_label.cpu().detach().tolist()
            # 将标签序列解码为实体字典(真实实体和预测实体)
            true_entities = self.decode(sentence, true_label)
            pred_entities = self.decode(sentence, pred_label)
            # 统计各类实体的识别情况
            # 正确率 = 识别出的正确实体数 / 识别出的实体数
            # 召回率 = 识别出的正确实体数 / 样本的实体数
            for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
                # 统计正确识别的实体数(预测实体中与真实实体匹配的数量)
                self.stats_dict[key]["正确识别"] += len([ent for ent in pred_entities[key] if ent in true_entities[key]])
                # 统计样本中的真实实体总数
                self.stats_dict[key]["样本实体数"] += len(true_entities[key])
                # 统计模型识别出的实体总数
                self.stats_dict[key]["识别出实体数"] += len(pred_entities[key])
        return

    def show_stats(self):
        """
        显示并计算最终的评估指标:准确率、召回率、F1分数
        包括Macro-F1(宏平均)和Micro-F1(微平均)
        """
        F1_scores = []
        # 计算每个实体类别的评估指标
        for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
            # 准确率(Precision)= 识别出的正确实体数 / 识别出的实体数
            # 加1e-5防止除零错误
            precision = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["识别出实体数"])
            # 召回率(Recall)= 识别出的正确实体数 / 样本的实体数
            recall = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["样本实体数"])
            # F1分数 = 2 * (准确率 * 召回率) / (准确率 + 召回率)
            F1 = (2 * precision * recall) / (precision + recall + 1e-5)
            F1_scores.append(F1)
            # 输出每个类别的评估指标
            self.logger.info("%s类实体,准确率:%f, 召回率: %f, F1: %f" % (key, precision, recall, F1))
        # Macro-F1:对所有类别的F1分数求平均(宏平均)
        self.logger.info("Macro-F1: %f" % np.mean(F1_scores))
        # 计算Micro-F1:将所有类别的统计合并后计算
        # 所有类别正确识别的实体总数
        correct_pred = sum([self.stats_dict[key]["正确识别"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
        # 所有类别识别出的实体总数
        total_pred = sum([self.stats_dict[key]["识别出实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
        # 所有类别样本中的真实实体总数
        true_enti = sum([self.stats_dict[key]["样本实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
        # Micro准确率:所有类别合并后的准确率
        micro_precision = correct_pred / (total_pred + 1e-5)
        # Micro召回率:所有类别合并后的召回率
        micro_recall = correct_pred / (true_enti + 1e-5)
        # Micro-F1:所有类别合并后的F1分数(微平均)
        micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5)
        self.logger.info("Micro-F1 %f" % micro_f1)
        self.logger.info("--------------------")
        return

    '''
    标签到数字的映射关系(BIO标注体系):
    {
      "B-LOCATION": 0,      # 地点实体的开始
      "B-ORGANIZATION": 1,  # 组织实体的开始
      "B-PERSON": 2,        # 人名实体的开始
      "B-TIME": 3,          # 时间实体的开始
      "I-LOCATION": 4,      # 地点实体的内部
      "I-ORGANIZATION": 5,  # 组织实体的内部
      "I-PERSON": 6,        # 人名实体的内部
      "I-TIME": 7,          # 时间实体的内部
      "O": 8                # 非实体(Outside)
    }
    '''
    def decode(self, sentence, labels):
        """
        将标签序列解码为实体字典
        使用正则表达式匹配BIO标签模式,提取实体文本
        
        :param sentence: 原始句子字符串,如 "我在北京工作"
        :param labels: 标签列表,如 [8, 8, 0, 4, 8, 8](对应"我"、"在"、"北"、"京"、"工"、"作")
        :return: 字典,包含各类实体列表,如 {"LOCATION": ["北京"], "PERSON": [], ...}
        """
        # 将标签列表转换为字符串,只取与句子长度相同的部分
        # 例如:[8, 8, 0, 4, 8, 8] -> "884488"
        labels = "".join([str(x) for x in labels[:len(sentence)]])
        # 初始化结果字典,每个实体类型对应一个列表
        results = defaultdict(list)
        
        # 提取LOCATION(地点)实体
        # 正则表达式 "(04+)" 匹配:0(B-LOCATION)后跟一个或多个4(I-LOCATION)
        # 例如:"044" 或 "0444" 表示一个地点实体
        for location in re.finditer("(04+)", labels):
            s, e = location.span()  # 获取匹配的起始位置s和结束位置e
            results["LOCATION"].append(sentence[s:e])  # 提取对应位置的文本作为地点实体
        
        # 提取ORGANIZATION(组织)实体
        # 正则表达式 "(15+)" 匹配:1(B-ORGANIZATION)后跟一个或多个5(I-ORGANIZATION)
        # 例如:"155" 或 "1555" 表示一个组织实体
        for location in re.finditer("(15+)", labels):
            s, e = location.span()  # 获取匹配的起始位置s和结束位置e
            results["ORGANIZATION"].append(sentence[s:e])  # 提取对应位置的文本作为组织实体
        
        # 提取PERSON(人名)实体
        # 正则表达式 "(26+)" 匹配:2(B-PERSON)后跟一个或多个6(I-PERSON)
        # 例如:"266" 或 "2666" 表示一个人名实体
        for location in re.finditer("(26+)", labels):
            s, e = location.span()  # 获取匹配的起始位置s和结束位置e
            results["PERSON"].append(sentence[s:e])  # 提取对应位置的文本作为人名实体
        
        # 提取TIME(时间)实体
        # 正则表达式 "(37+)" 匹配:3(B-TIME)后跟一个或多个7(I-TIME)
        # 例如:"377" 或 "3777" 表示一个时间实体
        for location in re.finditer("(37+)", labels):
            s, e = location.span()  # 获取匹配的起始位置s和结束位置e
            results["TIME"].append(sentence[s:e])  # 提取对应位置的文本作为时间实体
        
        return results

5. config.py - 配置文件

功能:集中管理所有超参数和路径配置。

主要配置项:
配置项 说明 默认值
model_path 模型保存路径 "model_output"
schema_path 标签映射文件路径 "ner_data/schema.json"
train_data_path 训练集路径 "ner_data/train"
valid_data_path 验证集路径 "ner_data/test"
bert_path BERT模型路径 r"F:\pretrain_models\bert-base-chinese"
max_length 序列最大长度 100
batch_size 批次大小 16
epoch 训练轮数 20
optimizer 优化器类型 "adam"
learning_rate 学习率 1e-4
use_crf 是否使用CRF层 False
class_num 标签类别数 9

使用方式

python 复制代码
from config import Config
model = TorchModel(Config)
完整代码:
python 复制代码
# -*- coding: utf-8 -*-

"""
配置参数信息
"""

Config = {
    "model_path": "model_output",
    "schema_path": "ner_data/schema.json",
    "train_data_path": "ner_data/train",
    "valid_data_path": "ner_data/test",
    "vocab_path":"chars.txt",
    "max_length": 100,
    "hidden_size": 256,
    "num_layers": 2,
    "epoch": 20,
    "batch_size": 16,
    "optimizer": "adam",
    "learning_rate": 1e-4,
    "use_crf": False,
    "class_num": 9,
    "bert_path": r"F:\pretrain_models\bert-base-chinese"
}

🚀 如何使用

环境要求

bash 复制代码
torch>=1.8.0
transformers>=4.0.0
torchcrf>=1.1.0
numpy

数据准备

  1. 数据格式 :CoNLL格式,每行token<空格>label,句子间空行分隔
  2. 标签体系 :BIO标注,标签定义在schema.json
  3. 数据目录 :将训练/验证/测试数据放在ner_data/目录下

训练模型

bash 复制代码
python main.py

评估模型

模型训练过程中会自动在验证集上评估,每个epoch结束后输出:

  • 每个实体类型的Precision、Recall、F1
  • Macro-F1和Micro-F1

📈 评估指标详解

NER任务中常用的评估指标及其计算公式如下:

符号说明

在介绍具体指标之前,先统一说明公式中使用的符号含义:

符号 含义 说明
TPTPTP (True Positive) 正确识别的实体数 模型预测的实体中,与真实实体匹配的数量
FPFPFP (False Positive) 错误识别的实体数 模型预测出但实际不存在的实体(误报)
FNFNFN (False Negative) 漏识别的实体数 真实存在但模型未识别出的实体(漏报)
PrecisionPrecisionPrecision 准确率/精确率 模型预测的实体中,有多少是正确的
RecallRecallRecall 召回率 真实实体中,有多少被模型识别出来
F1F1F1 F1分数 准确率和召回率的调和平均数
TPPERSONTP_{PERSON}TPPERSON PERSON类别的TP PERSON类别中正确识别的实体数
TPLOCATIONTP_{LOCATION}TPLOCATION LOCATION类别的TP LOCATION类别中正确识别的实体数
TPTIMETP_{TIME}TPTIME TIME类别的TP TIME类别中正确识别的实体数
TPORGANIZATIONTP_{ORGANIZATION}TPORGANIZATION ORGANIZATION类别的TP ORGANIZATION类别中正确识别的实体数
FPPERSONFP_{PERSON}FPPERSON PERSON类别的FP PERSON类别中错误识别的实体数
FPLOCATIONFP_{LOCATION}FPLOCATION LOCATION类别的FP LOCATION类别中错误识别的实体数
FPTIMEFP_{TIME}FPTIME TIME类别的FP TIME类别中错误识别的实体数
FPORGANIZATIONFP_{ORGANIZATION}FPORGANIZATION ORGANIZATION类别的FP ORGANIZATION类别中错误识别的实体数
FNPERSONFN_{PERSON}FNPERSON PERSON类别的FN PERSON类别中漏识别的实体数
FNLOCATIONFN_{LOCATION}FNLOCATION LOCATION类别的FN LOCATION类别中漏识别的实体数
FNTIMEFN_{TIME}FNTIME TIME类别的FN TIME类别中漏识别的实体数
FNORGANIZATIONFN_{ORGANIZATION}FNORGANIZATION ORGANIZATION类别的FN ORGANIZATION类别中漏识别的实体数
F1PERSONF1_{PERSON}F1PERSON PERSON类别的F1分数 PERSON类别的F1值
F1LOCATIONF1_{LOCATION}F1LOCATION LOCATION类别的F1分数 LOCATION类别的F1值
F1TIMEF1_{TIME}F1TIME TIME类别的F1分数 TIME类别的F1值
F1ORGANIZATIONF1_{ORGANIZATION}F1ORGANIZATION ORGANIZATION类别的F1分数 ORGANIZATION类别的F1值
TPtotalTP_{total}TPtotal 所有类别的TP总和 TPPERSON+TPLOCATION+TPTIME+TPORGANIZATIONTP_{PERSON} + TP_{LOCATION} + TP_{TIME} + TP_{ORGANIZATION}TPPERSON+TPLOCATION+TPTIME+TPORGANIZATION
FPtotalFP_{total}FPtotal 所有类别的FP总和 FPPERSON+FPLOCATION+FPTIME+FPORGANIZATIONFP_{PERSON} + FP_{LOCATION} + FP_{TIME} + FP_{ORGANIZATION}FPPERSON+FPLOCATION+FPTIME+FPORGANIZATION
FNtotalFN_{total}FNtotal 所有类别的FN总和 FNPERSON+FNLOCATION+FNTIME+FNORGANIZATIONFN_{PERSON} + FN_{LOCATION} + FN_{TIME} + FN_{ORGANIZATION}FNPERSON+FNLOCATION+FNTIME+FNORGANIZATION

1. 基础指标

Precision(准确率/精确率)

定义:模型预测的实体中,有多少是正确的。

计算公式

Precision=TPTP+FP=正确识别的实体数模型识别出的实体总数Precision = \frac{TP}{TP + FP} = \frac{\text{正确识别的实体数}}{\text{模型识别出的实体总数}}Precision=TP+FPTP=模型识别出的实体总数正确识别的实体数

其中符号含义见上方"符号说明"表格。

示例

  • 真实实体:["北京", "张三"]
  • 预测实体:["北京", "上海", "张三"]
  • Precision = 2 / 3 = 0.667(预测了3个,其中2个正确)

Recall(召回率)

定义:真实实体中,有多少被模型识别出来。

计算公式

Recall=TPTP+FN=正确识别的实体数样本中的真实实体总数Recall = \frac{TP}{TP + FN} = \frac{\text{正确识别的实体数}}{\text{样本中的真实实体总数}}Recall=TP+FNTP=样本中的真实实体总数正确识别的实体数

其中符号含义见上方"符号说明"表格。

示例

  • 真实实体:["北京", "张三", "李四"]
  • 预测实体:["北京", "张三"]
  • Recall = 2 / 3 = 0.667(真实有3个,只识别出2个)

F1分数(F1-Score)

定义:准确率和召回率的调和平均数,综合评估模型性能。

计算公式

F1=2×Precision×RecallPrecision+Recall=2×TP2×TP+FP+FNF1 = \frac{2 \times Precision \times Recall}{Precision + Recall} = \frac{2 \times TP}{2 \times TP + FP + FN}F1=Precision+Recall2×Precision×Recall=2×TP+FP+FN2×TP

特点

  • F1分数同时考虑了准确率和召回率
  • 取值范围:[0, 1],值越大越好
  • 当Precision和Recall都高时,F1分数才会高

示例

  • Precision = 0.8, Recall = 0.6
  • F1=2×0.8×0.60.8+0.6=0.686F1 = \frac{2 \times 0.8 \times 0.6}{0.8 + 0.6} = 0.686F1=0.8+0.62×0.8×0.6=0.686

2. 宏平均与微平均

Macro-F1(宏平均F1)

定义:Macro-F1(宏平均F1)是对所有类别的F1分数求算术平均。它先计算每个类别的F1分数,然后对所有类别的F1分数求平均。

计算公式

Macro-F1=F1PERSON+F1LOCATION+F1TIME+F1ORGANIZATION4Macro\text{-}F1 = \frac{F1_{PERSON} + F1_{LOCATION} + F1_{TIME} + F1_{ORGANIZATION}}{4}Macro-F1=4F1PERSON+F1LOCATION+F1TIME+F1ORGANIZATION

作用与特点

  1. 平等对待所有类别

    • 每个类别的权重相等,无论该类别样本数量多少
    • 即使某个类别只有10个样本,另一个类别有1000个样本,它们在Macro-F1中的权重也是相同的
  2. 评估类别平衡性

    • 能够反映模型在各个类别上的平均表现
    • 如果某个类别表现很差,会显著拉低Macro-F1分数
    • 适合评估模型是否对所有类别都有较好的识别能力
  3. 适用场景

    • 类别不平衡的数据集:当各类别样本数量差异很大时,Macro-F1能更好地反映模型在少数类别上的表现
    • 需要关注所有类别:当业务要求所有实体类型都同等重要时(如医疗诊断中,所有疾病类型都重要)
    • 评估模型公平性:判断模型是否存在对某些类别的偏见
  4. 局限性

    • 可能被样本很少但表现很差的类别"拖累"
    • 如果某个类别样本极少且识别困难,Macro-F1可能偏低

示例

  • F1PERSON=0.90F1_{PERSON} = 0.90F1PERSON=0.90
  • F1LOCATION=0.85F1_{LOCATION} = 0.85F1LOCATION=0.85
  • F1TIME=0.95F1_{TIME} = 0.95F1TIME=0.95
  • F1ORGANIZATION=0.80F1_{ORGANIZATION} = 0.80F1ORGANIZATION=0.80
  • Macro-F1=0.90+0.85+0.95+0.804=0.875Macro\text{-}F1 = \frac{0.90 + 0.85 + 0.95 + 0.80}{4} = 0.875Macro-F1=40.90+0.85+0.95+0.80=0.875

Micro-F1(微平均F1)

定义:Micro-F1(微平均F1)是将所有类别的TP、FP、FN统计合并后,再统一计算Precision、Recall和F1分数。它把多分类问题看作一个整体的二分类问题。

计算公式

Micro-Precision=TPPERSON+TPLOCATION+TPTIME+TPORGANIZATION(TPPERSON+TPLOCATION+TPTIME+TPORGANIZATION)+(FPPERSON+FPLOCATION+FPTIME+FPORGANIZATION)Micro\text{-}Precision = \frac{TP_{PERSON} + TP_{LOCATION} + TP_{TIME} + TP_{ORGANIZATION}}{(TP_{PERSON} + TP_{LOCATION} + TP_{TIME} + TP_{ORGANIZATION}) + (FP_{PERSON} + FP_{LOCATION} + FP_{TIME} + FP_{ORGANIZATION})}Micro-Precision=(TPPERSON+TPLOCATION+TPTIME+TPORGANIZATION)+(FPPERSON+FPLOCATION+FPTIME+FPORGANIZATION)TPPERSON+TPLOCATION+TPTIME+TPORGANIZATION

Micro-Recall=TPPERSON+TPLOCATION+TPTIME+TPORGANIZATION(TPPERSON+TPLOCATION+TPTIME+TPORGANIZATION)+(FNPERSON+FNLOCATION+FNTIME+FNORGANIZATION)Micro\text{-}Recall = \frac{TP_{PERSON} + TP_{LOCATION} + TP_{TIME} + TP_{ORGANIZATION}}{(TP_{PERSON} + TP_{LOCATION} + TP_{TIME} + TP_{ORGANIZATION}) + (FN_{PERSON} + FN_{LOCATION} + FN_{TIME} + FN_{ORGANIZATION})}Micro-Recall=(TPPERSON+TPLOCATION+TPTIME+TPORGANIZATION)+(FNPERSON+FNLOCATION+FNTIME+FNORGANIZATION)TPPERSON+TPLOCATION+TPTIME+TPORGANIZATION

Micro-F1=2×Micro-Precision×Micro-RecallMicro-Precision+Micro-RecallMicro\text{-}F1 = \frac{2 \times Micro\text{-}Precision \times Micro\text{-}Recall}{Micro\text{-}Precision + Micro\text{-}Recall}Micro-F1=Micro-Precision+Micro-Recall2×Micro-Precision×Micro-Recall

简化形式(设总TP、总FP、总FN):

TPtotal=TPPERSON+TPLOCATION+TPTIME+TPORGANIZATIONTP_{total} = TP_{PERSON} + TP_{LOCATION} + TP_{TIME} + TP_{ORGANIZATION}TPtotal=TPPERSON+TPLOCATION+TPTIME+TPORGANIZATION

FPtotal=FPPERSON+FPLOCATION+FPTIME+FPORGANIZATIONFP_{total} = FP_{PERSON} + FP_{LOCATION} + FP_{TIME} + FP_{ORGANIZATION}FPtotal=FPPERSON+FPLOCATION+FPTIME+FPORGANIZATION

FNtotal=FNPERSON+FNLOCATION+FNTIME+FNORGANIZATIONFN_{total} = FN_{PERSON} + FN_{LOCATION} + FN_{TIME} + FN_{ORGANIZATION}FNtotal=FNPERSON+FNLOCATION+FNTIME+FNORGANIZATION

Micro-Precision=TPtotalTPtotal+FPtotalMicro\text{-}Precision = \frac{TP_{total}}{TP_{total} + FP_{total}}Micro-Precision=TPtotal+FPtotalTPtotal

Micro-Recall=TPtotalTPtotal+FNtotalMicro\text{-}Recall = \frac{TP_{total}}{TP_{total} + FN_{total}}Micro-Recall=TPtotal+FNtotalTPtotal

Micro-F1=2×Micro-Precision×Micro-RecallMicro-Precision+Micro-RecallMicro\text{-}F1 = \frac{2 \times Micro\text{-}Precision \times Micro\text{-}Recall}{Micro\text{-}Precision + Micro\text{-}Recall}Micro-F1=Micro-Precision+Micro-Recall2×Micro-Precision×Micro-Recall

作用与特点

  1. 反映整体性能

    • 将所有类别的预测结果合并,计算整体的准确率和召回率
    • 相当于把"识别实体"和"不识别实体"看作二分类问题
    • 能够反映模型在整个数据集上的综合表现
  2. 样本数量加权

    • 样本多的类别对Micro-F1的影响更大
    • 如果PERSON类别有1000个样本,LOCATION只有10个样本,PERSON的表现对Micro-F1的影响远大于LOCATION
    • 更接近实际应用场景(常见实体类型更重要)
  3. 适用场景

    • 类别不平衡但关注主流类别:当主要关注样本多的类别时(如新闻中,人名和地名比组织名更常见)
    • 评估整体准确率:需要知道模型在所有数据上的整体表现
    • 生产环境评估:实际应用中,常见实体类型的识别准确率更重要
  4. 局限性

    • 可能掩盖少数类别表现差的问题
    • 如果某个类别样本很少但很重要,Micro-F1可能无法反映其真实表现

示例

假设各类别的统计:

  • PERSON: TP=30, FP=5, FN=3
  • LOCATION: TP=25, FP=8, FN=4
  • TIME: TP=20, FP=3, FN=2
  • ORGANIZATION: TP=25, FP=4, FN=6

则:

  • TPtotal=30+25+20+25=100TP_{total} = 30 + 25 + 20 + 25 = 100TPtotal=30+25+20+25=100
  • FPtotal=5+8+3+4=20FP_{total} = 5 + 8 + 3 + 4 = 20FPtotal=5+8+3+4=20
  • FNtotal=3+4+2+6=15FN_{total} = 3 + 4 + 2 + 6 = 15FNtotal=3+4+2+6=15
  • Micro-Precision=100100+20=0.833Micro\text{-}Precision = \frac{100}{100 + 20} = 0.833Micro-Precision=100+20100=0.833
  • Micro-Recall=100100+15=0.870Micro\text{-}Recall = \frac{100}{100 + 15} = 0.870Micro-Recall=100+15100=0.870
  • Micro-F1=2×0.833×0.8700.833+0.870=0.851Micro\text{-}F1 = \frac{2 \times 0.833 \times 0.870}{0.833 + 0.870} = 0.851Micro-F1=0.833+0.8702×0.833×0.870=0.851

3. Macro-F1 vs Micro-F1 对比

对比维度 Macro-F1(宏平均) Micro-F1(微平均)
计算方式 先计算各类别F1,再求平均 先合并所有类别统计,再计算F1
类别权重 所有类别权重相等 样本多的类别权重更大
关注点 各类别的平均表现 整体数据集的综合表现
样本不平衡影响 不受样本数量影响,少数类别同样重要 受样本数量影响,多数类别更重要
适用场景 所有类别同等重要(如医疗诊断) 关注主流类别(如新闻NER)
优势 能发现少数类别表现差的问题 更接近实际应用的整体性能
劣势 可能被少数困难类别拖累 可能掩盖少数类别的问题

选择建议

  • 使用Macro-F1:当所有实体类型都同等重要,需要确保模型对所有类别都有良好表现时
  • 使用Micro-F1:当关注整体性能,常见实体类型更重要时
  • 同时使用:在实际项目中,建议同时报告两个指标,全面评估模型性能

4. 其他指标对比

指标 关注点 适用场景
Precision 预测的准确性 需要高准确率的场景(如医疗诊断)
Recall 识别的完整性 需要尽可能找出所有实体的场景
F1 准确率与召回率的平衡 需要综合评估的场景

5. 实体匹配规则

在本项目中,实体匹配采用严格匹配(Exact Match)方式:

  • 匹配条件:预测实体的文本和位置必须与真实实体完全一致
  • 示例
    • 真实实体:{"LOCATION": ["北京"]}
    • 预测实体:{"LOCATION": ["北京"]} → ✅ 匹配成功
    • 预测实体:{"LOCATION": ["北", "京"]} → ❌ 匹配失败(位置不对)
    • 预测实体:{"LOCATION": ["北京", "上海"]} → ❌ 匹配失败(多预测了"上海")

6. 代码实现

本项目中的评估指标计算在 evaluate.pyshow_stats() 方法中实现:

对每个类别计算 Precision、Recall、F1

Precision=正确识别数识别出实体数+10−5Precision = \frac{\text{正确识别数}}{\text{识别出实体数} + 10^{-5}}Precision=识别出实体数+10−5正确识别数

Recall=正确识别数样本实体数+10−5Recall = \frac{\text{正确识别数}}{\text{样本实体数} + 10^{-5}}Recall=样本实体数+10−5正确识别数

F1=2×Precision×RecallPrecision+Recall+10−5F1 = \frac{2 \times Precision \times Recall}{Precision + Recall + 10^{-5}}F1=Precision+Recall+10−52×Precision×Recall

Macro-F1:各类别F1的平均

Macro-F1=F1PERSON+F1LOCATION+F1TIME+F1ORGANIZATION4Macro\text{-}F1 = \frac{F1_{PERSON} + F1_{LOCATION} + F1_{TIME} + F1_{ORGANIZATION}}{4}Macro-F1=4F1PERSON+F1LOCATION+F1TIME+F1ORGANIZATION

Micro-F1:合并所有类别后计算

TPtotal=TPPERSON+TPLOCATION+TPTIME+TPORGANIZATIONTP_{total} = TP_{PERSON} + TP_{LOCATION} + TP_{TIME} + TP_{ORGANIZATION}TPtotal=TPPERSON+TPLOCATION+TPTIME+TPORGANIZATION

FPtotal=FPPERSON+FPLOCATION+FPTIME+FPORGANIZATIONFP_{total} = FP_{PERSON} + FP_{LOCATION} + FP_{TIME} + FP_{ORGANIZATION}FPtotal=FPPERSON+FPLOCATION+FPTIME+FPORGANIZATION

FNtotal=FNPERSON+FNLOCATION+FNTIME+FNORGANIZATIONFN_{total} = FN_{PERSON} + FN_{LOCATION} + FN_{TIME} + FN_{ORGANIZATION}FNtotal=FNPERSON+FNLOCATION+FNTIME+FNORGANIZATION

Micro-Precision=TPtotalTPtotal+FPtotal+10−5Micro\text{-}Precision = \frac{TP_{total}}{TP_{total} + FP_{total} + 10^{-5}}Micro-Precision=TPtotal+FPtotal+10−5TPtotal

Micro-Recall=TPtotalTPtotal+FNtotal+10−5Micro\text{-}Recall = \frac{TP_{total}}{TP_{total} + FN_{total} + 10^{-5}}Micro-Recall=TPtotal+FNtotal+10−5TPtotal

Micro-F1=2×Micro-Precision×Micro-RecallMicro-Precision+Micro-Recall+10−5Micro\text{-}F1 = \frac{2 \times Micro\text{-}Precision \times Micro\text{-}Recall}{Micro\text{-}Precision + Micro\text{-}Recall + 10^{-5}}Micro-F1=Micro-Precision+Micro-Recall+10−52×Micro-Precision×Micro-Recall

注意 :公式中的 10−510^{-5}10−5(即 1e-5)是为了防止除零错误。


📊 模型性能

模型在验证集上的表现(示例输出):

复制代码
PERSON类实体,准确率:0.923456, 召回率: 0.876543, F1: 0.899234
LOCATION类实体,准确率:0.912345, 召回率: 0.887654, F1: 0.899876
TIME类实体,准确率:0.945678, 召回率: 0.923456, F1: 0.934321
ORGANIZATION类实体,准确率:0.890123, 召回率: 0.856789, F1: 0.873123
Macro-F1: 0.901639
Micro-F1: 0.904567

🔍 技术亮点

  1. BERT预训练模型:利用大规模预训练知识,提升模型性能
  2. 双向上下文理解:BERT能够同时利用前后文信息,解决歧义问题
  3. 可选CRF层:通过标签转移约束,提升序列标注的合理性
  4. 完整的评估体系:Macro-F1和Micro-F1多维度评估
  5. 灵活的配置管理:集中管理超参数,便于调优

📝 总结

本项目实现了一个完整的中文命名实体识别系统,采用BERT作为特征提取器,结合序列标注方法,能够有效识别文本中的人名、地点、时间和组织机构等实体。通过模块化的代码设计,实现了数据加载、模型训练、评估等完整流程,为中文NER任务提供了一个可用的解决方案。

相关推荐
兔兔爱学习兔兔爱学习13 小时前
DeepSeek-OCR及其他主流OCR调研
人工智能
啊巴矲13 小时前
小白从零开始勇闯人工智能:计算机视觉初级篇(OpenCV进阶操作(上))
人工智能·opencv·计算机视觉
科学计算技术爱好者13 小时前
NVIDIA GPU 系列用途分类梳理
人工智能·算法·gpu算力
GatiArt雷13 小时前
AI自动化测试落地指南:基于LangChain+TestGPT的实操实现与效能验证
人工智能·langchain
数说星榆18113 小时前
音乐创作新生态:AI作曲与个性化音乐体验
人工智能
源创力环形导轨13 小时前
环形导轨:自动化生产线的核心传输解决方案
运维·人工智能·自动化
不会飞的鲨鱼13 小时前
腾讯录音文件语音识别 python api接口
人工智能·python·语音识别
楚来客13 小时前
AI基础概念之十三:Transformer 算法结构相比传统神经网络的改进
深度学习·神经网络·transformer
wengad13 小时前
豆包的深入研究的浅析-应用于股市投顾
人工智能