用Python实现辅助病案首页主诊断编码:从数据清洗到模型上线(下)

5.2 长文本的滑窗聚合策略

python 复制代码
class SlidingWindowClassifier:
    """滑窗分类器(用于推理)"""
    
    def __init__(self, model, tokenizer, max_length=256, overlap=0.2):
        self.model = model
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.overlap = overlap
    
    def predict(self, text, top_k=3):
        """对长文本进行预测"""
        # 分块
        chunks = self._split_text(text)
        
        if not chunks:
            return []
        
        # 对每个块进行预测
        all_logits = []
        for chunk in chunks:
            inputs = self.tokenizer(
                chunk,
                truncation=True,
                padding=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                logits = outputs.logits
                all_logits.append(logits)
        
        # 聚合策略:平均logits
        if len(all_logits) > 1:
            avg_logits = torch.mean(torch.stack(all_logits), dim=0)
        else:
            avg_logits = all_logits[0]
        
        # 获取top-k预测
        probs = F.softmax(avg_logits, dim=-1)[0]
        top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k, len(probs)))
        
        predictions = []
        for prob, idx in zip(top_k_probs.tolist(), top_k_indices.tolist()):
            predictions.append({
                'icd_code': self.model.config.id2label[idx],
                'confidence': prob,
                'evidence_chunks': chunks[:2]  # 返回前两个关键chunk
            })
        
        return predictions
    
    def _split_text(self, text):
        """智能分块"""
        # 按段落分割
        paragraphs = [p.strip() for p in text.split('\n') if p.strip()]
        
        chunks = []
        for para in paragraphs:
            if len(para) < self.max_length * 0.8:
                chunks.append(para)
            else:
                # 按句子分割
                sentences = self._split_sentences(para)
                current_chunk = ""
                
                for sent in sentences:
                    if len(current_chunk) + len(sent) < self.max_length:
                        current_chunk += sent
                    else:
                        if current_chunk:
                            chunks.append(current_chunk)
                        current_chunk = sent
                
                if current_chunk:
                    chunks.append(current_chunk)
        
        return chunks
    
    def _split_sentences(self, text):
        """简单的中文分句"""
        import re
        delimiters = r'[。!?!?]'
        sentences = re.split(delimiters, text)
        return [s.strip() + '.' for s in sentences if s.strip()]

5.3 模型融合策略

python 复制代码
class ModelEnsemble:
    """模型融合器"""
    
    def __init__(self, model_paths, tokenizer_name="bert-base-chinese"):
        self.models = []
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        
        for path in model_paths:
            model = AutoModelForSequenceClassification.from_pretrained(path)
            model.eval()
            self.models.append(model)
    
    def predict_proba(self, texts, aggregation='mean'):
        """预测概率"""
        all_probs = []
        
        for model in self.models:
            inputs = self.tokenizer(
                texts,
                truncation=True,
                padding=True,
                max_length=256,
                return_tensors="pt"
            )
            
            with torch.no_grad():
                outputs = model(**inputs)
                probs = F.softmax(outputs.logits, dim=-1)
                all_probs.append(probs)
        
        # 聚合策略
        if aggregation == 'mean':
            final_probs = torch.mean(torch.stack(all_probs), dim=0)
        elif aggregation == 'max':
            final_probs = torch.max(torch.stack(all_probs), dim=0)[0]
        elif aggregation == 'geometric':
            final_probs = torch.exp(torch.mean(torch.log(torch.stack(all_probs)), dim=0))
        else:
            raise ValueError(f"不支持的聚合方法:{aggregation}")
        
        return final_probs
    
    def predict_ensemble(self, texts, top_k=3):
        """集成预测"""
        probs = self.predict_proba(texts)
        
        predictions = []
        for prob in probs:
            top_k_probs, top_k_indices = torch.topk(prob, k=min(top_k, len(prob)))
            preds = [
                {'icd_code': self.models[0].config.id2label[idx.item()], 
                 'confidence': prob.item()}
                for prob, idx in zip(top_k_probs, top_k_indices)
            ]
            predictions.append(preds)
        
        return predictions

六、评估与分析

6.1 综合评估脚本

创建 scripts/evaluate_model.py

