torch.bmm功能解读

bmmbatched matrix multiple 的简写,即批量矩阵乘法,矩阵是二维的,加上batch一个维度,因此该函数的输入必须是两个三维的 tensor,三个维度代表的含义分别是:(批量,行,列)。

对于 torch.bmm(tensor_a, tensor_b) 而言,

tensor_ashape为 (a, b, c)

tensor_bshape为 (d, e, f)

要求 a = d, c = e,即批量数相同,在计算时tensor_a 的第 i 个矩阵与 tensor_b 的第 i 个矩阵作乘法,i = 1, 2, 3, ..., a。因此为了矩阵乘法能够进行,c 和 e 必须相同。计算过程如图1所示。
图1. bmm计算过程

测试代码如下:

python 复制代码
import torch

BatchMatrix1 = torch.randn((3,4,3))
BatchMatrix2 = torch.randn((3,3,4))

BatchMatrixMultiple = torch.bmm(BatchMatrix1, BatchMatrix2)

print(BatchMatrixMultiple.shape)

输出为,与图1中绿色矩阵对应。

相关推荐
Mao.O2 小时前
开源项目“AI思维圆桌”的介绍和对于当前AI编程的思考
人工智能
jake don2 小时前
AI 深度学习路线
人工智能·深度学习
夏鹏今天学习了吗3 小时前
【LeetCode热题100(87/100)】最小路径和
算法·leetcode·职场和发展
信创天地3 小时前
信创场景软件兼容性测试实战:适配国产软硬件生态,破解运行故障难题
人工智能·开源·dubbo·运维开发·risc-v
哈哈不让取名字3 小时前
基于C++的爬虫框架
开发语言·c++·算法
幻云20103 小时前
Python深度学习:从筑基到登仙
前端·javascript·vue.js·人工智能·python
bst@微胖子3 小时前
LlamaIndex之核心概念及部署以及入门案例
pytorch·深度学习·机器学习
无风听海3 小时前
CBOW 模型中的输出层
人工智能·机器学习
汇智信科3 小时前
智慧矿山和工业大数据解决方案“智能设备管理系统”
大数据·人工智能·工业大数据·智能矿山·汇智信科·智能设备管理系统
静听松涛1333 小时前
跨语言低资源场景下的零样本迁移
人工智能