理解torch.argmax() ,我是错误的

torch.max()

python 复制代码
import torch

# 定义张量 b
b = torch.tensor([[1, 3, 5, 7],
                  [2, 4, 6, 8],
                  [11, 12, 13, 17]])

# 使用 torch.max() 找到最大值
max_indices = torch.max(b, dim=0)

print(max_indices)

输出:>>> print(max_indices)

torch.return_types.max(

values=tensor([11, 12, 13, 17]),

indices=tensor([2, 2, 2, 2]))

分析:张量b是3*4 二维张量,dim=0 得到3个张量分别是【1,2,5,7】,【2,4,6,8】,【11,12,13,17】,最大是谁呢?因此tourch.argmax() 得到indices=2对应第三个。注意啊,是4个2!

python 复制代码
import torch

# 定义张量 b
b = torch.tensor([[1, 3, 5, 7],
                  [2, 4, 6, 8],
                  [11, 12, 13, 17]])

# 使用 torch.max() 找到最大
max_indices = torch.max(b, dim=1)

print(max_indices)

输出:>>> print(max_indices)

torch.return_types.max(

values=tensor([ 7, 8, 17]),

indices=tensor([3, 3, 3]))

分析:张量b是3*4 二维张量,dim=1 得到4个张量分别是【1,2,11】,【3,4,12】,【15,6,13】,【7,8,17】最大是谁呢?因此indices=3对应第4个,因此tourch.argmax() 得到indices=3对应第4个。注意啊,是3个3!

结论:torch.argmax() ,我开始的理解是错误的,通过torch.max() 分析,重新理解argmax() 返回所有元素中的最大值索引!问题来了,索引可能多个,例如indices=tensor([3, 3, 3])),我的疑惑就是索引都是3,能否得到索引indices=tensor([2, 3, 3]))这样的例子呢?如果构造张量b 确保3个索引值是不同的呢?

相关推荐
一人の梅雨6 分钟前
1688 店铺商品全量采集与智能分析:从接口调用到供应链数据挖掘
开发语言·python·php
apocalypsx30 分钟前
深度学习-Kaggle实战1(房价预测)
人工智能·深度学习
春末的南方城市34 分钟前
开放指令编辑创新突破!小米开源 Lego-Edit 登顶 SOTA:用强化学习为 MLLM 编辑开辟全新赛道!
人工智能·深度学习·机器学习·计算机视觉·aigc
Terio_my37 分钟前
Python制作12306查票工具:从零构建铁路购票信息查询系统
开发语言·python·microsoft
万粉变现经纪人1 小时前
如何解决 pip install -r requirements.txt 约束文件 constraints.txt 仅允许固定版本(未锁定报错)问题
开发语言·python·r语言·django·beautifulsoup·pandas·pip
站大爷IP1 小时前
Python定时任务实战:APScheduler从入门到精通
python
Fairy_sevenseven1 小时前
[1]python爬虫入门,爬取豆瓣电影top250实践
开发语言·爬虫·python
ThisIsMirror1 小时前
CompletableFuture并行任务超时处理模板
java·windows·python
java1234_小锋2 小时前
TensorFlow2 Python深度学习 - TensorFlow2框架入门 - 计算图和 tf.function 简介
python·深度学习·tensorflow·tensorflow2
程序员晚枫2 小时前
Python 3.14新特性:Zstandard压缩库正式加入标准库,性能提升30%
python