文章目录
- [1. torch_bmm](#1. torch_bmm)
- [2. pytorch源码](#2. pytorch源码)
1. torch_bmm
torch.bmm的作用是基于batch_size的矩阵乘法,torch.bmm的作用是对应batch位置的矩阵相乘,比如,
- mat1的第
1
个位置和mat2的第1
个位置进行矩阵相乘得到mat3的第1
个位置 - mat1的第
2
个位置和mat2的第2
个位置进行矩阵相乘得到mat3的第2
个位置
2. pytorch源码
python
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(precision=3, sci_mode=False)
if __name__ == "__main__":
run_code = 0
batch_size = 2
mat1_h = 3
mat1_w = 4
mat1_total = batch_size * mat1_w * mat1_h
mat2_h = 4
mat2_w = 5
mat2_total = batch_size * mat2_w * mat2_h
mat1 = torch.arange(mat1_total).reshape((batch_size, mat1_h, mat1_w))
mat2 = torch.arange(mat2_total).reshape((batch_size, mat2_h, mat2_w))
mat3 = torch.bmm(mat1, mat2)
print(f"mat1=\n{mat1}")
print(f"mat2=\n{mat2}")
print(f"mat3=\n{mat3}")
- 结果:
python
mat1=
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
mat2=
tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]],
[[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29],
[30, 31, 32, 33, 34],
[35, 36, 37, 38, 39]]])
mat3=
tensor([[[ 70, 76, 82, 88, 94],
[ 190, 212, 234, 256, 278],
[ 310, 348, 386, 424, 462]],
[[1510, 1564, 1618, 1672, 1726],
[1950, 2020, 2090, 2160, 2230],
[2390, 2476, 2562, 2648, 2734]]])