pythonstudy Day41

早停策略和模型权重的保存

@疏锦行

clike 复制代码
import os
import random
from dataclasses import dataclass
from typing import Dict, Tuple

import numpy as np
import pandas as pd

import joblib
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix


# =========================
# 0) Config
# =========================
@dataclass
class Config:
    data_path: str = "data.xlsx"
    target_col: str = "Credit Default"
    drop_cols: Tuple[str, ...] = ("Id",)

    seed: int = 42
    batch_size: int = 256

    # Stage 1 training
    stage1_max_epochs: int = 30

    # Stage 2 training (resume)
    stage2_max_epochs: int = 50

    # Early stopping
    patience: int = 8
    min_delta: float = 1e-4

    # Optimization
    lr: float = 1e-3
    weight_decay: float = 1e-4

    # Artifacts
    artifacts_dir: str = "artifacts"
    preprocess_file: str = "preprocess.joblib"
    checkpoint_file: str = "checkpoint_best.pt"
    final_weights_file: str = "model_final.pt"


CFG = Config()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


# =========================
# 1) Utils
# =========================
def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def to_dense(x):
    return x.toarray() if hasattr(x, "toarray") else np.asarray(x)


@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, y_true: np.ndarray, threshold: float = 0.5) -> Dict:
    model.eval()
    probs = []
    for xb, _ in loader:
        xb = xb.to(DEVICE)
        logits = model(xb)
        p = torch.sigmoid(logits).detach().cpu().numpy()
        probs.append(p)
    probs = np.concatenate(probs, axis=0)
    pred = (probs >= threshold).astype(int)

    return {
        "acc": float(accuracy_score(y_true, pred)),
        "f1": float(f1_score(y_true, pred)),
        "auc": float(roc_auc_score(y_true, probs)),
        "cm": confusion_matrix(y_true, pred),
    }


def save_checkpoint(path: str, model: nn.Module, optimizer: torch.optim.Optimizer,
                    epoch: int, best_auc: float, patience_counter: int) -> None:
    payload = {
        "epoch": epoch,
        "best_auc": best_auc,
        "patience_counter": patience_counter,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
    }
    torch.save(payload, path)


def load_checkpoint(path: str, model: nn.Module, optimizer: torch.optim.Optimizer = None):
    ckpt = torch.load(path, map_location="cpu")
    model.load_state_dict(ckpt["model_state"])
    if optimizer is not None and "optimizer_state" in ckpt:
        optimizer.load_state_dict(ckpt["optimizer_state"])
    return ckpt


# =========================
# 2) Data
# =========================
class NPDataset(Dataset):
    def __init__(self, X_np: np.ndarray, y_np: np.ndarray):
        self.X = torch.from_numpy(X_np.astype(np.float32))
        self.y = torch.from_numpy(y_np.astype(np.float32))

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, i):
        return self.X[i], self.y[i]


def build_preprocess(X: pd.DataFrame) -> ColumnTransformer:
    cat_cols = X.select_dtypes(include=["object"]).columns.tolist()
    num_cols = [c for c in X.columns if c not in cat_cols]

    numeric_pipe = Pipeline([
        ("imputer", SimpleImputer(strategy="median")),
        ("scaler", StandardScaler())
    ])
    categorical_pipe = Pipeline([
        ("imputer", SimpleImputer(strategy="most_frequent")),
        ("onehot", OneHotEncoder(handle_unknown="ignore"))
    ])

    preprocess = ColumnTransformer(
        transformers=[
            ("num", numeric_pipe, num_cols),
            ("cat", categorical_pipe, cat_cols),
        ],
        remainder="drop"
    )
    return preprocess


