深度学习训练八股

一、模型中的函数的定义

1.torchmetrics.AUROC

(1).binary
复制代码
>>> from torch import tensor
>>> preds = tensor([0.13, 0.26, 0.08, 0.19, 0.34])
>>> target = tensor([0, 0, 1, 1, 1])
>>> auroc = AUROC(task="binary")
>>> auroc(preds, target)
tensor(0.5000)
(2).multiclass
复制代码
>>> preds = tensor([[0.90, 0.05, 0.05],
...                       [0.05, 0.90, 0.05],
...                       [0.05, 0.05, 0.90],
...                       [0.85, 0.05, 0.10],
...                       [0.10, 0.10, 0.80]])
>>> target = tensor([0, 1, 1, 2, 2])
>>> auroc = AUROC(task="multiclass", num_classes=3)
>>> auroc(preds, target)
tensor(0.7778)

注意函数中average参数的默认值为"macro"。

二、test_k_fold_test_copy.py---.logs_k_fold/result_draw

复制代码
# Test script
def test(model, test_loader, writer, device,criterion,roc_path,fold):
    model.eval()
    accuracy = Accuracy(task='multiclass', num_classes=2).to(device)
    precision = Precision(task='multiclass', average='macro', num_classes=2).to(device)
    recall = Recall(task='multiclass', average='macro', num_classes=2).to(device)
    auroc = AUROC(task='multiclass',num_classes=2).to(device)
    f1 = F1Score(num_classes=2, task='multiclass', average='macro').to(device)
    specificity=Specificity(num_classes=2, task='multiclass', average='macro').to(device)

    pred_scores = [] 
    true_labels = []
    pred_labels = []
    

    fold_results={}
    with torch.no_grad():
        for images, coords, labels, _, _  in test_loader:
            images = images.to(device)
            labels = labels.to(device) 
            outputs = model(images,coords)
            _, predicted = torch.max(outputs.data, 1)
                    
            accuracy(predicted, labels.data)
            precision(predicted, labels.data)
            recall(predicted, labels.data)
            f1(predicted, labels.data)
            #auroc(predicted, labels.data)
            specificity(predicted, labels.data)
            auroc(outputs, labels.data)
            pred_labels.extend(predicted.cpu().numpy())
            pred_scores.extend(outputs.cpu().numpy()) 
            true_labels.extend(labels.cpu().numpy())

    acc = accuracy.compute().item() 
    prec = precision.compute().item() 
    rec = recall.compute().item() 
    f1_score = f1.compute().item()
    auroc_score = auroc.compute().item()
    spec=specificity.compute().item()

    fold_results['fold']=fold
    fold_results['accuracy'] = acc
    fold_results['precision'] = prec
    fold_results['recall'] = rec
    fold_results['f1_score'] = f1_score
    fold_results['auroc_score'] = auroc_score
    fold_results['specificity'] = spec
    
    logging.info(f"Test Accuracy: {acc:.4f}, Test precision: {prec:.4f}, Test recall: {rec:.4f}, Test f1: {f1_score:.4f}, Test auroc: {auroc_score:4f},Test specificity:{spec:.4f}")
    logging.error("This is a fatal log!")   
    
    roc = MulticlassROC(num_classes=2, thresholds=None)
    pred_scores = torch.Tensor(pred_scores).to(device)
    true_labels = torch.Tensor(true_labels).int().to(device)
    fpr, tpr, thresholds = roc(pred_scores, true_labels)
    
    draw_fold_path = Path(os.path.join(fprs_tprs_path, f'fold_{fold}'))
    draw_fold_path.mkdir(parents=True, exist_ok=True)
    torch.save(tpr,os.path.join(draw_fold_path,"tpr.pt"))
    torch.save(fpr,os.path.join(draw_fold_path,"fpr.pt"))
   
    return fold_results, fpr, tpr
相关推荐
leo__5208 小时前
基于A星算法的MATLAB路径规划实现
人工智能·算法·matlab
AAD555888998 小时前
基于YOLO11的自然景观多类别目标检测系统 山脉海洋湖泊森林建筑物桥梁道路农田沙漠海滩等多种景观元素检测识别
人工智能·目标检测·计算机视觉
数据分享者8 小时前
新闻文本智能识别数据集:40587条高质量标注数据推动自然语言处理技术发展-新闻信息提取、舆情分析、媒体内容理解-机器学习模型训练-智能分类系统
人工智能·自然语言处理·数据挖掘·easyui·新闻文本
___波子 Pro Max.8 小时前
LLM大语言模型定义与核心特征解析
人工智能·语言模型·自然语言处理
LDG_AGI8 小时前
【机器学习】深度学习推荐系统(三十):X 推荐算法Phoenix rerank机制
人工智能·分布式·深度学习·算法·机器学习·推荐算法
厦门小杨9 小时前
汽车内饰的面料究竟如何依靠AI验布机实现检测创新
大数据·人工智能·深度学习·汽车·制造·ai视觉验布机·纺织
devnullcoffee9 小时前
2026年Amazon Listing优化完全指南:COSMO算法与Rufus AI技术解析
人工智能·python·算法·亚马逊运营·amazon listing·cosmo算法·rufus ai技术
python机器学习ML9 小时前
机器学习——16种模型(基础+集成学习)+多角度SHAP高级可视化+Streamlit交互式应用+RFE特征选择+Optuna+完整项目
人工智能·python·机器学习·分类·数据挖掘·scikit-learn·集成学习
OLOLOadsd1239 小时前
激光设备目标检测 - 基于YOLOv5-HGNetV2的高精度检测模型实现_1
人工智能·yolo·目标检测
喜欢吃豆9 小时前
PostgreSQL 高维向量存储架构深度解析:架构限制、核心原理与行业解决方案
数据库·人工智能·postgresql·架构·2025博客之星