torch.bmm功能解读

bmmbatched matrix multiple 的简写,即批量矩阵乘法,矩阵是二维的,加上batch一个维度,因此该函数的输入必须是两个三维的 tensor,三个维度代表的含义分别是:(批量,行,列)。

对于 torch.bmm(tensor_a, tensor_b) 而言,

tensor_ashape为 (a, b, c)

tensor_bshape为 (d, e, f)

要求 a = d, c = e,即批量数相同,在计算时tensor_a 的第 i 个矩阵与 tensor_b 的第 i 个矩阵作乘法,i = 1, 2, 3, ..., a。因此为了矩阵乘法能够进行,c 和 e 必须相同。计算过程如图1所示。
图1. bmm计算过程

测试代码如下:

python 复制代码
import torch

BatchMatrix1 = torch.randn((3,4,3))
BatchMatrix2 = torch.randn((3,3,4))

BatchMatrixMultiple = torch.bmm(BatchMatrix1, BatchMatrix2)

print(BatchMatrixMultiple.shape)

输出为,与图1中绿色矩阵对应。

相关推荐
C雨后彩虹几秒前
优雅子数组
java·数据结构·算法·华为·面试
wangmengxxw4 分钟前
SpringAI-mysql
java·数据库·人工智能·mysql·springai
漫随流水7 分钟前
leetcode回溯算法(46.全排列)
数据结构·算法·leetcode·回溯算法
考證寶題庫網7 分钟前
Designing and Implementing a Microsoft Azure AI Solution 微軟Azure AI-102 認證全攻略
人工智能·microsoft·azure
We་ct10 分钟前
LeetCode 68. 文本左右对齐:贪心算法的两种实现与深度解析
前端·算法·leetcode·typescript
努力学算法的蒟蒻13 分钟前
day67(1.26)——leetcode面试经典150
算法·leetcode·面试
iAkuya15 分钟前
(leetcode) 力扣100 52腐烂的橘子(BFS)
算法·leetcode·宽度优先
逄逄不是胖胖15 分钟前
《动手学深度学习》-52文本预处理实现
人工智能·pytorch·python·深度学习
Pyeako18 分钟前
opencv计算机视觉--图形透视(投影)变换&图形拼接
人工智能·python·opencv·计算机视觉·图片拼接·投影变换·图形透视变换
老鼠只爱大米19 分钟前
LeetCode经典算法面试题 #148:排序链表(插入、归并、快速等五种实现方案解析)
算法·leetcode·链表·插入排序·归并排序·快速排序·链表排序