def load_and_prepare(cfg: Config):
    df = pd.read_excel(cfg.data_path)

    # 常见占位异常值 -> 缺失
    if "Current Loan Amount" in df.columns:
        df.loc[df["Current Loan Amount"] == 99999999, "Current Loan Amount"] = np.nan

    for c in cfg.drop_cols:
        if c in df.columns:
            df = df.drop(columns=[c])

    if cfg.target_col not in df.columns:
        raise ValueError(f"Target column '{cfg.target_col}' not found. Columns: {list(df.columns)}")

    y = df[cfg.target_col].astype(int).values
    X = df.drop(columns=[cfg.target_col])

    # split (70/15/15) stratify
    X_train, X_temp, y_train, y_temp = train_test_split(
        X, y, test_size=0.30, random_state=cfg.seed, stratify=y
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=0.50, random_state=cfg.seed, stratify=y_temp
    )

    preprocess = build_preprocess(X_train)
    X_train_t = to_dense(preprocess.fit_transform(X_train)).astype(np.float32)
    X_val_t = to_dense(preprocess.transform(X_val)).astype(np.float32)
    X_test_t = to_dense(preprocess.transform(X_test)).astype(np.float32)

    train_loader = DataLoader(NPDataset(X_train_t, y_train), batch_size=cfg.batch_size, shuffle=True)
    val_loader = DataLoader(NPDataset(X_val_t, y_val), batch_size=cfg.batch_size, shuffle=False)
    test_loader = DataLoader(NPDataset(X_test_t, y_test), batch_size=cfg.batch_size, shuffle=False)

    return preprocess, (X_train_t, y_train, train_loader), (X_val_t, y_val, val_loader), (X_test_t, y_test, test_loader)


# =========================
# 3) Model
# =========================
class MLP(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.25),

            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.20),

            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.10),

            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.net(x).squeeze(1)


def make_loss(y_train: np.ndarray) -> nn.Module:
    pos = (y_train == 1).sum()
    neg = (y_train == 0).sum()
    pos_weight = torch.tensor([neg / max(pos, 1)], dtype=torch.float32, device=DEVICE)
    return nn.BCEWithLogitsLoss(pos_weight=pos_weight)


# =========================
# 4) Train loop with early stopping + checkpoint
# =========================
def train_with_early_stopping(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    y_val: np.ndarray,
    max_epochs: int,
    patience: int,
    min_delta: float,
    ckpt_path: str,
    start_epoch: int = 0,
    best_auc_init: float = -1.0,
    patience_counter_init: int = 0,
) -> Dict:
    best_auc = best_auc_init
    patience_counter = patience_counter_init

    for epoch in range(start_epoch + 1, start_epoch + max_epochs + 1):
        model.train()
        losses = []

        for xb, yb in train_loader:
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)

            optimizer.zero_grad()
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optimizer.step()
            losses.append(loss.item())

        val_metrics = evaluate(model, val_loader, y_val, threshold=0.5)
        tr_loss = float(np.mean(losses))

        print(
            f"Epoch {epoch:03d} | loss={tr_loss:.4f} | "
            f"val_auc={val_metrics['auc']:.4f} val_f1={val_metrics['f1']:.4f} val_acc={val_metrics['acc']:.4f}"
        )

        improved = val_metrics["auc"] > (best_auc + min_delta)
        if improved:
            best_auc = val_metrics["auc"]
            patience_counter = 0
            save_checkpoint(ckpt_path, model, optimizer, epoch, best_auc, patience_counter)
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch}. Best val AUC = {best_auc:.4f}")
                break

    return {
        "best_auc": best_auc,
        "patience_counter": patience_counter,
    }