python 复制代码
import json
import numpy as np
import pandas as pd
from sklearn.metrics import (
    classification_report, 
    confusion_matrix,
    cohen_kappa_score,
    accuracy_score,
    f1_score
)
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path

class ModelEvaluator:
    """模型评估器"""
    
    def __init__(self, model_path, label_map_path):
        self.model_path = model_path
        self.label_map_path = label_map_path
        
        # 加载模型和映射
        self.load_model()
    
    def load_model(self):
        """加载模型和标签映射"""
        from transformers import AutoTokenizer, AutoModelForSequenceClassification
        
        # 加载标签映射
        with open(self.label_map_path, 'r', encoding='utf-8') as f:
            self.label_mappings = json.load(f)
        
        # 加载模型
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(self.model_path)
        self.model.eval()
    
    def evaluate(self, test_data, output_dir="./eval_results"):
        """全面评估模型"""
        os.makedirs(output_dir, exist_ok=True)
        
        # 预测
        predictions, true_labels = self.predict(test_data)
        
        # 计算各项指标
        metrics = self.compute_all_metrics(true_labels, predictions)
        
        # 生成详细报告
        self.generate_reports(true_labels, predictions, output_dir)
        
        # 分析错误案例
        error_analysis = self.analyze_errors(test_data, true_labels, predictions)
        
        return metrics
    
    def predict(self, test_data):
        """批量预测"""
        from torch.utils.data import DataLoader
        import torch
        
        predictions = []
        true_labels = []
        
        dataloader = DataLoader(test_data, batch_size=32)
        
        for batch in dataloader:
            texts = batch['text']
            labels = batch['label']
            
            inputs = self.tokenizer(
                texts,
                truncation=True,
                padding=True,
                max_length=256,
                return_tensors="pt"
            )
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                preds = torch.argmax(outputs.logits, dim=-1)
                predictions.extend(preds.cpu().numpy())
                true_labels.extend(labels.numpy())
        
        return predictions, true_labels
    
    def compute_all_metrics(self, y_true, y_pred):
        """计算所有评估指标"""
        metrics = {}
        
        # 基础指标
        metrics['accuracy'] = accuracy_score(y_true, y_pred)
        metrics['macro_f1'] = f1_score(y_true, y_pred, average='macro')
        metrics['weighted_f1'] = f1_score(y_true, y_pred, average='weighted')
        metrics['kappa'] = cohen_kappa_score(y_true, y_pred)
        
        # Top-k准确率
        metrics['top_3_accuracy'] = self.compute_top_k_accuracy(y_true, y_pred, k=3)
        metrics['top_5_accuracy'] = self.compute_top_k_accuracy(y_true, y_pred, k=5)
        
        # 类别层面的指标
        class_report = classification_report(
            y_true, y_pred, 
            target_names=list(self.label_mappings['id2label'].values()),
            output_dict=True
        )
        
        # 保存类别指标
        class_metrics = []
        for class_name, class_metrics_dict in class_report.items():
            if class_name not in ['accuracy', 'macro avg', 'weighted avg']:
                class_metrics.append({
                    'class': class_name,
                    'precision': class_metrics_dict.get('precision', 0),
                    'recall': class_metrics_dict.get('recall', 0),
                    'f1': class_metrics_dict.get('f1-score', 0),
                    'support': class_metrics_dict.get('support', 0)
                })
        
        metrics['class_metrics'] = class_metrics
        
        return metrics
    
    def compute_top_k_accuracy(self, y_true, y_pred_probs, k=3):
        """计算Top-k准确率(简化版,实际需要概率)"""
        # 这里需要修改以获取概率分布
        # 简化实现:假设预测正确或在前k个中
        correct = 0
        for true, pred in zip(y_true, y_pred_probs):
            # 这里需要扩展以处理多候选
            pass
        return 0  # 占位符
    
    def generate_reports(self, y_true, y_pred, output_dir):
        """生成评估报告"""
        # 1. 混淆矩阵热图
        self.plot_confusion_matrix(y_true, y_pred, output_dir)
        
        # 2. 类别性能分布
        self.plot_class_performance(y_true, y_pred, output_dir)
        
        # 3. 保存详细指标
        self.save_detailed_metrics(y_true, y_pred, output_dir)
    
    def plot_confusion_matrix(self, y_true, y_pred, output_dir):
        """绘制混淆矩阵"""
        cm = confusion_matrix(y_true, y_pred)
        
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.savefig(f'{output_dir}/confusion_matrix.png', dpi=300)
        plt.close()
    
    def analyze_errors(self, test_data, y_true, y_pred):
        """错误分析"""
        errors = []
        
        for i, (true, pred) in enumerate(zip(y_true, y_pred)):
            if true != pred:
                error = {
                    'index': i,
                    'true_label': self.label_mappings['id2label'][str(true)],
                    'pred_label': self.label_mappings['id2label'][str(pred)],
                    'text': test_data[i]['text'][:200] + '...'  # 截取部分文本
                }
                errors.append(error)
        
        # 按错误类型分组
        error_df = pd.DataFrame(errors)
        
        # 保存错误分析
        error_df.to_csv(f'{output_dir}/error_analysis.csv', index=False, encoding='utf-8')
        
        return error_df

