python
复制代码
# BERT + CRF 中文 NER 带早停功能版本
# 基于BertCRF_NER_WITH_EVAL.py,添加早停、类别权重等改进功能
# ----------------------------------------------------
# 新增特性:
# - 早停机制(基于验证loss)
# - 类别权重损失函数(处理数据不平衡)
# - 学习率监控和自动调整
# - 最佳模型保存
# - 详细的训练进度跟踪
# - 支持继续训练(从检查点恢复)
# - 【分层采样】确保每个标签在所有数据集中都有分布
#
# ===================================================
# 依赖安装
# ===================================================
# 基础依赖(必须):
# pip install torch transformers pytorch-crf
#
# 分层采样依赖(推荐):
# pip install iterative-stratification
#
# GPU监控依赖(可选):
# pip install nvidia-ml-py3
# ===================================================
import os
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from torchcrf import CRF
from config import get_config
import datetime
import random
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
from pathlib import Path
# ==================== 超参数配置 ====================
class TrainingConfig:
"""训练配置类"""
# ============ 基础训练参数 ============
# EPOCHS: 训练轮数,表示整个数据集被训练的次数
# 使用建议:
# - 数据量小(5K以下): 10-20轮
# - 数据量中等(5K-50K): 5-10轮
# - 数据量大(50K以上): 3-5轮
# - 配合早停机制可适当设大(如10-15),让模型自动停止
EPOCHS = 5
# BATCH_SIZE: 每个批次训练的样本数量
# 使用建议: 根据GPU显存调整
# - 显存 8GB: 8-16
# - 显存 16GB: 16-32
# - 显存 24GB+: 32-64
# 注意: batch_size越大显存占用越高,但训练更稳定;越小则训练越快但可能不稳定
BATCH_SIZE = 32
# LEARNING_RATE: 学习率,控制参数更新步长
# 使用建议:
# - BERT类模型推荐: 1e-5 到 5e-5
# - 微调建议从小开始: 2e-5 是较安全的值
# - 如果loss震荡/NaN: 降低至1e-5
# - 如果收敛太慢: 提高至3e-5或5e-5
LEARNING_RATE = 2e-5
# WEIGHT_DECAY: 权重衰减(L2正则化系数),防止过拟合
# 使用建议:
# - 标准值: 0.01 (BERT论文推荐)
# - 数据少/易过拟合: 0.01-0.1
# - 数据多/欠拟合: 0.001-0.01
# - 设为0禁用权重衰减
WEIGHT_DECAY = 0.01
# WARMUP_RATIO: 预热比例,前多少比例的步数用于学习率预热
# 使用建议:
# - 预热期间学习率从0线性增加到LEARNING_RATE
# - 标准值: 0.1-0.2 (即10%-20%步数预热)
# - 数据少: 可设0.1
# - 数据多: 可设0.15-0.2
# - 总预热步数 = total_steps * WARMUP_RATIO
WARMUP_RATIO = 0.15
# MAX_GRAD_NORM: 梯度裁剪阈值,防止梯度爆炸
# 使用建议:
# - BERT标准值: 1.0-5.0
# - 训练不稳定(loss震荡): 降低至1.0
# - 梯度正常: 可提高至5.0
# - 设为None禁用梯度裁剪
MAX_GRAD_NORM = 3.0
# ============ 早停参数 ============
# PATIENCE: 早停耐心值,验证loss连续n个epoch未改善时停止训练
# 使用建议:
# - 数据少: 3-5 (避免过拟合)
# - 数据多: 5-10 (给模型更多机会改善)
# - 训练快: 可设大一点
# - 训练慢: 设小一点节省时间
PATIENCE = 3
# MIN_DELTA: 最小改善阈值,验证loss下降超过此值才算有改善
# 使用建议:
# - 标准值: 0.001-0.01
# - 要求高精度: 0.0001-0.001
# - 允许小幅波动: 0.01-0.05
# - 设为0则任何下降都算改善
MIN_DELTA = 0.01
# ============ 模型保存参数 ============
# SAVE_BEST_MODEL: 是否保存最佳模型
# 使用建议:
# - True: 自动保存验证loss最低的模型到MODEL_SAVE_DIR
# - False: 不保存,只用于实验
SAVE_BEST_MODEL = True
# MODEL_SAVE_DIR: 最佳模型保存路径
# 使用建议: 相对或绝对路径均可,目录会自动创建
MODEL_SAVE_DIR = "./saved_models"
# CHECKPOINT_DIR: 检查点保存路径(每个epoch保存一次)
# 使用建议: 用于训练中断后恢复训练
CHECKPOINT_DIR = "./checkpoints"
# ============ 类别权重参数 ============
# USE_CLASS_WEIGHTS: 是否使用类别权重(处理数据不平衡)
# 使用建议:
# - True: 自动计算类别权重,少数类权重大,多数类权重小
# - False: 所有类别同等对待
# 适用场景: 标签分布严重不均时启用(如O标签占90%,实体标签仅10%)
# 注意: 当前CRF实现未真正应用权重,此参数仅作预留
USE_CLASS_WEIGHTS = True
# ============ 数据集划分参数 ============
# TRAIN_RATIO: 训练集比例
# 使用建议:
# - 标准值: 0.7 (训练70%,验证15%,测试15%)
# - 数据少: 0.6-0.7 (留更多验证和测试)
# - 数据多: 0.7-0.8 (更多训练数据)
# 注意: 配合分层采样使用,确保每个标签在所有数据集中都有分布
TRAIN_RATIO = 0.7
# VAL_RATIO: 验证集比例
# 使用建议:
# - 标准值: 0.15 (15%)
# - 测试集比例自动 = 1 - TRAIN_RATIO - VAL_RATIO
# - 建议验证集和测试集大小相近,以便公平评估模型
VAL_RATIO = 0.15
# USE_STRATIFIED_SPLIT: 是否使用分层采样
# 使用建议:
# - True: 使用iterative-stratification库,确保每个标签类型在所有数据集中都有分布
# - False: 使用简单随机采样(可能某些标签只出现在一个数据集中)
# 适用场景: 数据不平衡、稀有标签较少时强烈推荐启用
# 注意: 启用需要安装库: pip install iterative-stratification
USE_STRATIFIED_SPLIT = True
# STRATIFIED_MIN_SAMPLES: 分层采样时,每个标签在每个数据集中至少分配的样本数
# 使用建议:
# - 标准值: 1-3
# - 稀有标签样本少: 设为1,确保至少有1个样本
# - 稀有标签样本多: 可设为2-3,提高分布均匀性
# - 如果某个标签总样本数 < STRATIFIED_MIN_SAMPLES * 3,将按比例分配
STRATIFIED_MIN_SAMPLES = 1
# RANDOM_SEED: 随机种子,保证实验可复现
# 使用建议: 保持固定值(如42)以确保每次运行结果一致
RANDOM_SEED = 42
# ==================== GPU监控模块 ====================
try:
import pynvml
pynvml.nvmlInit()
GPU_MONITORING_AVAILABLE = True
print("✅ GPU监控功能已启用")
except ImportError:
GPU_MONITORING_AVAILABLE = False
print("⚠️ pynvml未安装,GPU监控功能将禁用")
print(" 安装命令: pip install pynvml")
class GPUMonitor:
"""GPU监控类"""
def __init__(self):
self.available = GPU_MONITORING_AVAILABLE
self.device_count = 0
self.gpu_handles = []
if self.available:
try:
self.device_count = pynvml.nvmlDeviceGetCount()
for i in range(self.device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
self.gpu_handles.append(handle)
except Exception as e:
print(f"⚠️ GPU监控初始化失败: {e}")
self.available = False
def get_gpu_info(self, device_id=0):
"""获取GPU信息"""
if not self.available or device_id >= len(self.gpu_handles):
return {
'memory_used_mb': 0,
'memory_total_mb': 0,
'memory_util': 0.0,
'gpu_util': 0.0,
'temperature': 0,
'power_watts': 0.0
}
try:
handle = self.gpu_handles[device_id]
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
memory_used_mb = mem_info.used // 1024**2
memory_total_mb = mem_info.total // 1024**2
memory_util = (mem_info.used / mem_info.total) * 100
try:
util_rates = pynvml.nvmlDeviceGetUtilizationRates(handle)
gpu_util = util_rates.gpu
except:
gpu_util = 0
try:
temperature = pynvml.nvmlDeviceGetTemperature(
handle, pynvml.NVML_TEMPERATURE_GPU)
except:
temperature = 0
try:
power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0
except:
power = 0.0
return {
'memory_used_mb': memory_used_mb,
'memory_total_mb': memory_total_mb,
'memory_util': memory_util,
'gpu_util': gpu_util,
'temperature': temperature,
'power_watts': power
}
except Exception as e:
return {
'memory_used_mb': 0,
'memory_total_mb': 0,
'memory_util': 0.0,
'gpu_util': 0.0,
'temperature': 0,
'power_watts': 0.0,
'error': str(e)
}
def format_gpu_info(self, gpu_info):
"""格式化GPU信息用于日志"""
if gpu_info.get('error'):
return f"GPU监控错误: {gpu_info['error']}"
return (f"显存: {gpu_info['memory_used_mb']}/{gpu_info['memory_total_mb']}MB "
f"({gpu_info['memory_util']:.1f}%) | "
f"利用率: {gpu_info['gpu_util']}% | "
f"温度: {gpu_info['temperature']}°C | "
f"功耗: {gpu_info['power_watts']:.1f}W")
# 初始化GPU监控器
gpu_monitor = GPUMonitor()
# =====================================================
# Device配置
# =====================================================
if torch.cuda.is_available():
device = torch.device("cuda")
gpu_name = torch.cuda.get_device_name(0)
print(f"🎮 使用GPU训练: {gpu_name}")
if gpu_monitor.available:
gpu_info = gpu_monitor.get_gpu_info(0)
print(f" 显存: {gpu_info['memory_total_mb']}MB")
else:
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f"🎯 使用{device}训练")
if torch.cuda.is_available():
torch.cuda.empty_cache()
# =====================================================
# 1. 数据加载和划分
# =====================================================
config = get_config()
INPUT_JSONL = config.legacy_fallback(config.msra_train_annotated, "msra_train_annotated")
ANNOTATED_TRAIN_FILE = config.legacy_fallback(config.annotated_train, "annotated_train")
def convert_annotated_jsonl(path):
dataset_samples = []
try:
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
if "text" in data and "labels" in data:
dataset_samples.append(data)
except json.JSONDecodeError as e:
print(f"⚠ 第{line_num}行JSON解析失败: {e}")
return dataset_samples
except FileNotFoundError:
print(f"❌ 文件不存在: {path}")
return []
except Exception as e:
print(f"❌ 读取文件失败: {e}")
return []
# 加载数据
dataset_samples = []
if os.path.exists(ANNOTATED_TRAIN_FILE):
print(f"📝 使用直接标注数据: {ANNOTATED_TRAIN_FILE}")
dataset_samples = convert_annotated_jsonl(ANNOTATED_TRAIN_FILE)
elif os.path.exists(INPUT_JSONL):
print(f"📝 使用原始MSRA数据: {INPUT_JSONL}")
from ner_converter import convert_jsonl
dataset_samples = convert_jsonl(INPUT_JSONL)
else:
print("⚠ 未检测到任何数据文件,使用示例数据。")
dataset_samples = [
{"text": "我爱北京天安门", "labels": ["O", "O", "B-LOC", "I-LOC", "I-LOC", "I-LOC", "I-LOC"]},
{"text": "张三来自中国上海", "labels": ["B-PER", "I-PER", "O", "O", "B-LOC", "I-LOC", "I-LOC"]},
{"text": "李明在清华大学工作", "labels": ["B-PER", "I-PER", "O", "B-ORG", "I-ORG", "I-ORG", "I-ORG", "O"]},
]
print(f"📊 原始数据集大小: {len(dataset_samples)} 条样本")
def split_dataset(samples, train_ratio=0.7, val_ratio=0.15, random_seed=42, use_stratified=True):
"""
划分数据集为训练集、验证集、测试集
支持两种划分方式:
1. 分层采样(use_stratified=True): 使用iterative-stratification库,
确保每个标签类型在所有数据集中都有分布,适合数据不平衡场景
2. 随机采样(use_stratified=False): 简单随机打乱后划分
Args:
samples: 样本列表,每个样本包含 text 和 labels
train_ratio: 训练集比例 (默认0.7)
val_ratio: 验证集比例 (默认0.15),测试集比例 = 1 - train_ratio - val_ratio
random_seed: 随机种子 (默认42)
use_stratified: 是否使用分层采样 (默认True,推荐)
Returns:
train_samples, val_samples, test_samples: 划分后的三个数据集
"""
from collections import Counter
if use_stratified:
# ========== 分层采样(推荐)==========
try:
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
import numpy as np
print(f"🎯 使用分层采样划分数据集...")
# 1. 构建标签集合
all_labels = sorted(list(set(l for s in samples for l in s["labels"] if l != "O")))
print(f"📊 发现 {len(all_labels)} 种实体类型: {all_labels}")
# 2. 构建多标签矩阵 (每行是一个样本的标签向量)
label2id = {l: i for i, l in enumerate(all_labels)}
y = np.zeros((len(samples), len(all_labels)), dtype=int)
for idx, sample in enumerate(samples):
for label in sample["labels"]:
if label in label2id:
y[idx, label2id[label]] = 1
# 3. 统计初始标签分布
initial_dist = Counter([l for s in samples for l in s["labels"] if l != "O"])
print(f"📋 原始数据集标签分布: {dict(initial_dist)}")
# 4. 使用迭代分层采样
# 第一次划分:训练+验证 vs 测试
indices = np.arange(len(samples))
test_ratio = 1 - train_ratio - val_ratio
# 计算第一个K折的折数(根据测试集比例)
n_splits_1 = max(3, int(1 / test_ratio)) # 至少3折
actual_test_ratio = 1.0 / n_splits_1
print(f"\n📐 分层采样参数:")
print(f" 配置比例 - 训练集: {train_ratio*100:.1f}%, 验证集: {val_ratio*100:.1f}%, 测试集: {test_ratio*100:.1f}%")
print(f" 第一次划分: {n_splits_1}折交叉验证 → 实际测试集比例: {actual_test_ratio*100:.1f}%")
# 创建分层K折,取第一折作为测试集
mskf = MultilabelStratifiedKFold(
n_splits=n_splits_1,
shuffle=True,
random_state=random_seed
)
# 获取第一次划分:训练验证集 + 测试集
folds = list(mskf.split(indices, y))
train_val_idx, test_idx = folds[0]
# 第二次划分:从训练验证集中划分训练集和验证集
y_train_val = y[train_val_idx]
train_val_indices = np.arange(len(train_val_idx))
# 计算第二次K折的折数
train_val_ratio = 1 - actual_test_ratio
val_ratio_from_train_val = val_ratio / train_val_ratio
n_splits_2 = max(2, int(1 / val_ratio_from_train_val)) # 至少2折
actual_val_ratio_total = actual_test_ratio + (train_val_ratio / n_splits_2)
actual_train_ratio = 1 - actual_val_ratio_total
actual_val_ratio_only = actual_val_ratio_total - actual_test_ratio
print(f" 第二次划分: {n_splits_2}折交叉验证 → 实际验证集比例: {actual_val_ratio_only*100:.1f}%")
print(f" ✅ 最终划分比例 - 训练集: {actual_train_ratio*100:.1f}%, 验证集: {actual_val_ratio_only*100:.1f}%, 测试集: {actual_test_ratio*100:.1f}%")
# 使用不同的随机种子进行第二次划分
mskf2 = MultilabelStratifiedKFold(
n_splits=n_splits_2,
shuffle=True,
random_state=random_seed + 1
)
folds2 = list(mskf2.split(train_val_indices, y_train_val))
train_idx, val_idx = folds2[0]
# 转换回原始索引
train_indices = train_val_idx[train_idx]
val_indices = train_val_idx[val_idx]
# 5. 构建最终数据集
train_samples = [samples[i] for i in train_indices]
val_samples = [samples[i] for i in val_indices]
test_samples = [samples[i] for i in test_idx]
# 6. 验证标签分布
def get_label_dist(samples_subset):
dist = Counter([l for s in samples_subset for l in s["labels"] if l != "O"])
return dict(dist)
train_dist = get_label_dist(train_samples)
val_dist = get_label_dist(val_samples)
test_dist = get_label_dist(test_samples)
print(f"\n✅ 分层采样结果:")
print(f" 训练集标签分布: {train_dist}")
print(f" 验证集标签分布: {val_dist}")
print(f" 测试集标签分布: {test_dist}")
# 检查是否有标签缺失
missing_in_train = [l for l in all_labels if l not in train_dist]
missing_in_val = [l for l in all_labels if l not in val_dist]
missing_in_test = [l for l in all_labels if l not in test_dist]
if missing_in_train:
print(f"⚠️ 警告: 以下标签在训练集中缺失: {missing_in_train}")
if missing_in_val:
print(f"⚠️ 警告: 以下标签在验证集中缺失: {missing_in_val}")
if missing_in_test:
print(f"⚠️ 警告: 以下标签在测试集中缺失: {missing_in_test}")
if not (missing_in_train or missing_in_val or missing_in_test):
print(f"✅ 所有标签在三个数据集中都有分布!")
return train_samples, val_samples, test_samples
except ImportError:
print("⚠️ 警告: 未安装 iterative-stratification 库")
print("💡 请运行: pip install iterative-stratification")
print("🔄 将回退到简单随机采样...")
use_stratified = False
if not use_stratified:
# ========== 简单随机采样(备用)==========
print(f"🎲 使用简单随机采样划分数据集...")
random.seed(random_seed)
random.shuffle(samples)
# 计算划分点
n = len(samples)
train_end = int(n * train_ratio)
val_end = int(n * (train_ratio + val_ratio))
train_samples = samples[:train_end]
val_samples = samples[train_end:val_end]
test_samples = samples[val_end:]
# 统计标签分布
def get_label_dist(samples_subset):
dist = Counter([l for s in samples_subset for l in s["labels"] if l != "O"])
return dict(dist)
train_dist = get_label_dist(train_samples)
val_dist = get_label_dist(val_samples)
test_dist = get_label_dist(test_samples)
print(f"\n✅ 随机采样结果:")
print(f" 训练集标签分布: {train_dist}")
print(f" 验证集标签分布: {val_dist}")
print(f" 测试集标签分布: {test_dist}")
return train_samples, val_samples, test_samples
# 划分数据集:训练/验证/测试 = 70%/15%/15%
train_samples, val_samples, test_samples = split_dataset(
dataset_samples,
train_ratio=TrainingConfig.TRAIN_RATIO,
val_ratio=TrainingConfig.VAL_RATIO,
random_seed=TrainingConfig.RANDOM_SEED,
use_stratified=TrainingConfig.USE_STRATIFIED_SPLIT
)
print(f"📊 训练集大小: {len(train_samples)} 条样本")
print(f"📊 验证集大小: {len(val_samples)} 条样本")
print(f"📊 测试集大小: {len(test_samples)} 条样本")
# 构建标签映射(从所有数据集中收集标签,避免测试集出现未知标签)
all_labels = set()
for samples in [train_samples, val_samples, test_samples]:
all_labels.update(l for s in samples for l in s["labels"])
label_list = ["O"] + sorted(l for l in all_labels if l != "O")
label2id = {lab: i for i, lab in enumerate(label_list)}
id2label = {v: k for k, v in label2id.items()}
print(f"📊 标签类别数: {len(label_list)}")
print(f"📋 标签列表: {label_list}")
# 计算类别权重
def calculate_class_weights(train_samples, label2id, device):
"""计算类别权重,用于处理数据不平衡"""
from collections import Counter
label_counts = Counter()
for sample in train_samples:
for label in sample["labels"]:
label_counts[label] += 1
total_labels = sum(label_counts.values())
num_classes = len(label2id)
# 计算权重(倒数频率,归一化)
class_counts = np.array([label_counts.get(l, 0) for l in label_list])
# 避免除零错误
class_counts = np.maximum(class_counts, 1)
# 计算权重:权重与类别频率成反比
weights = total_labels / (num_classes * class_counts)
# 归一化权重到1附近
weights = weights / np.mean(weights)
print(f"📊 类别统计和权重:")
for label, count, weight in zip(label_list, class_counts, weights):
print(f" {label:12}: {count:6d} 样本 | 权重: {weight:.4f}")
return torch.FloatTensor(weights).to(device)
# 计算类别权重
class_weights = calculate_class_weights(train_samples, label2id, device) if TrainingConfig.USE_CLASS_WEIGHTS else None
# =====================================================
# 2. 数据集类
# =====================================================
class NERDataset(Dataset):
def __init__(self, samples, max_len=128):
self.samples = samples
self.max_len = max_len
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
text = self.samples[idx]["text"]
labels = self.samples[idx]["labels"]
encoding = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=self.max_len,
return_tensors="pt",
add_special_tokens=True
)
input_ids = encoding["input_ids"].squeeze(0)
attention_mask = encoding["attention_mask"].squeeze(0)
# 标签对齐
aligned_labels = [label2id.get("O", 0)]
chars = list(text)
if len(labels) != len(chars):
min_len = min(len(labels), len(chars))
for i in range(min_len):
aligned_labels.append(label2id[labels[i]])
else:
for tag in labels:
aligned_labels.append(label2id[tag])
aligned_labels.append(label2id.get("O", 0))
while len(aligned_labels) < self.max_len:
aligned_labels.append(label2id.get("O", 0))
label_ids = torch.tensor(aligned_labels[:self.max_len], dtype=torch.long)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": label_ids
}
# 创建tokenizer
LOCAL_MODEL_PATH = "/home/zhangmanman/.cache/huggingface/hub/models--bert-base-chinese/snapshots/8f23c25b06e129b6c986331a13d8d025a92cf0ea/"
print(f"🔄 使用本地预训练模型: {LOCAL_MODEL_PATH}")
tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_PATH)
print(f"✅ Tokenizer加载成功,词汇表大小: {len(tokenizer)}")
train_dataset = NERDataset(train_samples, max_len=128)
val_dataset = NERDataset(val_samples, max_len=128)
test_dataset = NERDataset(test_samples, max_len=128)
# 创建DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=TrainingConfig.BATCH_SIZE,
shuffle=True,
num_workers=2,
pin_memory=True,
drop_last=True
)
val_loader = DataLoader(
val_dataset,
batch_size=TrainingConfig.BATCH_SIZE,
shuffle=False,
num_workers=2,
pin_memory=True,
drop_last=False
)
test_loader = DataLoader(
test_dataset,
batch_size=TrainingConfig.BATCH_SIZE,
shuffle=False,
num_workers=2,
pin_memory=True,
drop_last=False
)
# =====================================================
# 3. 模型定义(支持类别权重)
# =====================================================
class BertCRFNER(nn.Module):
def __init__(self, model_name, num_labels, class_weights=None):
super().__init__()
self.bert = AutoModel.from_pretrained(model_name)
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
self.crf = CRF(num_labels, batch_first=True)
self.class_weights = class_weights
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.bert(input_ids, attention_mask=attention_mask)
sequence_output = self.dropout(outputs.last_hidden_state)
emissions = self.classifier(sequence_output)
if labels is not None:
if self.class_weights is not None:
# 注意: pytorch-crf不支持直接应用类别权重
# 要实现真正的weighted CRF需要自定义CRF实现
# 这里暂时禁用权重计算,使用标准CRF损失
# 如需处理类别不平衡,建议使用重采样或focal loss等替代方案
pass
loss = -self.crf(emissions, labels, mask=attention_mask.bool(), reduction="mean")
return loss
else:
return self.crf.decode(emissions, mask=attention_mask.bool())
# =====================================================
# 4. 早停类
# =====================================================
class EarlyStopping:
"""早停机制类"""
def __init__(self, patience=3, min_delta=0.01, restore_best_weights=True):
"""
Args:
patience: 验证loss不下降的容忍次数
min_delta: 认为有改善的最小变化量
restore_best_weights: 是否恢复最佳权重
"""
self.patience = patience
self.min_delta = min_delta
self.restore_best_weights = restore_best_weights
self.counter = 0
self.best_loss = None
self.best_model_state = None
self.early_stop = False
self.best_epoch = 0
def __call__(self, current_val_loss, model, epoch):
"""
检查是否应该早停
Returns:
bool: 是否应该停止训练
"""
if self.best_loss is None:
self.best_loss = current_val_loss
self.best_epoch = epoch
self.best_model_state = {
'model_state_dict': model.state_dict(),
'epoch': epoch
}
return False
# 检查验证损失是否改善
if current_val_loss < self.best_loss - self.min_delta:
# 验证loss有改善
old_best_loss = self.best_loss
self.best_loss = current_val_loss
self.best_epoch = epoch
self.counter = 0
self.best_model_state = {
'model_state_dict': model.state_dict(),
'epoch': epoch
}
print(f"✅ 验证loss改善: {current_val_loss:.4f} < {old_best_loss:.4f}")
return False
else:
# 验证loss没有改善
self.counter += 1
print(f"⚠️ 验证loss未改善: {current_val_loss:.4f} >= {self.best_loss:.4f} "
f"(计数: {self.counter}/{self.patience})")
if self.counter >= self.patience:
self.early_stop = True
print(f"🛑 早停触发:验证loss连续{self.patience}个epoch未改善")
return True
return False
def restore_model(self, model, optimizer=None):
"""恢复最佳模型权重"""
if self.best_model_state is None:
return
if self.restore_best_weights and self.best_model_state is not None:
model.load_state_dict(self.best_model_state['model_state_dict'])
print(f"🔄 恢复最佳模型(Epoch {self.best_epoch})")
# 注意: 我们不恢复optimizer状态,因为这可能导致训练不稳定
# 如果需要恢复optimizer,应该在checkpoint中单独保存
def save_checkpoint(self, filepath, model, optimizer, epoch, train_losses, val_losses,
optimizer_state_dict=None):
"""保存训练检查点"""
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict() if optimizer_state_dict is not None
else optimizer.state_dict(),
'train_losses': train_losses,
'val_losses': val_losses,
'label2id': label2id,
'id2label': id2label,
'best_val_loss': self.best_loss,
'best_epoch': self.best_epoch,
'counter': self.counter,
'best_model_state': self.best_model_state # 保存最佳模型状态
}
torch.save(checkpoint, filepath)
print(f"💾 检查点已保存: {filepath}")
def load_checkpoint(self, filepath, model, optimizer=None):
"""从检查点恢复训练"""
if not os.path.exists(filepath):
raise FileNotFoundError(f"检查点文件不存在: {filepath}")
checkpoint = torch.load(filepath)
# 恢复模型状态
model.load_state_dict(checkpoint['model_state_dict'])
# 恢复优化器状态(如果有)
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# 恢复训练历史
train_losses = checkpoint.get('train_losses', [])
val_losses = checkpoint.get('val_losses', [])
print(f"📂 从检查点恢复训练:")
print(f" Epoch: {checkpoint['epoch']}")
print(f" 最佳验证loss: {checkpoint.get('best_val_loss', 'N/A')}")
print(f" 最佳epoch: {checkpoint.get('best_epoch', 'N/A')}")
# 恢复早停状态
self.best_loss = checkpoint.get('best_val_loss', None)
self.best_epoch = checkpoint.get('best_epoch', 0)
self.counter = checkpoint.get('counter', 0)
if self.best_loss is not None:
self.best_model_state = checkpoint.get('best_model_state', None)
return train_losses, val_losses
# 创建必要的目录
Path(TrainingConfig.MODEL_SAVE_DIR).mkdir(parents=True, exist_ok=True)
Path(TrainingConfig.CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)
# =====================================================
# 5. 评估函数
# =====================================================
def evaluate_model(model, data_loader, device, log_file=None):
"""评估模型性能"""
model.eval()
all_predictions = []
all_true_labels = []
print(f"\n🧪 开始评估数据集...")
def write_log(msg):
"""评估专用的日志写入函数"""
print(msg)
if log_file:
timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"[{timestamp}] {msg}\n")
with torch.no_grad():
for batch_idx, batch in enumerate(data_loader):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
predictions = model(input_ids, attention_mask)
for i in range(len(predictions)):
pred_ids = predictions[i]
true_ids = labels[i].cpu().numpy()
mask = attention_mask[i].cpu().numpy()
valid_length = sum(mask)
pred_ids = pred_ids[:valid_length]
true_ids = true_ids[:valid_length]
if len(pred_ids) >= 2:
pred_ids = pred_ids[1:-1]
if len(true_ids) >= 2:
true_ids = true_ids[1:-1]
pred_labels = [id2label.get(pid, "O") for pid in pred_ids]
true_labels = [id2label.get(tid, "O") for tid in true_ids]
all_predictions.extend(pred_labels)
all_true_labels.extend(true_labels)
# 生成分类报告
report = classification_report(all_true_labels, all_predictions, zero_division=0)
if log_file:
log_message("\n" + "="*60, log_file)
log_message("📊 模型评估结果", log_file)
log_message("="*60, log_file)
log_message(f"数据集大小: {len(data_loader.dataset)}", log_file)
log_message(f"预测标签总数: {len(all_predictions)}", log_file)
log_message("\n📈 详细分类报告:", log_file)
for line in report.split('\n'):
log_message(line, log_file)
log_message("="*60, log_file)
print("\n" + report)
return report
def log_message(message, log_file):
"""日志记录函数"""
print(message)
if log_file:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"[{timestamp}] {message}\n")
# =====================================================
# 6. 训练循环(带早停)
# =====================================================
def train_with_early_stopping(
model, train_loader, val_loader, test_loader,
optimizer, scheduler, device,
config, log_file=None
):
"""带早停的训练循环"""
# 初始化早停
early_stopping = EarlyStopping(
patience=config.PATIENCE,
min_delta=config.MIN_DELTA,
restore_best_weights=True
)
# 训练历史
train_losses = []
val_losses = []
# 检查是否有之前的检查点
checkpoint_path = Path(TrainingConfig.CHECKPOINT_DIR) / "latest_checkpoint.pt"
start_epoch = 0
if checkpoint_path.exists():
try:
train_losses, val_losses = early_stopping.load_checkpoint(
checkpoint_path, model, optimizer
)
start_epoch = len(train_losses)
print(f"📂 从检查点恢复训练,从epoch {start_epoch}继续")
except Exception as e:
print(f"⚠️ 加载检查点失败: {e}")
print("🔄 从头开始训练")
train_losses = []
val_losses = []
start_epoch = 0
else:
print("🚀 从头开始训练")
train_losses = []
val_losses = []
# 计算总步数用于进度显示
total_steps = config.EPOCHS * len(train_loader)
log_message(f"🎯 开始训练,总步数: {total_steps}", log_file)
# 训练循环
for epoch in range(start_epoch, config.EPOCHS):
model.train()
total_loss = 0.0
epoch_grad_norms = []
# GPU监控:记录epoch开始状态
if gpu_monitor.available and torch.cuda.is_available():
gpu_start = gpu_monitor.get_gpu_info(0)
log_message(f"\n🖥️ Epoch {epoch+1}/{config.EPOCHS} 开始 GPU状态: {gpu_monitor.format_gpu_info(gpu_start)}", log_file)
# 训练一个epoch
for batch_idx, batch in enumerate(train_loader):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
optimizer.zero_grad()
try:
# 前向传播
loss = model(input_ids, attention_mask, labels)
# 反向传播
loss.backward()
# 梯度裁剪(标准实现)
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.MAX_GRAD_NORM)
# 记录梯度范数
epoch_grad_norms.append(total_norm.item())
# 参数更新
optimizer.step()
scheduler.step()
total_loss += loss.item()
# 定期报告(包含GPU监控)
if batch_idx % 20 == 0:
current_lr = optimizer.param_groups[0]['lr']
current_step = epoch * len(train_loader) + batch_idx
progress = (current_step / total_steps) * 100
message = (f"Epoch {epoch+1}/{config.EPOCHS} [{progress:.1f}%] "
f"Loss: {loss.item():.4f} | "
f"LR: {current_lr:.2e} | "
f"Grad: {total_norm:.3f}")
# 添加GPU监控信息
if gpu_monitor.available and torch.cuda.is_available():
gpu_info = gpu_monitor.get_gpu_info(0)
message += f" | {gpu_monitor.format_gpu_info(gpu_info)}"
log_message(message, log_file)
except RuntimeError as e:
if "out of memory" in str(e):
log_message(f"⚠️ OOM错误: Epoch {epoch+1}, Batch {batch_idx+1}", log_file)
log_message("建议:降低batch_size或使用混合精度训练", log_file)
torch.cuda.empty_cache()
continue
else:
raise e
# Epoch训练总结
avg_train_loss = total_loss / len(train_loader)
avg_grad_norm = sum(epoch_grad_norms) / len(epoch_grad_norms)
train_losses.append(avg_train_loss)
log_message(
f"\n📊 Epoch {epoch+1}/{config.EPOCHS} 完成 | "
f"平均损失: {avg_train_loss:.4f} | "
f"平均梯度范数: {avg_grad_norm:.3f} | "
f"最终学习率: {optimizer.param_groups[0]['lr']:.2e}",
log_file
)
# GPU监控:记录epoch结束状态
if gpu_monitor.available and torch.cuda.is_available():
gpu_end = gpu_monitor.get_gpu_info(0)
log_message(f"🖥️ Epoch {epoch+1}/{config.EPOCHS} 结束 GPU状态: {gpu_monitor.format_gpu_info(gpu_end)}\n", log_file)
# GPU内存清理
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 验证集评估
model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in val_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
loss = model(input_ids, attention_mask, labels)
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
val_losses.append(avg_val_loss)
# 学习率
current_lr = optimizer.param_groups[0]['lr']
# 日志
best_loss_str = f"{early_stopping.best_loss:.4f}" if early_stopping.best_loss is not None else "N/A"
log_message(
f"Epoch {epoch+1}/{config.EPOCHS} | "
f"Train Loss: {avg_train_loss:.4f} | "
f"Val Loss: {avg_val_loss:.4f} | "
f"LR: {current_lr:.2e} | "
f"Best Val Loss: {best_loss_str} | "
f"Counter: {early_stopping.counter}/{early_stopping.patience}",
log_file
)
# 保存检查点
checkpoint_path = Path(TrainingConfig.CHECKPOINT_DIR) / f"checkpoint_epoch_{epoch+1}.pt"
early_stopping.save_checkpoint(
checkpoint_path,
model, optimizer, epoch,
train_losses, val_losses
)
# 检查早停
should_stop = early_stopping(
avg_val_loss, model, epoch
)
if should_stop:
log_message(f"\n🛑 早停触发,停止训练", log_file)
break
# 恢复最佳模型
early_stopping.restore_model(model, optimizer)
log_message(f"\n🔄 已恢复到最佳模型(Epoch {early_stopping.best_epoch})", log_file)
return train_losses, val_losses, early_stopping
# =====================================================
# 7. 主训练函数
# =====================================================
def main():
"""主函数"""
# 日志文件
LOG_FILE = "./STRATIFIED_TRAINING_LOG.log"
# 初始化日志
log_message("=" * 80, LOG_FILE)
log_message("🚀 BERT + CRF 中文 NER 早停功能版训练开始", LOG_FILE)
log_message("=" * 80, LOG_FILE)
# =====================================================
# 模型初始化
# =====================================================
log_message("📋 训练配置:", LOG_FILE)
log_message(f" - 模型: BERT-base-chinese + CRF", LOG_FILE)
log_message(f" - 训练集大小: {len(train_dataset)} 样本", LOG_FILE)
log_message(f" - 验证集大小: {len(val_dataset)} 样本", LOG_FILE)
log_message(f" - 测试集大小: {len(test_dataset)} 样本", LOG_FILE)
log_message(f" - Batch Size: {TrainingConfig.BATCH_SIZE}", LOG_FILE)
log_message(f" - Epochs: {TrainingConfig.EPOCHS}", LOG_FILE)
log_message(f" - 学习率: {TrainingConfig.LEARNING_RATE:.0e}", LOG_FILE)
log_message(f" - 梯度裁剪: {TrainingConfig.MAX_GRAD_NORM}", LOG_FILE)
log_message(f" - 预热步数: {TrainingConfig.WARMUP_RATIO*100:.0f}%", LOG_FILE)
log_message(f" - 标签类别数: {len(label2id)}", LOG_FILE)
log_message(f" - 类别权重: {TrainingConfig.USE_CLASS_WEIGHTS}", LOG_FILE)
log_message(f" - 早停patience: {TrainingConfig.PATIENCE}", LOG_FILE)
log_message(f" - 最小改善: {TrainingConfig.MIN_DELTA}", LOG_FILE)
if gpu_monitor.available:
log_message("🔧 GPU监控: ✅ 启用", LOG_FILE)
gpu_info = gpu_monitor.get_gpu_info(0)
log_message(f" GPU信息: {gpu_name} ({gpu_info['memory_total_mb']}MB显存)", LOG_FILE)
else:
log_message("🔧 GPU监控: ❌ 禁用 (pynvml未安装)", LOG_FILE)
# =====================================================
# 创建模型
# =====================================================
log_message("\n🔄 初始化模型...", LOG_FILE)
model = BertCRFNER(
LOCAL_MODEL_PATH,
num_labels=len(label2id),
class_weights=class_weights
)
model = model.to(device)
log_message(f"✅ 模型初始化完成", LOG_FILE)
# =====================================================
# 优化器和调度器
# =====================================================
optimizer = torch.optim.AdamW(
model.parameters(),
lr=TrainingConfig.LEARNING_RATE,
weight_decay=TrainingConfig.WEIGHT_DECAY,
betas=(0.9, 0.999),
eps=1e-8
)
total_steps = len(train_loader) * TrainingConfig.EPOCHS
warmup_steps = int(total_steps * TrainingConfig.WARMUP_RATIO)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
# =====================================================
# 开始训练
# =====================================================
log_message(f"\n🔄 准备开始训练...", LOG_FILE)
log_message(f"📋 预热步数: {warmup_steps}", LOG_FILE)
train_losses, val_losses, early_stopping = train_with_early_stopping(
model, train_loader, val_loader, test_loader,
optimizer, scheduler, device,
TrainingConfig, LOG_FILE
)
# =====================================================
# 最终测试集评估
# =====================================================
log_message("\n🎉 训练完成!开始最终评估...", LOG_FILE)
# 加载最佳模型(如果未恢复)
if not early_stopping.early_stop:
# 训练正常结束,使用最后一个epoch的模型
pass
else:
# 早停触发,已经恢复了最佳模型
pass
evaluate_model(model, test_loader, device, LOG_FILE)
# =====================================================
# 保存最佳模型
# =====================================================
if TrainingConfig.SAVE_BEST_MODEL:
best_model_path = Path(TrainingConfig.MODEL_SAVE_DIR) / "best_bert_crf_ner.pt"
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'label2id': label2id,
'id2label': id2label,
'config': {
'model_name': LOCAL_MODEL_PATH,
'num_labels': len(label2id),
'max_len': 128
},
'training_info': {
'best_epoch': early_stopping.best_epoch,
'best_val_loss': early_stopping.best_loss,
'train_losses': train_losses,
'val_losses': val_losses,
'class_weights': class_weights
},
'early_stopping': {
'patience': TrainingConfig.PATIENCE,
'min_delta': TrainingConfig.MIN_DELTA,
'counter': early_stopping.counter
}
}, best_model_path)
log_message(f"💾 最佳模型已保存到: {best_model_path}", LOG_FILE)
# =====================================================
# 训练总结
# =====================================================
log_message("\n" + "=" * 80, LOG_FILE)
log_message("🎉 训练和评估完成总结", LOG_FILE)
log_message("=" * 80, LOG_FILE)
# 训练损失统计
if len(train_losses) > 1:
loss_change = train_losses[-1] - train_losses[0]
log_message(f"📈 训练损失变化: {train_losses[0]:.4f} → {train_losses[-1]:.4f} ({loss_change:+.4f})", LOG_FILE)
# 验证损失统计
if len(val_losses) > 0:
log_message(f"📊 验证损失变化: {val_losses[0]:.4f} → {val_losses[-1]:.4f}", LOG_FILE)
log_message(f"🏆 最佳验证损失: {early_stopping.best_loss:.4f} (Epoch {early_stopping.best_epoch+1})", LOG_FILE)
# 早停信息
log_message(f"🛑 早停状态: {'✅ 已触发' if early_stopping.early_stop else '❌ 未触发'}", LOG_FILE)
log_message(f" 最佳epoch: {early_stopping.best_epoch+1}", LOG_FILE)
log_message(f" 早停计数: {early_stopping.counter}/{TrainingConfig.PATIENCE}", LOG_FILE)
# 学习率信息
log_message("🔧 最终学习率配置:", LOG_FILE)
log_message(f" 初始学习率: {TrainingConfig.LEARNING_RATE:.0e}", LOG_FILE)
log_message(f" 最终学习率: {optimizer.param_groups[0]['lr']:.2e}", LOG_FILE)
# 类别权重
if TrainingConfig.USE_CLASS_WEIGHTS and class_weights is not None:
log_message("⚖️ 类别权重:", LOG_FILE)
weights = class_weights.cpu().numpy()
for label, weight in zip(label_list, weights):
log_message(f" {label:12}: {weight:.4f}", LOG_FILE)
log_message("✅ 训练流程完成", LOG_FILE)
log_message(f"📝 完整日志: {LOG_FILE}", LOG_FILE)
log_message("=" * 80, LOG_FILE)
print("\n🎉 早停版训练完成!")
if __name__ == "__main__":
main()