理解torch.argmax() ,我是错误的

torch.max()

python 复制代码
import torch

# 定义张量 b
b = torch.tensor([[1, 3, 5, 7],
                  [2, 4, 6, 8],
                  [11, 12, 13, 17]])

# 使用 torch.max() 找到最大值
max_indices = torch.max(b, dim=0)

print(max_indices)

输出:>>> print(max_indices)

torch.return_types.max(

values=tensor([11, 12, 13, 17]),

indices=tensor([2, 2, 2, 2]))

分析:张量b是3*4 二维张量,dim=0 得到3个张量分别是【1,2,5,7】,【2,4,6,8】,【11,12,13,17】,最大是谁呢?因此tourch.argmax() 得到indices=2对应第三个。注意啊,是4个2!

python 复制代码
import torch

# 定义张量 b
b = torch.tensor([[1, 3, 5, 7],
                  [2, 4, 6, 8],
                  [11, 12, 13, 17]])

# 使用 torch.max() 找到最大
max_indices = torch.max(b, dim=1)

print(max_indices)

输出:>>> print(max_indices)

torch.return_types.max(

values=tensor([ 7, 8, 17]),

indices=tensor([3, 3, 3]))

分析:张量b是3*4 二维张量,dim=1 得到4个张量分别是【1,2,11】,【3,4,12】,【15,6,13】,【7,8,17】最大是谁呢?因此indices=3对应第4个,因此tourch.argmax() 得到indices=3对应第4个。注意啊,是3个3!

结论:torch.argmax() ,我开始的理解是错误的,通过torch.max() 分析,重新理解argmax() 返回所有元素中的最大值索引!问题来了,索引可能多个,例如indices=tensor([3, 3, 3])),我的疑惑就是索引都是3,能否得到索引indices=tensor([2, 3, 3]))这样的例子呢?如果构造张量b 确保3个索引值是不同的呢?

相关推荐
西柚小萌新5 分钟前
【深入浅出PyTorch】--7.2.PyTorch可视化2
人工智能·pytorch·python
java1234_小锋18 分钟前
TensorFlow2 Python深度学习 - 使用TensorBoard可视化数据
python·深度学习·tensorflow·tensorflow2
源来是大数据的菜鸟20 分钟前
基于Multi-Agent开发的SmartCare系统自动化运维管家
python·运维开发
该用户已不存在32 分钟前
我的Python工具箱,不用加班的秘密
前端·后端·python
星期天要睡觉43 分钟前
计算机视觉(opencv)——实时颜色检测
人工智能·python·opencv·计算机视觉
aerror1 小时前
json转excel xlsx文件
开发语言·python·json
查士丁尼·绵2 小时前
笔试-士兵过河
python
weixin_46682 小时前
编程之python基础
开发语言·python
打酱油的;2 小时前
【无标题】
爬虫·python·php
幸福清风3 小时前
【Python】基于Tkinter库实现文件夹拖拽与选择功能
windows·python·microsoft·tkinter