PyTorch 里的矩阵乘法

Pyorch 里的矩阵乘法

flyfish

输入 运算
1D × 1D dot product
2D × 2D matrix multiply
1D × 2D vector × matrix
2D × 1D matrix × vector
ND × ND batch matrix multiply

torch.mm / matmul / bmm 区别

函数 作用
torch.mm 只支持 2D矩阵乘法
torch.matmul 通用乘法
torch.bmm batch矩阵乘法

mm = matrix multiply

bmm = batch matrix multiply

matmul = matrix multiply

矩阵乘法其实是:

复制代码
行 × 列

例如

复制代码
A = [a1 a2 a3]
B = [b1
     b2
     b3]

结果:

复制代码
a1*b1 + a2*b2 + a3*b3

就是 dot product

矩阵乘法 = 行向量和列向量的 点乘集合dot product

概念 是否矩阵乘法
dot product 矩阵乘法的基本单元
cross product 不是矩阵乘法
A @ B 等价 torch.matmul
torch.mm 2D矩阵
torch.bmm batch矩阵
torch.matmul 通用
cpp 复制代码
import torch

# ==============================================
# 规则1:1D @ 1D → 点积(标量)
# 展示:逐元素相乘 → 全部相加 → 最终结果
# ==============================================
print("="*60)
print("规则1:1维向量 @ 1维向量 = 点积")
a1 = torch.tensor([1, 2, 3])
b1 = torch.tensor([4, 5, 6])
print(f"向量a1: {a1}")
print(f"向量b1: {b1}")
# 手动计算过程
print("【逐元素相乘 + 求和过程】")
print(f"1 × 4 = 4")
print(f"2 × 5 = 10")
print(f"3 × 6 = 18")
print(f"总和:4 + 10 + 18 = 32")
# 代码运算
res_at = a1 @ b1
res_eq = torch.dot(a1, b1)
print(f"@ 运算结果: {res_at}")
print(f"等价torch.dot结果: {res_eq}")
print(f"结果形状: {res_at.shape}\n")

# ==============================================
# 规则2:2D @ 2D → 标准矩阵乘法
# 展示:每行 × 每列 → 对应元素相乘 → 相加 → 矩阵元素
# ==============================================
print("="*60)
print("规则2:2维矩阵 @ 2维矩阵 = 矩阵乘法")
a2 = torch.tensor([[1, 2], [3, 4]])
b2 = torch.tensor([[5, 6], [7, 8]])
print(f"矩阵a2:\n{a2}")
print(f"矩阵b2:\n{b2}")
# 手动计算过程(逐个元素计算)
print("【逐元素计算过程:行 × 列 相乘后相加】")
print("第1行第1列:1×5 + 2×7 = 5 + 14 = 19")
print("第1行第2列:1×6 + 2×8 = 6 + 16 = 22")
print("第2行第1列:3×5 + 4×7 = 15 + 28 = 43")
print("第2行第2列:3×6 + 4×8 = 18 + 32 = 50")
# 代码运算
res_at = a2 @ b2
res_eq = torch.mm(a2, b2)
print(f"@ 运算结果:\n{res_at}")
print(f"等价torch.mm结果:\n{res_eq}")
print(f"结果形状: {res_at.shape}\n")

# ==============================================
# 规则3:1D @ 2D → 自动升维计算
# 展示:1维升维 → 行×列相乘相加 → 降维输出
# ==============================================
print("="*60)
print("规则3:1维向量 @ 2维矩阵 = 自动升维计算")
a3 = torch.tensor([1, 2])
b3 = torch.tensor([[1, 0], [0, 1]])
print(f"向量a3: {a3}")
print(f"矩阵b3:\n{b3}")
# 手动计算过程
print("【计算过程:a3升维为(1,2),再行×列相乘相加】")
print("第1列:1×1 + 2×0 = 1 + 0 = 1")
print("第2列:1×0 + 2×1 = 0 + 2 = 2")
# 代码运算
res_at = a3 @ b3
print(f"@ 运算结果: {res_at}")
print(f"结果形状: {res_at.shape}\n")

# ==============================================
# 规则4:2D @ 1D → 矩阵-向量乘法
# 展示:矩阵每行 × 向量 → 相乘相加 → 输出向量
# ==============================================
print("="*60)
print("规则4:2维矩阵 @ 1维向量 = 矩阵-向量乘法")
a4 = torch.tensor([[1, 2], [3, 4]])
b4 = torch.tensor([1, 1])
print(f"矩阵a4:\n{a4}")
print(f"向量b4: {b4}")
# 手动计算过程
print("【逐行 × 向量 相乘相加】")
print("第1个结果:1×1 + 2×1 = 1 + 2 = 3")
print("第2个结果:3×1 + 4×1 = 3 + 4 = 7")
# 代码运算
res_at = a4 @ b4
print(f"@ 运算结果: {res_at}")
print(f"结果形状: {res_at.shape}\n")