6.2 业务友好型评估

python 复制代码
class BusinessMetricsEvaluator:
    """业务指标评估器"""
    
    def __init__(self, icd_hierarchy_path=None):
        """
        icd_hierarchy_path: ICD层级结构文件路径
        用于计算层级准确性
        """
        self.icd_hierarchy = self.load_icd_hierarchy(icd_hierarchy_path)
    
    def load_icd_hierarchy(self, path):
        """加载ICD层级结构"""
        if path and os.path.exists(path):
            with open(path, 'r', encoding='utf-8') as f:
                return json.load(f)
        return None
    
    def evaluate_hierarchical_accuracy(self, y_true, y_pred, level=1):
        """
        评估层级准确性
        level=1: 仅检查章(如A,B,C)
        level=2: 检查类别(如A00,A01)
        """
        if not self.icd_hierarchy:
            return None
        
        correct = 0
        for true_code, pred_code in zip(y_true, y_pred):
            true_prefix = true_code[:level]
            pred_prefix = pred_code[:level]
            if true_prefix == pred_prefix:
                correct += 1
        
        return correct / len(y_true)
    
    def evaluate_by_department(self, predictions, department_info):
        """按科室评估"""
        dept_metrics = {}
        
        for dept in set(department_info.values()):
            dept_indices = [i for i, dept_name in department_info.items() if dept_name == dept]
            
            if dept_indices:
                dept_true = [predictions['true_labels'][i] for i in dept_indices]
                dept_pred = [predictions['pred_labels'][i] for i in dept_indices]
                
                metrics = {
                    'accuracy': accuracy_score(dept_true, dept_pred),
                    'macro_f1': f1_score(dept_true, dept_pred, average='macro'),
                    'sample_count': len(dept_indices)
                }
                dept_metrics[dept] = metrics
        
        return dept_metrics
    
    def calculate_cost_savings(self, predictions, manual_cost_per_record=10, 
                              ai_assist_improvement=0.3):
        """计算成本节约"""
        baseline_accuracy = predictions.get('baseline_accuracy', 0.7)
        ai_accuracy = predictions.get('ai_accuracy', 0)
        
        # 计算错误减少
        error_reduction = ai_accuracy - baseline_accuracy
        
        # 假设每条记录的手工编码成本
        total_records = len(predictions['true_labels'])
        cost_saving = total_records * error_reduction * manual_cost_per_record
        
        # 考虑AI辅助提升效率
        time_saving = total_records * ai_assist_improvement * 2  # 假设每份节省2分钟
        
        return {
            'annual_cost_saving': cost_saving * 12,  # 年化
            'time_saving_hours': time_saving / 60,
            'error_reduction_rate': error_reduction
        }

七、部署上线指南

7.1 推理服务封装

创建 scripts/inference_service.py

python 复制代码
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import json
from typing import List, Optional
import numpy as np

app = FastAPI(title="ICD编码AI辅助服务")

class ICDRequest(BaseModel):
    text: str
    top_k: Optional[int] = 3
    return_evidence: Optional[bool] = False

class ICDResponse(BaseModel):
    predictions: List[dict]
    model_version: str
    inference_time: float

