【深度学习】多分类任务评估指标sklearn和torchmetrics对比

【深度学习】多分类任务评估指标sklearn和torchmetrics对比

  • 说明
  • sklearn代码
  • torchmetrics代码
  • 两个MultiClassReport类的对比分析
    • [1. 代码结构与实现方式](#1. 代码结构与实现方式)
    • [2. 数据处理与内存使用](#2. 数据处理与内存使用)
    • [3. 性能与效率](#3. 性能与效率)
  • 二分类任务评估指标
    • [1. 准确率(Accuracy)](#1. 准确率(Accuracy))
    • [2. 精确率(Precision)](#2. 精确率(Precision))
    • [3. 召回率(Recall)](#3. 召回率(Recall))
    • [4. F1值(F1-score)](#4. F1值(F1-score))
  • 多分类评估指标
    • [1. 混淆矩阵(Confusion Matrix)](#1. 混淆矩阵(Confusion Matrix))
    • [2. 准确率(Accuracy)](#2. 准确率(Accuracy))
    • [3. 精确率(Precision)](#3. 精确率(Precision))
    • [4. 召回率(Recall)](#4. 召回率(Recall))
    • [5. F1值(宏平均)](#5. F1值(宏平均))

说明

sklearn和torchmetrics两个metric代码跑模型的输出结果一致,对比他们的区别。评估指标写在下面

sklearn代码

python 复制代码
import torch
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

class MultiClassReport():
    """
    Accuracy, F1 Score, Precision and Recall for multi - class classification task.
    """

    def __init__(self, name='MultiClassReport', average='macro'):
        super(MultiClassReport, self).__init__()
        self.average = average
        self._name = name
        self.reset()

    def reset(self):
        """
        Resets all the metric state.
        """
        self.y_prob = []
        self.y_true = []

    def update(self, probs, labels):
        # 将Tensor转换为numpy数组并添加到相应列表中
        if isinstance(probs, torch.Tensor):
            if probs.requires_grad:
                probs = probs.detach()
            probs = probs.cpu().numpy()
        if isinstance(labels, torch.Tensor):
            if labels.requires_grad:
                labels = labels.detach()
            labels = labels.cpu().numpy()
        self.y_prob.extend(probs)
        self.y_true.extend(labels)
        self.y_prob.extend(probs)
        self.y_true.extend(labels)

    def accumulate(self):
        accuracy = accuracy_score(self.y_true, np.argmax(self.y_prob, axis=1))
        f1 = f1_score(self.y_true, np.argmax(self.y_prob, axis=1), average=self.average)
        precision = precision_score(self.y_true, np.argmax(self.y_prob, axis=1), average=self.average)
        recall = recall_score(self.y_true, np.argmax(self.y_prob, axis=1), average=self.average)
        return accuracy, f1, precision, recall

    def name(self):
        """
        Returns metric name
        """
        return self._name

torchmetrics代码

python 复制代码
from torchmetrics import Accuracy, F1Score, Precision, Recall
from model import polarity_classes, device

# 创建评估指标对象
accuracy_metric = Accuracy(task='multiclass', num_classes=polarity_classes).to(device)
f1_metric = F1Score(task='multiclass', num_classes=polarity_classes, average='macro').to(device)
precision_metric = Precision(task='multiclass', num_classes=polarity_classes, average='macro').to(device)
recall_metric = Recall(task='multiclass', num_classes=polarity_classes, average='macro').to(device)

class MultiClassReport():
    """
    Accuracy, F1 Score, Precision and Recall for multi-class classification task.
    average:micro、macro
    """

    def __init__(self, name='MultiClassReport', average='macro'):
        super(MultiClassReport, self).__init__()
        self.average = average
        self._name = name

    def reset(self):
        """
        Resets all the metric state.
        """
        accuracy_metric.reset()
        f1_metric.reset()
        precision_metric.reset()
        recall_metric.reset()

    def update(self, probs, labels):
        accuracy_metric.update(probs, labels)
        f1_metric.update(probs, labels)
        precision_metric.update(probs, labels)
        recall_metric.update(probs, labels)

    def accumulate(self):
        accuracy = accuracy_metric.compute()
        f1 = f1_metric.compute()
        precision = precision_metric.compute()
        recall = recall_metric.compute()
        return accuracy, f1, precision, recall

    def name(self):
        """
        Returns metric name
        """
        return self._name

两个MultiClassReport类的对比分析

1. 代码结构与实现方式

  • sklearn版本
    • 代码逻辑较为清晰直接。在update方法中,将输入的PyTorch张量转换为numpy数组,并存储到y_proby_true列表中。在accumulate方法中,直接使用sklearnaccuracy_scoref1_scoreprecision_scorerecall_score函数基于存储的列表数据计算评估指标。
  • torchmetrics版本
    • 利用torchmetrics库提供的专门的评估指标类(AccuracyF1ScorePrecisionRecall)。在update方法中,直接调用这些类的update方法来处理输入数据,内部有自己的状态管理机制。在accumulate方法中,通过调用相应类的compute方法获取评估指标值。
    • 这种方式与PyTorch的生态系统集成得更好,尤其是在基于PyTorch进行深度学习项目开发时,可以方便地在GPU上进行计算(如果deviceGPU),并且可以利用torchmetrics库的其他特性,如分布式训练支持等。

2. 数据处理与内存使用

  • sklearn版本
    • update方法中不断扩展y_proby_true列表来存储数据。如果处理大量数据,可能会占用较多内存,因为它需要将所有的预测概率和真实标签都保存在内存中。
    • 每次计算评估指标时,都需要对整个存储的数组进行操作,如np.argmax等,这在数据量较大时可能会有一定的计算开销。
  • torchmetrics版本
    • 虽然torchmetrics类内部也需要存储一定的状态信息,但它们的设计可能更高效地利用内存和处理数据更新。例如,它们可能会采用增量计算的方式,而不是像sklearn版本那样一次性处理所有数据。
    • 在处理大规模数据或长时间训练过程中,torchmetrics版本可能在内存管理和计算效率方面更有优势。

3. 性能与效率

  • sklearn版本
    • 在小规模数据和简单场景下,性能表现良好。但随着数据量的增加和模型复杂度的提高,由于数据转换和计算方式的原因,可能会出现性能瓶颈。
  • torchmetrics版本
    • 设计初衷就是为了在PyTorch深度学习环境中高效运行,特别是在利用GPU计算资源时,能够更高效地更新和计算评估指标,更适合大规模数据和复杂模型的评估场景。

二分类任务评估指标

TP(True Positive)是真正例,TN(True Negative)是真反例,FP(False Positive)是假正例,FN(False Negative)是假反例。

1. 准确率(Accuracy)

准确率是指在所有预测样本中,预测正确的样本所占的比例。它衡量的是模型整体预测正确的程度。

2. 精确率(Precision)

精确率是指在所有被预测为正类的样本中,真正为正类的样本所占的比例。

3. 召回率(Recall)

召回率是指在所有实际为正类的样本中,被模型正确预测为正类的样本所占的比例。

4. F1值(F1-score)

F1值是精确率和召回率的调和平均数,它综合考虑了精确率和召回率两个指标,能够更全面地评估模型的性能。

多分类评估指标

1. 混淆矩阵(Confusion Matrix)

它是一个方阵,用来展示分类模型在每个类别上的预测对错情况。行代表真实类别,列代表预测类别,某个位置的值就是实际是某类却被预测成另一类的样本数量,能直观呈现模型对各类别预测的混淆情况。

2. 准确率(Accuracy)

就是模型预测正确的样本数占总样本数的比例,反映整体预测正确程度。

3. 精确率(Precision)

  • 类别精确率 :对于每个类别,是预测为该类且正确的样本数除以预测为该类的样本数,看预测某类时的准确程度。

  • 宏平均精确率(Macro-average Precision) :先算出每个类别的精确率,再求平均,平等看待每个类别。

  • 微平均精确率(Micro-average Precision) :将所有类别预测对的情况汇总除以预测的总数,从整体上看预测的精准情况,对类别不平衡不太敏感。

4. 召回率(Recall)

  • 类别召回率 :对于每个类别,是预测为该类且正确的样本数除以实际是该类的样本数,体现对该类样本的召回能力。

  • 宏平均召回率(Macro- average Recall) :先算出每个类别的召回率,再求平均,衡量对每个类别样本的召回水平。

  • 微平均召回率(Micro-average Recall) :从整体角度,用所有类别预测正确的样本总数除以实际各类别样本总数,综合评估召回情况。

5. F1值(宏平均)

  • 类别F1值 :是类别精确率和召回率的调和平均数,综合二者信息。

  • 宏平均F1值(Macro-average F1) :先计算每个类别的F1值,再平均,更全面地体现模型对各分类的整体性能。

  • 微平均F1值(Micro-average F1) :基于微平均精确率(Precision micro)和微平均召回率(Recall micro)来计算

    微平均

相关推荐
学习BigData4 小时前
【使用PyQt5和YOLOv11开发电脑屏幕区域的实时分类GUI】——选择检测区域
qt·yolo·分类
Leweslyh5 小时前
物理信息神经网络(PINN)八课时教案
人工智能·深度学习·神经网络·物理信息神经网络
love you joyfully5 小时前
目标检测与R-CNN——pytorch与paddle实现目标检测与R-CNN
人工智能·pytorch·目标检测·cnn·paddle
大多_C5 小时前
BERT outputs
人工智能·深度学习·bert
知恩呐1117 小时前
seed_everything 函数
人工智能·深度学习
【建模先锋】8 小时前
故障诊断 | 一个小创新:特征提取+KAN分类
人工智能·分类·数据挖掘
卓琢9 小时前
2024 年 IA 技术大爆发深度解析
深度学习·ai·论文笔记
不如语冰9 小时前
深度学习Python基础(2)
人工智能·python·深度学习·语言模型
七夜星七夜月9 小时前
时间序列预测论文阅读和相关代码库
论文阅读·python·深度学习
红色的山茶花10 小时前
YOLOv9-0.1部分代码阅读笔记-dataloaders.py
笔记·深度学习·yolo