# ==============================================
# 规则5:3D @ 3D → 批量矩阵乘法
# 展示:每个批次独立计算矩阵乘法
# ==============================================
print("="*60)
print("规则5:3维批量矩阵 @ 3维批量矩阵 = 批量矩阵乘法")
a5 = torch.tensor([[[1,2],[3,4]], [[5,6],[7,8]]])
b5 = torch.tensor([[[1,0],[0,1]], [[1,0],[0,1]]])
print(f"批量矩阵a5 形状: {a5.shape}")
print(f"批量矩阵b5 形状: {b5.shape}")
# 手动计算过程(分批次计算)
print("【第1个批次计算:】")
print("1×1 + 2×0 = 1 | 1×0 + 2×1 = 2")
print("3×1 + 4×0 = 3 | 3×0 + 4×1 = 4")
print("【第2个批次计算:】")
print("5×1 + 6×0 = 5 | 5×0 + 6×1 = 6")
print("7×1 + 8×0 = 7 | 7×0 + 8×1 = 8")
# 代码运算
res_at = a5 @ b5
res_eq = torch.bmm(a5, b5)
print(f"@ 运算结果:\n{res_at}")
print(f"等价torch.bmm结果:\n{res_eq}")
print(f"结果形状: {res_at.shape}")

输出

复制代码
============================================================
规则1:1维向量 @ 1维向量 = 点积
向量a1: tensor([1, 2, 3])
向量b1: tensor([4, 5, 6])
【逐元素相乘 + 求和过程】
1 × 4 = 4
2 × 5 = 10
3 × 6 = 18
总和:4 + 10 + 18 = 32
@ 运算结果: 32
等价torch.dot结果: 32
结果形状: torch.Size([])

============================================================
规则2:2维矩阵 @ 2维矩阵 = 矩阵乘法
矩阵a2:
tensor([[1, 2],
        [3, 4]])
矩阵b2:
tensor([[5, 6],
        [7, 8]])
【逐元素计算过程:行 × 列 相乘后相加】
第1行第1列:1×5 + 2×7 = 5 + 14 = 19
第1行第2列:1×6 + 2×8 = 6 + 16 = 22
第2行第1列:3×5 + 4×7 = 15 + 28 = 43
第2行第2列:3×6 + 4×8 = 18 + 32 = 50
@ 运算结果:
tensor([[19, 22],
        [43, 50]])
等价torch.mm结果:
tensor([[19, 22],
        [43, 50]])
结果形状: torch.Size([2, 2])

============================================================
规则3:1维向量 @ 2维矩阵 = 自动升维计算
向量a3: tensor([1, 2])
矩阵b3:
tensor([[1, 0],
        [0, 1]])
【计算过程:a3升维为(1,2),再行×列相乘相加】
第1列:1×1 + 2×0 = 1 + 0 = 1
第2列:1×0 + 2×1 = 0 + 2 = 2
@ 运算结果: tensor([1, 2])
结果形状: torch.Size([2])

============================================================
规则4:2维矩阵 @ 1维向量 = 矩阵-向量乘法
矩阵a4:
tensor([[1, 2],
        [3, 4]])
向量b4: tensor([1, 1])
【逐行 × 向量 相乘相加】
第1个结果:1×1 + 2×1 = 1 + 2 = 3
第2个结果:3×1 + 4×1 = 3 + 4 = 7
@ 运算结果: tensor([3, 7])
结果形状: torch.Size([2])

============================================================
规则5:3维批量矩阵 @ 3维批量矩阵 = 批量矩阵乘法
批量矩阵a5 形状: torch.Size([2, 2, 2])
批量矩阵b5 形状: torch.Size([2, 2, 2])
【第1个批次计算:】
1×1 + 2×0 = 1 | 1×0 + 2×1 = 2
3×1 + 4×0 = 3 | 3×0 + 4×1 = 4
【第2个批次计算:】
5×1 + 6×0 = 5 | 5×0 + 6×1 = 6
7×1 + 8×0 = 7 | 7×0 + 8×1 = 8
@ 运算结果:
tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]])
等价torch.bmm结果:
tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]])
结果形状: torch.Size([2, 2, 2])

以上代码测试环境 PyTorch 版本: 2.8.0+cu128

单词

英文短语 词性 含义 关键词
matrix multiply 动词短语(verb phrase) 执行矩阵相乘的动作 「做乘法」
matrix multiplication 名词短语(noun phrase) 矩阵乘法这一数学运算的概念 「乘法运算本身」
matrix product 名词短语(noun phrase) 矩阵乘法的计算结果 「矩阵积」

示例

In PyTorch, semi-structured sparsity is implemented via a Tensor subclass. By subclassing, we can override __torch_dispatch__ , allowing us to use faster sparse kernels when performing matrix multiplication. We can also store the tensor in it's compressed form inside the subclass to reduce memory overhead.

Matrix product of two tensors.

If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply . After the matrix multiply, the prepended dimension is removed.

相关推荐
_深海凉_6 小时前
LeetCode热题100-搜索二维矩阵
算法·leetcode·矩阵
盼小辉丶10 小时前
PyTorch强化学习实战(6)——交叉熵方法详解与实现
人工智能·pytorch·python·强化学习
ZhengEnCi10 小时前
06-多头注意力机制 🎯
人工智能·pytorch·python
赵优秀一一14 小时前
AI入门学习
人工智能·pytorch·深度学习
盼小辉丶15 小时前
PyTorch强化学习实战(5)——PyTorch Ignite 事件驱动机制与实践
人工智能·pytorch·python·强化学习
ZhengEnCi1 天前
05-自注意力机制详解 🧠
人工智能·pytorch·深度学习
im_AMBER1 天前
手撕hot100之矩阵!看完这篇就AC~
javascript·数据结构·线性代数·算法·leetcode·矩阵
Wadli1 天前
hot100|矩阵
线性代数·矩阵
xier_ran1 天前
【BUG问题】5060Ti显卡Windows配置Anaconda中的CUDA及Pytorch,sm_120问题
人工智能·pytorch·windows
呃呃本2 天前
算法题(矩阵)
线性代数·算法·矩阵