神经网络入门实战:(十八)Argmax函数的详细介绍,可以用来计算模型训练准确率

Argmax函数介绍

在Python中,argmax 函数通常用于找出给定数组或列表中元素值最大的索引。

(一) Numpy 中的 Argmax 函数:

numpy.argmax 函数用于找出给定轴(axis)上最大值所在的索引。

示例:

python 复制代码
import numpy as np
# 一维数组
arr = np.array([1, 3, 2, 5, 4])
index = np.argmax(arr)
print(index)  

# 二维数组
arr_2d = np.array([[1, 3, 2], [5, 4, 6]])
index_axis_0 = np.argmax(arr_2d, axis=0) # 沿着 列(axis=0) 找出最大值的索引
index_axis_1 = np.argmax(arr_2d, axis=1) # 沿着 行(axis=1) 找出最大值的索引
print(index_axis_0)  # 输出: [1 1 1]
print(index_axis_1)  # 输出: [1 2]
-------------------------------------------------------------------------------------------------------
# 运行结果:
3
[1 1 1]
[1 2]

(二) pytorch 中的 Argmax 函数:

  • 功能:torch.argmax 函数在 PyTorch 中可以用于找到给定张量(Tensor)中每一行或每一列的最大值的索引。

  • 用法:对于一个给定的张量,argmax 会返回沿着指定维度的最大值的索引。如果不指定维度,默认是在最后一个维度上操作

    列就是第一个维度,行就是第二个维度,第三个维度可以理解成厚度,以此类推...

  • 代码:

    python 复制代码
    torch.argmax(input, dim=None, keepdim=False)
    • input:输入的张量。

    • dim:指定在哪个维度上寻找最大值,0为一维,1为二维...如果为 None,则默认在最后一个维度上操作。

    • keepdim:如果为 True,则输出的张量将保持与输入张量相同的维度数,否则,输出张量将减少一个维度,默认是 False

    • 输出结果是 int 型,64位的电脑,输出默认就是 int64 型。

    • 简洁写法:

      python 复制代码
      # 假设有一个多维张量 some_tensor
      output = some_tensor.argmax(dim=1) # 逐行找出最大值

    注意事项:

    1. 数据类型argmax 返回的是 int 类型(64位电脑就是 int64 )的张量,因为索引值必须是整数。
    2. 多维张量dim 参数指定了在哪个维度上寻找最大值。如果 dim 为负值,则从最后一个维度开始计数(例如,dim=-1 等于默认行为)。
    3. 相同最大值 :如果有多个相同的最大值,argmax 将返回 ++第一个++ 出现的最大值的索引。
    4. NaN 值 :如果张量中包含 NaN 值,argmax 会忽略这些值,并在剩余的有效值中寻找最大值。
  • 示例:

    python 复制代码
    import torch
    
    x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    result1 = torch.argmax(x, dim=0) # 在第一个维度(列)上寻找最大值的索引
    print(result1)
    result2 = torch.argmax(x, dim=1) # 在最后一个维度(行)上寻找最大值的索引
    print(result2)
    result3 = torch.argmax(x, dim=1, keepdim=True) # 指定keepdim=True
    print(result3) 
    -------------------------------------------------------------------------------------------------------
    # 运行结果:
    tensor([1, 1, 1])
    tensor([2, 2])
    tensor([[2],
            [2]])
  • 高级应用:

    可以在神经网络训练中,计算准确率。

    假设有一个 n 分类模型,输出结果为 n 个概率,分别对应输入图片被识别成这 n 个类别的概率(∈[0,1]);

    n=3 (这 3 种类别分别用 0、1、2 表示),现在有两张输入图片,输出为:

    python 复制代码
    outputs=[[0.1, 0.2, 0.8],
            [0.4, 0.6, 0.2]]

    假设两张图片分别为类别 0 和类别 2 (即 target=[2,2])。对 outputs 使用 torch.argmax 函数:

    python 复制代码
    result = torch.argmax(outputs, dim=1) # 在每行上寻找最大值的索引,因为 keepdim 默认为 False ,故输出是一维张量 
    # 输出结果为:
    tensor([2,1])

    因为 target=[2,2] ,而训练结果为 result=[2,1] ,所以第一个图片识别正确,第二个图片识别错误。

    正确率的计算 :

    python 复制代码
    # 前提是 target 和 result 都是张量
    print(target == result) # 直接比较
    A = (target == result).sum() #True默认为1,False默认位0
    print(A)
    accuracy = (A.item() / len(result)) * 100
    accuracy = roun
    d(accuracy, 2) # 2表示保留两位小数(四舍五入)
    print(f"正确率为:{accuracy}%")
    ---------------------------------------------------------------------------------------
    # 运行结果
    tensor([True,False])
    tensor(1)
    正确率为:50%

上一篇 下一篇
神经网络入门实战(十七) 待发布
相关推荐
AngelPP3 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年3 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼3 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS3 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区4 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈5 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang5 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx
shengjk16 小时前
NanoClaw 深度剖析:一个"AI 原生"架构的个人助手是如何运转的?
人工智能
西门老铁8 小时前
🦞OpenClaw 让 MacMini 脱销了,而我拿出了6年陈的安卓机
人工智能