语义分割训练精度计算

语义分割训练的output结果一般是[batch_size, num_classes, width, height]这样的形式,而label的结果一般是[batch_size, width, height],类似如下形状,outputs:[4,6,480,320],而真值label:[4,480,320]。由于维度不同,无法直接比较,所以这两者要比较就要采取一点方法。

output里面每个类型都有一个值,要取最大的值作为得到的类别结果,所以要用到torch.max()函数。

python 复制代码
output = torch.max(input, dim)
# input是softmax函数输出的一个tensor
# dim是max函数索引的维度0/1,0是每列的最大值,1是每行的最大值

dim就是维度,我这里应该取1

该函数的输出是:

函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引。

一般我们不需要tensor每行的最大值,而需要的是索引,也就是结果是哪一类。因为这里的num_classes维度为1,所以dim=1,也就是取出这个维度上的最大值。

python 复制代码
torch.max(output, dim=1)

其次,我们需要比较output和label的值,那就是说对长*宽的所有像素,要比较类别,如果类别一致就加到正确结果中,用正确结果的数量去除以所有像素的总数量,就是精度。这里有两个问题需要解决:

  1. 计算两个tensor之间有多少值相等

在PyTorch中,要计算两个tensor之间有多少值相等,可以使用torch.eq()函数来生成一个布尔型tensor,其中每个元素表示对应位置的元素是否相等(相等为True,不相等为False)。然后,可以使用torch.sum()函数来计算这个布尔型tensor中True的总数,即有多少值相等。但是,需要注意的是,torch.sum()默认计算的是所有元素的和,对于布尔型tensor,True会被当作1处理,False会被当作0处理。

举例如下:

python 复制代码
import torch  
  
# 定义两个tensor  
tensor1 = torch.tensor([1, 2, 3, 4, 5])  
tensor2 = torch.tensor([1, 3, 3, 4, 6])  
  
# 使用torch.eq()比较两个tensor  
equal_mask = torch.eq(tensor1, tensor2)  
  
# 计算相等的值的数量  
equal_count = torch.sum(equal_mask)  
  
print("相等的值的数量:", equal_count)
# 输出:相等的值的数量: tensor(3)

在这个例子中,tensor1和tensor2在位置0、2、3的值是相等的(即1,3和4),所以输出会是相等的值的数量: tensor(3, dtype=torch.int64),表示有3个值相等。

注意,torch.sum()的默认数据类型是torch.int64,但你可以通过指定dtype参数来改变结果的数据类型,例如torch.sum(equal_mask, dtype=torch.float32)。不过,在这个场景下,通常使用默认的torch.int64就足够了。

  1. 获取一个tensor的所有元素总数

在PyTorch中,要获取一个tensor的元素个数(即tensor的总大小),可以使用.numel()方法。这个方法会返回tensor中所有元素的数量,不论tensor的维度如何。

下面是一个简单的例子:

python 复制代码
import torch  
  
# 定义一个tensor  
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])  
  
# 获取tensor的元素个数  
num_elements = tensor.numel()  
  
print("Tensor的元素个数:", num_elements)
# 输出:Tensor的元素个数: 6

在这个例子中,tensor是一个2x3的二维tensor,包含6个元素。因此,输出将会是Tensor的元素个数: 6。

.numel()方法非常适合于快速计算tensor中元素的总数,无需手动计算各维度的乘积。这是处理不同形状和大小的tensor时的一个非常有用的功能。

综上所述,完整的计算模型训练精度的函数如下:

python 复制代码
def get_acc(y_pred, y_true, num_classes):
    #print("y_pred.shape: ", y_pred.shape)
    #print("y_true.shape: ", y_true.shape)
    temp = torch.max(y_pred, dim=1)
    #print(temp[1])
    #print(temp[1].shape)
    # 比较y_true和temp[1]
    equal_mask = torch.eq(temp[1], y_true) # temp[1]就是返回的dim=1的最大值索引,也即是类别号
    equal_count = torch.sum(equal_mask)
    num_elements = y_true.numel()
    #print("equal_count: ", equal_count)
    #print("num_elements: ", num_elements)
    acc = equal_count/num_elements
    #print("acc: ", acc)
    return acc
相关推荐
不惑_1 天前
通俗理解经典CNN架构:VGGNet
人工智能·神经网络·cnn
没学上了1 天前
MNIST
人工智能
audyxiao0011 天前
人工智能顶级期刊PR论文解读|HCRT:基于相关性感知区域的混合网络,用于DCE-MRI图像中的乳腺肿瘤分割
网络·人工智能·智慧医疗·肿瘤分割
零售ERP菜鸟1 天前
IT价值证明:从“成本中心”到“增长引擎”的确定性度量
大数据·人工智能·职场和发展·创业创新·学习方法·业界资讯
叫我:松哥1 天前
基于大数据和深度学习的智能空气质量监测与预测平台,采用Spark数据预处理,利用TensorFlow构建LSTM深度学习模型
大数据·python·深度学习·机器学习·spark·flask·lstm
童话名剑1 天前
目标检测(吴恩达深度学习笔记)
人工智能·目标检测·滑动窗口·目标定位·yolo算法·特征点检测
木卫四科技1 天前
【木卫四 CES 2026】观察:融合智能体与联邦数据湖的安全数据运营成为趋势
人工智能·安全·汽车
珠海西格电力1 天前
零碳园区有哪些政策支持?
大数据·数据库·人工智能·物联网·能源
じ☆冷颜〃1 天前
黎曼几何驱动的算法与系统设计:理论、实践与跨领域应用
笔记·python·深度学习·网络协议·算法·机器学习
启途AI1 天前
2026免费好用的AIPPT工具榜:智能演示文稿制作新纪元
人工智能·powerpoint·ppt