理解torch.argmax() ,我是错误的

torch.max()

python 复制代码
import torch

# 定义张量 b
b = torch.tensor([[1, 3, 5, 7],
                  [2, 4, 6, 8],
                  [11, 12, 13, 17]])

# 使用 torch.max() 找到最大值
max_indices = torch.max(b, dim=0)

print(max_indices)

输出:>>> print(max_indices)

torch.return_types.max(

values=tensor(11, 12, 13, 17),

indices=tensor(2, 2, 2, 2))

分析:张量b是3*4 二维张量,dim=0 得到3个张量分别是【1,2,5,7】,【2,4,6,8】,【11,12,13,17】,最大是谁呢?因此tourch.argmax() 得到indices=2对应第三个。注意啊,是4个2!

python 复制代码
import torch

# 定义张量 b
b = torch.tensor([[1, 3, 5, 7],
                  [2, 4, 6, 8],
                  [11, 12, 13, 17]])

# 使用 torch.max() 找到最大
max_indices = torch.max(b, dim=1)

print(max_indices)

输出:>>> print(max_indices)

torch.return_types.max(

values=tensor( 7, 8, 17),

indices=tensor(3, 3, 3))

分析:张量b是3*4 二维张量,dim=1 得到4个张量分别是【1,2,11】,【3,4,12】,【15,6,13】,【7,8,17】最大是谁呢?因此indices=3对应第4个,因此tourch.argmax() 得到indices=3对应第4个。注意啊,是3个3!

结论:torch.argmax() ,我开始的理解是错误的,通过torch.max() 分析,重新理解argmax() 返回所有元素中的最大值索引!问题来了,索引可能多个,例如indices=tensor(3, 3, 3)),我的疑惑就是索引都是3,能否得到索引indices=tensor(2, 3, 3))这样的例子呢?如果构造张量b 确保3个索引值是不同的呢?

相关推荐
AI人工智能+3 分钟前
融合图像处理与模式识别算法的智能银行卡识别系统,为金融行业带来了革命性的效率提升
人工智能·深度学习·ocr·银行卡识别
小桥流水---人工智能18 分钟前
【已解决】ImportError: cannot import name ‘AdamW‘ from ‘transformers.optimization‘
python
芝麻开门GEO26 分钟前
泰安GEO优化服务,真的能提升效果吗?
人工智能·python
颜酱35 分钟前
选读:工业级调用 LangChain:从 Demo 到企业级应用
python
颜酱1 小时前
LangChain 调用大模型实战:从跑通到服务商与模型选型
python·langchain
唐装鼠2 小时前
Nginx + Gunicorn + Python Web 应用 架构(Claude)
python·nginx·gunicorn
梦想三三2 小时前
【PYthon词频统计与文本向量化】苏宁易购评论分析实战
开发语言·python
zhangfeng11332 小时前
Mamba transformer的颠覆者 论文技术解读与应用实践深度报告,
人工智能·深度学习·transformer
biter down3 小时前
9:JSONSchema
python
日晨难再3 小时前
C语言&Python&Bash&Tcl:全局变量和局部变量
c语言·python·bash·tcl