点积运算要求第一个矩阵 shape:(n, m),第二个矩阵 shape: (m, p), 两个矩阵点积运算shape为:(n,p)
- 运算符 @ 用于进行两个矩阵的点乘运算
- torch.mm 用于进行两个矩阵点乘运算,要求输入的矩阵为3维 (mm 代表 mat, mul)
- torch.bmm 用于批量进行矩阵点乘运算,要求输入的矩阵为3维 (b 代表 batch)
- torch.matmul 对进行点乘运算的两矩阵形状没有限定。
a. 对于输入都是二维的张量相当于 mm 运算
b. 对于输入都是三维的张量相当于 bmm 运算
c. 对数输入的shape不同的张量,对应的最后几个维度必须符合矩阵运算规则
代码
python
import torch
import numpy as np
# 使用@运算符
def test01():
# 形状为:3行2列
data1 = torch.tensor([[1,2], [3,4], [5,6]])
# 形状为:2行2列
data2 = torch.tensor([[5,6], [7,8]])
data = data1 @ data2
print(data)
# 使用 mm 函数
def test02():
# 要求输入的张量形状都是二维的
# 形状为:3行2列
data1 = torch.tensor([[1,2], [3,4], [5,6]])
# 形状为:2行2列
data2 = torch.tensor([[5,6], [7,8]])
data = torch.mm(data1, data2)
print(data)
print(data.shape)
# 使用 bmm 函数
def test03():
# 第一个维度:表示批次
# 第二个维度:多少行
# 第三个维度:多少列
data1 = torch.randn(3, 4, 5)
data2 = torch.randn(3, 5, 8)
data = torch.bmm(data1, data2)
print(data.shape)
# 使用 matmul 函数
def test04():
# 对二维进行计算
data1 = torch.randn(4,5)
data2 = torch.randn(5,8)
print(torch.matmul(data1, data2).shape)
# 对三维进行计算
data1 = torch.randn(3, 4, 5)
data2 = torch.randn(3, 5, 8)
print(torch.matmul(data1, data2).shape)
data1 = torch.randn(3, 4, 5)
data2 = torch.randn(5, 8)
print(torch.matmul(data1, data2).shape)
if __name__ == "__main__":
test04()