PyTorch 里的矩阵乘法

Pyorch 里的矩阵乘法

flyfish

输入 运算
1D × 1D dot product
2D × 2D matrix multiply
1D × 2D vector × matrix
2D × 1D matrix × vector
ND × ND batch matrix multiply

torch.mm / matmul / bmm 区别

函数 作用
torch.mm 只支持 2D矩阵乘法
torch.matmul 通用乘法
torch.bmm batch矩阵乘法

mm = matrix multiply

bmm = batch matrix multiply

matmul = matrix multiply

矩阵乘法其实是:

复制代码
行 × 列

例如

复制代码
A = [a1 a2 a3]
B = [b1
     b2
     b3]

结果:

复制代码
a1*b1 + a2*b2 + a3*b3

就是 dot product

矩阵乘法 = 行向量和列向量的 点乘集合dot product

概念 是否矩阵乘法
dot product 矩阵乘法的基本单元
cross product 不是矩阵乘法
A @ B 等价 torch.matmul
torch.mm 2D矩阵
torch.bmm batch矩阵
torch.matmul 通用
cpp 复制代码
import torch

# ==============================================
# 规则1:1D @ 1D → 点积(标量)
# 展示:逐元素相乘 → 全部相加 → 最终结果
# ==============================================
print("="*60)
print("规则1:1维向量 @ 1维向量 = 点积")
a1 = torch.tensor([1, 2, 3])
b1 = torch.tensor([4, 5, 6])
print(f"向量a1: {a1}")
print(f"向量b1: {b1}")
# 手动计算过程
print("【逐元素相乘 + 求和过程】")
print(f"1 × 4 = 4")
print(f"2 × 5 = 10")
print(f"3 × 6 = 18")
print(f"总和:4 + 10 + 18 = 32")
# 代码运算
res_at = a1 @ b1
res_eq = torch.dot(a1, b1)
print(f"@ 运算结果: {res_at}")
print(f"等价torch.dot结果: {res_eq}")
print(f"结果形状: {res_at.shape}\n")

# ==============================================
# 规则2:2D @ 2D → 标准矩阵乘法
# 展示:每行 × 每列 → 对应元素相乘 → 相加 → 矩阵元素
# ==============================================
print("="*60)
print("规则2:2维矩阵 @ 2维矩阵 = 矩阵乘法")
a2 = torch.tensor([[1, 2], [3, 4]])
b2 = torch.tensor([[5, 6], [7, 8]])
print(f"矩阵a2:\n{a2}")
print(f"矩阵b2:\n{b2}")
# 手动计算过程(逐个元素计算)
print("【逐元素计算过程:行 × 列 相乘后相加】")
print("第1行第1列:1×5 + 2×7 = 5 + 14 = 19")
print("第1行第2列:1×6 + 2×8 = 6 + 16 = 22")
print("第2行第1列:3×5 + 4×7 = 15 + 28 = 43")
print("第2行第2列:3×6 + 4×8 = 18 + 32 = 50")
# 代码运算
res_at = a2 @ b2
res_eq = torch.mm(a2, b2)
print(f"@ 运算结果:\n{res_at}")
print(f"等价torch.mm结果:\n{res_eq}")
print(f"结果形状: {res_at.shape}\n")

# ==============================================
# 规则3:1D @ 2D → 自动升维计算
# 展示:1维升维 → 行×列相乘相加 → 降维输出
# ==============================================
print("="*60)
print("规则3:1维向量 @ 2维矩阵 = 自动升维计算")
a3 = torch.tensor([1, 2])
b3 = torch.tensor([[1, 0], [0, 1]])
print(f"向量a3: {a3}")
print(f"矩阵b3:\n{b3}")
# 手动计算过程
print("【计算过程:a3升维为(1,2),再行×列相乘相加】")
print("第1列:1×1 + 2×0 = 1 + 0 = 1")
print("第2列:1×0 + 2×1 = 0 + 2 = 2")
# 代码运算
res_at = a3 @ b3
print(f"@ 运算结果: {res_at}")
print(f"结果形状: {res_at.shape}\n")

# ==============================================
# 规则4:2D @ 1D → 矩阵-向量乘法
# 展示:矩阵每行 × 向量 → 相乘相加 → 输出向量
# ==============================================
print("="*60)
print("规则4:2维矩阵 @ 1维向量 = 矩阵-向量乘法")
a4 = torch.tensor([[1, 2], [3, 4]])
b4 = torch.tensor([1, 1])
print(f"矩阵a4:\n{a4}")
print(f"向量b4: {b4}")
# 手动计算过程
print("【逐行 × 向量 相乘相加】")
print("第1个结果:1×1 + 2×1 = 1 + 2 = 3")
print("第2个结果:3×1 + 4×1 = 3 + 4 = 7")
# 代码运算
res_at = a4 @ b4
print(f"@ 运算结果: {res_at}")
print(f"结果形状: {res_at.shape}\n")

# ==============================================
# 规则5:3D @ 3D → 批量矩阵乘法
# 展示:每个批次独立计算矩阵乘法
# ==============================================
print("="*60)
print("规则5:3维批量矩阵 @ 3维批量矩阵 = 批量矩阵乘法")
a5 = torch.tensor([[[1,2],[3,4]], [[5,6],[7,8]]])
b5 = torch.tensor([[[1,0],[0,1]], [[1,0],[0,1]]])
print(f"批量矩阵a5 形状: {a5.shape}")
print(f"批量矩阵b5 形状: {b5.shape}")
# 手动计算过程(分批次计算)
print("【第1个批次计算:】")
print("1×1 + 2×0 = 1 | 1×0 + 2×1 = 2")
print("3×1 + 4×0 = 3 | 3×0 + 4×1 = 4")
print("【第2个批次计算:】")
print("5×1 + 6×0 = 5 | 5×0 + 6×1 = 6")
print("7×1 + 8×0 = 7 | 7×0 + 8×1 = 8")
# 代码运算
res_at = a5 @ b5
res_eq = torch.bmm(a5, b5)
print(f"@ 运算结果:\n{res_at}")
print(f"等价torch.bmm结果:\n{res_eq}")
print(f"结果形状: {res_at.shape}")

