BERT + CRF实现的中文 NER模型训练

这里写自定义目录标题

  • [NER 训练实现代码+路径自动加载代码](#NER 训练实现代码+路径自动加载代码)
  • [路径自动加载代码 ,文件名 config.py](#路径自动加载代码 ,文件名 config.py)

NER 训练实现代码+路径自动加载代码

python 复制代码
# BERT + CRF 中文 NER 带早停功能版本
# 基于BertCRF_NER_WITH_EVAL.py,添加早停、类别权重等改进功能
# ----------------------------------------------------
# 新增特性:
# - 早停机制(基于验证loss)
# - 类别权重损失函数(处理数据不平衡)
# - 学习率监控和自动调整
# - 最佳模型保存
# - 详细的训练进度跟踪
# - 支持继续训练(从检查点恢复)
# - 【分层采样】确保每个标签在所有数据集中都有分布
#
# ===================================================
# 依赖安装
# ===================================================
# 基础依赖(必须):
#   pip install torch transformers pytorch-crf
#
# 分层采样依赖(推荐):
#   pip install iterative-stratification
#
# GPU监控依赖(可选):
#   pip install nvidia-ml-py3
# ===================================================

import os
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from torchcrf import CRF
from config import get_config
import datetime
import random
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
from pathlib import Path

# ==================== 超参数配置 ====================
class TrainingConfig:
    """训练配置类"""
    # ============ 基础训练参数 ============

    # EPOCHS: 训练轮数,表示整个数据集被训练的次数
    # 使用建议:
    #   - 数据量小(5K以下): 10-20轮
    #   - 数据量中等(5K-50K): 5-10轮
    #   - 数据量大(50K以上): 3-5轮
    #   - 配合早停机制可适当设大(如10-15),让模型自动停止
    EPOCHS = 5

    # BATCH_SIZE: 每个批次训练的样本数量
    # 使用建议: 根据GPU显存调整
    #   - 显存 8GB:  8-16
    #   - 显存 16GB: 16-32
    #   - 显存 24GB+: 32-64
    # 注意: batch_size越大显存占用越高,但训练更稳定;越小则训练越快但可能不稳定
    BATCH_SIZE = 32

    # LEARNING_RATE: 学习率,控制参数更新步长
    # 使用建议:
    #   - BERT类模型推荐: 1e-5 到 5e-5
    #   - 微调建议从小开始: 2e-5 是较安全的值
    #   - 如果loss震荡/NaN: 降低至1e-5
    #   - 如果收敛太慢: 提高至3e-5或5e-5
    LEARNING_RATE = 2e-5

    # WEIGHT_DECAY: 权重衰减(L2正则化系数),防止过拟合
    # 使用建议:
    #   - 标准值: 0.01 (BERT论文推荐)
    #   - 数据少/易过拟合: 0.01-0.1
    #   - 数据多/欠拟合: 0.001-0.01
    #   - 设为0禁用权重衰减
    WEIGHT_DECAY = 0.01

    # WARMUP_RATIO: 预热比例,前多少比例的步数用于学习率预热
    # 使用建议:
    #   - 预热期间学习率从0线性增加到LEARNING_RATE
    #   - 标准值: 0.1-0.2 (即10%-20%步数预热)
    #   - 数据少: 可设0.1
    #   - 数据多: 可设0.15-0.2
    #   - 总预热步数 = total_steps * WARMUP_RATIO
    WARMUP_RATIO = 0.15

    # MAX_GRAD_NORM: 梯度裁剪阈值,防止梯度爆炸
    # 使用建议:
    #   - BERT标准值: 1.0-5.0
    #   - 训练不稳定(loss震荡): 降低至1.0
    #   - 梯度正常: 可提高至5.0
    #   - 设为None禁用梯度裁剪
    MAX_GRAD_NORM = 3.0

    # ============ 早停参数 ============

    # PATIENCE: 早停耐心值,验证loss连续n个epoch未改善时停止训练
    # 使用建议:
    #   - 数据少: 3-5 (避免过拟合)
    #   - 数据多: 5-10 (给模型更多机会改善)
    #   - 训练快: 可设大一点
    #   - 训练慢: 设小一点节省时间
    PATIENCE = 3

    # MIN_DELTA: 最小改善阈值,验证loss下降超过此值才算有改善
    # 使用建议:
    #   - 标准值: 0.001-0.01
    #   - 要求高精度: 0.0001-0.001
    #   - 允许小幅波动: 0.01-0.05
    #   - 设为0则任何下降都算改善
    MIN_DELTA = 0.01

    # ============ 模型保存参数 ============

    # SAVE_BEST_MODEL: 是否保存最佳模型
    # 使用建议:
    #   - True: 自动保存验证loss最低的模型到MODEL_SAVE_DIR
    #   - False: 不保存,只用于实验
    SAVE_BEST_MODEL = True

    # MODEL_SAVE_DIR: 最佳模型保存路径
    # 使用建议: 相对或绝对路径均可,目录会自动创建
    MODEL_SAVE_DIR = "./saved_models"

    # CHECKPOINT_DIR: 检查点保存路径(每个epoch保存一次)
    # 使用建议: 用于训练中断后恢复训练
    CHECKPOINT_DIR = "./checkpoints"

    # ============ 类别权重参数 ============

    # USE_CLASS_WEIGHTS: 是否使用类别权重(处理数据不平衡)
    # 使用建议:
    #   - True: 自动计算类别权重,少数类权重大,多数类权重小
    #   - False: 所有类别同等对待
    # 适用场景: 标签分布严重不均时启用(如O标签占90%,实体标签仅10%)
    # 注意: 当前CRF实现未真正应用权重,此参数仅作预留
    USE_CLASS_WEIGHTS = True

    # ============ 数据集划分参数 ============

    # TRAIN_RATIO: 训练集比例
    # 使用建议:
    #   - 标准值: 0.7 (训练70%,验证15%,测试15%)
    #   - 数据少: 0.6-0.7 (留更多验证和测试)
    #   - 数据多: 0.7-0.8 (更多训练数据)
    # 注意: 配合分层采样使用,确保每个标签在所有数据集中都有分布
    TRAIN_RATIO = 0.7

    # VAL_RATIO: 验证集比例
    # 使用建议:
    #   - 标准值: 0.15 (15%)
    #   - 测试集比例自动 = 1 - TRAIN_RATIO - VAL_RATIO
    #   - 建议验证集和测试集大小相近,以便公平评估模型
    VAL_RATIO = 0.15

    # USE_STRATIFIED_SPLIT: 是否使用分层采样
    # 使用建议:
    #   - True: 使用iterative-stratification库,确保每个标签类型在所有数据集中都有分布
    #   - False: 使用简单随机采样(可能某些标签只出现在一个数据集中)
    # 适用场景: 数据不平衡、稀有标签较少时强烈推荐启用
    # 注意: 启用需要安装库: pip install iterative-stratification
    USE_STRATIFIED_SPLIT = True

    # STRATIFIED_MIN_SAMPLES: 分层采样时,每个标签在每个数据集中至少分配的样本数
    # 使用建议:
    #   - 标准值: 1-3
    #   - 稀有标签样本少: 设为1,确保至少有1个样本
    #   - 稀有标签样本多: 可设为2-3,提高分布均匀性
    #   - 如果某个标签总样本数 < STRATIFIED_MIN_SAMPLES * 3,将按比例分配
    STRATIFIED_MIN_SAMPLES = 1

    # RANDOM_SEED: 随机种子,保证实验可复现
    # 使用建议: 保持固定值(如42)以确保每次运行结果一致
    RANDOM_SEED = 42

# ==================== GPU监控模块 ====================
try:
    import pynvml
    pynvml.nvmlInit()
    GPU_MONITORING_AVAILABLE = True
    print("✅ GPU监控功能已启用")
except ImportError:
    GPU_MONITORING_AVAILABLE = False
    print("⚠️ pynvml未安装,GPU监控功能将禁用")
    print("   安装命令: pip install pynvml")

class GPUMonitor:
    """GPU监控类"""

    def __init__(self):
        self.available = GPU_MONITORING_AVAILABLE
        self.device_count = 0
        self.gpu_handles = []

        if self.available:
            try:
                self.device_count = pynvml.nvmlDeviceGetCount()
                for i in range(self.device_count):
                    handle = pynvml.nvmlDeviceGetHandleByIndex(i)
                    self.gpu_handles.append(handle)
            except Exception as e:
                print(f"⚠️ GPU监控初始化失败: {e}")
                self.available = False

    def get_gpu_info(self, device_id=0):
        """获取GPU信息"""
        if not self.available or device_id >= len(self.gpu_handles):
            return {
                'memory_used_mb': 0,
                'memory_total_mb': 0,
                'memory_util': 0.0,
                'gpu_util': 0.0,
                'temperature': 0,
                'power_watts': 0.0
            }

        try:
            handle = self.gpu_handles[device_id]
            mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
            memory_used_mb = mem_info.used // 1024**2
            memory_total_mb = mem_info.total // 1024**2
            memory_util = (mem_info.used / mem_info.total) * 100

            try:
                util_rates = pynvml.nvmlDeviceGetUtilizationRates(handle)
                gpu_util = util_rates.gpu
            except:
                gpu_util = 0

            try:
                temperature = pynvml.nvmlDeviceGetTemperature(
                    handle, pynvml.NVML_TEMPERATURE_GPU)
            except:
                temperature = 0

            try:
                power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0
            except:
                power = 0.0

            return {
                'memory_used_mb': memory_used_mb,
                'memory_total_mb': memory_total_mb,
                'memory_util': memory_util,
                'gpu_util': gpu_util,
                'temperature': temperature,
                'power_watts': power
            }

        except Exception as e:
            return {
                'memory_used_mb': 0,
                'memory_total_mb': 0,
                'memory_util': 0.0,
                'gpu_util': 0.0,
                'temperature': 0,
                'power_watts': 0.0,
                'error': str(e)
            }

    def format_gpu_info(self, gpu_info):
        """格式化GPU信息用于日志"""
        if gpu_info.get('error'):
            return f"GPU监控错误: {gpu_info['error']}"

        return (f"显存: {gpu_info['memory_used_mb']}/{gpu_info['memory_total_mb']}MB "
                f"({gpu_info['memory_util']:.1f}%) | "
                f"利用率: {gpu_info['gpu_util']}% | "
                f"温度: {gpu_info['temperature']}°C | "
                f"功耗: {gpu_info['power_watts']:.1f}W")

# 初始化GPU监控器
gpu_monitor = GPUMonitor()

# =====================================================
# Device配置
# =====================================================
if torch.cuda.is_available():
    device = torch.device("cuda")
    gpu_name = torch.cuda.get_device_name(0)
    print(f"🎮 使用GPU训练: {gpu_name}")

    if gpu_monitor.available:
        gpu_info = gpu_monitor.get_gpu_info(0)
        print(f"   显存: {gpu_info['memory_total_mb']}MB")
else:
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
    print(f"🎯 使用{device}训练")

if torch.cuda.is_available():
    torch.cuda.empty_cache()

# =====================================================
# 1. 数据加载和划分
# =====================================================
config = get_config()
INPUT_JSONL = config.legacy_fallback(config.msra_train_annotated, "msra_train_annotated")
ANNOTATED_TRAIN_FILE = config.legacy_fallback(config.annotated_train, "annotated_train")

def convert_annotated_jsonl(path):
    dataset_samples = []
    try:
        with open(path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if not line:
                    continue
                try:
                    data = json.loads(line)
                    if "text" in data and "labels" in data:
                        dataset_samples.append(data)
                except json.JSONDecodeError as e:
                    print(f"⚠ 第{line_num}行JSON解析失败: {e}")
        return dataset_samples
    except FileNotFoundError:
        print(f"❌ 文件不存在: {path}")
        return []
    except Exception as e:
        print(f"❌ 读取文件失败: {e}")
        return []

# 加载数据
dataset_samples = []
if os.path.exists(ANNOTATED_TRAIN_FILE):
    print(f"📝 使用直接标注数据: {ANNOTATED_TRAIN_FILE}")
    dataset_samples = convert_annotated_jsonl(ANNOTATED_TRAIN_FILE)
elif os.path.exists(INPUT_JSONL):
    print(f"📝 使用原始MSRA数据: {INPUT_JSONL}")
    from ner_converter import convert_jsonl
    dataset_samples = convert_jsonl(INPUT_JSONL)
else:
    print("⚠ 未检测到任何数据文件,使用示例数据。")
    dataset_samples = [
        {"text": "我爱北京天安门", "labels": ["O", "O", "B-LOC", "I-LOC", "I-LOC", "I-LOC", "I-LOC"]},
        {"text": "张三来自中国上海", "labels": ["B-PER", "I-PER", "O", "O", "B-LOC", "I-LOC", "I-LOC"]},
        {"text": "李明在清华大学工作", "labels": ["B-PER", "I-PER", "O", "B-ORG", "I-ORG", "I-ORG", "I-ORG", "O"]},
    ]

print(f"📊 原始数据集大小: {len(dataset_samples)} 条样本")

def split_dataset(samples, train_ratio=0.7, val_ratio=0.15, random_seed=42, use_stratified=True):
    """
    划分数据集为训练集、验证集、测试集

    支持两种划分方式:
    1. 分层采样(use_stratified=True): 使用iterative-stratification库,
       确保每个标签类型在所有数据集中都有分布,适合数据不平衡场景
    2. 随机采样(use_stratified=False): 简单随机打乱后划分

    Args:
        samples: 样本列表,每个样本包含 text 和 labels
        train_ratio: 训练集比例 (默认0.7)
        val_ratio: 验证集比例 (默认0.15),测试集比例 = 1 - train_ratio - val_ratio
        random_seed: 随机种子 (默认42)
        use_stratified: 是否使用分层采样 (默认True,推荐)

    Returns:
        train_samples, val_samples, test_samples: 划分后的三个数据集
    """
    from collections import Counter

    if use_stratified:
        # ========== 分层采样(推荐)==========
        try:
            from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
            import numpy as np

            print(f"🎯 使用分层采样划分数据集...")

            # 1. 构建标签集合
            all_labels = sorted(list(set(l for s in samples for l in s["labels"] if l != "O")))
            print(f"📊 发现 {len(all_labels)} 种实体类型: {all_labels}")

            # 2. 构建多标签矩阵 (每行是一个样本的标签向量)
            label2id = {l: i for i, l in enumerate(all_labels)}
            y = np.zeros((len(samples), len(all_labels)), dtype=int)

            for idx, sample in enumerate(samples):
                for label in sample["labels"]:
                    if label in label2id:
                        y[idx, label2id[label]] = 1

            # 3. 统计初始标签分布
            initial_dist = Counter([l for s in samples for l in s["labels"] if l != "O"])
            print(f"📋 原始数据集标签分布: {dict(initial_dist)}")

            # 4. 使用迭代分层采样
            # 第一次划分:训练+验证 vs 测试
            indices = np.arange(len(samples))
            test_ratio = 1 - train_ratio - val_ratio

            # 计算第一个K折的折数(根据测试集比例)
            n_splits_1 = max(3, int(1 / test_ratio))  # 至少3折
            actual_test_ratio = 1.0 / n_splits_1

            print(f"\n📐 分层采样参数:")
            print(f"   配置比例 - 训练集: {train_ratio*100:.1f}%, 验证集: {val_ratio*100:.1f}%, 测试集: {test_ratio*100:.1f}%")
            print(f"   第一次划分: {n_splits_1}折交叉验证 → 实际测试集比例: {actual_test_ratio*100:.1f}%")

            # 创建分层K折,取第一折作为测试集
            mskf = MultilabelStratifiedKFold(
                n_splits=n_splits_1,
                shuffle=True,
                random_state=random_seed
            )

            # 获取第一次划分:训练验证集 + 测试集
            folds = list(mskf.split(indices, y))
            train_val_idx, test_idx = folds[0]

            # 第二次划分:从训练验证集中划分训练集和验证集
            y_train_val = y[train_val_idx]
            train_val_indices = np.arange(len(train_val_idx))

            # 计算第二次K折的折数
            train_val_ratio = 1 - actual_test_ratio
            val_ratio_from_train_val = val_ratio / train_val_ratio
            n_splits_2 = max(2, int(1 / val_ratio_from_train_val))  # 至少2折
            actual_val_ratio_total = actual_test_ratio + (train_val_ratio / n_splits_2)
            actual_train_ratio = 1 - actual_val_ratio_total
            actual_val_ratio_only = actual_val_ratio_total - actual_test_ratio

            print(f"   第二次划分: {n_splits_2}折交叉验证 → 实际验证集比例: {actual_val_ratio_only*100:.1f}%")
            print(f"   ✅ 最终划分比例 - 训练集: {actual_train_ratio*100:.1f}%, 验证集: {actual_val_ratio_only*100:.1f}%, 测试集: {actual_test_ratio*100:.1f}%")

            # 使用不同的随机种子进行第二次划分
            mskf2 = MultilabelStratifiedKFold(
                n_splits=n_splits_2,
                shuffle=True,
                random_state=random_seed + 1
            )

            folds2 = list(mskf2.split(train_val_indices, y_train_val))
            train_idx, val_idx = folds2[0]

            # 转换回原始索引
            train_indices = train_val_idx[train_idx]
            val_indices = train_val_idx[val_idx]

            # 5. 构建最终数据集
            train_samples = [samples[i] for i in train_indices]
            val_samples = [samples[i] for i in val_indices]
            test_samples = [samples[i] for i in test_idx]

            # 6. 验证标签分布
            def get_label_dist(samples_subset):
                dist = Counter([l for s in samples_subset for l in s["labels"] if l != "O"])
                return dict(dist)

            train_dist = get_label_dist(train_samples)
            val_dist = get_label_dist(val_samples)
            test_dist = get_label_dist(test_samples)

            print(f"\n✅ 分层采样结果:")
            print(f"   训练集标签分布: {train_dist}")
            print(f"   验证集标签分布: {val_dist}")
            print(f"   测试集标签分布: {test_dist}")

            # 检查是否有标签缺失
            missing_in_train = [l for l in all_labels if l not in train_dist]
            missing_in_val = [l for l in all_labels if l not in val_dist]
            missing_in_test = [l for l in all_labels if l not in test_dist]

            if missing_in_train:
                print(f"⚠️  警告: 以下标签在训练集中缺失: {missing_in_train}")
            if missing_in_val:
                print(f"⚠️  警告: 以下标签在验证集中缺失: {missing_in_val}")
            if missing_in_test:
                print(f"⚠️  警告: 以下标签在测试集中缺失: {missing_in_test}")

            if not (missing_in_train or missing_in_val or missing_in_test):
                print(f"✅ 所有标签在三个数据集中都有分布!")

            return train_samples, val_samples, test_samples

        except ImportError:
            print("⚠️  警告: 未安装 iterative-stratification 库")
            print("💡 请运行: pip install iterative-stratification")
            print("🔄 将回退到简单随机采样...")
            use_stratified = False

    if not use_stratified:
        # ========== 简单随机采样(备用)==========
        print(f"🎲 使用简单随机采样划分数据集...")

        random.seed(random_seed)
        random.shuffle(samples)

        # 计算划分点
        n = len(samples)
        train_end = int(n * train_ratio)
        val_end = int(n * (train_ratio + val_ratio))

        train_samples = samples[:train_end]
        val_samples = samples[train_end:val_end]
        test_samples = samples[val_end:]

        # 统计标签分布
        def get_label_dist(samples_subset):
            dist = Counter([l for s in samples_subset for l in s["labels"] if l != "O"])
            return dict(dist)

        train_dist = get_label_dist(train_samples)
        val_dist = get_label_dist(val_samples)
        test_dist = get_label_dist(test_samples)

        print(f"\n✅ 随机采样结果:")
        print(f"   训练集标签分布: {train_dist}")
        print(f"   验证集标签分布: {val_dist}")
        print(f"   测试集标签分布: {test_dist}")

        return train_samples, val_samples, test_samples

# 划分数据集:训练/验证/测试 = 70%/15%/15%
train_samples, val_samples, test_samples = split_dataset(
    dataset_samples,
    train_ratio=TrainingConfig.TRAIN_RATIO,
    val_ratio=TrainingConfig.VAL_RATIO,
    random_seed=TrainingConfig.RANDOM_SEED,
    use_stratified=TrainingConfig.USE_STRATIFIED_SPLIT
)

print(f"📊 训练集大小: {len(train_samples)} 条样本")
print(f"📊 验证集大小: {len(val_samples)} 条样本")
print(f"📊 测试集大小: {len(test_samples)} 条样本")

# 构建标签映射(从所有数据集中收集标签,避免测试集出现未知标签)
all_labels = set()
for samples in [train_samples, val_samples, test_samples]:
    all_labels.update(l for s in samples for l in s["labels"])

label_list = ["O"] + sorted(l for l in all_labels if l != "O")
label2id = {lab: i for i, lab in enumerate(label_list)}
id2label = {v: k for k, v in label2id.items()}

print(f"📊 标签类别数: {len(label_list)}")
print(f"📋 标签列表: {label_list}")

# 计算类别权重
def calculate_class_weights(train_samples, label2id, device):
    """计算类别权重,用于处理数据不平衡"""
    from collections import Counter

    label_counts = Counter()
    for sample in train_samples:
        for label in sample["labels"]:
            label_counts[label] += 1

    total_labels = sum(label_counts.values())
    num_classes = len(label2id)

    # 计算权重(倒数频率,归一化)
    class_counts = np.array([label_counts.get(l, 0) for l in label_list])

    # 避免除零错误
    class_counts = np.maximum(class_counts, 1)

    # 计算权重:权重与类别频率成反比
    weights = total_labels / (num_classes * class_counts)
    # 归一化权重到1附近
    weights = weights / np.mean(weights)

    print(f"📊 类别统计和权重:")
    for label, count, weight in zip(label_list, class_counts, weights):
        print(f"  {label:12}: {count:6d} 样本 | 权重: {weight:.4f}")

    return torch.FloatTensor(weights).to(device)

# 计算类别权重
class_weights = calculate_class_weights(train_samples, label2id, device) if TrainingConfig.USE_CLASS_WEIGHTS else None

# =====================================================
# 2. 数据集类
# =====================================================
class NERDataset(Dataset):
    def __init__(self, samples, max_len=128):
        self.samples = samples
        self.max_len = max_len

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        text = self.samples[idx]["text"]
        labels = self.samples[idx]["labels"]

        encoding = tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt",
            add_special_tokens=True
        )

        input_ids = encoding["input_ids"].squeeze(0)
        attention_mask = encoding["attention_mask"].squeeze(0)

        # 标签对齐
        aligned_labels = [label2id.get("O", 0)]
        chars = list(text)
        if len(labels) != len(chars):
            min_len = min(len(labels), len(chars))
            for i in range(min_len):
                aligned_labels.append(label2id[labels[i]])
        else:
            for tag in labels:
                aligned_labels.append(label2id[tag])

        aligned_labels.append(label2id.get("O", 0))

        while len(aligned_labels) < self.max_len:
            aligned_labels.append(label2id.get("O", 0))

        label_ids = torch.tensor(aligned_labels[:self.max_len], dtype=torch.long)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": label_ids
        }

# 创建tokenizer
LOCAL_MODEL_PATH = "/home/zhangmanman/.cache/huggingface/hub/models--bert-base-chinese/snapshots/8f23c25b06e129b6c986331a13d8d025a92cf0ea/"
print(f"🔄 使用本地预训练模型: {LOCAL_MODEL_PATH}")
tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_PATH)
print(f"✅ Tokenizer加载成功,词汇表大小: {len(tokenizer)}")

