神经网络入门实战:(十八)Argmax函数的详细介绍,可以用来计算模型训练准确率

Argmax函数介绍

在Python中,argmax 函数通常用于找出给定数组或列表中元素值最大的索引。

(一) Numpy 中的 Argmax 函数:

numpy.argmax 函数用于找出给定轴(axis)上最大值所在的索引。

示例:

python 复制代码
import numpy as np
# 一维数组
arr = np.array([1, 3, 2, 5, 4])
index = np.argmax(arr)
print(index)  

# 二维数组
arr_2d = np.array([[1, 3, 2], [5, 4, 6]])
index_axis_0 = np.argmax(arr_2d, axis=0) # 沿着 列(axis=0) 找出最大值的索引
index_axis_1 = np.argmax(arr_2d, axis=1) # 沿着 行(axis=1) 找出最大值的索引
print(index_axis_0)  # 输出: [1 1 1]
print(index_axis_1)  # 输出: [1 2]
-------------------------------------------------------------------------------------------------------
# 运行结果:
3
[1 1 1]
[1 2]

(二) pytorch 中的 Argmax 函数:

  • 功能:torch.argmax 函数在 PyTorch 中可以用于找到给定张量(Tensor)中每一行或每一列的最大值的索引。

  • 用法:对于一个给定的张量,argmax 会返回沿着指定维度的最大值的索引。如果不指定维度,默认是在最后一个维度上操作

    列就是第一个维度,行就是第二个维度,第三个维度可以理解成厚度,以此类推...

  • 代码:

    python 复制代码
    torch.argmax(input, dim=None, keepdim=False)
    • input:输入的张量。

    • dim:指定在哪个维度上寻找最大值,0为一维,1为二维...如果为 None,则默认在最后一个维度上操作。

    • keepdim:如果为 True,则输出的张量将保持与输入张量相同的维度数,否则,输出张量将减少一个维度,默认是 False

    • 输出结果是 int 型,64位的电脑,输出默认就是 int64 型。

    • 简洁写法:

      python 复制代码
      # 假设有一个多维张量 some_tensor
      output = some_tensor.argmax(dim=1) # 逐行找出最大值

    注意事项:

    1. 数据类型argmax 返回的是 int 类型(64位电脑就是 int64 )的张量,因为索引值必须是整数。
    2. 多维张量dim 参数指定了在哪个维度上寻找最大值。如果 dim 为负值,则从最后一个维度开始计数(例如,dim=-1 等于默认行为)。
    3. 相同最大值 :如果有多个相同的最大值,argmax 将返回 ++第一个++ 出现的最大值的索引。
    4. NaN 值 :如果张量中包含 NaN 值,argmax 会忽略这些值,并在剩余的有效值中寻找最大值。
  • 示例:

    python 复制代码
    import torch
    
    x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    result1 = torch.argmax(x, dim=0) # 在第一个维度(列)上寻找最大值的索引
    print(result1)
    result2 = torch.argmax(x, dim=1) # 在最后一个维度(行)上寻找最大值的索引
    print(result2)
    result3 = torch.argmax(x, dim=1, keepdim=True) # 指定keepdim=True
    print(result3) 
    -------------------------------------------------------------------------------------------------------
    # 运行结果:
    tensor([1, 1, 1])
    tensor([2, 2])
    tensor([[2],
            [2]])
  • 高级应用:

    可以在神经网络训练中,计算准确率。

    假设有一个 n 分类模型,输出结果为 n 个概率,分别对应输入图片被识别成这 n 个类别的概率(∈[0,1]);

    n=3 (这 3 种类别分别用 0、1、2 表示),现在有两张输入图片,输出为:

    python 复制代码
    outputs=[[0.1, 0.2, 0.8],
            [0.4, 0.6, 0.2]]

    假设两张图片分别为类别 0 和类别 2 (即 target=[2,2])。对 outputs 使用 torch.argmax 函数:

    python 复制代码
    result = torch.argmax(outputs, dim=1) # 在每行上寻找最大值的索引,因为 keepdim 默认为 False ,故输出是一维张量 
    # 输出结果为:
    tensor([2,1])

    因为 target=[2,2] ,而训练结果为 result=[2,1] ,所以第一个图片识别正确,第二个图片识别错误。

    正确率的计算 :

    python 复制代码
    # 前提是 target 和 result 都是张量
    print(target == result) # 直接比较
    A = (target == result).sum() #True默认为1,False默认位0
    print(A)
    accuracy = (A.item() / len(result)) * 100
    accuracy = roun
    d(accuracy, 2) # 2表示保留两位小数(四舍五入)
    print(f"正确率为:{accuracy}%")
    ---------------------------------------------------------------------------------------
    # 运行结果
    tensor([True,False])
    tensor(1)
    正确率为:50%

上一篇 下一篇
神经网络入门实战(十七) 待发布
相关推荐
hanniuniu1315 分钟前
网络安全厂商F5推出AI Gateway,化解大模型应用风险
人工智能·web安全·gateway
Iamccc13_26 分钟前
智能仓储的未来:自动化、AI与数据分析如何重塑物流中心
人工智能·数据分析·自动化
蹦蹦跳跳真可爱5891 小时前
Python----目标检测(使用YOLO 模型进行线程安全推理和流媒体源)
人工智能·python·yolo·目标检测·目标跟踪
思尔芯S2C1 小时前
思尔芯携手Andes晶心科技,加速先进RISC-V 芯片开发
人工智能·科技·fpga开发·risc-v·debugging·prototyping·soc validation
风铃儿~1 小时前
Spring AI 入门:Java 开发者的生成式 AI 实践之路
java·人工智能·spring
晓枫-迷麟1 小时前
【使用conda】安装pytorch
人工智能·pytorch·conda
爱补鱼的猫猫1 小时前
Pytorch知识点2
人工智能·pytorch·python
deephub2 小时前
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
人工智能·pytorch·python·深度学习·机器学习·正则化
小于不是小鱼呀2 小时前
手撕 K-Means
人工智能·算法·机器学习