理解torch.argmax() ,我是错误的

torch.max()

python 复制代码
import torch

# 定义张量 b
b = torch.tensor([[1, 3, 5, 7],
                  [2, 4, 6, 8],
                  [11, 12, 13, 17]])

# 使用 torch.max() 找到最大值
max_indices = torch.max(b, dim=0)

print(max_indices)

输出:>>> print(max_indices)

torch.return_types.max(

values=tensor([11, 12, 13, 17]),

indices=tensor([2, 2, 2, 2]))

分析:张量b是3*4 二维张量,dim=0 得到3个张量分别是【1,2,5,7】,【2,4,6,8】,【11,12,13,17】,最大是谁呢?因此tourch.argmax() 得到indices=2对应第三个。注意啊,是4个2!

python 复制代码
import torch

# 定义张量 b
b = torch.tensor([[1, 3, 5, 7],
                  [2, 4, 6, 8],
                  [11, 12, 13, 17]])

# 使用 torch.max() 找到最大
max_indices = torch.max(b, dim=1)

print(max_indices)

输出:>>> print(max_indices)

torch.return_types.max(

values=tensor([ 7, 8, 17]),

indices=tensor([3, 3, 3]))

分析:张量b是3*4 二维张量,dim=1 得到4个张量分别是【1,2,11】,【3,4,12】,【15,6,13】,【7,8,17】最大是谁呢?因此indices=3对应第4个,因此tourch.argmax() 得到indices=3对应第4个。注意啊,是3个3!

结论:torch.argmax() ,我开始的理解是错误的,通过torch.max() 分析,重新理解argmax() 返回所有元素中的最大值索引!问题来了,索引可能多个,例如indices=tensor([3, 3, 3])),我的疑惑就是索引都是3,能否得到索引indices=tensor([2, 3, 3]))这样的例子呢?如果构造张量b 确保3个索引值是不同的呢?

相关推荐
生信大表哥6 小时前
单细胞测序分析(五)降维聚类&数据整合
linux·python·聚类·数信院生信服务器
seeyoutlb7 小时前
微服务全局日志处理
java·python·微服务
ada7_7 小时前
LeetCode(python)——148.排序链表
python·算法·leetcode·链表
岁月宁静8 小时前
LangChain + LangGraph 实战:构建生产级多模态 WorkflowAgent 的完整指南
人工智能·python·agent
梯度下降不了班8 小时前
【mmodel/xDit】Cross-Attention 深度解析:文生图/文生视频的核心桥梁
人工智能·深度学习·ai作画·stable diffusion·音视频·transformer
第二只羽毛9 小时前
主题爬虫采集主题新闻信息
大数据·爬虫·python·网络爬虫
plmm烟酒僧9 小时前
TensorRT 推理 YOLO Demo 分享 (Python)
开发语言·python·yolo·tensorrt·runtime·推理
天才测试猿9 小时前
Postman中变量的使用详解
自动化测试·软件测试·python·测试工具·职场和发展·接口测试·postman
帕巴啦9 小时前
Arcgis计算面要素的面积、周长、宽度、长度及最大直径
python·arcgis
AI小云9 小时前
【数据操作与可视化】Matplotlib绘图-生成其他图表类型
开发语言·python·matplotlib