用 Python 实现 AI 辅助病案首页主诊断编码:从数据清洗到模型上线的完整工程指南

引言:为什么我们需要智能编码助手?
在医疗信息化高速发展的今天,病案首页数据的准确性直接影响医院管理、医保支付和医疗质量评估。然而,人工进行ICD编码面临着诸多挑战:
- 文本复杂性:一份出院小结动辄数千字,关键诊断信息分散在不同段落
- 编码专业性:ICD-10编码规则复杂,新编码员需要数年经验才能熟练掌握
- 数据不一致性:来自HIS、EMR、LIS、PACS等不同系统的数据格式各异
- 长尾分布:罕见病种样本量少,难以学习准确的编码规律
本文将带你从零开始构建一套可落地、可扩展、可解释的AI辅助编码系统,包含完整的代码实现和工程实践。
一、环境准备:5分钟搭建可复现环境
首先创建项目目录并设置虚拟环境:
bash
# 创建项目目录
mkdir icd_coding_system && cd icd_coding_system
# 创建虚拟环境(Python 3.10+推荐)
python -m venv venv
# 激活环境
# Windows: venv\Scripts\activate
source venv/bin/activate
# 安装核心依赖
pip install -U pip
pip install transformers==4.40.0
pip install datasets==2.18.0
pip install evaluate==0.4.1
pip install scikit-learn==1.4.0
pip install optuna==3.6.0
pip install peft==0.10.0
pip install accelerate==0.27.2
pip install pandas==2.2.1
pip install torch==2.2.0 --index-url https://download.pytorch.org/whl/cu118
# 创建项目结构
mkdir -p {data,scripts,outputs,configs,utils}
二、数据准备与清洗策略
2.1 数据格式标准化
你的原始数据可能来自多个系统,需要统一为以下CSV格式:
python
import pandas as pd
from pathlib import Path
def prepare_dataset(raw_data_path: str, output_path: str) -> pd.DataFrame:
"""
将原始数据转换为标准格式
输入至少包含:文本内容和ICD编码
"""
# 读取原始数据(根据实际情况调整)
if raw_data_path.endswith('.csv'):
df = pd.read_csv(raw_data_path, encoding='gbk') # 中文医院常用GBK编码
elif raw_data_path.endswith('.xlsx'):
df = pd.read_excel(raw_data_path)
else:
raise ValueError("仅支持CSV或Excel格式")
# 关键字段映射(根据你的实际字段名调整)
column_mapping = {
'出院诊断': 'diagnosis',
'主诊断': 'main_diagnosis',
'ICD编码': 'icd_code',
'病历文本': 'medical_record',
'出院小结': 'discharge_summary'
}
df = df.rename(columns=column_mapping)
# 构建文本字段:建议将关键信息拼接
def build_text(row):
parts = []
if 'discharge_summary' in row and pd.notna(row['discharge_summary']):
parts.append(f"出院小结:{str(row['discharge_summary'])}")
if 'diagnosis' in row and pd.notna(row['diagnosis']):
parts.append(f"诊断信息:{str(row['diagnosis'])}")
if 'main_diagnosis' in row and pd.notna(row['main_diagnosis']):
parts.append(f"主诊断:{str(row['main_diagnosis'])}")
return "\n".join(parts)
df['text'] = df.apply(build_text, axis=1)
# 确保必需字段存在
required_cols = ['text', 'icd_code']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
raise ValueError(f"缺少必需字段:{missing_cols}")
# 保存处理后的数据
df[required_cols].to_csv(output_path, index=False, encoding='utf-8')
print(f"数据已保存至:{output_path},共{len(df)}条记录")
# 统计ICD分布
icd_stats = df['icd_code'].value_counts()
print(f"ICD编码分布:共{len(icd_stats)}个类别")
print("Top 10 ICD编码:")
print(icd_stats.head(10))
return df
# 使用示例
# df = prepare_dataset('data/raw_his_data.xlsx', 'data/processed_data.csv')
2.2 文本清洗与标准化
python
import re
from typing import List, Optional
class MedicalTextCleaner:
"""医疗文本清洗工具类"""
def __init__(self):
# 常见需要脱敏的信息模式
self.patterns = {
'patient_id': r'病历号[::]\s*\d+',
'phone': r'1[3-9]\d{9}',
'id_card': r'\d{17}[\dXx]',
'name': r'患者[::]\s*[\u4e00-\u9fa5]{2,4}',
'age': r'年龄[::]\s*\d+'
}
def clean_text(self, text: str, replace_with: str = "[已脱敏]") -> str:
"""清洗文本,去除敏感信息"""
if not isinstance(text, str):
return ""
# 基础清洗
text = text.strip()
# 脱敏处理
for pattern_name, pattern in self.patterns.items():
text = re.sub(pattern, replace_with, text)
# 统一标点
text = text.replace(',', ',').replace('。', '.').replace(';', ';')
# 去除多余空白
text = re.sub(r'\s+', ' ', text)
# 处理常见缩写
abbreviations = {
'BP': '血压',
'HR': '心率',
'WBC': '白细胞',
'RBC': '红细胞',
# 可根据实际情况扩展
}
for abbr, full in abbreviations.items():
text = text.replace(abbr, full)
return text
def split_long_text(self, text: str, max_length: int = 500,
overlap: int = 50) -> List[str]:
"""将长文本分割为重叠的片段"""
if len(text) <= max_length:
return [text]
segments = []
start = 0
while start < len(text):
end = start + max_length
# 尽量在标点处截断
if end < len(text):
for break_char in ['。', ';', ',', '.', ';', ',', '\n']:
break_pos = text.rfind(break_char, start, end)
if break_pos > start + max_length * 0.5: # 确保不过短
end = break_pos + 1
break
segments.append(text[start:end])
start = end - overlap # 设置重叠
return segments
# 使用示例
cleaner = MedicalTextCleaner()
df['cleaned_text'] = df['text'].apply(cleaner.clean_text)
三、标签处理与数据集构建
3.1 ICD编码标准化
python
from collections import Counter
import json
class ICDLabelProcessor:
"""ICD标签处理器"""
def __init__(self, icd_version: str = "ICD-10"):
self.icd_version = icd_version
self.label2id = {}
self.id2label = {}
def fit(self, icd_codes: List[str]):
"""学习ICD编码到ID的映射"""
# 统计频率
code_counter = Counter(icd_codes)
# 按频率降序排列(有助于模型学习)
sorted_codes = [code for code, _ in code_counter.most_common()]
# 创建映射
self.label2id = {code: idx for idx, code in enumerate(sorted_codes)}
self.id2label = {idx: code for code, idx in self.label2id.items()}
print(f"创建了{len(self.label2id)}个ICD编码的映射")
print(f"样本分布:")
for code, count in code_counter.most_common(10):
print(f" {code}: {count}条")
def encode(self, icd_codes: List[str]) -> List[int]:
"""将ICD编码转换为ID"""
return [self.label2id[code] for code in icd_codes]
def decode(self, label_ids: List[int]) -> List[str]:
"""将ID转换回ICD编码"""
return [self.id2label[label_id] for label_id in label_ids]
def save_mappings(self, output_dir: str):
"""保存映射关系"""
mappings = {
'label2id': self.label2id,
'id2label': self.id2label,
'icd_version': self.icd_version
}
with open(f"{output_dir}/label_mappings.json", 'w', encoding='utf-8') as f:
json.dump(mappings, f, ensure_ascii=False, indent=2)
def load_mappings(self, mapping_path: str):
"""加载映射关系"""
with open(mapping_path, 'r', encoding='utf-8') as f:
mappings = json.load(f)
self.label2id = mappings['label2id']
self.id2label = {int(k): v for k, v in mappings['id2label'].items()}
self.icd_version = mappings.get('icd_version', 'ICD-10')
# 使用示例
processor = ICDLabelProcessor()
processor.fit(df['icd_code'].tolist())
df['label'] = processor.encode(df['icd_code'].tolist())
processor.save_mappings('outputs/')
3.2 数据集划分与加载
python
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict
import torch
def create_datasets(df, processor, test_size=0.2, val_size=0.1, random_seed=42):
"""
创建训练、验证、测试集
"""
# 第一次分割:分出测试集
train_val_df, test_df = train_test_split(
df, test_size=test_size, random_state=random_seed,
stratify=df['label'] # 保持类别分布
)
# 第二次分割:分出验证集
train_df, val_df = train_test_split(
train_val_df, test_size=val_size/(1-test_size),
random_state=random_seed, stratify=train_val_df['label']
)
print(f"训练集:{len(train_df)}条")
print(f"验证集:{len(val_df)}条")
print(f"测试集:{len(test_df)}条")
# 转换为HuggingFace Dataset格式
def create_hf_dataset(sub_df):
return Dataset.from_dict({
'text': sub_df['cleaned_text'].tolist(),
'label': sub_df['label'].tolist(),
'icd_code': sub_df['icd_code'].tolist()
})
dataset_dict = DatasetDict({
'train': create_hf_dataset(train_df),
'validation': create_hf_dataset(val_df),
'test': create_hf_dataset(test_df)
})
return dataset_dict
# 创建数据集
datasets = create_datasets(df, processor)
四、核心训练脚本:Trainer + Optuna + LoRA
4.1 完整训练脚本
创建 scripts/train_icd_model.py:
python
#!/usr/bin/env python3
"""
病案首页主诊断ICD编码训练脚本
支持:基础训练、LoRA微调、Optuna超参数优化
"""
import argparse
import json
import os
import numpy as np
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
EarlyStoppingCallback
)
from transformers import DataCollatorWithPadding
from datasets import DatasetDict, load_dataset
import evaluate
from sklearn.metrics import accuracy_score, f1_score, recall_score, confusion_matrix
# 可选组件
try:
import optuna
OPTUNA_AVAILABLE = True
except ImportError:
OPTUNA_AVAILABLE = False
print("Optuna未安装,如需超参数优化请安装:pip install optuna")
try:
from peft import (
LoraConfig,
TaskType,
get_peft_model,
prepare_model_for_kbit_training
)
PEFT_AVAILABLE = True
except ImportError:
PEFT_AVAILABLE = False
print("PEFT未安装,如需LoRA微调请安装:pip install peft")
class ICDClassificationTrainer:
"""ICD分类训练器"""
def __init__(self, args):
self.args = args
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备:{self.device}")
# 加载标签映射
self.load_label_mappings()
# 初始化tokenizer和模型
self.init_model()
def load_label_mappings(self):
"""加载标签映射"""
if self.args.label_map_path and os.path.exists(self.args.label_map_path):
with open(self.args.label_map_path, 'r', encoding='utf-8') as f:
mappings = json.load(f)
self.id2label = {int(k): v for k, v in mappings['id2label'].items()}
self.label2id = mappings['label2id']
self.num_labels = len(self.id2label)
else:
# 从数据集推断
self.num_labels = self.args.num_labels
self.id2label = None
self.label2id = None
def init_model(self):
"""初始化模型和tokenizer"""
print(f"加载预训练模型:{self.args.model_name}")
# 加载tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.args.model_name,
trust_remote_code=True
)
# 处理中文填充问题
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token or '[PAD]'
# 加载模型
model_config = {
"num_labels": self.num_labels,
"id2label": self.id2label,
"label2id": self.label2id
}
self.model = AutoModelForSequenceClassification.from_pretrained(
self.args.model_name,
**model_config,
trust_remote_code=True
)
# 应用LoRA配置(如果启用)
if self.args.use_lora and PEFT_AVAILABLE:
print("应用LoRA配置...")
self.model = self.prepare_lora_model()
self.model.to(self.device)
def prepare_lora_model(self):
"""准备LoRA模型"""
# 配置LoRA
lora_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
r=self.args.lora_r,
lora_alpha=self.args.lora_alpha,
lora_dropout=self.args.lora_dropout,
target_modules=self.args.lora_target_modules.split(","),
bias="none",
)
# 准备模型
model = prepare_model_for_kbit_training(self.model)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model
def preprocess_function(self, examples):
"""文本预处理函数"""
# 处理长文本:如果启用滑窗策略
if self.args.sliding_window and len(examples['text'][0]) > self.args.max_length * 3:
texts = []
for text in examples['text']:
# 简单的滑窗策略
chunks = self.sliding_window_chunk(text, self.args.max_length)
texts.extend(chunks[:3]) # 取前3个chunk
# 这里需要调整标签处理,实际中可能需要更复杂的聚合策略
batch = self.tokenizer(
texts,
truncation=True,
padding=True,
max_length=self.args.max_length
)
# 复制标签以匹配chunk数量
if 'label' in examples:
batch['labels'] = []
for label in examples['label']:
batch['labels'].extend([label] * min(3, len(texts)//len(examples['text'])))
else:
batch = self.tokenizer(
examples['text'],
truncation=True,
padding=True,
max_length=self.args.max_length
)
if 'label' in examples:
batch['labels'] = examples['label']
return batch
def sliding_window_chunk(self, text, max_length, overlap=0.2):
"""滑窗分块(简化版)"""
tokens = self.tokenizer.encode(text, truncation=False)
chunks = []
stride = int(max_length * (1 - overlap))
for i in range(0, len(tokens), stride):
chunk_tokens = tokens[i:i + max_length]
chunk_text = self.tokenizer.decode(chunk_tokens, skip_special_tokens=True)
chunks.append(chunk_text)
if i + max_length >= len(tokens):
break
return chunks
def compute_metrics(self, eval_pred):
"""计算评估指标"""
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
# 基础指标
accuracy = accuracy_score(labels, predictions)
macro_f1 = f1_score(labels, predictions, average='macro')
weighted_f1 = f1_score(labels, predictions, average='weighted')
# 计算Kappa(需要sklearn)
try:
from sklearn.metrics import cohen_kappa_score
kappa = cohen_kappa_score(labels, predictions)
except:
kappa = 0
# Top-k准确率
top_k = 3
if len(predictions.shape) == 2: # 有概率分布
top_k_preds = np.argsort(predictions, axis=1)[:, -top_k:]
top_k_acc = np.mean([labels[i] in top_k_preds[i] for i in range(len(labels))])
else:
top_k_acc = 0
return {
"accuracy": accuracy,
"f1_macro": macro_f1,
"f1_weighted": weighted_f1,
"kappa": kappa,
f"top_{top_k}_accuracy": top_k_acc
}
def train(self, train_dataset, eval_dataset):
"""训练模型"""
print("开始训练...")
# 训练参数
training_args = TrainingArguments(
output_dir=self.args.output_dir,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=self.args.learning_rate,
per_device_train_batch_size=self.args.batch_size,
per_device_eval_batch_size=self.args.batch_size * 2,
num_train_epochs=self.args.num_epochs,
weight_decay=self.args.weight_decay,
load_best_model_at_end=True,
metric_for_best_model="f1_macro",
greater_is_better=True,
logging_dir=f"{self.args.output_dir}/logs",
logging_steps=50,
save_total_limit=2,
fp16=self.args.fp16,
report_to="tensorboard" if self.args.log_to_tb else "none",
)
# 创建Trainer
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=self.tokenizer,
compute_metrics=self.compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] if self.args.early_stopping else None,
)
# 训练
train_result = trainer.train()
# 保存模型
trainer.save_model()
trainer.save_state()
# 评估
eval_results = trainer.evaluate()
print(f"评估结果:{eval_results}")
return trainer, eval_results
def hpo_objective(self, trial, train_dataset, eval_dataset):
"""Optuna超参数优化目标函数"""
if not OPTUNA_AVAILABLE:
raise ImportError("需要安装Optuna:pip install optuna")
# 建议超参数
learning_rate = trial.suggest_float("learning_rate", 1e-5, 5e-4, log=True)
batch_size = trial.suggest_categorical("batch_size", [8, 16, 32])
weight_decay = trial.suggest_float("weight_decay", 0.0, 0.1)
# 更新参数
self.args.learning_rate = learning_rate
self.args.batch_size = batch_size
self.args.weight_decay = weight_decay
# 重新初始化模型
self.init_model()
# 训练
_, eval_results = self.train(train_dataset, eval_dataset)
# 返回优化目标(Macro-F1)
return eval_results["eval_f1_macro"]
def run_hyperparameter_optimization(self, train_dataset, eval_dataset, n_trials=20):
"""运行超参数优化"""
print(f"开始超参数优化,共{n_trials}次试验...")
study = optuna.create_study(
direction="maximize",
study_name="icd_classification_hpo",
storage=f"sqlite:///{self.args.output_dir}/optuna.db",
load_if_exists=True
)
study.optimize(
lambda trial: self.hpo_objective(trial, train_dataset, eval_dataset),
n_trials=n_trials,
show_progress_bar=True
)
print("最佳超参数:")
for key, value in study.best_params.items():
print(f" {key}: {value}")
print(f"最佳Macro-F1: {study.best_value:.4f}")
# 保存优化结果
with open(f"{self.args.output_dir}/hpo_results.json", 'w') as f:
json.dump({
"best_params": study.best_params,
"best_value": study.best_value,
"trials": len(study.trials)
}, f, indent=2)
return study.best_params
def main():
parser = argparse.ArgumentParser(description="ICD编码分类模型训练")
# 数据参数
parser.add_argument("--data_path", type=str, required=True, help="数据路径")
parser.add_argument("--label_map_path", type=str, help="标签映射文件路径")
parser.add_argument("--num_labels", type=int, default=None, help="标签数量")
# 模型参数
parser.add_argument("--model_name", type=str,
default="bert-base-chinese",
help="预训练模型名称")
parser.add_argument("--output_dir", type=str,
default="./outputs/icd_model",
help="输出目录")
# 训练参数
parser.add_argument("--max_length", type=int, default=256, help="最大序列长度")
parser.add_argument("--batch_size", type=int, default=16, help="批大小")
parser.add_argument("--learning_rate", type=float, default=2e-5, help="学习率")
parser.add_argument("--num_epochs", type=int, default=10, help="训练轮数")
parser.add_argument("--weight_decay", type=float, default=0.01, help="权重衰减")
parser.add_argument("--fp16", action="store_true", help="使用混合精度训练")
parser.add_argument("--early_stopping", action="store_true", help="早停")
# 高级功能
parser.add_argument("--use_lora", action="store_true", help="使用LoRA")
parser.add_argument("--lora_r", type=int, default=8, help="LoRA秩")
parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha")
parser.add_argument("--lora_dropout", type=float, default=0.1, help="LoRA dropout")
parser.add_argument("--lora_target_modules", type=str,
default="query,key,value",
help="LoRA目标模块")
parser.add_argument("--do_hpo", action="store_true", help="进行超参数优化")
parser.add_argument("--hpo_trials", type=int, default=20, help="HPO试验次数")
parser.add_argument("--sliding_window", action="store_true", help="使用滑窗处理长文本")
parser.add_argument("--log_to_tb", action="store_true", help="记录到TensorBoard")
args = parser.parse_args()
# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
# 保存配置
with open(f"{args.output_dir}/config.json", 'w') as f:
json.dump(vars(args), f, indent=2)
# 初始化训练器
trainer = ICDClassificationTrainer(args)
# 加载数据集
dataset = load_dataset('csv', data_files={'train': args.data_path})
# 划分数据集(如果未预先划分)
if 'validation' not in dataset:
dataset = dataset['train'].train_test_split(test_size=0.2, seed=42)
dataset = DatasetDict({
'train': dataset['train'],
'validation': dataset['test']
})
# 预处理数据集
encoded_dataset = dataset.map(
trainer.preprocess_function,
batched=True,
remove_columns=dataset['train'].column_names
)
# 训练或优化
if args.do_hpo and OPTUNA_AVAILABLE:
best_params = trainer.run_hyperparameter_optimization(
encoded_dataset['train'],
encoded_dataset['validation'],
n_trials=args.hpo_trials
)
print(f"最佳参数:{best_params}")
else:
trainer.train(encoded_dataset['train'], encoded_dataset['validation'])
print("训练完成!")
if __name__ == "__main__":
main()
4.2 不同训练模式的启动命令
bash
# 基础训练
python scripts/train_icd_model.py \
--data_path data/processed_data.csv \
--output_dir outputs/baseline_model \
--model_name hfl/chinese-roberta-wwm-ext \
--max_length 256 \
--batch_size 16 \
--num_epochs 10
# 启用LoRA(节省显存,适合快速迭代)
python scripts/train_icd_model.py \
--data_path data/processed_data.csv \
--output_dir outputs/lora_model \
--use_lora \
--lora_r 8 \
--lora_alpha 32 \
--batch_size 32 # LoRA可以加大批大小
# 超参数优化
python scripts/train_icd_model.py \
--data_path data/processed_data.csv \
--output_dir outputs/hpo_model \
--do_hpo \
--hpo_trials 30
# 处理长文本(启用滑窗)
python scripts/train_icd_model.py \
--data_path data/processed_data.csv \
--output_dir outputs/sliding_model \
--sliding_window \
--max_length 128 # 每个chunk的长度
五、高级优化技巧
5.1 处理类别不平衡的损失函数
python
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
"""Focal Loss用于处理类别不平衡"""
def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
if self.alpha is not None:
if isinstance(self.alpha, (float, int)):
alpha_t = self.alpha
else:
alpha_t = self.alpha[targets]
ce_loss = alpha_t * ce_loss
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class WeightedCrossEntropyLoss(nn.Module):
"""加权交叉熵损失"""
def __init__(self, class_weights=None):
super().__init__()
self.class_weights = class_weights
def forward(self, logits, targets):
if self.class_weights is not None:
weights = torch.tensor(self.class_weights, device=logits.device)
loss = F.cross_entropy(logits, targets, weight=weights)
else:
loss = F.cross_entropy(logits, targets)
return loss
# 计算类别权重
def compute_class_weights(labels, beta=0.999):
"""计算有效样本数的类别权重"""
from collections import Counter
import numpy as np
counter = Counter(labels)
total = len(labels)
num_classes = len(counter)
# 有效样本数方法
weights = []
for i in range(num_classes):
freq = counter.get(i, 0) / total
effective_num = 1.0 - np.power(beta, freq * total)
weight = (1.0 - beta) / effective_num
weights.append(weight)
# 归一化
weights = np.array(weights)
weights = weights / np.sum(weights) * num_classes
return torch.tensor(weights, dtype=torch.float32)