神经网络入门实战:(十八)Argmax函数的详细介绍,可以用来计算模型训练准确率

Argmax函数介绍

在Python中,argmax 函数通常用于找出给定数组或列表中元素值最大的索引。

(一) Numpy 中的 Argmax 函数:

numpy.argmax 函数用于找出给定轴(axis)上最大值所在的索引。

示例:

python 复制代码
import numpy as np
# 一维数组
arr = np.array([1, 3, 2, 5, 4])
index = np.argmax(arr)
print(index)  

# 二维数组
arr_2d = np.array([[1, 3, 2], [5, 4, 6]])
index_axis_0 = np.argmax(arr_2d, axis=0) # 沿着 列(axis=0) 找出最大值的索引
index_axis_1 = np.argmax(arr_2d, axis=1) # 沿着 行(axis=1) 找出最大值的索引
print(index_axis_0)  # 输出: [1 1 1]
print(index_axis_1)  # 输出: [1 2]
-------------------------------------------------------------------------------------------------------
# 运行结果:
3
[1 1 1]
[1 2]

(二) pytorch 中的 Argmax 函数:

  • 功能:torch.argmax 函数在 PyTorch 中可以用于找到给定张量(Tensor)中每一行或每一列的最大值的索引。

  • 用法:对于一个给定的张量,argmax 会返回沿着指定维度的最大值的索引。如果不指定维度,默认是在最后一个维度上操作

    列就是第一个维度,行就是第二个维度,第三个维度可以理解成厚度,以此类推...

  • 代码:

    python 复制代码
    torch.argmax(input, dim=None, keepdim=False)
    • input:输入的张量。

    • dim:指定在哪个维度上寻找最大值,0为一维,1为二维...如果为 None,则默认在最后一个维度上操作。

    • keepdim:如果为 True,则输出的张量将保持与输入张量相同的维度数,否则,输出张量将减少一个维度,默认是 False

    • 输出结果是 int 型,64位的电脑,输出默认就是 int64 型。

    • 简洁写法:

      python 复制代码
      # 假设有一个多维张量 some_tensor
      output = some_tensor.argmax(dim=1) # 逐行找出最大值

    注意事项:

    1. 数据类型argmax 返回的是 int 类型(64位电脑就是 int64 )的张量,因为索引值必须是整数。
    2. 多维张量dim 参数指定了在哪个维度上寻找最大值。如果 dim 为负值,则从最后一个维度开始计数(例如,dim=-1 等于默认行为)。
    3. 相同最大值 :如果有多个相同的最大值,argmax 将返回 ++第一个++ 出现的最大值的索引。
    4. NaN 值 :如果张量中包含 NaN 值,argmax 会忽略这些值,并在剩余的有效值中寻找最大值。
  • 示例:

    python 复制代码
    import torch
    
    x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    result1 = torch.argmax(x, dim=0) # 在第一个维度(列)上寻找最大值的索引
    print(result1)
    result2 = torch.argmax(x, dim=1) # 在最后一个维度(行)上寻找最大值的索引
    print(result2)
    result3 = torch.argmax(x, dim=1, keepdim=True) # 指定keepdim=True
    print(result3) 
    -------------------------------------------------------------------------------------------------------
    # 运行结果:
    tensor([1, 1, 1])
    tensor([2, 2])
    tensor([[2],
            [2]])
  • 高级应用:

    可以在神经网络训练中,计算准确率。

    假设有一个 n 分类模型,输出结果为 n 个概率,分别对应输入图片被识别成这 n 个类别的概率(∈[0,1]);

    n=3 (这 3 种类别分别用 0、1、2 表示),现在有两张输入图片,输出为:

    python 复制代码
    outputs=[[0.1, 0.2, 0.8],
            [0.4, 0.6, 0.2]]

    假设两张图片分别为类别 0 和类别 2 (即 target=[2,2])。对 outputs 使用 torch.argmax 函数:

    python 复制代码
    result = torch.argmax(outputs, dim=1) # 在每行上寻找最大值的索引,因为 keepdim 默认为 False ,故输出是一维张量 
    # 输出结果为:
    tensor([2,1])

    因为 target=[2,2] ,而训练结果为 result=[2,1] ,所以第一个图片识别正确,第二个图片识别错误。

    正确率的计算 :

    python 复制代码
    # 前提是 target 和 result 都是张量
    print(target == result) # 直接比较
    A = (target == result).sum() #True默认为1,False默认位0
    print(A)
    accuracy = (A.item() / len(result)) * 100
    accuracy = roun
    d(accuracy, 2) # 2表示保留两位小数(四舍五入)
    print(f"正确率为:{accuracy}%")
    ---------------------------------------------------------------------------------------
    # 运行结果
    tensor([True,False])
    tensor(1)
    正确率为:50%

上一篇 下一篇
神经网络入门实战(十七) 待发布
相关推荐
刀客12310 分钟前
python3+TensorFlow 2.x(四)反向传播
人工智能·python·tensorflow
SpikeKing16 分钟前
LLM - 大模型 ScallingLaws 的设计 100B 预训练方案(PLM) 教程(5)
人工智能·llm·预训练·scalinglaws·100b·deepnorm·egs
小枫@码40 分钟前
免费GPU算力,不花钱部署DeepSeek-R1
人工智能·语言模型
liruiqiang0541 分钟前
机器学习 - 初学者需要弄懂的一些线性代数的概念
人工智能·线性代数·机器学习·线性回归
Icomi_1 小时前
【外文原版书阅读】《机器学习前置知识》1.线性代数的重要性,初识向量以及向量加法
c语言·c++·人工智能·深度学习·神经网络·机器学习·计算机视觉
微学AI1 小时前
GPU算力平台|在GPU算力平台部署可图大模型Kolors的应用实战教程
人工智能·大模型·llm·gpu算力
西猫雷婶1 小时前
python学opencv|读取图像(四十六)使用cv2.bitwise_or()函数实现图像按位或运算
人工智能·opencv·计算机视觉
IT古董1 小时前
【深度学习】常见模型-生成对抗网络(Generative Adversarial Network, GAN)
人工智能·深度学习·生成对抗网络
Jackilina_Stone1 小时前
【论文阅读笔记】“万字”关于深度学习的图像和视频阴影检测、去除和生成的综述笔记 | 2024.9.3
论文阅读·人工智能·笔记·深度学习·ai
梦云澜1 小时前
论文阅读(三):微阵列数据的图形模型和多变量分析
论文阅读·深度学习