pytorch中torch.einsum函数的详细计算过程图解

第一次见到 rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)这行代码时,属实是懵了,网上找了很多博主的介绍,但都没有详细的说明函数内部的计算过程,看得我是一头雾水,只知道计算结果的维度是如何变化的,却不明白函数内部是如何计算的。话不多说,直接上示例代码

示例代码

python 复制代码
import torch
r_q = torch.tensor([[[[1, 2, 3, 4, 5],
                      [6, 7, 8, 9, 10],
                      [11, 12, 13, 14, 15],
                      [16, 17, 18, 19, 20]],
                     [[21, 22, 23, 24, 25],
                      [26, 27, 28, 29, 30],
                      [31, 32, 33, 34, 35],
                      [36, 37, 38, 39, 40]],
                     [[41, 42, 43, 44, 45],
                      [46, 47, 48, 49, 50],
                      [51, 52, 53, 54, 55],
                      [56, 57, 58, 59, 60]]]])

Rh = torch.tensor([[[1, 2, 3, 4, 5,],
                      [7, 8, 9, 10, 11, ],
                      [13, 14, 15, 16, 17, ],
                      [19, 20, 21, 22, 23, ],
                        [1, 2, 3, 4, 5,],
                    [1, 2, 3, 4, 5,],],
                     [[25, 26, 27, 28, 29, ],
                      [31, 32, 33, 34, 35, ],
                      [37, 38, 39, 40, 41, ],
                      [43, 44, 45, 46, 47, ],
                      [1, 2, 3, 4, 5,],
                      [1, 2, 3, 4, 5,],],
                     [[49, 50, 51, 52, 53, ],
                      [55, 56, 57, 58, 59, ],
                      [61, 62, 63, 64, 65, ],
                      [67, 68, 69, 70, 71, ],
                      [1, 2, 3, 4, 5,],
                      [1, 2, 3, 4, 5,],]])

rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
print(rel_h)

输出结果:

结果解释

文字很难解释清楚,直接上图。r_q的维度为(1, 3, 4, 5), Rh的维度为(3, 6, 5),函数torch.einsum("bhwc,hkc->bhwk", r_q, Rh)中b=1, h=3, w=4, c=5。所以最终结果Rel_h的维度为bhwk,即(1, 3, 4, 5)。具体计算过程如下图。

这回看懂了吧。还不理解的或者讲的不对的地方,欢迎在评论区留言。创作不易,喜欢的话点个关注吧

相关推荐
Liue612312319 分钟前
基于YOLO11-C3k2-Faster-CGLU的路面落叶检测与识别系统实现
python
blackicexs28 分钟前
第四周第七天
算法
硅谷秋水36 分钟前
RoboBrain 2.5:视野中的深度,思维中的时间
深度学习·机器学习·计算机视觉·语言模型·机器人
zhangfeng113341 分钟前
Warmup Scheduler深度学习训练中,在训练初期使用较低学习率进行预热(Warmup),然后再按照预定策略(如余弦退火、阶梯下降等)衰减学习率的方法
人工智能·深度学习·学习
Faker66363aaa42 分钟前
城市地标建筑与车辆检测 - 基于YOLOv10n的高效目标检测模型训练与应用
人工智能·yolo·目标检测
~央千澈~1 小时前
抖音弹幕游戏开发之第8集:pyautogui基础 - 模拟键盘操作·优雅草云桧·卓伊凡
网络·python·websocket·网络协议
沃达德软件1 小时前
电信诈骗预警平台功能解析
大数据·数据仓库·人工智能·深度学习·机器学习·数据库开发
期末考复习中,蓝桥杯都没时间学了1 小时前
力扣刷题19
算法·leetcode·职场和发展
占疏1 小时前
列表分成指定的份数
python
Hy行者勇哥1 小时前
Seedance 全面解析:定义、使用指南、同类软件与完整攻略
人工智能·学习方法·视频