理解torch.argmax() ,我是错误的

torch.max()

python 复制代码
import torch

# 定义张量 b
b = torch.tensor([[1, 3, 5, 7],
                  [2, 4, 6, 8],
                  [11, 12, 13, 17]])

# 使用 torch.max() 找到最大值
max_indices = torch.max(b, dim=0)

print(max_indices)

输出:>>> print(max_indices)

torch.return_types.max(

values=tensor([11, 12, 13, 17]),

indices=tensor([2, 2, 2, 2]))

分析:张量b是3*4 二维张量,dim=0 得到3个张量分别是【1,2,5,7】,【2,4,6,8】,【11,12,13,17】,最大是谁呢?因此tourch.argmax() 得到indices=2对应第三个。注意啊,是4个2!

python 复制代码
import torch

# 定义张量 b
b = torch.tensor([[1, 3, 5, 7],
                  [2, 4, 6, 8],
                  [11, 12, 13, 17]])

# 使用 torch.max() 找到最大
max_indices = torch.max(b, dim=1)

print(max_indices)

输出:>>> print(max_indices)

torch.return_types.max(

values=tensor([ 7, 8, 17]),

indices=tensor([3, 3, 3]))

分析:张量b是3*4 二维张量,dim=1 得到4个张量分别是【1,2,11】,【3,4,12】,【15,6,13】,【7,8,17】最大是谁呢?因此indices=3对应第4个,因此tourch.argmax() 得到indices=3对应第4个。注意啊,是3个3!

结论:torch.argmax() ,我开始的理解是错误的,通过torch.max() 分析,重新理解argmax() 返回所有元素中的最大值索引!问题来了,索引可能多个,例如indices=tensor([3, 3, 3])),我的疑惑就是索引都是3,能否得到索引indices=tensor([2, 3, 3]))这样的例子呢?如果构造张量b 确保3个索引值是不同的呢?

相关推荐
Java&Develop3 小时前
Aes加密 GCM java
java·开发语言·python
爱笑的眼睛115 小时前
超越MSE与交叉熵:深度解析损失函数的动态本质与高阶设计
java·人工智能·python·ai
Coding茶水间5 小时前
基于深度学习的非机动车头盔检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·机器学习·计算机视觉
Rose sait5 小时前
【环境配置】Linux配置虚拟环境pytorch
linux·人工智能·python
过期动态6 小时前
JDBC高级篇:优化、封装与事务全流程指南
android·java·开发语言·数据库·python·mysql
baby_hua6 小时前
20251024_PyTorch深度学习快速入门教程
人工智能·pytorch·深度学习
一世琉璃白_Y6 小时前
pg配置国内数据源安装
linux·python·postgresql·centos
liwulin05066 小时前
【PYTHON】COCO数据集中的物品ID
开发语言·python
小鸡吃米…6 小时前
Python - XML 处理
xml·开发语言·python·开源
我赵帅的飞起7 小时前
python国密SM4加解密
python·sm4加解密·国密sm4加解密