输出

复制代码
============================================================
规则1:1维向量 @ 1维向量 = 点积
向量a1: tensor([1, 2, 3])
向量b1: tensor([4, 5, 6])
【逐元素相乘 + 求和过程】
1 × 4 = 4
2 × 5 = 10
3 × 6 = 18
总和:4 + 10 + 18 = 32
@ 运算结果: 32
等价torch.dot结果: 32
结果形状: torch.Size([])

============================================================
规则2:2维矩阵 @ 2维矩阵 = 矩阵乘法
矩阵a2:
tensor([[1, 2],
        [3, 4]])
矩阵b2:
tensor([[5, 6],
        [7, 8]])
【逐元素计算过程:行 × 列 相乘后相加】
第1行第1列:1×5 + 2×7 = 5 + 14 = 19
第1行第2列:1×6 + 2×8 = 6 + 16 = 22
第2行第1列:3×5 + 4×7 = 15 + 28 = 43
第2行第2列:3×6 + 4×8 = 18 + 32 = 50
@ 运算结果:
tensor([[19, 22],
        [43, 50]])
等价torch.mm结果:
tensor([[19, 22],
        [43, 50]])
结果形状: torch.Size([2, 2])

============================================================
规则3:1维向量 @ 2维矩阵 = 自动升维计算
向量a3: tensor([1, 2])
矩阵b3:
tensor([[1, 0],
        [0, 1]])
【计算过程:a3升维为(1,2),再行×列相乘相加】
第1列:1×1 + 2×0 = 1 + 0 = 1
第2列:1×0 + 2×1 = 0 + 2 = 2
@ 运算结果: tensor([1, 2])
结果形状: torch.Size([2])

============================================================
规则4:2维矩阵 @ 1维向量 = 矩阵-向量乘法
矩阵a4:
tensor([[1, 2],
        [3, 4]])
向量b4: tensor([1, 1])
【逐行 × 向量 相乘相加】
第1个结果:1×1 + 2×1 = 1 + 2 = 3
第2个结果:3×1 + 4×1 = 3 + 4 = 7
@ 运算结果: tensor([3, 7])
结果形状: torch.Size([2])

============================================================
规则5:3维批量矩阵 @ 3维批量矩阵 = 批量矩阵乘法
批量矩阵a5 形状: torch.Size([2, 2, 2])
批量矩阵b5 形状: torch.Size([2, 2, 2])
【第1个批次计算:】
1×1 + 2×0 = 1 | 1×0 + 2×1 = 2
3×1 + 4×0 = 3 | 3×0 + 4×1 = 4
【第2个批次计算:】
5×1 + 6×0 = 5 | 5×0 + 6×1 = 6
7×1 + 8×0 = 7 | 7×0 + 8×1 = 8
@ 运算结果:
tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]])
等价torch.bmm结果:
tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]])
结果形状: torch.Size([2, 2, 2])

以上代码测试环境 PyTorch 版本: 2.8.0+cu128

单词

英文短语 词性 含义 关键词
matrix multiply 动词短语(verb phrase) 执行矩阵相乘的动作 「做乘法」
matrix multiplication 名词短语(noun phrase) 矩阵乘法这一数学运算的概念 「乘法运算本身」
matrix product 名词短语(noun phrase) 矩阵乘法的计算结果 「矩阵积」

示例

In PyTorch, semi-structured sparsity is implemented via a Tensor subclass. By subclassing, we can override __torch_dispatch__ , allowing us to use faster sparse kernels when performing matrix multiplication. We can also store the tensor in it's compressed form inside the subclass to reduce memory overhead.

Matrix product of two tensors.

If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply . After the matrix multiply, the prepended dimension is removed.

相关推荐
DeepLearningYolo3 小时前
UNet架构训练输电线路、输电杆塔、水泥杆和输电线路木头杆塔的语义分割模型检测输电线路分割
pytorch·深度学习·yolo·目标检测
kishu_iOS&AI4 小时前
深度学习 —— Pytorch
人工智能·pytorch·深度学习
脱氧核糖核酸__5 小时前
LeetCode热题100——73.矩阵置零(题目+题解+答案)
c++·算法·leetcode·矩阵
llm大模型算法工程师weng6 小时前
模型训练与知识蒸馏:从大模型到轻量级情绪分析系统
pytorch·深度学习·机器学习
郝学胜-神的一滴6 小时前
ReLU激活函数全解析:从原理到实战,解锁深度学习核心激活单元
人工智能·pytorch·python·深度学习·算法
脱氧核糖核酸__6 小时前
LeetCode热题100——54.螺旋矩阵(题解+答案+要点)
c++·算法·leetcode·矩阵
努力学习_小白20 小时前
ResNet-50——pytorch版
人工智能·pytorch·python
小辉同志1 天前
74. 搜索二维矩阵
c++·leetcode·矩阵·二分查找
mailangduoduo1 天前
实战对比PyTorch VS PyTorch Lighting以MNIST为例
人工智能·pytorch·python·深度学习·图像分类·全连接网络