Pyorch 里的矩阵乘法
flyfish

| 输入 | 运算 |
|---|---|
| 1D × 1D | dot product |
| 2D × 2D | matrix multiply |
| 1D × 2D | vector × matrix |
| 2D × 1D | matrix × vector |
| ND × ND | batch matrix multiply |
torch.mm / matmul / bmm 区别
| 函数 | 作用 |
|---|---|
torch.mm |
只支持 2D矩阵乘法 |
torch.matmul |
通用乘法 |
torch.bmm |
batch矩阵乘法 |
mm = matrix multiply
bmm = batch matrix multiply
matmul = matrix multiply
矩阵乘法其实是:
行 × 列
例如
A = [a1 a2 a3]
B = [b1
b2
b3]
结果:
a1*b1 + a2*b2 + a3*b3
就是 dot product
矩阵乘法 = 行向量和列向量的 点乘集合dot product。
| 概念 | 是否矩阵乘法 |
|---|---|
| dot product | 矩阵乘法的基本单元 |
| cross product | 不是矩阵乘法 |
| A @ B | 等价 torch.matmul |
| torch.mm | 2D矩阵 |
| torch.bmm | batch矩阵 |
| torch.matmul | 通用 |
cpp
import torch
# ==============================================
# 规则1:1D @ 1D → 点积(标量)
# 展示:逐元素相乘 → 全部相加 → 最终结果
# ==============================================
print("="*60)
print("规则1:1维向量 @ 1维向量 = 点积")
a1 = torch.tensor([1, 2, 3])
b1 = torch.tensor([4, 5, 6])
print(f"向量a1: {a1}")
print(f"向量b1: {b1}")
# 手动计算过程
print("【逐元素相乘 + 求和过程】")
print(f"1 × 4 = 4")
print(f"2 × 5 = 10")
print(f"3 × 6 = 18")
print(f"总和:4 + 10 + 18 = 32")
# 代码运算
res_at = a1 @ b1
res_eq = torch.dot(a1, b1)
print(f"@ 运算结果: {res_at}")
print(f"等价torch.dot结果: {res_eq}")
print(f"结果形状: {res_at.shape}\n")
# ==============================================
# 规则2:2D @ 2D → 标准矩阵乘法
# 展示:每行 × 每列 → 对应元素相乘 → 相加 → 矩阵元素
# ==============================================
print("="*60)
print("规则2:2维矩阵 @ 2维矩阵 = 矩阵乘法")
a2 = torch.tensor([[1, 2], [3, 4]])
b2 = torch.tensor([[5, 6], [7, 8]])
print(f"矩阵a2:\n{a2}")
print(f"矩阵b2:\n{b2}")
# 手动计算过程(逐个元素计算)
print("【逐元素计算过程:行 × 列 相乘后相加】")
print("第1行第1列:1×5 + 2×7 = 5 + 14 = 19")
print("第1行第2列:1×6 + 2×8 = 6 + 16 = 22")
print("第2行第1列:3×5 + 4×7 = 15 + 28 = 43")
print("第2行第2列:3×6 + 4×8 = 18 + 32 = 50")
# 代码运算
res_at = a2 @ b2
res_eq = torch.mm(a2, b2)
print(f"@ 运算结果:\n{res_at}")
print(f"等价torch.mm结果:\n{res_eq}")
print(f"结果形状: {res_at.shape}\n")
# ==============================================
# 规则3:1D @ 2D → 自动升维计算
# 展示:1维升维 → 行×列相乘相加 → 降维输出
# ==============================================
print("="*60)
print("规则3:1维向量 @ 2维矩阵 = 自动升维计算")
a3 = torch.tensor([1, 2])
b3 = torch.tensor([[1, 0], [0, 1]])
print(f"向量a3: {a3}")
print(f"矩阵b3:\n{b3}")
# 手动计算过程
print("【计算过程:a3升维为(1,2),再行×列相乘相加】")
print("第1列:1×1 + 2×0 = 1 + 0 = 1")
print("第2列:1×0 + 2×1 = 0 + 2 = 2")
# 代码运算
res_at = a3 @ b3
print(f"@ 运算结果: {res_at}")
print(f"结果形状: {res_at.shape}\n")
# ==============================================
# 规则4:2D @ 1D → 矩阵-向量乘法
# 展示:矩阵每行 × 向量 → 相乘相加 → 输出向量
# ==============================================
print("="*60)
print("规则4:2维矩阵 @ 1维向量 = 矩阵-向量乘法")
a4 = torch.tensor([[1, 2], [3, 4]])
b4 = torch.tensor([1, 1])
print(f"矩阵a4:\n{a4}")
print(f"向量b4: {b4}")
# 手动计算过程
print("【逐行 × 向量 相乘相加】")
print("第1个结果:1×1 + 2×1 = 1 + 2 = 3")
print("第2个结果:3×1 + 4×1 = 3 + 4 = 7")
# 代码运算
res_at = a4 @ b4
print(f"@ 运算结果: {res_at}")
print(f"结果形状: {res_at.shape}\n")
# ==============================================
# 规则5:3D @ 3D → 批量矩阵乘法
# 展示:每个批次独立计算矩阵乘法
# ==============================================
print("="*60)
print("规则5:3维批量矩阵 @ 3维批量矩阵 = 批量矩阵乘法")
a5 = torch.tensor([[[1,2],[3,4]], [[5,6],[7,8]]])
b5 = torch.tensor([[[1,0],[0,1]], [[1,0],[0,1]]])
print(f"批量矩阵a5 形状: {a5.shape}")
print(f"批量矩阵b5 形状: {b5.shape}")
# 手动计算过程(分批次计算)
print("【第1个批次计算:】")
print("1×1 + 2×0 = 1 | 1×0 + 2×1 = 2")
print("3×1 + 4×0 = 3 | 3×0 + 4×1 = 4")
print("【第2个批次计算:】")
print("5×1 + 6×0 = 5 | 5×0 + 6×1 = 6")
print("7×1 + 8×0 = 7 | 7×0 + 8×1 = 8")
# 代码运算
res_at = a5 @ b5
res_eq = torch.bmm(a5, b5)
print(f"@ 运算结果:\n{res_at}")
print(f"等价torch.bmm结果:\n{res_eq}")
print(f"结果形状: {res_at.shape}")
输出
============================================================
规则1:1维向量 @ 1维向量 = 点积
向量a1: tensor([1, 2, 3])
向量b1: tensor([4, 5, 6])
【逐元素相乘 + 求和过程】
1 × 4 = 4
2 × 5 = 10
3 × 6 = 18
总和:4 + 10 + 18 = 32
@ 运算结果: 32
等价torch.dot结果: 32
结果形状: torch.Size([])
============================================================
规则2:2维矩阵 @ 2维矩阵 = 矩阵乘法
矩阵a2:
tensor([[1, 2],
[3, 4]])
矩阵b2:
tensor([[5, 6],
[7, 8]])
【逐元素计算过程:行 × 列 相乘后相加】
第1行第1列:1×5 + 2×7 = 5 + 14 = 19
第1行第2列:1×6 + 2×8 = 6 + 16 = 22
第2行第1列:3×5 + 4×7 = 15 + 28 = 43
第2行第2列:3×6 + 4×8 = 18 + 32 = 50
@ 运算结果:
tensor([[19, 22],
[43, 50]])
等价torch.mm结果:
tensor([[19, 22],
[43, 50]])
结果形状: torch.Size([2, 2])
============================================================
规则3:1维向量 @ 2维矩阵 = 自动升维计算
向量a3: tensor([1, 2])
矩阵b3:
tensor([[1, 0],
[0, 1]])
【计算过程:a3升维为(1,2),再行×列相乘相加】
第1列:1×1 + 2×0 = 1 + 0 = 1
第2列:1×0 + 2×1 = 0 + 2 = 2
@ 运算结果: tensor([1, 2])
结果形状: torch.Size([2])
============================================================
规则4:2维矩阵 @ 1维向量 = 矩阵-向量乘法
矩阵a4:
tensor([[1, 2],
[3, 4]])
向量b4: tensor([1, 1])
【逐行 × 向量 相乘相加】
第1个结果:1×1 + 2×1 = 1 + 2 = 3
第2个结果:3×1 + 4×1 = 3 + 4 = 7
@ 运算结果: tensor([3, 7])
结果形状: torch.Size([2])
============================================================
规则5:3维批量矩阵 @ 3维批量矩阵 = 批量矩阵乘法
批量矩阵a5 形状: torch.Size([2, 2, 2])
批量矩阵b5 形状: torch.Size([2, 2, 2])
【第1个批次计算:】
1×1 + 2×0 = 1 | 1×0 + 2×1 = 2
3×1 + 4×0 = 3 | 3×0 + 4×1 = 4
【第2个批次计算:】
5×1 + 6×0 = 5 | 5×0 + 6×1 = 6
7×1 + 8×0 = 7 | 7×0 + 8×1 = 8
@ 运算结果:
tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
等价torch.bmm结果:
tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
结果形状: torch.Size([2, 2, 2])
以上代码测试环境 PyTorch 版本: 2.8.0+cu128
单词
| 英文短语 | 词性 | 含义 | 关键词 |
|---|---|---|---|
| matrix multiply | 动词短语(verb phrase) | 执行矩阵相乘的动作 | 「做乘法」 |
| matrix multiplication | 名词短语(noun phrase) | 矩阵乘法这一数学运算的概念 | 「乘法运算本身」 |
| matrix product | 名词短语(noun phrase) | 矩阵乘法的计算结果 | 「矩阵积」 |
示例
In PyTorch, semi-structured sparsity is implemented via a Tensor subclass. By subclassing, we can override __torch_dispatch__ , allowing us to use faster sparse kernels when performing matrix multiplication. We can also store the tensor in it's compressed form inside the subclass to reduce memory overhead.
Matrix product of two tensors.
If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply . After the matrix multiply, the prepended dimension is removed.