Argmax函数介绍
在Python中,
argmax
函数通常用于找出给定数组或列表中元素值最大的索引。
(一) Numpy 中的 Argmax 函数:
numpy.argmax
函数用于找出给定轴(axis)上最大值所在的索引。
示例:
python
import numpy as np
# 一维数组
arr = np.array([1, 3, 2, 5, 4])
index = np.argmax(arr)
print(index)
# 二维数组
arr_2d = np.array([[1, 3, 2], [5, 4, 6]])
index_axis_0 = np.argmax(arr_2d, axis=0) # 沿着 列(axis=0) 找出最大值的索引
index_axis_1 = np.argmax(arr_2d, axis=1) # 沿着 行(axis=1) 找出最大值的索引
print(index_axis_0) # 输出: [1 1 1]
print(index_axis_1) # 输出: [1 2]
-------------------------------------------------------------------------------------------------------
# 运行结果:
3
[1 1 1]
[1 2]
(二) pytorch 中的 Argmax 函数:
-
功能:
torch.argmax
函数在 PyTorch 中可以用于找到给定张量(Tensor)中每一行或每一列的最大值的索引。 -
用法:对于一个给定的张量,
argmax
会返回沿着指定维度的最大值的索引。如果不指定维度,默认是在最后一个维度上操作。列就是第一个维度,行就是第二个维度,第三个维度可以理解成厚度,以此类推...
-
代码:
pythontorch.argmax(input, dim=None, keepdim=False)
-
input
:输入的张量。 -
dim
:指定在哪个维度上寻找最大值,0
为一维,1
为二维...如果为None
,则默认在最后一个维度上操作。 -
keepdim
:如果为True
,则输出的张量将保持与输入张量相同的维度数,否则,输出张量将减少一个维度,默认是False
。 -
输出结果是
int
型,64位的电脑,输出默认就是int64
型。 -
简洁写法:
python# 假设有一个多维张量 some_tensor output = some_tensor.argmax(dim=1) # 逐行找出最大值
注意事项:
- 数据类型 :
argmax
返回的是int
类型(64位电脑就是int64
)的张量,因为索引值必须是整数。 - 多维张量 :
dim
参数指定了在哪个维度上寻找最大值。如果dim
为负值,则从最后一个维度开始计数(例如,dim=-1
等于默认行为)。 - 相同最大值 :如果有多个相同的最大值,
argmax
将返回 ++第一个++ 出现的最大值的索引。 - NaN 值 :如果张量中包含
NaN
值,argmax
会忽略这些值,并在剩余的有效值中寻找最大值。
-
-
示例:
pythonimport torch x = torch.tensor([[1, 2, 3], [4, 5, 6]]) result1 = torch.argmax(x, dim=0) # 在第一个维度(列)上寻找最大值的索引 print(result1) result2 = torch.argmax(x, dim=1) # 在最后一个维度(行)上寻找最大值的索引 print(result2) result3 = torch.argmax(x, dim=1, keepdim=True) # 指定keepdim=True print(result3) ------------------------------------------------------------------------------------------------------- # 运行结果: tensor([1, 1, 1]) tensor([2, 2]) tensor([[2], [2]])
-
高级应用:
可以在神经网络训练中,计算准确率。
假设有一个
n
分类模型,输出结果为n
个概率,分别对应输入图片被识别成这n
个类别的概率(∈[0,1]);令
n=3
(这 3 种类别分别用 0、1、2 表示),现在有两张输入图片,输出为:pythonoutputs=[[0.1, 0.2, 0.8], [0.4, 0.6, 0.2]]
假设两张图片分别为类别 0 和类别 2 (即
target=[2,2]
)。对outputs
使用torch.argmax
函数:pythonresult = torch.argmax(outputs, dim=1) # 在每行上寻找最大值的索引,因为 keepdim 默认为 False ,故输出是一维张量 # 输出结果为: tensor([2,1])
因为
target=[2,2]
,而训练结果为result=[2,1]
,所以第一个图片识别正确,第二个图片识别错误。正确率的计算 :
python# 前提是 target 和 result 都是张量 print(target == result) # 直接比较 A = (target == result).sum() #True默认为1,False默认位0 print(A) accuracy = (A.item() / len(result)) * 100 accuracy = roun d(accuracy, 2) # 2表示保留两位小数(四舍五入) print(f"正确率为:{accuracy}%") --------------------------------------------------------------------------------------- # 运行结果 tensor([True,False]) tensor(1) 正确率为:50%
上一篇 | 下一篇 |
---|---|
神经网络入门实战(十七) | 待发布 |