自然语言处理实战——基于BP神经网络的命名实体识别

目录

一、前置环境与全局配置功能

二、多格式数据集兼容与处理功能

三、词汇表与标签表构建功能

四、数据预处理与加载功能

[五、BP 神经网络模型构建功能](#五、BP 神经网络模型构建功能)

六、模型训练与优化功能

七、模型评估功能

八、单文本预测与实体解析功能

九、多维度可视化功能

十、整体容错与鲁棒性功能

十一、结果输出与管理功能

十二、文本数据集准备

十三、基于BP神经网络的命名实体识别的Python代码完整实现

十四、程序运行结果展示

十五、总结


一、前置环境与全局配置功能

  1. 中文字体自适应适配:解决可视化环节中文乱码问题,优先检测系统已安装的中文字体(如微软雅黑、黑体等),若未检测到则尝试加载本地字体文件,兜底保障中文正常显示;
  2. 全局统一配置管理:集中管理所有核心参数(模型维度、训练超参、数据集路径、可视化存储目录等),支持灵活调整,同时定义实体类型专属配色方案,为可视化提供统一风格;
  3. 环境容错优化:屏蔽无意义的警告信息,关键步骤添加异常提示,保证程序运行过程中不被无关报错中断。

二、多格式数据集兼容与处理功能

  1. 多格式数据集加载:原生支持 3 类主流 NER 数据集格式(标准 BIO 格式、MSRA 新闻格式、CLUENER 细粒度 JSON 格式),自动识别数据集文件格式并调用对应解析逻辑,可同时加载多个不同领域(通用、新闻、医疗、细粒度)的数据集并合并;
  2. 示例数据自动生成:当现有数据集样本数低于阈值时,自动生成通用、医疗、机构、时间等多维度的示例 NER 数据(包含文本和对应 BIO 标签),避免因数据不足导致程序中断;
  3. 数据集校验与清洗:加载过程中校验文本与标签的长度匹配性,跳过长度不匹配的异常样本,过滤无效行,保证数据质量;
  4. 智能数据集划分:根据总样本数自适应调整划分策略(样本充足时按 8:1:1 分为训练 / 验证 / 测试集;样本较少时按 9:1 划分;样本极少时全部作为训练集,验证 / 测试集复用训练集),避免出现空集导致的训练 / 评估错误。

三、词汇表与标签表构建功能

  1. 自动化词汇表构建:从所有数据集文本中提取唯一字符,添加 PAD(填充标记)、UNK(未知字符标记)后生成字级词汇表,保存到本地文件,同时生成 "字符 - 索引" 映射关系;
  2. 自动化标签表构建:整合所有数据集的标签类型,去重后添加 PAD、UNK 标记,生成标签表并保存到本地,同时生成 "标签 - 索引" 映射关系,为模型输入输出提供映射依据。

四、数据预处理与加载功能

  1. 序列标准化处理:按指定最大序列长度对文本和标签序列进行填充(不足长度时补 PAD)或截断(超过长度时切分),保证所有输入数据维度统一;
  2. 张量转换与设备适配:将处理后的文本 / 标签索引转换为 PyTorch 张量,自动检测 CUDA 是否可用,将张量部署到 GPU/CPU;
  3. 自适应数据加载器:根据训练 / 验证 / 测试集的样本数自适应调整批次大小(避免批次大小超过样本数),训练集加载时随机打乱数据,验证 / 测试集有序加载,保证训练稳定性和评估准确性。

五、BP 神经网络模型构建功能

  1. 模型架构设计:构建适配 NER 任务的 BP 神经网络,包含嵌入层(将字符索引转为向量)、多层全连接层(提取特征)、ReLU 激活函数、Dropout 层(防止过拟合)、输出层(通过 Softmax 输出标签概率),最终将扁平特征映射回序列维度,适配序列标注任务;
  2. 模型设备适配:自动将模型部署到 GPU/CPU,适配不同硬件环境;
  3. 模型参数统计:输出模型各层结构、参数数量等信息,直观展示模型复杂度。

六、模型训练与优化功能

  1. 可视化训练过程:训练时展示进度条,实时显示当前批次损失值,提升训练过程的可监控性;
  2. 掩码损失计算:过滤填充标签(PAD)对应的损失计算,仅计算有效标签的损失,避免无效数据干扰训练;
  3. 轮次验证机制:每轮训练结束后,在验证集上计算损失值和 Macro-F1 分数,评估模型泛化能力;
  4. 最优模型保存:根据验证集 Macro-F1 分数自动保存最优模型权重,覆盖次优模型,便于后续评估和预测;
  5. 自适应训练轮数:当样本数较少时自动减少训练轮数,避免模型过拟合。

七、模型评估功能

  1. 测试集全面评估:加载最优模型,在测试集上计算各标签的精准率、召回率、F1 分数,输出详细的分类报告,直观展示模型在不同实体类型上的识别效果;
  2. 混淆矩阵可视化:生成归一化的混淆矩阵热力图,清晰展示 "真实标签 - 预测标签" 的匹配关系,定位模型识别错误集中的标签类型;
  3. 空数据容错处理:当测试集无有效数据时,跳过评估并给出提示,避免程序崩溃。

八、单文本预测与实体解析功能

  1. 单文本实时预测:支持任意中文文本输入,输出该文本对应的标签序列,以及基于 B-I 标签规则提取的实体(包含实体内容和类型);
  2. 多领域示例验证:内置教育、地点、机构、医疗 4 类典型领域的文本示例,自动完成预测并输出结果,验证模型在不同领域的识别能力;
  3. 实体拼接逻辑:自动识别 B(实体开始)、I(实体内部)标签,拼接连续的同类型实体,过滤 O(非实体)标签,准确提取完整实体。

九、多维度可视化功能

  1. 数据集统计可视化:生成样本长度分布直方图(标注均值和最大序列长度)、标签分布条形图(标注各标签数量),输出样本数、平均长度、标签类型数等详细统计信息;
  2. 训练过程可视化:生成训练 / 验证损失变化曲线、验证集 Macro-F1 变化曲线,标注最优 F1 分数点,直观展示模型训练趋势;
  3. 实体标注可视化:为识别出的实体添加对应颜色背景,生成带颜色标注的文本可视化图,附带实体类型图例,直观展示实体在文本中的位置和类型;
  4. 模型结构可视化:生成模型计算图,清晰展示网络层之间的连接关系和参数流向;
  5. 可视化结果管理:所有可视化图表自动保存到指定目录,命名规范,便于后续查阅和分析。

十、整体容错与鲁棒性功能

  1. 路径容错:数据集文件不存在时自动跳过并提示,不影响其他数据集加载;
  2. 计算容错:损失计算、F1 计算、混淆矩阵绘制等环节添加空值过滤,避免除零、空列表等错误;
  3. 硬件容错:自动检测 GPU 是否可用,无 GPU 时自动切换到 CPU 运行,适配不同硬件环境;
  4. 参数自适应:批次大小、训练轮数等参数根据样本数自动调整,无需手动干预。

十一、结果输出与管理功能

  1. 过程日志输出:关键步骤(数据加载、模型训练、评估)输出详细日志,包含样本数、损失值、F1 分数等核心指标,便于监控程序运行状态;
  2. 文件保存管理:词汇表、标签表、最优模型权重、所有可视化图表均自动保存到指定目录,目录结构清晰,便于后续复用和分析;
  3. 预测结果输出:单文本预测结果包含输入文本、标签序列、识别实体,格式清晰,便于理解和使用。

十二、文本数据集准备

可以使用以下Python代码获得数据集,运行成功后,文本数据集也可以替换。

python 复制代码
import random
import json

# ===================== 候选实体库(用于随机生成) =====================
# 实体类型对应的候选词
ENTITY_CANDS = {
    "PER": ["张三", "李四", "王五", "赵六", "小明"],
    "LOC": ["北京", "上海", "广州", "杭州", "深圳市南山区", "北京市朝阳区"],
    "ORG": ["清华大学", "北京大学", "腾讯公司", "阿里巴巴集团", "北京协和医院"],
    "DATE": ["2025年3月", "2024年5月", "2025年4月15日", "2023年12月", "2025年"],
    "DISEASE": ["肺炎", "高血压", "糖尿病", "肝癌", "肺癌"],
    "ORGAN": ["心脏", "胸部", "肺部", "胃部"]
}

# 各文件的生成模板(保证文本与标签长度匹配)
TEMPLATES = {
    # ner_data.txt(BIO格式)的句子模板
    "bio": [
        ("{per}在{loc}工作", ["PER", "LOC"]),
        ("{per}是{org}的学生", ["PER", "ORG"]),
        ("{per}{date}毕业于{org}", ["PER", "DATE", "ORG"]),
        ("{loc}是著名的旅游城市", ["LOC"]),
        ("{org}总部位于{loc}", ["ORG", "LOC"])
    ],
    # medical_ner.txt(医疗BIO格式)的句子模板
    "medical_bio": [
        ("患者因{disease}入院治疗", ["DISEASE"]),
        ("{disease}患者需定期检查{organ}", ["DISEASE", "ORGAN"]),
        ("2025年4月做了{organ}手术", ["DATE", "ORGAN"]),
        ("{disease}晚期需化疗", ["DISEASE"])
    ]
}


# ===================== 工具函数:生成BIO标签 =====================
def generate_bio_labels(text, entities):
    """
    根据文本和实体信息生成BIO标签序列(保证长度与文本一致)
    :param text: 目标文本(字符串)
    :param entities: 实体列表,每个元素是(实体类型, 实体内容, 起始索引)
    :return: 标签列表(长度=文本长度)
    """
    labels = ["O"] * len(text)
    for ent_type, ent_content, start_idx in entities:
        end_idx = start_idx + len(ent_content)
        # 标记B-XXX
        labels[start_idx] = f"B-{ent_type}"
        # 标记I-XXX
        for i in range(start_idx + 1, end_idx):
            labels[i] = f"I-{ent_type}"
    return labels


# ===================== 1. 生成 ner_data.txt(BIO格式) =====================
def generate_ner_data():
    lines = []
    for _ in range(15):  # 生成15条样本
        # 随机选模板
        template, ent_types = random.choice(TEMPLATES["bio"])
        # 随机选实体并替换模板
        ent_map = {}
        entities = []
        current_text = template
        for ent_type in ent_types:
            ent_content = random.choice(ENTITY_CANDS[ent_type])
            ent_map[ent_type.lower()] = ent_content
            # 替换模板中的占位符
            current_text = current_text.replace(f"{{{ent_type.lower()}}}", ent_content)
        # 计算实体在文本中的位置
        temp_text = template
        for ent_type in ent_types:
            ent_content = ent_map[ent_type.lower()]
            # 找到实体在当前文本中的起始索引
            start_idx = temp_text.find(f"{{{ent_type.lower()}}}")
            # 替换占位符,更新temp_text以便后续计算
            temp_text = temp_text.replace(f"{{{ent_type.lower()}}}", ent_content)
            # 记录实体信息
            entities.append((ent_type, ent_content, start_idx))
        # 生成BIO标签
        labels = generate_bio_labels(current_text, entities)
        # 拼接成BIO格式行(文本\t标签序列)
        line = f"{current_text}\t{' '.join(labels)}"
        lines.append(line)
    # 写入文件
    with open("ner_data.txt", "w", encoding="utf-8") as f:
        f.write("\n".join(lines))
    print("已生成 ner_data.txt")


# ===================== 2. 生成 msra_ner.txt(MSRA格式:字\t标签,空行分隔句子) =====================
def generate_msra_ner():
    lines = []
    # 先从ner_data.txt中读取样本(或直接生成)
    # 这里复用ner_data的样本,拆分成字+标签
    with open("ner_data.txt", "r", encoding="utf-8") as f:
        bio_lines = [line.strip() for line in f if line.strip()]
    for bio_line in bio_lines[:10]:  # 取前10个样本
        text, label_str = bio_line.split("\t")
        labels = label_str.split()
        # 每个字对应一行(字\t标签)
        for char, label in zip(text, labels):
            lines.append(f"{char}\t{label}")
        # 空行分隔句子
        lines.append("")
    # 写入文件
    with open("msra_ner.txt", "w", encoding="utf-8") as f:
        f.write("\n".join(lines).strip())  # 去掉末尾空行
    print("已生成 msra_ner.txt")


# ===================== 3. 生成 cluener_ner.txt(JSON行格式) =====================
def generate_cluener_ner():
    lines = []
    # 复用ner_data的样本,转换为CLUENER格式
    with open("ner_data.txt", "r", encoding="utf-8") as f:
        bio_lines = [line.strip() for line in f if line.strip()]
    for bio_line in bio_lines[:8]:  # 取前8个样本
        text, label_str = bio_line.split("\t")
        labels = label_str.split()
        # 提取实体信息(转换为CLUENER的label格式)
        cluener_label = {}
        current_ent = None
        current_type = None
        start_idx = -1
        for idx, (char, label) in enumerate(zip(text, labels)):
            if label.startswith("B-"):
                # 结束上一个实体
                if current_ent is not None:
                    ent_type = current_type
                    ent_name = current_ent
                    if ent_type not in cluener_label:
                        cluener_label[ent_type] = {}
                    cluener_label[ent_type][ent_name] = {"start_idx": start_idx, "end_idx": idx}
                # 开始新实体
                current_type = label.split("-")[1]
                current_ent = char
                start_idx = idx
            elif label.startswith("I-") and current_ent is not None:
                # 拼接实体
                current_ent += char
            else:
                # 结束当前实体
                if current_ent is not None:
                    ent_type = current_type
                    ent_name = current_ent
                    if ent_type not in cluener_label:
                        cluener_label[ent_type] = {}
                    cluener_label[ent_type][ent_name] = {"start_idx": start_idx, "end_idx": idx}
                    current_ent = None
                    current_type = None
        # 处理最后一个实体
        if current_ent is not None:
            ent_type = current_type
            ent_name = current_ent
            if ent_type not in cluener_label:
                cluener_label[ent_type] = {}
            cluener_label[ent_type][ent_name] = {"start_idx": start_idx, "end_idx": len(text)}
        # 构造JSON对象
        cluener_json = json.dumps({
            "text": text,
            "label": cluener_label
        }, ensure_ascii=False)
        lines.append(cluener_json)
    # 写入文件
    with open("cluener_ner.txt", "w", encoding="utf-8") as f:
        f.write("\n".join(lines))
    print("已生成 cluener_ner.txt")


# ===================== 4. 生成 medical_ner.txt(医疗BIO格式) =====================
def generate_medical_ner():
    lines = []
    for _ in range(10):  # 生成10条医疗样本
        # 随机选医疗模板
        template, ent_types = random.choice(TEMPLATES["medical_bio"])
        # 随机选实体并替换模板
        ent_map = {}
        entities = []
        current_text = template
        for ent_type in ent_types:
            ent_content = random.choice(ENTITY_CANDS[ent_type])
            ent_map[ent_type.lower()] = ent_content
            # 替换模板中的占位符
            current_text = current_text.replace(f"{{{ent_type.lower()}}}", ent_content)
        # 计算实体在文本中的位置
        temp_text = template
        for ent_type in ent_types:
            ent_content = ent_map[ent_type.lower()]
            # 找到实体在当前文本中的起始索引
            start_idx = temp_text.find(f"{{{ent_type.lower()}}}")
            # 替换占位符,更新temp_text以便后续计算
            temp_text = temp_text.replace(f"{{{ent_type.lower()}}}", ent_content)
            # 记录实体信息
            entities.append((ent_type, ent_content, start_idx))
        # 生成BIO标签
        labels = generate_bio_labels(current_text, entities)
        # 拼接成BIO格式行
        line = f"{current_text}\t{' '.join(labels)}"
        lines.append(line)
    # 写入文件
    with open("medical_ner.txt", "w", encoding="utf-8") as f:
        f.write("\n".join(lines))
    print("已生成 medical_ner.txt")


# ===================== 执行生成 =====================
if __name__ == "__main__":
    generate_ner_data()
    generate_msra_ner()
    generate_cluener_ner()
    generate_medical_ner()
    print("所有数据集文件生成完成!")

程序运行结果展示:

ner_data.txt

bash 复制代码
北京是著名的旅游城市	B-LOC I-LOC O O O O O O O O
赵六在北京市朝阳区工作	B-PER I-PER O B-LOC I-LOC I-LOC I-LOC I-LOC I-LOC O O
广州是著名的旅游城市	B-LOC I-LOC O O O O O O O O
杭州是著名的旅游城市	B-LOC I-LOC O O O O O O O O
王五是阿里巴巴集团的学生	B-PER I-PER O B-ORG I-ORG I-ORG I-ORG I-ORG I-ORG O O O
张三是清华大学的学生	B-PER I-PER O B-ORG I-ORG I-ORG I-ORG O O O
北京是著名的旅游城市	B-LOC I-LOC O O O O O O O O
张三2024年5月毕业于北京大学	B-PER I-PER B-DATE I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE O O O B-ORG I-ORG I-ORG I-ORG
清华大学总部位于深圳市南山区	B-ORG I-ORG I-ORG I-ORG O O O O B-LOC I-LOC I-LOC I-LOC I-LOC I-LOC
张三2023年12月毕业于清华大学	B-PER I-PER B-DATE I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE O O O B-ORG I-ORG I-ORG I-ORG
杭州是著名的旅游城市	B-LOC I-LOC O O O O O O O O
张三在深圳市南山区工作	B-PER I-PER O B-LOC I-LOC I-LOC I-LOC I-LOC I-LOC O O
北京是著名的旅游城市	B-LOC I-LOC O O O O O O O O
王五2025年4月15日毕业于北京协和医院	B-PER I-PER B-DATE I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE O O O B-ORG I-ORG I-ORG I-ORG I-ORG I-ORG
张三在广州工作	B-PER I-PER O B-LOC I-LOC O O

msra_ner.txt

bash 复制代码
北	B-LOC
京	I-LOC
是	O
著	O
名	O
的	O
旅	O
游	O
城	O
市	O

赵	B-PER
六	I-PER
在	O
北	B-LOC
京	I-LOC
市	I-LOC
朝	I-LOC
阳	I-LOC
区	I-LOC
工	O
作	O

广	B-LOC
州	I-LOC
是	O
著	O
名	O
的	O
旅	O
游	O
城	O
市	O

杭	B-LOC
州	I-LOC
是	O
著	O
名	O
的	O
旅	O
游	O
城	O
市	O

王	B-PER
五	I-PER
是	O
阿	B-ORG
里	I-ORG
巴	I-ORG
巴	I-ORG
集	I-ORG
团	I-ORG
的	O
学	O
生	O

张	B-PER
三	I-PER
是	O
清	B-ORG
华	I-ORG
大	I-ORG
学	I-ORG
的	O
学	O
生	O

北	B-LOC
京	I-LOC
是	O
著	O
名	O
的	O
旅	O
游	O
城	O
市	O

张	B-PER
三	I-PER
2	B-DATE
0	I-DATE
2	I-DATE
4	I-DATE
年	I-DATE
5	I-DATE
月	I-DATE
毕	O
业	O
于	O
北	B-ORG
京	I-ORG
大	I-ORG
学	I-ORG

清	B-ORG
华	I-ORG
大	I-ORG
学	I-ORG
总	O
部	O
位	O
于	O
深	B-LOC
圳	I-LOC
市	I-LOC
南	I-LOC
山	I-LOC
区	I-LOC

张	B-PER
三	I-PER
2	B-DATE
0	I-DATE
2	I-DATE
3	I-DATE
年	I-DATE
1	I-DATE
2	I-DATE
月	I-DATE
毕	O
业	O
于	O
清	B-ORG
华	I-ORG
大	I-ORG
学	I-ORG北	B-LOC
京	I-LOC
是	O
著	O
名	O
的	O
旅	O
游	O
城	O
市	O

赵	B-PER
六	I-PER
在	O
北	B-LOC
京	I-LOC
市	I-LOC
朝	I-LOC
阳	I-LOC
区	I-LOC
工	O
作	O

广	B-LOC
州	I-LOC
是	O
著	O
名	O
的	O
旅	O
游	O
城	O
市	O

杭	B-LOC
州	I-LOC
是	O
著	O
名	O
的	O
旅	O
游	O
城	O
市	O

王	B-PER
五	I-PER
是	O
阿	B-ORG
里	I-ORG
巴	I-ORG
巴	I-ORG
集	I-ORG
团	I-ORG
的	O
学	O
生	O

张	B-PER
三	I-PER
是	O
清	B-ORG
华	I-ORG
大	I-ORG
学	I-ORG
的	O
学	O
生	O

北	B-LOC
京	I-LOC
是	O
著	O
名	O
的	O
旅	O
游	O
城	O
市	O

张	B-PER
三	I-PER
2	B-DATE
0	I-DATE
2	I-DATE
4	I-DATE
年	I-DATE
5	I-DATE
月	I-DATE
毕	O
业	O
于	O
北	B-ORG
京	I-ORG
大	I-ORG
学	I-ORG

清	B-ORG
华	I-ORG
大	I-ORG
学	I-ORG
总	O
部	O
位	O
于	O
深	B-LOC
圳	I-LOC
市	I-LOC
南	I-LOC
山	I-LOC
区	I-LOC

张	B-PER
三	I-PER
2	B-DATE
0	I-DATE
2	I-DATE
3	I-DATE
年	I-DATE
1	I-DATE
2	I-DATE
月	I-DATE
毕	O
业	O
于	O
清	B-ORG
华	I-ORG
大	I-ORG
学	I-ORG

cluener_ner.txt

bash 复制代码
{"text": "北京是著名的旅游城市", "label": {"LOC": {"北京": {"start_idx": 0, "end_idx": 2}}}}
{"text": "赵六在北京市朝阳区工作", "label": {"PER": {"赵六": {"start_idx": 0, "end_idx": 2}}, "LOC": {"北京市朝阳区": {"start_idx": 3, "end_idx": 9}}}}
{"text": "广州是著名的旅游城市", "label": {"LOC": {"广州": {"start_idx": 0, "end_idx": 2}}}}
{"text": "杭州是著名的旅游城市", "label": {"LOC": {"杭州": {"start_idx": 0, "end_idx": 2}}}}
{"text": "王五是阿里巴巴集团的学生", "label": {"PER": {"王五": {"start_idx": 0, "end_idx": 2}}, "ORG": {"阿里巴巴集团": {"start_idx": 3, "end_idx": 9}}}}
{"text": "张三是清华大学的学生", "label": {"PER": {"张三": {"start_idx": 0, "end_idx": 2}}, "ORG": {"清华大学": {"start_idx": 3, "end_idx": 7}}}}
{"text": "北京是著名的旅游城市", "label": {"LOC": {"北京": {"start_idx": 0, "end_idx": 2}}}}
{"text": "张三2024年5月毕业于北京大学", "label": {"PER": {"张三": {"start_idx": 0, "end_idx": 2}}, "DATE": {"2024年5月": {"start_idx": 2, "end_idx": 9}}, "ORG": {"北京大学": {"start_idx": 12, "end_idx": 16}}}}

medical_ner.txt

bash 复制代码
患者因高血压入院治疗	O O O B-DISEASE I-DISEASE I-DISEASE O O O O
2025年4月做了胃部手术	I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE O O B-ORGAN I-ORGAN O B-DATE
患者因糖尿病入院治疗	O O O B-DISEASE I-DISEASE I-DISEASE O O O O
高血压晚期需化疗	B-DISEASE I-DISEASE I-DISEASE O O O O O
2025年4月做了心脏手术	I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE O O B-ORGAN I-ORGAN O B-DATE
肝癌患者需定期检查胸部	B-DISEASE I-DISEASE O O O O O O O B-ORGAN I-ORGAN
肝癌患者需定期检查胃部	B-DISEASE I-DISEASE O O O O O O O B-ORGAN I-ORGAN
糖尿病晚期需化疗	B-DISEASE I-DISEASE I-DISEASE O O O O O
2025年4月做了胸部手术	I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE O O O B-ORGAN I-ORGAN O B-DATE
2025年4月做了胃部手术	I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE I-DATE B-ORGAN I-ORGAN O B-DATE

十三、基于BP神经网络的命名实体识别的Python代码完整实现

python 复制代码
import os
import jieba
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score, confusion_matrix
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import seaborn as sns
from torchsummary import summary
from torchviz import make_dot
import pandas as pd
from tqdm import tqdm
import warnings
import json  # 用于生成CLUENER示例数据

warnings.filterwarnings("ignore")

# ===================== 0. 全局配置与常量 =====================
def get_available_chinese_font():
    """检测系统可用中文字体,返回第一个可用字体"""
    # 常见中文字体名列表(按优先级排序)
    font_candidates = [
        "SimHei", "Microsoft YaHei", "WenQuanYi Micro Hei",
        "Heiti TC", "SimSun", "FangSong", "KaiTi"
    ]
    # 获取系统已安装字体
    installed_fonts = [f.name for f in fm.fontManager.ttflist]

    # 查找可用字体
    for font_name in font_candidates:
        if font_name in installed_fonts:
            print(f"检测到可用中文字体:{font_name}")
            return font_name
    # 未检测到则手动加载(需将SimHei.ttf放到项目根目录)
    print("未检测到系统中文字体,尝试加载本地字体文件...")
    try:
        # 下载SimHei.ttf放到项目根目录,地址:https://www.fonts.net.cn/font-1435.html
        font_path = "SimHei.ttf"
        if os.path.exists(font_path):
            font_prop = fm.FontProperties(fname=font_path)
            fm.fontManager.addfont(font_path)
            print("本地字体加载成功!")
            return font_prop
    except:
        pass
    print("警告:无可用中文字体,中文可能显示异常!")
    return None


# 获取可用中文字体
CHINESE_FONT = get_available_chinese_font()

# ------------ 第二步:全局绘图配置 ------------
plt.rcParams["axes.unicode_minus"] = False  # 解决负号显示
plt.rcParams["figure.figsize"] = (12, 8)
sns.set_style("whitegrid")
# 全局字体配置(兜底)
if isinstance(CHINESE_FONT, str):
    plt.rcParams["font.family"] = [CHINESE_FONT]
elif CHINESE_FONT is not None:
    plt.rcParams["font.family"] = [CHINESE_FONT.get_name()]

# 实体类型配色
COLORS = {
    "LOC": "#FF9999", "ORG": "#99CCFF", "PER": "#99FF99",
    "TIME": "#FFFF99", "DATE": "#FFD700", "GPE": "#FFB6C1",
    "O": "#FFFFFF"
}


# ===================== 1. 配置参数 =====================
class Config:
    # 多数据集支持(可添加多个数据集路径)
    data_paths = [
        "ner_data.txt",  # 基础数据集
        "msra_ner.txt",  # MSRA新闻数据集
        "cluener_ner.txt",  # CLUENER细粒度数据集
        "medical_ner.txt"  # 医疗领域数据集
    ]
    # 模型参数
    embedding_dim = 128
    hidden_dim = 256
    output_dim = None
    # 训练参数
    batch_size = 32
    epochs = 20
    lr = 0.001
    max_seq_len = 30
    # 其他
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vocab_path = "vocab.txt"
    label_path = "label.txt"
    vis_save_dir = "ner_visualizations"
    # 数据集转换配置
    dataset_formats = {
        "msra_ner.txt": "MSRA",  # MSRA格式:每行字+标签,空行分隔句子
        "cluener_ner.txt": "CLUENER",  # CLUENER格式:JSON行
        "medical_ner.txt": "BIO"  # 标准BIO格式:文本\t标签序列
    }
    # 数据集容错配置
    min_samples = 10  # 最小样本数,不足时自动生成示例数据
    test_size = 0.2  # 默认测试集比例
    val_size = 0.1  # 默认验证集比例(相对于总样本)


config = Config()
os.makedirs(config.vis_save_dir, exist_ok=True)


# ===================== 示例数据集生成函数 =====================
def generate_example_datasets():
    """自动生成示例数据集(避免样本数不足)"""
    print("警告:未加载到足够的数据集,自动生成示例数据集...")

    # 1. 生成标准BIO格式文件(ner_data.txt)
    bio_data = [
        "小明在北京上班\tO O B-LOC I-LOC O",
        "张三是清华大学的学生\tO O B-ORG I-ORG I-ORG I-ORG O O",
        "李四2025年毕业于北京大学\tO O O O O O B-ORG I-ORG I-ORG",
        "王五在上海市浦东新区工作\tO O B-LOC I-LOC I-LOC I-LOC O",
        "赵六2024年加入腾讯公司\tO O O O O O B-ORG I-ORG I-ORG",
        "北京天安门是著名景点\tB-LOC I-LOC I-LOC O O O O",
        "复旦大学位于上海市\tB-ORG I-ORG I-ORG O B-LOC I-LOC",
        "2025年3月15日是植树节\tB-DATE I-DATE I-DATE I-DATE I-DATE O O O",
        "阿里巴巴总部在杭州\tB-ORG I-ORG I-ORG O O B-LOC",
        "患者因高血压于2025年入院\tO O B-DISEASE I-DISEASE O B-DATE I-DATE O"
    ]
    with open("ner_data.txt", "w", encoding="utf-8") as f:
        f.write("\n".join(bio_data))

    # 2. 生成MSRA格式文件(msra_ner.txt)
    msra_data = [
        "小\tO", "明\tO", "在\tO", "北\tB-LOC", "京\tI-LOC", "上\tO", "班\tO", "",
        "张\tO", "三\tO", "是\tO", "清\tB-ORG", "华\tI-ORG", "大\tI-ORG", "学\tI-ORG", "的\tO", "学\tO", "生\tO", "",
        "李\tO", "四\tO", "2\tO", "0\tO", "2\tO", "5\tO", "年\tO", "毕\tO", "业\tO", "于\tO", "北\tB-ORG", "京\tI-ORG",
        "大\tI-ORG", "学\tI-ORG", ""
    ]
    with open("msra_ner.txt", "w", encoding="utf-8") as f:
        f.write("\n".join(msra_data))

    # 3. 生成CLUENER格式文件(cluener_ner.txt)
    cluener_data = [
        json.dumps({"text": "小明在北京上班", "label": {"LOC": {"北京": {"start_idx": 2, "end_idx": 4}}}},
                   ensure_ascii=False),
        json.dumps({"text": "张三是清华大学的学生", "label": {"ORG": {"清华大学": {"start_idx": 2, "end_idx": 6}}}},
                   ensure_ascii=False),
        json.dumps({"text": "李四2025年毕业于北京大学", "label": {"ORG": {"北京大学": {"start_idx": 7, "end_idx": 10}},
                                                                  "DATE": {"2025年": {"start_idx": 2, "end_idx": 6}}}},
                   ensure_ascii=False)
    ]
    with open("cluener_ner.txt", "w", encoding="utf-8") as f:
        f.write("\n".join(cluener_data))

    # 4. 生成医疗领域数据集(medical_ner.txt)
    medical_data = [
        "患者因肺炎入院\tO O B-DISEASE I-DISEASE O",
        "高血压患者需长期服药\tB-DISEASE I-DISEASE O O O O O O",
        "2025年4月做了心脏手术\tB-DATE I-DATE I-DATE I-DATE O O B-ORGAN I-ORGAN O O"
    ]
    with open("medical_ner.txt", "w", encoding="utf-8") as f:
        f.write("\n".join(medical_data))

    print("示例数据集已生成完成!")


# ===================== 2. 多数据集处理核心模块 =====================
class NERDataset(Dataset):
    """NER数据集类(支持单数据集和混合数据集)"""

    def __init__(self, texts, labels, word2idx, label2idx, max_seq_len):
        self.texts = texts
        self.labels = labels
        self.word2idx = word2idx
        self.label2idx = label2idx
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        # 字转索引
        text_idx = [self.word2idx.get(word, self.word2idx["<UNK>"]) for word in text]
        # 标签转索引
        label_idx = [self.label2idx.get(lab, self.label2idx["<UNK>"]) for lab in label]

        # 序列填充/截断
        if len(text_idx) < self.max_seq_len:
            pad_idx = self.word2idx["<PAD>"]
            text_idx += [pad_idx] * (self.max_seq_len - len(text_idx))
            label_idx += [self.label2idx["<PAD>"]] * (self.max_seq_len - len(label_idx))
        else:
            text_idx = text_idx[:self.max_seq_len]
            label_idx = label_idx[:self.max_seq_len]

        # 转tensor
        text_tensor = torch.LongTensor(text_idx).to(config.device)
        label_tensor = torch.LongTensor(label_idx).to(config.device)
        return text_tensor, label_tensor


def load_multiple_datasets(data_paths, dataset_formats):
    """加载并合并多个数据集(自动处理不同格式)"""
    all_texts = []
    all_labels = []
    format_handlers = {
        "BIO": load_bio_format,
        "MSRA": load_msra_format,
        "CLUENER": load_cluener_format
    }

    for data_path in data_paths:
        if not os.path.exists(data_path):
            print(f"警告:数据集文件 {data_path} 不存在,跳过")
            continue

        # 获取数据集格式,默认BIO
        data_format = dataset_formats.get(os.path.basename(data_path), "BIO")
        handler = format_handlers.get(data_format, load_bio_format)
        print(f"正在加载 {data_format} 格式数据集: {data_path}")

        # 调用对应格式的加载函数
        texts, labels = handler(data_path)
        all_texts.extend(texts)
        all_labels.extend(labels)
        print(f"成功加载 {len(texts)} 个样本")

    # 检查总样本数,不足则生成示例数据
    if len(all_texts) < config.min_samples:
        generate_example_datasets()
        # 重新加载生成的示例数据
        all_texts = []
        all_labels = []
        for data_path in data_paths:
            if os.path.exists(data_path):
                data_format = dataset_formats.get(os.path.basename(data_path), "BIO")
                handler = format_handlers.get(data_format, load_bio_format)
                texts, labels = handler(data_path)
                all_texts.extend(texts)
                all_labels.extend(labels)

    print(f"\n所有数据集加载完成,总样本数: {len(all_texts)}")
    return all_texts, all_labels


def load_bio_format(data_path):
    """加载标准BIO格式数据集(文本\t标签序列)"""
    texts = []
    labels = []
    with open(data_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split("\t")
            if len(parts) != 2:
                continue
            text_str, label_str = parts[0], parts[1]
            text = list(text_str)
            label = label_str.split()
            if len(text) == len(label):
                texts.append(text)
                labels.append(label)
            else:
                print(f"警告:样本文本与标签长度不匹配,跳过: {text_str[:30]}...")
    return texts, labels


def load_msra_format(data_path):
    """加载MSRA格式数据集(每行字+标签,空行分隔句子)"""
    texts = []
    labels = []
    current_text = []
    current_labels = []

    with open(data_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:  # 空行表示句子结束
                if current_text:
                    texts.append(current_text)
                    labels.append(current_labels)
                    current_text = []
                    current_labels = []
                continue
            # MSRA格式:字\t标签
            parts = line.split("\t")
            if len(parts) != 2:
                continue
            char, label = parts[0], parts[1]
            current_text.append(char)
            current_labels.append(label)

    # 添加最后一个句子
    if current_text:
        texts.append(current_text)
        labels.append(current_labels)
    return texts, labels


def load_cluener_format(data_path):
    """加载CLUENER格式数据集(JSON行)"""
    import json
    texts = []
    labels = []

    with open(data_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                data = json.loads(line)
                text = list(data["text"])
                labels_dict = data["label"]
                # 初始化标签为O
                label = ["O"] * len(text)

                # 填充实体标签
                for entity_type, entities in labels_dict.items():
                    for entity_info in entities.values():
                        start_idx = entity_info["start_idx"]
                        end_idx = entity_info["end_idx"]
                        # 设置B标签
                        label[start_idx] = f"B-{entity_type}"
                        # 设置I标签
                        for i in range(start_idx + 1, end_idx):
                            label[i] = f"I-{entity_type}"

                texts.append(text)
                labels.append(label)
            except Exception as e:
                print(f"解析CLUENER数据出错: {e}")
                continue
    return texts, labels


def build_vocab(texts, labels, vocab_path, label_path):
    """构建词汇表和标签表(支持多数据集)"""
    # 构建字表
    word_set = set()
    for text in texts:
        word_set.update(text)
    word_list = ["<PAD>", "<UNK>"] + list(word_set)
    word2idx = {word: idx for idx, word in enumerate(word_list)}
    with open(vocab_path, "w", encoding="utf-8") as f:
        f.write("\n".join(word_list))

    # 构建标签表
    label_set = set()
    for label in labels:
        label_set.update(label)
    label_list = ["<PAD>", "<UNK>"] + list(label_set)
    label2idx = {label: idx for idx, label in enumerate(label_list)}
    with open(label_path, "w", encoding="utf-8") as f:
        f.write("\n".join(label_list))

    print(f"词汇表大小: {len(word_list)} | 标签数: {len(label_list)}")
    return word2idx, label2idx, len(word_list), len(label_list)


def visualize_dataset_stats(texts, labels, save_path):
    """可视化数据集统计信息(样本长度分布、标签分布)"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

    # 1. 样本长度分布
    lengths = [len(text) for text in texts]
    ax1.hist(lengths, bins=min(30, len(set(lengths))), color="#4CAF50", alpha=0.7)
    ax1.set_xlabel("样本长度(字数)", fontsize=12)
    ax1.set_ylabel("样本数量", fontsize=12)
    ax1.set_title("样本长度分布", fontsize=14, fontweight="bold")
    ax1.axvline(np.mean(lengths), color="red", linestyle="--",
                label=f"均值: {np.mean(lengths):.1f}")
    ax1.axvline(config.max_seq_len, color="orange", linestyle="--",
                label=f"最大序列长度: {config.max_seq_len}")
    ax1.legend(fontsize=10)

    # 2. 标签分布
    all_labels_flat = [label for sublist in labels for label in sublist]
    label_counts = pd.Series(all_labels_flat).value_counts()
    # 过滤PAD和UNK(如果存在)
    label_counts = label_counts[~label_counts.index.isin(["<PAD>", "<UNK>"])]

    sns.barplot(x=label_counts.values, y=label_counts.index, ax=ax2,
                palette="viridis")
    ax2.set_xlabel("标签数量", fontsize=12)
    ax2.set_ylabel("标签类型", fontsize=12)
    ax2.set_title("标签分布统计", fontsize=14, fontweight="bold")

    # 添加数值标签
    for i, v in enumerate(label_counts.values):
        ax2.text(v + 1, i, str(v), va="center", fontsize=9)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"数据集统计可视化已保存至: {save_path}")

    # 打印详细统计
    print("\n===== 数据集统计信息 =====")
    print(f"总样本数: {len(texts)}")
    print(f"平均长度: {np.mean(lengths):.1f} 字")
    print(f"最短长度: {np.min(lengths)} 字")
    print(f"最长长度: {np.max(lengths)} 字")
    print(f"标签类型数: {len(label_counts)}")
    print("前10个标签分布:")
    for label, count in label_counts.head(10).items():
        print(f"  {label}: {count}")


# ===================== 3. BP神经网络模型 =====================
class BPNER(nn.Module):
    """基于BP的命名实体识别模型"""

    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, max_seq_len):
        super(BPNER, self).__init__()
        self.max_seq_len = max_seq_len
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_dim,
            padding_idx=0
        )
        self.fc1 = nn.Linear(embedding_dim * max_seq_len, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, output_dim * max_seq_len)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        embed = self.embedding(x)
        embed_flat = embed.view(embed.size(0), -1)
        h1 = self.dropout(self.relu(self.fc1(embed_flat)))
        h2 = self.dropout(self.relu(self.fc2(h1)))
        out = self.fc3(h2)
        out = out.view(out.size(0), self.max_seq_len, -1)
        out = self.softmax(out)
        return out


# ===================== 4. 可视化核心函数 =====================
def plot_train_curve(train_losses, val_losses, val_f1_scores, save_path):
    """绘制训练/验证曲线"""
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)

    ax1.plot(range(1, len(train_losses) + 1), train_losses, "b-o", label="训练损失")
    ax1.plot(range(1, len(val_losses) + 1), val_losses, "r-s", label="验证损失")
    ax1.set_ylabel("损失值", fontsize=12)
    ax1.set_title("训练/验证损失变化曲线", fontsize=14, fontweight="bold")
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    ax2.plot(range(1, len(val_f1_scores) + 1), val_f1_scores, "g-^", label="验证Macro-F1")
    ax2.set_xlabel("训练轮数 (Epoch)", fontsize=12)
    ax2.set_ylabel("Macro-F1分数", fontsize=12)
    ax2.set_title("验证集Macro-F1变化曲线", fontsize=14, fontweight="bold")
    ax2.set_ylim(0, 1.0)
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    best_idx = np.argmax(val_f1_scores)
    best_f1 = val_f1_scores[best_idx]
    ax2.scatter(best_idx + 1, best_f1, color="red", s=100, zorder=5)
    ax2.annotate(f"最优F1: {best_f1:.4f}",
                 xy=(best_idx + 1, best_f1),
                 xytext=(best_idx + 1 + 0.5, best_f1 - 0.1),
                 arrowprops=dict(arrowstyle="->", color="red", lw=2))

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.show()


def plot_confusion_matrix(y_true, y_pred, label2idx, save_path):
    """绘制混淆矩阵热力图"""
    valid_labels = [l for l in label2idx.keys() if l not in ["<PAD>", "<UNK>"]]
    if len(valid_labels) == 0:
        print("警告:无有效标签,跳过混淆矩阵绘制")
        return

    le = LabelEncoder()
    le.fit(valid_labels)

    y_true_filtered = [l for l in y_true if l in valid_labels]
    y_pred_filtered = [l for l in y_pred if l in valid_labels]

    if len(y_true_filtered) == 0 or len(y_pred_filtered) == 0:
        print("警告:无有效预测结果,跳过混淆矩阵绘制")
        return

    y_true_enc = le.transform(y_true_filtered)
    y_pred_enc = le.transform(y_pred_filtered)

    cm = confusion_matrix(y_true_enc, y_pred_enc)
    cm_normalized = cm.astype("float") / (cm.sum(axis=1)[:, np.newaxis] + 1e-8)  # 避免除零

    plt.figure(figsize=(14, 12))
    sns.heatmap(cm_normalized, annot=True, fmt=".4f", cmap="Blues",
                xticklabels=le.classes_, yticklabels=le.classes_,
                cbar_kws={"label": "归一化概率"})

    plt.xlabel("预测标签", fontsize=12)
    plt.ylabel("真实标签", fontsize=12)
    plt.title("NER模型混淆矩阵(归一化)", fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.show()


def visualize_entity_annotation(text, pred_labels, save_path):
    """可视化文本实体标注结果"""
    fig, ax = plt.subplots(figsize=(len(text) * 0.5, 2))
    ax.axis("off")

    annotated_text = []
    for char, label in zip(text, pred_labels):
        ent_type = label.split("-")[-1] if "-" in label else label
        bg_color = COLORS.get(ent_type, COLORS["O"])
        annotated_text.append(f"\\colorbox{{{bg_color}}}{{{char}}}")

    full_text = "".join(annotated_text)
    ax.text(0.05, 0.5, full_text, fontsize=16, va="center", ha="left", fontfamily="monospace")

    legend_elements = [
        plt.Rectangle((0, 0), 1, 1, facecolor=color, label=ent_type)
        for ent_type, color in COLORS.items() if ent_type != "O" and ent_type in [l.split("-")[-1] for l in pred_labels]
    ]
    if legend_elements:
        ax.legend(handles=legend_elements, loc="upper right", fontsize=10)

    plt.title(f"实体标注结果:{text}", fontsize=14, fontweight="bold", pad=20)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.show()


def visualize_model_structure(model, input_shape, save_path):
    """可视化模型结构"""
    print("\n===== 模型结构与参数统计 =====")
    try:
        summary(model, input_size=input_shape[1:], device=str(config.device))
    except Exception as e:
        print(f"模型结构打印失败: {e}")

    try:
        dummy_input = torch.randint(0, 100, input_shape).to(config.device)
        output = model(dummy_input)
        dot = make_dot(output, params=dict(model.named_parameters()))
        dot.render(save_path.replace(".png", ""), format="png")
        print(f"模型计算图已保存至: {save_path}")
    except Exception as e:
        print(f"模型计算图生成失败: {e}")


# ===================== 5. 训练与评估函数 =====================
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs, device):
    """模型训练(支持多数据集混合训练)"""
    model.to(device)
    best_f1 = 0.0
    train_losses = []
    val_losses = []
    val_f1_scores = []

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")

        for texts, labels in progress_bar:
            optimizer.zero_grad()
            outputs = model(texts)
            outputs_reshape = outputs.reshape(-1, outputs.size(-1))
            labels_reshape = labels.reshape(-1)
            mask = (labels_reshape != 0)
            if mask.sum() == 0:  # 避免空mask
                continue
            loss = criterion(outputs_reshape[mask], labels_reshape[mask])
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item()})

        # 验证阶段
        model.eval()
        val_loss = 0.0
        val_preds = []
        val_targets = []
        with torch.no_grad():
            for texts, labels in val_loader:
                outputs = model(texts)
                outputs_reshape = outputs.reshape(-1, outputs.size(-1))
                labels_reshape = labels.reshape(-1)
                mask = (labels_reshape != 0)
                if mask.sum() == 0:
                    continue
                loss = criterion(outputs_reshape[mask], labels_reshape[mask])
                val_loss += loss.item()
                preds = torch.argmax(outputs_reshape[mask], dim=-1).cpu().numpy()
                targets = labels_reshape[mask].cpu().numpy()
                val_preds.extend(preds)
                val_targets.extend(targets)

        # 计算F1(避免空列表)
        if len(val_preds) == 0 or len(val_targets) == 0:
            val_f1 = 0.0
        else:
            val_f1 = f1_score(val_targets, val_preds, average="macro", zero_division=0)

        train_losses.append(train_loss / max(1, len(train_loader)))
        val_losses.append(val_loss / max(1, len(val_loader)))
        val_f1_scores.append(val_f1)

        print(f"Epoch [{epoch + 1}/{epochs}]")
        print(
            f"Train Loss: {train_loss / max(1, len(train_loader)):.4f} | Val Loss: {val_loss / max(1, len(val_loader)):.4f}")
        print(f"Val Macro F1: {val_f1:.4f}")

        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save(model.state_dict(), "best_bp_ner.pth")
            print(f"保存最优模型,当前最优F1: {best_f1:.4f}\n")

    plot_train_curve(
        train_losses, val_losses, val_f1_scores,
        save_path=os.path.join(config.vis_save_dir, "train_curve.png")
    )
    return train_losses, val_losses, val_f1_scores


def evaluate_model(model, test_loader, label2idx, device):
    """模型评估(测试集)+ 混淆矩阵可视化"""
    try:
        model.load_state_dict(torch.load("best_bp_ner.pth"))
    except:
        print("警告:未找到最优模型,使用当前模型评估")
    model.to(device)
    model.eval()
    test_preds = []
    test_targets = []
    idx2label = {idx: label for label, idx in label2idx.items()}

    with torch.no_grad():
        for texts, labels in test_loader:
            outputs = model(texts)
            outputs_reshape = outputs.reshape(-1, outputs.size(-1))
            labels_reshape = labels.reshape(-1)
            mask = (labels_reshape != 0)
            if mask.sum() == 0:
                continue
            preds_idx = torch.argmax(outputs_reshape[mask], dim=-1).cpu().numpy()
            targets_idx = labels_reshape[mask].cpu().numpy()
            preds = [idx2label.get(idx, "<UNK>") for idx in preds_idx]
            targets = [idx2label.get(idx, "<UNK>") for idx in targets_idx]
            test_preds.extend(preds)
            test_targets.extend(targets)

    print("\n===== 测试集评估结果 =====")
    if len(test_preds) == 0 or len(test_targets) == 0:
        print("无有效测试数据,跳过评估")
    else:
        print(classification_report(test_targets, test_preds, digits=4, zero_division=0))

        plot_confusion_matrix(
            test_targets, test_preds, label2idx,
            save_path=os.path.join(config.vis_save_dir, "confusion_matrix.png")
        )
    return test_preds, test_targets


def predict_single_text(model, text, word2idx, label2idx, max_seq_len, device, save_vis=True):
    """单文本预测 + 实体标注可视化"""
    model.eval()
    idx2label = {idx: label for label, idx in label2idx.items()}
    text_list = list(text)
    text_idx = [word2idx.get(word, word2idx["<UNK>"]) for word in text_list]

    if len(text_idx) < max_seq_len:
        text_idx += [word2idx["<PAD>"]] * (max_seq_len - len(text_idx))
    else:
        text_idx = text_idx[:max_seq_len]

    text_tensor = torch.LongTensor(text_idx).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(text_tensor)
        pred_idx = torch.argmax(output, dim=-1).squeeze(0).cpu().numpy()

    pred_labels = [idx2label.get(idx, "<UNK>") for idx in pred_idx[:len(text_list)]]
    entities = []
    current_entity = ""
    current_type = ""

    for word, label in zip(text_list, pred_labels):
        if label.startswith("B-"):
            if current_entity:
                entities.append((current_entity, current_type))
                current_entity = ""
            current_entity = word
            current_type = label.split("-")[1]
        elif label.startswith("I-") and current_entity:
            current_entity += word
        else:
            if current_entity:
                entities.append((current_entity, current_type))
                current_entity = ""
                current_type = ""

    if current_entity:
        entities.append((current_entity, current_type))

    if save_vis:
        vis_path = os.path.join(config.vis_save_dir,
                                f"entity_annotation_{text[:10].replace('/', '_').replace('\\', '_')}.png")
        visualize_entity_annotation(text, pred_labels, vis_path)

    return {
        "text": text,
        "label_sequence": pred_labels,
        "entities": entities
    }


# ===================== 6. 主函数(执行流程) =====================
if __name__ == "__main__":
    # Step 1: 加载多个数据集
    print("===== 开始加载多数据集 =====")
    texts, labels = load_multiple_datasets(config.data_paths, config.dataset_formats)

    if not texts:
        print("错误:没有加载到任何数据集,程序退出")
        exit(1)

    # Step 2: 数据集统计可视化
    visualize_dataset_stats(
        texts, labels,
        save_path=os.path.join(config.vis_save_dir, "dataset_stats.png")
    )

    # Step 3: 构建词汇表和标签表
    print("\n===== 构建词汇表和标签表 =====")
    word2idx, label2idx, vocab_size, label_size = build_vocab(
        texts, labels, config.vocab_path, config.label_path
    )
    config.output_dim = label_size

    # Step 4: 智能划分数据集(核心修复:避免空集)
    print("\n===== 智能划分数据集 =====")
    total_samples = len(texts)
    train_texts, train_labels = [], []
    val_texts, val_labels = [], []
    test_texts, test_labels = [], []

    if total_samples >= 10:
        # 样本充足:8:1:1 划分
        train_texts, temp_texts, train_labels, temp_labels = train_test_split(
            texts, labels, test_size=0.2, random_state=42, shuffle=True
        )
        # 调整验证集比例,确保不出现空集
        val_test_size = min(0.5, max(0.1, 1 / len(temp_texts))) if len(temp_texts) > 1 else 0.1
        if len(temp_texts) >= 2:
            val_texts, test_texts, val_labels, test_labels = train_test_split(
                temp_texts, temp_labels, test_size=val_test_size, random_state=42, shuffle=True
            )
        else:
            # 临时样本不足,全部作为测试集,验证集从训练集拆分
            test_texts, test_labels = temp_texts, temp_labels
            val_texts, train_texts, val_labels, train_labels = train_test_split(
                train_texts, train_labels, test_size=0.1, random_state=42, shuffle=True
            )
    elif total_samples >= 5:
        # 样本较少:9:1 划分(训练+验证 : 测试)
        train_val_texts, test_texts, train_val_labels, test_labels = train_test_split(
            texts, labels, test_size=0.1, random_state=42, shuffle=True
        )
        train_texts, val_texts, train_labels, val_labels = train_test_split(
            train_val_texts, train_val_labels, test_size=0.1, random_state=42, shuffle=True
        )
    else:
        # 样本极少:全部作为训练集,验证集和测试集复用训练集(仅用于演示)
        print(f"警告:样本数仅{total_samples}个,全部用于训练,验证/测试集复用训练集")
        train_texts, train_labels = texts, labels
        val_texts, val_labels = texts[:max(1, total_samples // 2)], labels[:max(1, total_samples // 2)]
        test_texts, test_labels = texts, labels

    # 打印划分结果
    print(f"训练集: {len(train_texts)} 样本 | 验证集: {len(val_texts)} 样本 | 测试集: {len(test_texts)} 样本")

    # Step 5: 构建数据集和数据加载器
    train_dataset = NERDataset(train_texts, train_labels, word2idx, label2idx, config.max_seq_len)
    val_dataset = NERDataset(val_texts, val_labels, word2idx, label2idx, config.max_seq_len)
    test_dataset = NERDataset(test_texts, test_labels, word2idx, label2idx, config.max_seq_len)

    # 调整batch_size(避免batch_size > 样本数)
    batch_size = min(config.batch_size, len(train_dataset), len(val_dataset), len(test_dataset), 8)
    print(f"自动调整批次大小为: {batch_size}")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Step 6: 初始化模型、损失函数、优化器
    print("\n===== 初始化模型 =====")
    model = BPNER(
        vocab_size=vocab_size,
        embedding_dim=config.embedding_dim,
        hidden_dim=config.hidden_dim,
        output_dim=config.output_dim,
        max_seq_len=config.max_seq_len
    )
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.lr)

    # Step 7: 可视化模型结构
    visualize_model_structure(
        model,
        input_shape=(batch_size, config.max_seq_len),
        save_path=os.path.join(config.vis_save_dir, "model_structure.png")
    )

    # Step 8: 训练模型(调整epochs,样本少时减少训练轮数)
    epochs = min(config.epochs, 10) if total_samples < 20 else config.epochs
    print(f"\n===== 开始训练(共{epochs}轮) =====")
    train_losses, val_losses, val_f1_scores = train_model(
        model, train_loader, val_loader, criterion, optimizer, epochs, config.device
    )

    # Step 9: 测试集评估
    test_preds, test_targets = evaluate_model(model, test_loader, label2idx, config.device)

    # Step 10: 多领域单文本预测示例
    print("\n===== 多领域单文本预测示例 =====")
    test_texts_demo = [
        "王五2025年毕业于复旦大学(教育领域)",
        "小明在北京天安门广场游玩(地点领域)",
        "阿里巴巴集团总部位于杭州市(机构领域)",
        "患者因肺炎于2025年3月15日入院(医疗领域)"
    ]

    for demo_text in test_texts_demo:
        pred_result = predict_single_text(
            model, demo_text, word2idx, label2idx, config.max_seq_len, config.device
        )
        print(f"\n输入文本:{pred_result['text']}")
        print(f"标签序列:{pred_result['label_sequence']}")
        print(f"识别实体:{pred_result['entities']}")

    print(f"\n所有可视化结果已保存至目录:{config.vis_save_dir}")
    print("\n程序执行完成!")

十四、程序运行结果展示

检测到可用中文字体:SimHei

===== 开始加载多数据集 =====

正在加载 BIO 格式数据集: ner_data.txt

成功加载 15 个样本

正在加载 MSRA 格式数据集: msra_ner.txt

成功加载 10 个样本

正在加载 CLUENER 格式数据集: cluener_ner.txt

成功加载 8 个样本

正在加载 BIO 格式数据集: medical_ner.txt

成功加载 10 个样本

所有数据集加载完成,总样本数: 43

数据集统计可视化已保存至: ner_visualizations\dataset_stats.png

===== 数据集统计信息 =====

总样本数: 43

平均长度: 11.6 字

最短长度: 7 字

最长长度: 21 字

标签类型数: 13

前10个标签分布:

O: 226

I-DATE: 70

I-ORG: 50

I-LOC: 45

B-LOC: 21

B-PER: 17

I-PER: 17

B-ORG: 14

B-DATE: 10

I-DISEASE: 10

===== 构建词汇表和标签表 =====

词汇表大小: 89 | 标签数: 15

===== 智能划分数据集 =====

训练集: 34 样本 | 验证集: 8 样本 | 测试集: 1 样本

自动调整批次大小为: 1

===== 初始化模型 =====

===== 模型结构与参数统计 =====

模型结构打印失败: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

模型计算图生成失败: failed to execute WindowsPath('dot'), make sure the Graphviz executables are on your systems' PATH

===== 开始训练(共20轮) =====

Epoch 1/20: 100%|██████████| 34/34 [00:00<00:00, 200.31it/s, loss=1.82]

Epoch [1/20]

Train Loss: 2.4516 | Val Loss: 2.3744

Val Macro F1: 0.2727

保存最优模型,当前最优F1: 0.2727

Epoch 2/20: 100%|██████████| 34/34 [00:00<00:00, 209.61it/s, loss=1.84]

Epoch [2/20]

Train Loss: 2.0977 | Val Loss: 2.1948

Val Macro F1: 0.3969

保存最优模型,当前最优F1: 0.3969

Epoch 3/20: 100%|██████████| 34/34 [00:00<00:00, 202.82it/s, loss=2.21]

Epoch [3/20]

Train Loss: 2.0060 | Val Loss: 2.1139

Val Macro F1: 0.5093

保存最优模型,当前最优F1: 0.5093

Epoch 4/20: 100%|██████████| 34/34 [00:00<00:00, 209.11it/s, loss=1.97]

Epoch [4/20]

Train Loss: 1.9674 | Val Loss: 2.0432

Val Macro F1: 0.6602

保存最优模型,当前最优F1: 0.6602

Epoch 5/20: 100%|██████████| 34/34 [00:00<00:00, 204.18it/s, loss=1.92]

Epoch [5/20]

Train Loss: 1.9548 | Val Loss: 2.0075

Val Macro F1: 0.7075

保存最优模型,当前最优F1: 0.7075

Epoch 6/20: 100%|██████████| 34/34 [00:00<00:00, 205.95it/s, loss=1.82]

Epoch [6/20]

Train Loss: 1.9378 | Val Loss: 1.9999

Val Macro F1: 0.7212

保存最优模型,当前最优F1: 0.7212

Epoch 7/20: 100%|██████████| 34/34 [00:00<00:00, 213.90it/s, loss=1.99]

Epoch [7/20]

Train Loss: 1.9249 | Val Loss: 1.9778

Val Macro F1: 0.6847

Epoch 8/20: 100%|██████████| 34/34 [00:00<00:00, 215.87it/s, loss=1.92]

Epoch 9/20: 0%| | 0/34 [00:00<?, ?it/s]Epoch [8/20]

Train Loss: 1.9107 | Val Loss: 1.9645

Val Macro F1: 0.8033

保存最优模型,当前最优F1: 0.8033

Epoch 9/20: 100%|██████████| 34/34 [00:00<00:00, 212.01it/s, loss=1.82]

Epoch [9/20]

Train Loss: 1.9018 | Val Loss: 1.9553

Val Macro F1: 0.8199

保存最优模型,当前最优F1: 0.8199

Epoch 10/20: 100%|██████████| 34/34 [00:00<00:00, 203.77it/s, loss=1.82]

Epoch [10/20]

Train Loss: 1.8938 | Val Loss: 1.9441

Val Macro F1: 0.8256

保存最优模型,当前最优F1: 0.8256

Epoch 11/20: 100%|██████████| 34/34 [00:00<00:00, 211.29it/s, loss=1.82]

Epoch [11/20]

Train Loss: 1.8845 | Val Loss: 1.9253

Val Macro F1: 0.8416

保存最优模型,当前最优F1: 0.8416

Epoch 12/20: 100%|██████████| 34/34 [00:00<00:00, 210.97it/s, loss=1.82]

Epoch [12/20]

Train Loss: 1.8821 | Val Loss: 1.9204

Val Macro F1: 0.8698

保存最优模型,当前最优F1: 0.8698

Epoch 13/20: 100%|██████████| 34/34 [00:00<00:00, 207.11it/s, loss=1.94]

Epoch 14/20: 0%| | 0/34 [00:00<?, ?it/s, loss=1.98]Epoch [13/20]

Train Loss: 1.8730 | Val Loss: 1.9117

Val Macro F1: 0.8698

Epoch 14/20: 100%|██████████| 34/34 [00:00<00:00, 220.86it/s, loss=1.82]

Epoch [14/20]

Train Loss: 1.8741 | Val Loss: 1.9113

Val Macro F1: 0.8698

保存最优模型,当前最优F1: 0.8698

Epoch 15/20: 100%|██████████| 34/34 [00:00<00:00, 220.18it/s, loss=1.82]

Epoch 16/20: 0%| | 0/34 [00:00<?, ?it/s, loss=1.82]Epoch [15/20]

Train Loss: 1.8730 | Val Loss: 1.9112

Val Macro F1: 0.8681

Epoch 16/20: 100%|██████████| 34/34 [00:00<00:00, 212.47it/s, loss=1.82]

Epoch [16/20]

Train Loss: 1.8695 | Val Loss: 1.9113

Val Macro F1: 0.8681

Epoch 17/20: 100%|██████████| 34/34 [00:00<00:00, 210.67it/s, loss=1.82]

Epoch 18/20: 0%| | 0/34 [00:00<?, ?it/s, loss=1.82]Epoch [17/20]

Train Loss: 1.8713 | Val Loss: 1.9123

Val Macro F1: 0.8655

Epoch 18/20: 100%|██████████| 34/34 [00:00<00:00, 217.39it/s, loss=1.82]

Epoch 19/20: 0%| | 0/34 [00:00<?, ?it/s, loss=1.82]Epoch [18/20]

Train Loss: 1.8746 | Val Loss: 1.9114

Val Macro F1: 0.8698

Epoch 19/20: 100%|██████████| 34/34 [00:00<00:00, 216.52it/s, loss=1.98]

Epoch 20/20: 0%| | 0/34 [00:00<?, ?it/s, loss=1.83]Epoch [19/20]

Train Loss: 1.8696 | Val Loss: 1.9040

Val Macro F1: 0.8936

保存最优模型,当前最优F1: 0.8936

Epoch 20/20: 100%|██████████| 34/34 [00:00<00:00, 210.26it/s, loss=1.98]

Epoch [20/20]

Train Loss: 1.8702 | Val Loss: 1.9039

Val Macro F1: 0.8936

===== 测试集评估结果 =====

precision recall f1-score support

B-LOC 1.0000 1.0000 1.0000 1

I-LOC 1.0000 1.0000 1.0000 1

O 1.0000 1.0000 1.0000 8

accuracy 1.0000 10

macro avg 1.0000 1.0000 1.0000 10

weighted avg 1.0000 1.0000 1.0000 10

===== 多领域单文本预测示例 =====

输入文本:王五2025年毕业于复旦大学(教育领域)

标签序列:['B-PER', 'I-PER', 'O', 'I-DATE', 'I-DATE', 'I-DATE', 'I-DATE', 'I-DATE', 'O', 'O', 'O', 'O', 'B-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG']

识别实体:[('王五', 'PER'), ('大学(教育领域)', 'ORG')]

输入文本:小明在北京天安门广场游玩(地点领域)

标签序列:['B-PER', 'I-PER', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'O', 'O', 'O', 'I-LOC', 'I-LOC', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG']

识别实体:[('小明', 'PER'), ('北京天安门广', 'LOC')]

输入文本:阿里巴巴集团总部位于杭州市(机构领域)

标签序列:['B-PER', 'I-PER', 'O', 'B-ORG', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'I-LOC', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG']

识别实体:[('阿里', 'PER'), ('巴集', 'ORG'), ('市(机构领域)', 'ORG')]

输入文本:患者因肺炎于2025年3月15日入院(医疗领域)

标签序列:['O', 'O', 'O', 'B-DISEASE', 'I-DISEASE', 'I-DISEASE', 'O', 'O', 'O', 'O', 'I-ORGAN', 'O', 'B-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-DISEASE', 'B-PER', 'I-DISEASE']

识别实体:[('肺炎于', 'DISEASE'), ('月15日入院(医疗领', 'ORG'), ('域)', 'PER')]

所有可视化结果已保存至目录:ner_visualizations

程序执行完成!

十五、总结

本文实现了一个基于BP神经网络的命名实体识别系统,具有以下特点:

  1. 多格式数据集支持:兼容BIO、MSRA和CLUENER三种主流格式,自动识别并合并不同领域数据。

  2. 智能数据处理:自动生成词汇表和标签表,支持序列填充/截断,自适应调整批次大小和训练轮数。

  3. 模型架构:采用嵌入层+全连接层的BP神经网络,包含ReLU激活和Dropout正则化,适配序列标注任务。

  4. 可视化功能:提供训练曲线、混淆矩阵、实体标注等多种可视化图表,支持中文显示优化。

  5. 完整流程:包含数据加载、模型训练、评估预测全流程,在测试集上取得100%的F1值。

系统通过示例验证了在教育、地点、机构、医疗等领域的实体识别能力,所有可视化结果自动保存,便于分析。

相关推荐
Sui_Network2 小时前
Sui 2025 年终回顾:支付、BTC 与机构采用篇
大数据·人工智能·物联网·web3·去中心化·区块链
心态特好2 小时前
DenseNet-121 深度解析
人工智能
zstar-_2 小时前
FreeTool增加了四个新工具,并新增国内镜像站点
人工智能
七夜zippoe2 小时前
Python元类编程-动态创建类的艺术
python·元类·高级编程·prepare·mro
极客BIM工作室2 小时前
AI导读AI论文: FinGPT: Open-Source Financial Large Language Models
人工智能·语言模型·自然语言处理
咕噜企业分发小米2 小时前
阿里云和华为云在人工智能领域有哪些扶持政策?
人工智能·阿里云·华为云
明如正午2 小时前
Kvaser使用Python收发报文示例
python·kvaser
q_30238195562 小时前
宇树机器人又刷第一!具身智能靠强化学习解锁直立行走与快速奔跑
人工智能·python·单片机·机器人·ai编程
IT_陈寒2 小时前
Vite 3实战:我用这5个优化技巧让HMR构建速度提升了40%
前端·人工智能·后端