class ICDClassifier:
    """ICD分类器服务"""
    
    def __init__(self, model_path: str, device: str = None):
        self.model_path = model_path
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.load_model()
    
    def load_model(self):
        """加载模型"""
        print(f"加载模型从:{self.model_path}")
        
        # 加载配置
        with open(f"{self.model_path}/config.json", 'r') as f:
            config = json.load(f)
        
        # 加载标签映射
        with open(f"{self.model_path}/label_mappings.json", 'r', encoding='utf-8') as f:
            self.label_mappings = json.load(f)
        
        # 加载tokenizer和模型
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(self.model_path)
        self.model.to(self.device)
        self.model.eval()
        
        self.model_version = config.get('model_version', '1.0.0')
    
    def preprocess(self, text: str, max_length: int = 256):
        """预处理文本"""
        # 这里可以加入文本清洗和规范化
        inputs = self.tokenizer(
            text,
            truncation=True,
            padding=True,
            max_length=max_length,
            return_tensors="pt"
        )
        
        return {k: v.to(self.device) for k, v in inputs.items()}
    
    def predict(self, text: str, top_k: int = 3, return_evidence: bool = False):
        """预测ICD编码"""
        import time
        start_time = time.time()
        
        # 预处理
        inputs = self.preprocess(text)
        
        # 推理
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probs = F.softmax(logits, dim=-1)[0]
        
        # 获取top-k
        top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k, len(probs)))
        
        # 构建响应
        predictions = []
        for prob, idx in zip(top_k_probs.tolist(), top_k_indices.tolist()):
            pred = {
                'icd_code': self.label_mappings['id2label'][str(idx)],
                'confidence': float(prob),
                'rank': len(predictions) + 1
            }
            
            if return_evidence:
                # 提取关键证据(简化版)
                pred['evidence'] = self.extract_evidence(text, idx)
            
            predictions.append(pred)
        
        inference_time = time.time() - start_time
        
        return {
            'predictions': predictions,
            'model_version': self.model_version,
            'inference_time': inference_time
        }
    
    def extract_evidence(self, text: str, predicted_label_id: int):
        """提取预测证据(简化实现)"""
        # 实际实现可能需要注意力可视化或关键词提取
        # 这里返回关键句子作为示例
        sentences = [s.strip() for s in text.split('。') if s.strip()]
        
        # 简单返回前两个句子作为证据
        return sentences[:2] if len(sentences) >= 2 else sentences

# 全局模型实例
classifier = None

@app.on_event("startup")
async def startup_event():
    """启动时加载模型"""
    global classifier
    model_path = "outputs/best_model"  # 修改为你的模型路径
    classifier = ICDClassifier(model_path)

@app.post("/predict", response_model=ICDResponse)
async def predict_icd(request: ICDRequest):
    """预测ICD编码"""
    try:
        result = classifier.predict(
            text=request.text,
            top_k=request.top_k,
            return_evidence=request.return_evidence
        )
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    """健康检查"""
    return {"status": "healthy", "model_version": classifier.model_version}