train_dataset = NERDataset(train_samples, max_len=128)
val_dataset = NERDataset(val_samples, max_len=128)
test_dataset = NERDataset(test_samples, max_len=128)

# 创建DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=TrainingConfig.BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=TrainingConfig.BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    drop_last=False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=TrainingConfig.BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    drop_last=False
)

# =====================================================
# 3. 模型定义(支持类别权重)
# =====================================================
class BertCRFNER(nn.Module):
    def __init__(self, model_name, num_labels, class_weights=None):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.crf = CRF(num_labels, batch_first=True)

        self.class_weights = class_weights

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        sequence_output = self.dropout(outputs.last_hidden_state)
        emissions = self.classifier(sequence_output)

        if labels is not None:
            if self.class_weights is not None:
                # 注意: pytorch-crf不支持直接应用类别权重
                # 要实现真正的weighted CRF需要自定义CRF实现
                # 这里暂时禁用权重计算,使用标准CRF损失
                # 如需处理类别不平衡,建议使用重采样或focal loss等替代方案
                pass
            loss = -self.crf(emissions, labels, mask=attention_mask.bool(), reduction="mean")
            return loss
        else:
            return self.crf.decode(emissions, mask=attention_mask.bool())

# =====================================================
# 4. 早停类
# =====================================================
class EarlyStopping:
    """早停机制类"""

    def __init__(self, patience=3, min_delta=0.01, restore_best_weights=True):
        """
        Args:
            patience: 验证loss不下降的容忍次数
            min_delta: 认为有改善的最小变化量
            restore_best_weights: 是否恢复最佳权重
        """
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights

        self.counter = 0
        self.best_loss = None
        self.best_model_state = None
        self.early_stop = False
        self.best_epoch = 0

    def __call__(self, current_val_loss, model, epoch):
        """
        检查是否应该早停

        Returns:
            bool: 是否应该停止训练
        """

        if self.best_loss is None:
            self.best_loss = current_val_loss
            self.best_epoch = epoch
            self.best_model_state = {
                'model_state_dict': model.state_dict(),
                'epoch': epoch
            }
            return False

        # 检查验证损失是否改善
        if current_val_loss < self.best_loss - self.min_delta:
            # 验证loss有改善
            old_best_loss = self.best_loss
            self.best_loss = current_val_loss
            self.best_epoch = epoch
            self.counter = 0
            self.best_model_state = {
                'model_state_dict': model.state_dict(),
                'epoch': epoch
            }
            print(f"✅ 验证loss改善: {current_val_loss:.4f} < {old_best_loss:.4f}")
            return False
        else:
            # 验证loss没有改善
            self.counter += 1
            print(f"⚠️ 验证loss未改善: {current_val_loss:.4f} >= {self.best_loss:.4f} "
                  f"(计数: {self.counter}/{self.patience})")

            if self.counter >= self.patience:
                self.early_stop = True
                print(f"🛑 早停触发:验证loss连续{self.patience}个epoch未改善")
                return True
            return False

    def restore_model(self, model, optimizer=None):
        """恢复最佳模型权重"""
        if self.best_model_state is None:
            return

        if self.restore_best_weights and self.best_model_state is not None:
            model.load_state_dict(self.best_model_state['model_state_dict'])
            print(f"🔄 恢复最佳模型(Epoch {self.best_epoch})")

            # 注意: 我们不恢复optimizer状态,因为这可能导致训练不稳定
            # 如果需要恢复optimizer,应该在checkpoint中单独保存

    def save_checkpoint(self, filepath, model, optimizer, epoch, train_losses, val_losses,
                        optimizer_state_dict=None):
        """保存训练检查点"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict() if optimizer_state_dict is not None
                                     else optimizer.state_dict(),
            'train_losses': train_losses,
            'val_losses': val_losses,
            'label2id': label2id,
            'id2label': id2label,
            'best_val_loss': self.best_loss,
            'best_epoch': self.best_epoch,
            'counter': self.counter,
            'best_model_state': self.best_model_state  # 保存最佳模型状态
        }

        torch.save(checkpoint, filepath)
        print(f"💾 检查点已保存: {filepath}")

    def load_checkpoint(self, filepath, model, optimizer=None):
        """从检查点恢复训练"""
        if not os.path.exists(filepath):
            raise FileNotFoundError(f"检查点文件不存在: {filepath}")

        checkpoint = torch.load(filepath)

        # 恢复模型状态
        model.load_state_dict(checkpoint['model_state_dict'])

        # 恢复优化器状态(如果有)
        if optimizer is not None and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # 恢复训练历史
        train_losses = checkpoint.get('train_losses', [])
        val_losses = checkpoint.get('val_losses', [])

        print(f"📂 从检查点恢复训练:")
        print(f"   Epoch: {checkpoint['epoch']}")
        print(f"   最佳验证loss: {checkpoint.get('best_val_loss', 'N/A')}")
        print(f"   最佳epoch: {checkpoint.get('best_epoch', 'N/A')}")

        # 恢复早停状态
        self.best_loss = checkpoint.get('best_val_loss', None)
        self.best_epoch = checkpoint.get('best_epoch', 0)
        self.counter = checkpoint.get('counter', 0)
        if self.best_loss is not None:
            self.best_model_state = checkpoint.get('best_model_state', None)

        return train_losses, val_losses

# 创建必要的目录
Path(TrainingConfig.MODEL_SAVE_DIR).mkdir(parents=True, exist_ok=True)
Path(TrainingConfig.CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)

# =====================================================
# 5. 评估函数
# =====================================================
def evaluate_model(model, data_loader, device, log_file=None):
    """评估模型性能"""
    model.eval()

    all_predictions = []
    all_true_labels = []

    print(f"\n🧪 开始评估数据集...")

    def write_log(msg):
        """评估专用的日志写入函数"""
        print(msg)
        if log_file:
            timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            with open(log_file, "a", encoding="utf-8") as f:
                f.write(f"[{timestamp}] {msg}\n")

    with torch.no_grad():
        for batch_idx, batch in enumerate(data_loader):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            predictions = model(input_ids, attention_mask)

            for i in range(len(predictions)):
                pred_ids = predictions[i]
                true_ids = labels[i].cpu().numpy()
                mask = attention_mask[i].cpu().numpy()

                valid_length = sum(mask)
                pred_ids = pred_ids[:valid_length]
                true_ids = true_ids[:valid_length]

                if len(pred_ids) >= 2:
                    pred_ids = pred_ids[1:-1]
                if len(true_ids) >= 2:
                    true_ids = true_ids[1:-1]

                pred_labels = [id2label.get(pid, "O") for pid in pred_ids]
                true_labels = [id2label.get(tid, "O") for tid in true_ids]

                all_predictions.extend(pred_labels)
                all_true_labels.extend(true_labels)

    # 生成分类报告
    report = classification_report(all_true_labels, all_predictions, zero_division=0)

    if log_file:
        log_message("\n" + "="*60, log_file)
        log_message("📊 模型评估结果", log_file)
        log_message("="*60, log_file)
        log_message(f"数据集大小: {len(data_loader.dataset)}", log_file)
        log_message(f"预测标签总数: {len(all_predictions)}", log_file)
        log_message("\n📈 详细分类报告:", log_file)

        for line in report.split('\n'):
            log_message(line, log_file)

        log_message("="*60, log_file)

    print("\n" + report)
    return report

def log_message(message, log_file):
    """日志记录函数"""
    print(message)
    if log_file:
        timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        with open(log_file, "a", encoding="utf-8") as f:
            f.write(f"[{timestamp}] {message}\n")

# =====================================================
# 6. 训练循环(带早停)
# =====================================================
def train_with_early_stopping(
    model, train_loader, val_loader, test_loader,
    optimizer, scheduler, device,
    config, log_file=None
):
    """带早停的训练循环"""

    # 初始化早停
    early_stopping = EarlyStopping(
        patience=config.PATIENCE,
        min_delta=config.MIN_DELTA,
        restore_best_weights=True
    )

    # 训练历史
    train_losses = []
    val_losses = []

    # 检查是否有之前的检查点
    checkpoint_path = Path(TrainingConfig.CHECKPOINT_DIR) / "latest_checkpoint.pt"
    start_epoch = 0

    if checkpoint_path.exists():
        try:
            train_losses, val_losses = early_stopping.load_checkpoint(
                checkpoint_path, model, optimizer
            )
            start_epoch = len(train_losses)
            print(f"📂 从检查点恢复训练,从epoch {start_epoch}继续")
        except Exception as e:
            print(f"⚠️ 加载检查点失败: {e}")
            print("🔄 从头开始训练")
            train_losses = []
            val_losses = []
            start_epoch = 0
    else:
        print("🚀 从头开始训练")
        train_losses = []
        val_losses = []

    # 计算总步数用于进度显示
    total_steps = config.EPOCHS * len(train_loader)
    log_message(f"🎯 开始训练,总步数: {total_steps}", log_file)

    # 训练循环
    for epoch in range(start_epoch, config.EPOCHS):
        model.train()
        total_loss = 0.0
        epoch_grad_norms = []

        # GPU监控:记录epoch开始状态
        if gpu_monitor.available and torch.cuda.is_available():
            gpu_start = gpu_monitor.get_gpu_info(0)
            log_message(f"\n🖥️ Epoch {epoch+1}/{config.EPOCHS} 开始 GPU状态: {gpu_monitor.format_gpu_info(gpu_start)}", log_file)

        # 训练一个epoch
        for batch_idx, batch in enumerate(train_loader):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()

            try:
                # 前向传播
                loss = model(input_ids, attention_mask, labels)

                # 反向传播
                loss.backward()

                # 梯度裁剪(标准实现)
                total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.MAX_GRAD_NORM)

                # 记录梯度范数
                epoch_grad_norms.append(total_norm.item())

                # 参数更新
                optimizer.step()
                scheduler.step()

                total_loss += loss.item()

                # 定期报告(包含GPU监控)
                if batch_idx % 20 == 0:
                    current_lr = optimizer.param_groups[0]['lr']
                    current_step = epoch * len(train_loader) + batch_idx
                    progress = (current_step / total_steps) * 100

                    message = (f"Epoch {epoch+1}/{config.EPOCHS} [{progress:.1f}%] "
                              f"Loss: {loss.item():.4f} | "
                              f"LR: {current_lr:.2e} | "
                              f"Grad: {total_norm:.3f}")

                    # 添加GPU监控信息
                    if gpu_monitor.available and torch.cuda.is_available():
                        gpu_info = gpu_monitor.get_gpu_info(0)
                        message += f" | {gpu_monitor.format_gpu_info(gpu_info)}"

                    log_message(message, log_file)

            except RuntimeError as e:
                if "out of memory" in str(e):
                    log_message(f"⚠️ OOM错误: Epoch {epoch+1}, Batch {batch_idx+1}", log_file)
                    log_message("建议:降低batch_size或使用混合精度训练", log_file)
                    torch.cuda.empty_cache()
                    continue
                else:
                    raise e

        # Epoch训练总结
        avg_train_loss = total_loss / len(train_loader)
        avg_grad_norm = sum(epoch_grad_norms) / len(epoch_grad_norms)
        train_losses.append(avg_train_loss)

        log_message(
            f"\n📊 Epoch {epoch+1}/{config.EPOCHS} 完成 | "
            f"平均损失: {avg_train_loss:.4f} | "
            f"平均梯度范数: {avg_grad_norm:.3f} | "
            f"最终学习率: {optimizer.param_groups[0]['lr']:.2e}",
            log_file
        )

        # GPU监控:记录epoch结束状态
        if gpu_monitor.available and torch.cuda.is_available():
            gpu_end = gpu_monitor.get_gpu_info(0)
            log_message(f"🖥️ Epoch {epoch+1}/{config.EPOCHS} 结束 GPU状态: {gpu_monitor.format_gpu_info(gpu_end)}\n", log_file)

        # GPU内存清理
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # 验证集评估
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                loss = model(input_ids, attention_mask, labels)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)

        # 学习率
        current_lr = optimizer.param_groups[0]['lr']

        # 日志
        best_loss_str = f"{early_stopping.best_loss:.4f}" if early_stopping.best_loss is not None else "N/A"
        log_message(
            f"Epoch {epoch+1}/{config.EPOCHS} | "
            f"Train Loss: {avg_train_loss:.4f} | "
            f"Val Loss: {avg_val_loss:.4f} | "
            f"LR: {current_lr:.2e} | "
            f"Best Val Loss: {best_loss_str} | "
            f"Counter: {early_stopping.counter}/{early_stopping.patience}",
            log_file
        )

        # 保存检查点
        checkpoint_path = Path(TrainingConfig.CHECKPOINT_DIR) / f"checkpoint_epoch_{epoch+1}.pt"
        early_stopping.save_checkpoint(
            checkpoint_path,
            model, optimizer, epoch,
            train_losses, val_losses
        )

        # 检查早停
        should_stop = early_stopping(
            avg_val_loss, model, epoch
        )

        if should_stop:
            log_message(f"\n🛑 早停触发,停止训练", log_file)
            break

    # 恢复最佳模型
    early_stopping.restore_model(model, optimizer)
    log_message(f"\n🔄 已恢复到最佳模型(Epoch {early_stopping.best_epoch})", log_file)

    return train_losses, val_losses, early_stopping

# =====================================================
# 7. 主训练函数
# =====================================================
def main():
    """主函数"""

    # 日志文件
    LOG_FILE = "./STRATIFIED_TRAINING_LOG.log"

    # 初始化日志
    log_message("=" * 80, LOG_FILE)
    log_message("🚀 BERT + CRF 中文 NER 早停功能版训练开始", LOG_FILE)
    log_message("=" * 80, LOG_FILE)

    # =====================================================
    # 模型初始化
    # =====================================================
    log_message("📋 训练配置:", LOG_FILE)
    log_message(f"  - 模型: BERT-base-chinese + CRF", LOG_FILE)
    log_message(f"  - 训练集大小: {len(train_dataset)} 样本", LOG_FILE)
    log_message(f"  - 验证集大小: {len(val_dataset)} 样本", LOG_FILE)
    log_message(f"  - 测试集大小: {len(test_dataset)} 样本", LOG_FILE)
    log_message(f"  - Batch Size: {TrainingConfig.BATCH_SIZE}", LOG_FILE)
    log_message(f"  - Epochs: {TrainingConfig.EPOCHS}", LOG_FILE)
    log_message(f"  - 学习率: {TrainingConfig.LEARNING_RATE:.0e}", LOG_FILE)
    log_message(f"  - 梯度裁剪: {TrainingConfig.MAX_GRAD_NORM}", LOG_FILE)
    log_message(f"  - 预热步数: {TrainingConfig.WARMUP_RATIO*100:.0f}%", LOG_FILE)
    log_message(f"  - 标签类别数: {len(label2id)}", LOG_FILE)
    log_message(f"  - 类别权重: {TrainingConfig.USE_CLASS_WEIGHTS}", LOG_FILE)
    log_message(f"  - 早停patience: {TrainingConfig.PATIENCE}", LOG_FILE)
    log_message(f"  - 最小改善: {TrainingConfig.MIN_DELTA}", LOG_FILE)

    if gpu_monitor.available:
        log_message("🔧 GPU监控: ✅ 启用", LOG_FILE)
        gpu_info = gpu_monitor.get_gpu_info(0)
        log_message(f"   GPU信息: {gpu_name} ({gpu_info['memory_total_mb']}MB显存)", LOG_FILE)
    else:
        log_message("🔧 GPU监控: ❌ 禁用 (pynvml未安装)", LOG_FILE)

    # =====================================================
    # 创建模型
    # =====================================================
    log_message("\n🔄 初始化模型...", LOG_FILE)
    model = BertCRFNER(
        LOCAL_MODEL_PATH,
        num_labels=len(label2id),
        class_weights=class_weights
    )
    model = model.to(device)
    log_message(f"✅ 模型初始化完成", LOG_FILE)

    # =====================================================
    # 优化器和调度器
    # =====================================================
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=TrainingConfig.LEARNING_RATE,
        weight_decay=TrainingConfig.WEIGHT_DECAY,
        betas=(0.9, 0.999),
        eps=1e-8
    )

    total_steps = len(train_loader) * TrainingConfig.EPOCHS
    warmup_steps = int(total_steps * TrainingConfig.WARMUP_RATIO)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )

    # =====================================================
    # 开始训练
    # =====================================================
    log_message(f"\n🔄 准备开始训练...", LOG_FILE)
    log_message(f"📋 预热步数: {warmup_steps}", LOG_FILE)

    train_losses, val_losses, early_stopping = train_with_early_stopping(
        model, train_loader, val_loader, test_loader,
        optimizer, scheduler, device,
        TrainingConfig, LOG_FILE
    )

    # =====================================================
    # 最终测试集评估
    # =====================================================
    log_message("\n🎉 训练完成!开始最终评估...", LOG_FILE)

    # 加载最佳模型(如果未恢复)
    if not early_stopping.early_stop:
        # 训练正常结束,使用最后一个epoch的模型
        pass
    else:
        # 早停触发,已经恢复了最佳模型
        pass

    evaluate_model(model, test_loader, device, LOG_FILE)

    # =====================================================
    # 保存最佳模型
    # =====================================================
    if TrainingConfig.SAVE_BEST_MODEL:
        best_model_path = Path(TrainingConfig.MODEL_SAVE_DIR) / "best_bert_crf_ner.pt"
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'label2id': label2id,
            'id2label': id2label,
            'config': {
                'model_name': LOCAL_MODEL_PATH,
                'num_labels': len(label2id),
                'max_len': 128
            },
            'training_info': {
                'best_epoch': early_stopping.best_epoch,
                'best_val_loss': early_stopping.best_loss,
                'train_losses': train_losses,
                'val_losses': val_losses,
                'class_weights': class_weights
            },
            'early_stopping': {
                'patience': TrainingConfig.PATIENCE,
                'min_delta': TrainingConfig.MIN_DELTA,
                'counter': early_stopping.counter
            }
        }, best_model_path)
        log_message(f"💾 最佳模型已保存到: {best_model_path}", LOG_FILE)

    # =====================================================
    # 训练总结
    # =====================================================
    log_message("\n" + "=" * 80, LOG_FILE)
    log_message("🎉 训练和评估完成总结", LOG_FILE)
    log_message("=" * 80, LOG_FILE)

    # 训练损失统计
    if len(train_losses) > 1:
        loss_change = train_losses[-1] - train_losses[0]
        log_message(f"📈 训练损失变化: {train_losses[0]:.4f} → {train_losses[-1]:.4f} ({loss_change:+.4f})", LOG_FILE)

    # 验证损失统计
    if len(val_losses) > 0:
        log_message(f"📊 验证损失变化: {val_losses[0]:.4f} → {val_losses[-1]:.4f}", LOG_FILE)
        log_message(f"🏆 最佳验证损失: {early_stopping.best_loss:.4f} (Epoch {early_stopping.best_epoch+1})", LOG_FILE)

    # 早停信息
    log_message(f"🛑 早停状态: {'✅ 已触发' if early_stopping.early_stop else '❌ 未触发'}", LOG_FILE)
    log_message(f"   最佳epoch: {early_stopping.best_epoch+1}", LOG_FILE)
    log_message(f"   早停计数: {early_stopping.counter}/{TrainingConfig.PATIENCE}", LOG_FILE)

    # 学习率信息
    log_message("🔧 最终学习率配置:", LOG_FILE)
    log_message(f"   初始学习率: {TrainingConfig.LEARNING_RATE:.0e}", LOG_FILE)
    log_message(f"   最终学习率: {optimizer.param_groups[0]['lr']:.2e}", LOG_FILE)

    # 类别权重
    if TrainingConfig.USE_CLASS_WEIGHTS and class_weights is not None:
        log_message("⚖️ 类别权重:", LOG_FILE)
        weights = class_weights.cpu().numpy()
        for label, weight in zip(label_list, weights):
            log_message(f"   {label:12}: {weight:.4f}", LOG_FILE)

    log_message("✅ 训练流程完成", LOG_FILE)
    log_message(f"📝 完整日志: {LOG_FILE}", LOG_FILE)
    log_message("=" * 80, LOG_FILE)
    print("\n🎉 早停版训练完成!")

if __name__ == "__main__":
    main()

路径自动加载代码 ,文件名 config.py

python 复制代码
"""
跨平台路径配置文件
支持 Windows、macOS、Linux 等不同操作系统环境
"""
import os
from pathlib import Path
from typing import Optional


class NERConfig:
    """NER项目跨平台路径配置类"""

    def __init__(self, project_root: Optional[Path] = None):
        """
        初始化配置

        Args:
            project_root: 项目根目录,如果为None则自动检测
        """
        # 自动检测项目根目录
        if project_root is None:
            self.project_root = Path(__file__).parent.absolute()
        else:
            self.project_root = project_root.absolute()

        # 设置基础目录
        self._setup_directories()

        # 设置数据文件路径
        self._setup_data_paths()

        # 设置输出路径
        self._setup_output_paths()

        # 向后兼容:支持旧版硬编码路径
        self._setup_legacy_paths()

    def _setup_directories(self):
        """设置基础目录结构"""
        self.data_dir = self.project_root / "data"
        self.raw_data_dir = self.data_dir / "raw"
        self.processed_data_dir = self.data_dir / "processed"
        self.models_dir = self.data_dir / "models"
        self.logs_dir = self.project_root / "logs"
        self.outputs_dir = self.project_root / "outputs"

    def _setup_data_paths(self):
        """设置数据文件路径"""
        # 原始训练数据
        self.msra_train_raw = self.raw_data_dir / "msra_train.txt"
        self.msra_train_annotated = self.raw_data_dir / "msra_train_annotated.txt"

        # 处理后的数据
        self.annotated_train = self.processed_data_dir / "annotated_train.txt"
        self.msra_train_bio = self.processed_data_dir / "msra_train_bio.txt"

        # 词典文件
        self.jieba_dict = self.project_root / "jieba_dict.txt"

        # 模型文件
        self.bert_model_dir = self.models_dir / "bert_ner"
        self.checkpoint_file = self.models_dir / "checkpoint.json"

    def _setup_output_paths(self):
        """设置输出路径"""
        self.batch_output = self.outputs_dir / "batch_output.txt"
        self.final_output = self.outputs_dir / "final_output.txt"
        self.failed_lines_log = self.logs_dir / "failed_lines.log"
        self.error_log = self.logs_dir / "ERROR.log"

    def _setup_legacy_paths(self):
        """设置向后兼容的旧版路径支持"""
        # 环境变量支持,允许指定外部数据目录
        legacy_data_dir = os.environ.get('NER_LEGACY_DATA_DIR')
        if legacy_data_dir:
            legacy_path = Path(legacy_data_dir)
            self.msra_train_raw_legacy = legacy_path / "msra_train.txt"
            self.msra_train_annotated_legacy = legacy_path / "msra_train_annotated.txt"
            self.annotated_train_legacy = legacy_path / "annotated_train.txt"
            self.batch_output_legacy = legacy_path / "batch_output.txt"
            self.failed_lines_log_legacy = legacy_path / "failed_lines.log"

    def ensure_directories(self):
        """确保所有必要的目录存在"""
        directories = [
            self.data_dir,
            self.raw_data_dir,
            self.processed_data_dir,
            self.models_dir,
            self.logs_dir,
            self.outputs_dir
        ]

        for directory in directories:
            directory.mkdir(parents=True, exist_ok=True)
            print(f"✓ 确保目录存在: {directory}")

    def get_path(self, path_name: str, use_legacy: bool = False) -> Path:
        """
        获取指定路径

        Args:
            path_name: 路径名称
            use_legacy: 是否使用旧版路径

        Returns:
            Path对象
        """
        if use_legacy and hasattr(self, f"{path_name}_legacy"):
            return getattr(self, f"{path_name}_legacy")
        return getattr(self, path_name)

    def legacy_fallback(self, primary_path: Path, legacy_name: str) -> Path:
        """
        智能路径回退:优先使用新路径,如果不存在且有旧版路径则使用旧版

        Args:
            primary_path: 主要路径
            legacy_name: 旧版路径属性名

        Returns:
            实际使用的路径
        """
        if primary_path.exists():
            return primary_path

        legacy_path = getattr(self, f"{legacy_name}_legacy", None)
        if legacy_path and legacy_path.exists():
            print(f"⚠️  使用旧版路径: {legacy_path}")
            return legacy_path

        return primary_path

    def get_input_data_file(self, filename: str) -> Path:
        """
        智能获取输入数据文件
        优先级:processed目录 -> raw目录 -> 旧版路径

        Args:
            filename: 文件名

        Returns:
            文件路径
        """
        # 优先在processed目录查找
        processed_path = self.processed_data_dir / filename
        if processed_path.exists():
            return processed_path

        # 然后在raw目录查找
        raw_path = self.raw_data_dir / filename
        if raw_path.exists():
            return raw_path

        # 最后尝试旧版路径
        legacy_path = getattr(self, f"{filename.split('.')[0]}_legacy", None)
        if legacy_path and legacy_path.exists():
            return legacy_path

        # 如果都不存在,返回processed目录的路径(用于创建新文件)
        return processed_path

    def to_dict(self) -> dict:
        """将配置转换为字典格式"""
        return {
            'project_root': str(self.project_root),
            'data_dir': str(self.data_dir),
            'raw_data_dir': str(self.raw_data_dir),
            'processed_data_dir': str(self.processed_data_dir),
            'models_dir': str(self.models_dir),
            'logs_dir': str(self.logs_dir),
            'outputs_dir': str(self.outputs_dir),
        }

    def print_config(self):
        """打印当前配置信息"""
        print("=" * 50)
        print("NER项目跨平台配置")
        print("=" * 50)
        print(f"项目根目录: {self.project_root}")
        print(f"操作系统: {os.name}")
        print(f"Python版本: {os.sys.version}")
        print("")
        print("目录结构:")
        print(f"  数据目录: {self.data_dir}")
        print(f"    原始数据: {self.raw_data_dir}")
        print(f"    处理数据: {self.processed_data_dir}")
        print(f"    模型文件: {self.models_dir}")
        print(f"  日志目录: {self.logs_dir}")
        print(f"  输出目录: {self.outputs_dir}")
        print("")

        # 检查关键文件
        key_files = [
            ("原始训练数据", self.msra_train_raw),
            ("标注数据", self.annotated_train),
            ("词典文件", self.jieba_dict),
        ]

        print("关键文件状态:")
        for name, path in key_files:
            status = "✓" if path.exists() else "✗"
            print(f"  {status} {name}: {path}")
        print("=" * 50)


# 创建全局配置实例
config = NERConfig()


def get_config() -> NERConfig:
    """获取全局配置实例"""
    return config


def init_project_directories():
    """初始化项目目录结构"""
    config.ensure_directories()
    print("🚀 项目目录结构初始化完成")


if __name__ == "__main__":
    # 测试配置
    config.print_config()
    init_project_directories()
相关推荐
诸葛务农2 小时前
神经网络信息编码技术:与人脑信息处理的差距及超越的替在优势和可能(上)
人工智能·深度学习·神经网络
oscar9992 小时前
神经网络前向传播:AI的“消化系统”全解析
人工智能·深度学习·神经网络
深蓝海拓2 小时前
PySide6从0开始学习的笔记(十六) 定时器QTimer
笔记·python·qt·学习·pyqt
元智启2 小时前
企业AI智能体:架构升级与生态跃迁,2025进入“智能体驱动”新阶段
人工智能·架构
合方圆~小文2 小时前
双目摄像头在不同距离精度差异
数据库·人工智能·模块测试
lxmyzzs2 小时前
【硬核部署】在 RK3588上部署毫秒级音频分类算法
人工智能·分类·音视频
阿杰学AI2 小时前
AI核心知识66——大语言模型之Machine Learning (简洁且通俗易懂版)
人工智能·ai·语言模型·自然语言处理·aigc·ml·机械学习
Macbethad2 小时前
智能硬件产品系统技术报告
大数据·人工智能
这张生成的图像能检测吗2 小时前
(论文速读)基于M-LLM的高效视频理解视频帧选择
人工智能·贪心算法·视频生成·多模态大语言模型