torch.bmm功能解读

bmmbatched matrix multiple 的简写,即批量矩阵乘法,矩阵是二维的,加上batch一个维度,因此该函数的输入必须是两个三维的 tensor,三个维度代表的含义分别是:(批量,行,列)。

对于 torch.bmm(tensor_a, tensor_b) 而言,

tensor_ashape为 (a, b, c)

tensor_bshape为 (d, e, f)

要求 a = d, c = e,即批量数相同,在计算时tensor_a 的第 i 个矩阵与 tensor_b 的第 i 个矩阵作乘法,i = 1, 2, 3, ..., a。因此为了矩阵乘法能够进行,c 和 e 必须相同。计算过程如图1所示。
图1. bmm计算过程

测试代码如下:

python 复制代码
import torch

BatchMatrix1 = torch.randn((3,4,3))
BatchMatrix2 = torch.randn((3,3,4))

BatchMatrixMultiple = torch.bmm(BatchMatrix1, BatchMatrix2)

print(BatchMatrixMultiple.shape)

输出为,与图1中绿色矩阵对应。

相关推荐
海边夕阳20061 分钟前
【每天一个AI小知识】:什么是逻辑回归?
人工智能·算法·逻辑回归
普通网友6 分钟前
C++编译期数据结构
开发语言·c++·算法
Gorgous—l7 分钟前
数据结构算法学习:LeetCode热题100-图论篇(岛屿数量、腐烂的橘子、课程表、实现 Trie (前缀树))
数据结构·学习·算法
im_AMBER15 分钟前
算法笔记 13 BFS | 图
笔记·学习·算法·广度优先
非著名架构师24 分钟前
团雾、结冰、大风——高速公路的“隐形杀手”:智慧气象预警如何为您的路网安全保驾护航
人工智能·新能源风光提高精度·疾风气象大模型4.0·疾风气象大模型·风光功率预测
IT_陈寒32 分钟前
Redis深度优化:10个让你的QPS提升50%的关键配置解析
前端·人工智能·后端
2501_9411429334 分钟前
5G与边缘计算结合在智能物流系统中的高效调度与实时监控应用研究
人工智能
普通网友35 分钟前
嵌入式C++安全编码
开发语言·c++·算法
2501_9411444239 分钟前
边缘计算与人工智能在智能制造生产线优化与故障预测中的应用研究
人工智能·边缘计算·制造
三寸3371 小时前
硬刚GPT 5.1,Grok 4.1来了,所有用户免费使用!
人工智能·ai·ai编程