@app.get("/model_info")
async def model_info():
    """获取模型信息"""
    return {
        "model_version": classifier.model_version,
        "num_classes": len(classifier.label_mappings['id2label']),
        "icd_version": classifier.label_mappings.get('icd_version', 'ICD-10')
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

7.2 批量推理脚本

python 复制代码
import pandas as pd
import torch
from tqdm import tqdm
import json
from pathlib import Path

class BatchInference:
    """批量推理"""
    
    def __init__(self, model_path, batch_size=32):
        self.model_path = model_path
        self.batch_size = batch_size
        self.load_model()
    
    def load_model(self):
        """加载模型"""
        from transformers import AutoTokenizer, AutoModelForSequenceClassification
        
        # 加载标签映射
        with open(f"{self.model_path}/label_mappings.json", 'r', encoding='utf-8') as f:
            self.label_mappings = json.load(f)
        
        # 加载模型
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(self.model_path)
        self.model.eval()
        
        if torch.cuda.is_available():
            self.model.cuda()
    
    def predict_batch(self, texts, top_k=3):
        """批量预测"""
        import torch.nn.functional as F
        
        all_predictions = []
        
        for i in tqdm(range(0, len(texts), self.batch_size), desc="批量推理"):
            batch_texts = texts[i:i+self.batch_size]
            
            # 编码
            inputs = self.tokenizer(
                batch_texts,
                truncation=True,
                padding=True,
                max_length=256,
                return_tensors="pt"
            )
            
            if torch.cuda.is_available():
                inputs = {k: v.cuda() for k, v in inputs.items()}
            
            # 推理
            with torch.no_grad():
                outputs = self.model(**inputs)
                probs = F.softmax(outputs.logits, dim=-1)
            
            # 获取top-k
            top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[1]), dim=1)
            
            # 转换为ICD编码
            for batch_idx in range(len(batch_texts)):
                predictions = []
                for k in range(top_k):
                    idx = top_k_indices[batch_idx][k].item()
                    prob = top_k_probs[batch_idx][k].item()
                    
                    predictions.append({
                        'icd_code': self.label_mappings['id2label'][str(idx)],
                        'confidence': prob,
                        'rank': k + 1
                    })
                
                all_predictions.append(predictions)
        
        return all_predictions
    
    def process_file(self, input_file, output_file, text_column='text'):
        """处理文件"""
        # 读取数据
        if input_file.endswith('.csv'):
            df = pd.read_csv(input_file, encoding='utf-8')
        elif input_file.endswith('.xlsx'):
            df = pd.read_excel(input_file)
        else:
            raise ValueError("仅支持CSV或Excel文件")
        
        # 预测
        texts = df[text_column].fillna('').tolist()
        predictions = self.predict_batch(texts)
        
        # 保存结果
        results = []
        for idx, preds in enumerate(predictions):
            result = {
                'original_text': texts[idx][:100] + '...' if len(texts[idx]) > 100 else texts[idx],
                'top_1_prediction': preds[0]['icd_code'],
                'top_1_confidence': preds[0]['confidence'],
                'all_predictions': preds
            }
            results.append(result)
        
        # 保存到文件
        output_df = pd.DataFrame(results)
        output_df.to_excel(output_file, index=False)
        
        print(f"预测完成,结果保存至:{output_file}")
        print(f"共处理{len(df)}条记录")

# 使用示例
if __name__ == "__main__":
    inference = BatchInference("outputs/best_model")
    inference.process_file(
        input_file="data/test_data.csv",
        output_file="results/predictions.xlsx"
    )

八、持续优化与监控

8.1 模型性能监控

python 复制代码
import sqlite3
import pandas as pd
from datetime import datetime

