日拱一卒之pytorch中的矩阵乘法
新建一个栏目,主要的点在于每天都要有点进步,每天进步千分之一,总归会有强大的一天
view(改变形状)与matmul(矩阵乘法)
在Python的PyTorch 库(深度学习框架)中,view 和 matmul 是两个非常核心的张量(Tensor)操作函数。
-
view :用于改变形状(Reshape)。 -
matmul :用于矩阵乘法(Matrix Multiplication)。
1. view 函数
含义
view 用于将一个张量(Tensor)重塑为不同的形状,但保持数据元素的总数不变 ,且不改变底层数据的内存(它仅仅是改变了看待数据的方式,因此叫 "view")。
- 规则:变换前后的元素总个数必须相等(例如:4x4 的矩阵有 16 个元素,可以变成 1x16, 2x8, 8x2 等,但不能变成 3x5)。
-
-1 的用法 :如果你确定了其他维度,想让 PyTorch 自动计算剩下的那个维度,可以使用-1。
举例说明
python
import torch
# 创建一个包含 0 到 5 的一维张量 (总共 6 个元素)
x = torch.tensor([0, 1, 2, 3, 4, 5])
print("原始数据:", x.shape) # torch.Size([6])
# 1. 将其变为 2行3列 的矩阵
y = x.view(2, 3)
print("\n变为 2x3 矩阵:\n", y)
# 输出:
# tensor([[0, 1, 2],
# [3, 4, 5]])
# 2. 使用 -1 自动推导维度
# 假设我们想要 3 行,列数让 PyTorch 自己算
z = x.view(3, -1)
print("\n变为 3行x自动计算列:\n", z)
# 输出:
# tensor([[0, 1],
# [2, 3],
# [4, 5]]) (自动变成了 3x2)
# 3. 深度学习常见场景:拉平(Flatten)
# 假设有一个 batch 为 2,通道为 1,高宽为 2x2 的图像数据
img_batch = torch.rand(2, 1, 2, 2) # 形状 [2, 1, 2, 2], 总共 8 个元素
# 我们想把它拉平成 [batch_size, features] 的形式,即 [2, 4]
flattened = img_batch.view(2, -1)
print(f"\n拉平后的形状: {flattened.shape}") # torch.Size([2, 4])
2. matmul 函数
含义
matmul 是 Matrix Multiplication 的缩写,执行的是标准的线性代数中的矩阵乘法(行乘以列求和)。
- 符号 :在 Python 3.5+ 中,可以使用
@符号作为matmul的简写。 - 区别 :它不同于普通的
*号。*号在 PyTorch 和 NumPy 中通常代表逐元素相乘(Element-wise multiplication)。 - 维度规则 :如果矩阵 A 是 (N×M)(N \times M)(N×M),矩阵 B 是 (M×K)(M \times K)(M×K),那么
matmul(A, B)的结果是 (N×K)(N \times K)(N×K)。中间的维度 MMM 必须对齐。
举例说明
python
import torch
# 创建两个矩阵
# A: 2行3列
A = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# B: 3行2列
B = torch.tensor([[7, 8],
[9, 10],
[11, 12]])
# --- 使用 matmul 进行矩阵乘法 ---
# 运算逻辑: (2x3) 乘以 (3x2) -> 结果应该是 (2x2)
# 例如结果左上角的数计算方式: 1*7 + 2*9 + 3*11 = 7 + 18 + 33 = 58
result = torch.matmul(A, B)
print("矩阵乘法结果 (matmul):\n", result)
# 或者使用 @ 符号 (推荐写法)
result_at = A @ B
print("使用 @ 符号的结果:\n", result_at)
# 输出:
# tensor([[ 58, 64],
# [139, 154]])
# --- 对比:使用 * 号 (逐元素相乘) ---
# 逐元素相乘要求形状必须相同,或者满足广播机制
C = torch.tensor([[1, 2], [3, 4]])
D = torch.tensor([[1, 0], [0, 1]])
print("\n逐元素相乘 (*):\n", C * D)
# 输出:
# tensor([[1, 0],
# [0, 4]])
# (对应位置简单的相乘:1*1, 2*0, 3*0, 4*1)
三维向量的矩阵乘法
在PyTorch中,当对两个三维张量 进行矩阵乘法(matmul 或 @)时,执行的操作被称为 批量矩阵乘法 (Batch Matrix Multiplication) 。
简单来说,PyTorch 会把第一个维度看作"批次(Batch)" ,不动它,然后对剩下的两个维度进行标准的 2D 矩阵乘法。
1. 计算规则与形状变化
假设有两个三维张量 AAA 和 BBB:
- 张量 A 的形状 :(B,N,M)(B, N, M)(B,N,M)
- 张量 B 的形状 :(B,M,K)(B, M, K)(B,M,K)
计算逻辑:
- 第一个维度 BBB (Batch Size) :必须相等(或者其中一个是1,触发广播机制)。PyTorch 会在这个维度上逐个取出矩阵。
- 后两个维度 :执行标准的矩阵乘法 (N×M)(N \times M)(N×M) 乘以 (M×K)(M \times K)(M×K)。中间的维度 MMM 必须对齐。
- 结果形状 :(B,N,K)(B, N, K)(B,N,K)。
可以理解为下面的伪代码循环:
python
# 假设 Result 是结果张量
for i in range(B):
# 取出第 i 个切片(是一个 2D 矩阵)进行乘法
Result[i] = matmul(A[i], B[i])
2. 举例说明
示例场景
- 我们有
2个批次的数据。 - 每个数据是一个
2行3列的矩阵。 - 我们要乘以一个
3行2列的矩阵。
形状变换 :(2,2,3)×(2,3,2)→(2,2,2)(2, 2, 3) \times (2, 3, 2) \rightarrow (2, 2, 2)(2,2,3)×(2,3,2)→(2,2,2)
代码演示
python
import torch
# --- 1. 定义数据 ---
# batch_size=2, 形状 (2, 3)
input1 = torch.tensor([
# 第1个矩阵 (Batch 0)
[[1, 0, 0],
[0, 1, 0]],
# 第2个矩阵 (Batch 1)
[[1, 2, 3],
[4, 5, 6]]
])
print(f"Input1 形状: {input1.shape}") # torch.Size([2, 2, 3])
# batch_size=2, 形状 (3, 2)
input2 = torch.tensor([
# 第1个矩阵 (Batch 0) - 全1矩阵以便观察
[[1, 1],
[1, 1],
[1, 1]],
# 第2个矩阵 (Batch 1) - 单位矩阵类似的结构
[[1, 0],
[0, 1],
[0, 0]]
])
print(f"Input2 形状: {input2.shape}") # torch.Size([2, 3, 2])
# --- 2. 执行矩阵乘法 ---
# 方式一:使用 matmul
output = torch.matmul(input1, input2)
# 方式二:使用 @ 符号
output_at = input1 @ input2
# 方式三:专门针对3D的 bmm (Batch Matrix Matrix product)
output_bmm = torch.bmm(input1, input2)
print("\n--- 结果 ---")
print(f"结果形状: {output.shape}") # torch.Size([2, 2, 2])
print(output)
手动推导验证结果
Batch 0 的计算 (对应结果的 output[0]):
-
左边 :[100010]\begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \end{bmatrix}[100100] (看起来像截断的单位矩阵)
-
右边 :[111111]\begin{bmatrix} 1 & 1 \\ 1 & 1 \\ 1 & 1 \end{bmatrix} 111111
-
计算:
- 第一行 x 第一列:1∗1+0∗1+0∗1=11*1 + 0*1 + 0*1 = 11∗1+0∗1+0∗1=1
- ...以此类推
-
Batch 0 结果 :[1111]\begin{bmatrix} 1 & 1 \\ 1 & 1 \end{bmatrix}[1111]
Batch 1 的计算 (对应结果的 output[1]):
-
左边 :[123456]\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix}[142536]
-
右边 :[100100]\begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 0 & 0 \end{bmatrix} 100010
-
计算:
- 左上:1∗1+2∗0+3∗0=11*1 + 2*0 + 3*0 = 11∗1+2∗0+3∗0=1
- 右上:1∗0+2∗1+3∗0=21*0 + 2*1 + 3*0 = 21∗0+2∗1+3∗0=2
- 左下:4∗1+5∗0+6∗0=44*1 + 5*0 + 6*0 = 44∗1+5∗0+6∗0=4
- 右下:4∗0+5∗1+6∗0=54*0 + 5*1 + 6*0 = 54∗0+5∗1+6∗0=5
-
Batch 1 结果 :[1245]\begin{bmatrix} 1 & 2 \\ 4 & 5 \end{bmatrix}[1425]