def main():
    print(f"Using device: {DEVICE}")
    set_seed(CFG.seed)
    ensure_dir(CFG.artifacts_dir)

    preprocess, train_pack, val_pack, test_pack = load_and_prepare(CFG)
    _, y_train, train_loader = train_pack
    _, y_val, val_loader = val_pack
    _, y_test, test_loader = test_pack

    # Save preprocess (must!)
    preprocess_path = os.path.join(CFG.artifacts_dir, CFG.preprocess_file)
    joblib.dump(preprocess, preprocess_path)
    print(f"Saved preprocess: {preprocess_path}")

    input_dim = train_pack[0].shape[1]
    model = MLP(input_dim).to(DEVICE)

    criterion = make_loss(y_train)
    optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)

    ckpt_path = os.path.join(CFG.artifacts_dir, CFG.checkpoint_file)

    # -------------------------
    # Stage 1: Train then save best checkpoint
    # -------------------------
    print("\n===== Stage 1: Train & Save Weights =====")
    stage1_state = train_with_early_stopping(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        train_loader=train_loader,
        val_loader=val_loader,
        y_val=y_val,
        max_epochs=CFG.stage1_max_epochs,
        patience=CFG.patience,
        min_delta=CFG.min_delta,
        ckpt_path=ckpt_path,
        start_epoch=0,
        best_auc_init=-1.0,
        patience_counter_init=0,
    )

    # -------------------------
    # Stage 2: Load weights and continue training up to 50 epochs with early stopping
    # -------------------------
    print("\n===== Stage 2: Resume from Checkpoint & Continue 50 Epochs (Early Stop) =====")
    # 重新构建 optimizer(也可以沿用),然后从 checkpoint 恢复
    optimizer2 = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)

    if os.path.exists(ckpt_path):
        ckpt = load_checkpoint(ckpt_path, model, optimizer2)
        start_epoch = int(ckpt.get("epoch", 0))
        best_auc = float(ckpt.get("best_auc", -1.0))
        # 第二阶段通常"重新计数 early stop",更符合"继续训练50轮并早停"
        patience_counter = 0
        print(f"Loaded checkpoint from epoch={start_epoch}, best_val_auc={best_auc:.4f}")
    else:
        # 没有 checkpoint 就从当前模型接着训练
        start_epoch = 0
        best_auc = stage1_state["best_auc"]
        patience_counter = 0
        print("Checkpoint not found, continuing from current weights.")

    _ = train_with_early_stopping(
        model=model,
        optimizer=optimizer2,
        criterion=criterion,
        train_loader=train_loader,
        val_loader=val_loader,
        y_val=y_val,
        max_epochs=CFG.stage2_max_epochs,   # <= 关键:最多继续50轮
        patience=CFG.patience,
        min_delta=CFG.min_delta,
        ckpt_path=ckpt_path,               # 持续覆盖保存"best checkpoint"
        start_epoch=start_epoch,
        best_auc_init=best_auc,
        patience_counter_init=patience_counter,
    )

    # 最终评估:加载 best checkpoint 再测 test(避免最后几轮变差)
    if os.path.exists(ckpt_path):
        _ = load_checkpoint(ckpt_path, model, optimizer=None)

    test_metrics = evaluate(model, test_loader, y_test, threshold=0.5)
    print("\n=== Test Metrics (Best Checkpoint) ===")
    print(f"Accuracy: {test_metrics['acc']:.4f}")
    print(f"F1      : {test_metrics['f1']:.4f}")
    print(f"ROC-AUC  : {test_metrics['auc']:.4f}")
    print("Confusion Matrix:\n", test_metrics["cm"])

    # 保存最终权重(可选:保存 best checkpoint 已经足够)
    final_path = os.path.join(CFG.artifacts_dir, CFG.final_weights_file)
    torch.save(model.state_dict(), final_path)
    print(f"\nSaved final weights: {final_path}")
    print(f"Best checkpoint: {ckpt_path}")


if __name__ == "__main__":
    main()
相关推荐
秦ぅ时2 分钟前
保姆级教程|OpenAI tts-1-hd模型调用全流程(Python+curl+懒人用法)
开发语言·python
Muyuan19984 分钟前
25.Paper RAG Agent 优化记录:上传反馈、计算器安全与 Chunk 参数调整
python·安全·django·sqlite·fastapi
Java面试题总结17 分钟前
使用 Python 设置 Excel 数据验证
开发语言·python·excel
小郑加油26 分钟前
python学习Day10天:列表进阶 + 内置函数 + 代码简化
开发语言·python·学习
数据智能老司机26 分钟前
学习 AutoML——理解 AutoML 流水线
机器学习
时空系1 小时前
第13篇:综合实战——制作我的小游戏 python中文编程
开发语言·python·ai编程
Li emily1 小时前
港股api接入指南:实时行情与历史数据获取
python·api·fastapi
AI技术增长1 小时前
Pytorch图像去噪实战(十三):DDIM加速扩散模型采样,让去噪从1000步降到50步
人工智能·pytorch·python
刀法如飞1 小时前
Python列表去重:从新手三连到高阶特技,20种解法全收录
python·算法·编程语言
小糖学代码1 小时前
LLM系列:1.python入门:16.正则表达式与文本处理 (re)
人工智能·pytorch·python·深度学习·神经网络·正则表达式