class ModelMonitor:
    """模型性能监控器"""
    
    def __init__(self, db_path="model_monitor.db"):
        self.db_path = db_path
        self.init_database()
    
    def init_database(self):
        """初始化监控数据库"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        # 创建推理记录表
        cursor.execute('''
        CREATE TABLE IF NOT EXISTS inference_logs (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            timestamp DATETIME,
            model_version TEXT,
            text_hash TEXT,
            top_1_prediction TEXT,
            top_1_confidence REAL,
            inference_time REAL
        )
        ''')
        
        # 创建反馈记录表
        cursor.execute('''
        CREATE TABLE IF NOT EXISTS feedback_logs (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            inference_id INTEGER,
            correct_icd TEXT,
            user_rating INTEGER,
            feedback TEXT,
            corrected_at DATETIME,
            FOREIGN KEY (inference_id) REFERENCES inference_logs (id)
        )
        ''')
        
        # 创建性能指标表
        cursor.execute('''
        CREATE TABLE IF NOT EXISTS performance_metrics (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            date DATE,
            model_version TEXT,
            total_predictions INTEGER,
            accuracy REAL,
            avg_confidence REAL,
            avg_response_time REAL
        )
        ''')
        
        conn.commit()
        conn.close()
    
    def log_inference(self, model_version, text, predictions, inference_time):
        """记录推理日志"""
        import hashlib
        
        # 生成文本哈希(用于去重和追踪)
        text_hash = hashlib.md5(text.encode()).hexdigest()
        
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        cursor.execute('''
        INSERT INTO inference_logs 
        (timestamp, model_version, text_hash, top_1_prediction, top_1_confidence, inference_time)
        VALUES (?, ?, ?, ?, ?, ?)
        ''', (
            datetime.now(),
            model_version,
            text_hash,
            predictions[0]['icd_code'] if predictions else None,
            predictions[0]['confidence'] if predictions else 0,
            inference_time
        ))
        
        inference_id = cursor.lastrowid
        conn.commit()
        conn.close()
        
        return inference_id
    
    def log_feedback(self, inference_id, correct_icd=None, user_rating=None, feedback=None):
        """记录用户反馈"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        cursor.execute('''
        INSERT INTO feedback_logs 
        (inference_id, correct_icd, user_rating, feedback, corrected_at)
        VALUES (?, ?, ?, ?, ?)
        ''', (
            inference_id,
            correct_icd,
            user_rating,
            feedback,
            datetime.now()
        ))
        
        conn.commit()
        conn.close()
    
    def calculate_daily_metrics(self):
        """计算每日指标"""
        conn = sqlite3.connect(self.db_path)
        
        # 计算每日统计
        query = '''
        SELECT 
            DATE(timestamp) as date,
            model_version,
            COUNT(*) as total_predictions,
            AVG(top_1_confidence) as avg_confidence,
            AVG(inference_time) as avg_response_time
        FROM inference_logs
        WHERE DATE(timestamp) = DATE('now', '-1 day')
        GROUP BY DATE(timestamp), model_version
        '''
        
        daily_stats = pd.read_sql_query(query, conn)
        
        # 计算准确率(需要反馈数据)
        accuracy_query = '''
        SELECT 
            DATE(i.timestamp) as date,
            i.model_version,
            COUNT(*) as total_feedback,
            SUM(CASE WHEN f.correct_icd = i.top_1_prediction THEN 1 ELSE 0 END) as correct_predictions
        FROM inference_logs i
        JOIN feedback_logs f ON i.id = f.inference_id
        WHERE DATE(i.timestamp) = DATE('now', '-1 day')
        GROUP BY DATE(i.timestamp), i.model_version
        '''
        
        accuracy_stats = pd.read_sql_query(accuracy_query, conn)
        
        conn.close()
        
        return daily_stats, accuracy_stats
    
    def detect_performance_drift(self, window_days=30):
        """检测性能漂移"""
        conn = sqlite3.connect(self.db_path)
        
        query = f'''
        SELECT 
            DATE(timestamp) as date,
            AVG(top_1_confidence) as avg_confidence
        FROM inference_logs
        WHERE DATE(timestamp) >= DATE('now', '-{window_days} days')
        GROUP BY DATE(timestamp)
        ORDER BY DATE(timestamp)
        '''
        
        confidence_trend = pd.read_sql_query(query, conn)
        conn.close()
        
        # 简单漂移检测:连续下降趋势
        if len(confidence_trend) >= 7:
            recent_avg = confidence_trend['avg_confidence'].tail(7).mean()
            historical_avg = confidence_trend['avg_confidence'].head(len(confidence_trend)-7).mean()
            
            if recent_avg < historical_avg * 0.95:  # 下降超过5%
                return {
                    'drift_detected': True,
                    'confidence_drop': (historical_avg - recent_avg) / historical_avg,
                    'suggestion': '建议重新评估模型或重新训练'
                }
        
        return {'drift_detected': False}

九、实际部署案例与性能数据

9.1 部署架构示例

复制代码
生产环境部署架构:
┌─────────────────┐    ┌─────────────────┐    ┌─────────────────┐
│   客户端请求     │───▶│   API网关       │───▶│   FastAPI服务   │
│ (EMR/病案系统)  │    │ (Nginx/Traefik) │    │   (ICD分类器)   │
└─────────────────┘    └─────────────────┘    └─────────────────┘
                                                       │
                                                       ▼
┌─────────────────┐    ┌─────────────────┐    ┌─────────────────┐
│   监控报警      │◀───│  性能监控       │◀───│  模型版本管理   │
│ (Prometheus)    │    │ (ModelMonitor)  │    │   (A/B测试)     │
└─────────────────┘    └─────────────────┘    └─────────────────┘

9.2 性能基准测试

基于真实医疗数据的测试结果(示例):

指标 基线模型 +LoRA优化 +滑窗策略 +集成学习
Accuracy 78.3% 79.1% 81.2% 82.5%
Macro-F1 65.2% 67.8% 72.1% 74.3%
Top-3 Acc 89.5% 90.2% 92.1% 93.4%
推理时间 45ms 42ms 68ms 120ms
显存占用 1.2GB 0.8GB 1.3GB 2.5GB

9.3 成本效益分析

假设一家三甲医院年出院病历10万份:

  • 人工编码成本:10万 × 15元/份 = 150万元/年
  • AI辅助后效率提升:预计提升30%
  • 年成本节约:150万 × 30% = 45万元
  • 系统投资回收期:约3-6个月

十、常见问题与解决方案

Q1:数据量不足怎么办?

解决方案

  1. 使用预训练语言模型(如华佗BERT、MedicalBERT)
  2. 应用数据增强:同义词替换、句子重组、回译
  3. 使用少样本学习技术:Prompt Tuning、Adapter
  4. 迁移学习:先在大规模医学文本预训练,再微调

Q2:如何处理新增的ICD编码?

解决方案

  1. 实现增量学习管道
  2. 使用开放集识别技术
  3. 建立编码更新同步机制
  4. 维护"未知编码"处理流程

Q3:模型如何与现有HIS系统集成?

解决方案

  1. 提供RESTful API接口
  2. 支持HL7/FHIR医疗数据标准
  3. 开发中间件适配不同厂商系统
  4. 提供批量处理和实时处理两种模式

Q4:如何确保模型公平性?

解决方案

  1. 按科室、病种、患者群体分层评估
  2. 监控不同群体间的性能差异
  3. 使用去偏技术(如对抗学习)
  4. 建立人工审核和纠正机制

总结

本文详细介绍了从零开始构建AI辅助病案首页主诊断编码系统的完整流程,涵盖了:

  1. 数据准备:清洗、脱敏、标准化处理
  2. 模型训练:基于Transformer的分类模型,支持LoRA和Optuna优化
  3. 高级优化:长文本处理、类别不平衡、模型融合
  4. 评估分析:多维度评估指标和错误分析
  5. 部署上线:API服务、批量处理、性能监控
  6. 持续改进:反馈收集、模型更新、性能监控

关键成功因素:

  • 数据质量:干净、标准化的数据比复杂的模型更重要
  • 可解释性:让编码员理解模型决策,增加信任度
  • 持续迭代:建立完整的"训练-评估-反馈"闭环
  • 人机协作:AI提供推荐,专家最终审核

下一步建议:

  1. 从简单病例开始试点,逐步扩展到复杂病例
  2. 建立编码专家反馈机制,持续优化模型
  3. 探索多模态信息融合(影像、检验报告等)
  4. 考虑构建ICD编码知识图谱,提升可解释性

开源资源推荐:

通过本文的指导,你可以构建一个准确率超过80%的AI辅助编码系统,显著提升编码效率和准确性。记住,AI不是要取代人工编码员,而是作为强大的辅助工具,让专业人士能够更专注于复杂病例的判断。


作者提示:本文代码均为示例性质,实际部署时需要根据具体业务场景进行调整。医疗AI系统涉及患者隐私和数据安全,务必遵守相关法律法规和医院信息安全规定。

版权声明:本文允许非商业用途的转载和分享,请注明出处。商业使用请联系作者授权。

相关推荐
YUJIANYUE2 小时前
asp/php日历式值班查询系统2026版
开发语言·php
FJW0208142 小时前
Python装饰器
开发语言·python
忠实米线2 小时前
使用lottie.js播放json动画文件
开发语言·javascript·json
深蓝电商API2 小时前
Selenium无头浏览器配置与反检测技巧
爬虫·python·selenium
0思必得02 小时前
[Web自动化] Selenium浏览器对象方法(操纵浏览器)
前端·python·selenium·自动化·web自动化
信创天地2 小时前
信创环境下数据库与中间件监控实战:指标采集、工具应用与告警体系构建
java·运维·数据库·安全·elk·华为·中间件
CS创新实验室2 小时前
《计算机网络》深入学:虚拟局域网(VLAN)技术与应用
开发语言·计算机网络·php·vlan·虚拟局域网
H Corey2 小时前
Java抽象类与接口实战指南
java·开发语言·学习·intellij-idea
少控科技2 小时前
QT高阶日记011
开发语言·qt