神经网络入门实战:(十八)Argmax函数的详细介绍,可以用来计算模型训练准确率

Argmax函数介绍

在Python中,argmax 函数通常用于找出给定数组或列表中元素值最大的索引。

(一) Numpy 中的 Argmax 函数:

numpy.argmax 函数用于找出给定轴(axis)上最大值所在的索引。

示例:

python 复制代码
import numpy as np
# 一维数组
arr = np.array([1, 3, 2, 5, 4])
index = np.argmax(arr)
print(index)  

# 二维数组
arr_2d = np.array([[1, 3, 2], [5, 4, 6]])
index_axis_0 = np.argmax(arr_2d, axis=0) # 沿着 列(axis=0) 找出最大值的索引
index_axis_1 = np.argmax(arr_2d, axis=1) # 沿着 行(axis=1) 找出最大值的索引
print(index_axis_0)  # 输出: [1 1 1]
print(index_axis_1)  # 输出: [1 2]
-------------------------------------------------------------------------------------------------------
# 运行结果:
3
[1 1 1]
[1 2]

(二) pytorch 中的 Argmax 函数:

  • 功能:torch.argmax 函数在 PyTorch 中可以用于找到给定张量(Tensor)中每一行或每一列的最大值的索引。

  • 用法:对于一个给定的张量,argmax 会返回沿着指定维度的最大值的索引。如果不指定维度,默认是在最后一个维度上操作

    列就是第一个维度,行就是第二个维度,第三个维度可以理解成厚度,以此类推...

  • 代码:

    python 复制代码
    torch.argmax(input, dim=None, keepdim=False)
    • input:输入的张量。

    • dim:指定在哪个维度上寻找最大值,0为一维,1为二维...如果为 None,则默认在最后一个维度上操作。

    • keepdim:如果为 True,则输出的张量将保持与输入张量相同的维度数,否则,输出张量将减少一个维度,默认是 False

    • 输出结果是 int 型,64位的电脑,输出默认就是 int64 型。

    • 简洁写法:

      python 复制代码
      # 假设有一个多维张量 some_tensor
      output = some_tensor.argmax(dim=1) # 逐行找出最大值

    注意事项:

    1. 数据类型argmax 返回的是 int 类型(64位电脑就是 int64 )的张量,因为索引值必须是整数。
    2. 多维张量dim 参数指定了在哪个维度上寻找最大值。如果 dim 为负值,则从最后一个维度开始计数(例如,dim=-1 等于默认行为)。
    3. 相同最大值 :如果有多个相同的最大值,argmax 将返回 ++第一个++ 出现的最大值的索引。
    4. NaN 值 :如果张量中包含 NaN 值,argmax 会忽略这些值,并在剩余的有效值中寻找最大值。
  • 示例:

    python 复制代码
    import torch
    
    x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    result1 = torch.argmax(x, dim=0) # 在第一个维度(列)上寻找最大值的索引
    print(result1)
    result2 = torch.argmax(x, dim=1) # 在最后一个维度(行)上寻找最大值的索引
    print(result2)
    result3 = torch.argmax(x, dim=1, keepdim=True) # 指定keepdim=True
    print(result3) 
    -------------------------------------------------------------------------------------------------------
    # 运行结果:
    tensor([1, 1, 1])
    tensor([2, 2])
    tensor([[2],
            [2]])
  • 高级应用:

    可以在神经网络训练中,计算准确率。

    假设有一个 n 分类模型,输出结果为 n 个概率,分别对应输入图片被识别成这 n 个类别的概率(∈[0,1]);

    n=3 (这 3 种类别分别用 0、1、2 表示),现在有两张输入图片,输出为:

    python 复制代码
    outputs=[[0.1, 0.2, 0.8],
            [0.4, 0.6, 0.2]]

    假设两张图片分别为类别 0 和类别 2 (即 target=[2,2])。对 outputs 使用 torch.argmax 函数:

    python 复制代码
    result = torch.argmax(outputs, dim=1) # 在每行上寻找最大值的索引,因为 keepdim 默认为 False ,故输出是一维张量 
    # 输出结果为:
    tensor([2,1])

    因为 target=[2,2] ,而训练结果为 result=[2,1] ,所以第一个图片识别正确,第二个图片识别错误。

    正确率的计算 :

    python 复制代码
    # 前提是 target 和 result 都是张量
    print(target == result) # 直接比较
    A = (target == result).sum() #True默认为1,False默认位0
    print(A)
    accuracy = (A.item() / len(result)) * 100
    accuracy = roun
    d(accuracy, 2) # 2表示保留两位小数(四舍五入)
    print(f"正确率为:{accuracy}%")
    ---------------------------------------------------------------------------------------
    # 运行结果
    tensor([True,False])
    tensor(1)
    正确率为:50%

上一篇 下一篇
神经网络入门实战(十七) 待发布
相关推荐
黄焖鸡能干四碗5 分钟前
信息系统安全保护措施文件方案
大数据·开发语言·人工智能·web安全·制造
hallo1288 分钟前
学习机器学习能看哪些书籍
人工智能·深度学习·机器学习
中國龍在廣州18 分钟前
哈工大提出空间机器人复合框架,突破高精度轨迹跟踪
人工智能·深度学习·机器学习·计算机视觉·机器人
cetcht888819 分钟前
安徽某能源企业积极推进运维智能化转型,引入高压配电房机器人巡检系统
运维·人工智能·物联网·机器人·能源
健康有益科技28 分钟前
AI驱动健康升级:新零售企业从“卖产品”到“卖健康”的转型路径
大数据·人工智能·健康医疗·零售
文心快码 Baidu Comate37 分钟前
AI界的“超能力”MCP,到底是个啥?
人工智能·程序员·ai编程·文心快码·comate zulu
石氏是时试38 分钟前
拉格朗日多项式
人工智能·算法·机器学习
大模型真好玩1 小时前
大模型工程面试经典(五)—大模型专业领域微调数据集如何构建?
人工智能·python·面试
码界奇点1 小时前
豆包新模型矩阵与PromptPilot构建企业级AI开发的体系化解决方案
人工智能·线性代数·ai·语言模型·矩阵·硬件工程
YangYang9YangYan1 小时前
2025年跨领域管理能力提升认证路径分析
大数据·人工智能