日拱一卒之pytorch中的矩阵乘法

日拱一卒之pytorch中的矩阵乘法

新建一个栏目,主要的点在于每天都要有点进步,每天进步千分之一,总归会有强大的一天

view(改变形状)与matmul(矩阵乘法)

在Python的PyTorch 库(深度学习框架)中,view​ 和 matmul 是两个非常核心的张量(Tensor)操作函数。

  • view :用于改变形状(Reshape)。
  • matmul :用于矩阵乘法(Matrix Multiplication)。

1. view 函数

含义

view​ 用于将一个张量(Tensor)重塑为不同的形状,但保持数据元素的总数不变 ,且不改变底层数据的内存(它仅仅是改变了看待数据的方式,因此叫 "view")。

  • 规则:变换前后的元素总个数必须相等(例如:4x4 的矩阵有 16 个元素,可以变成 1x16, 2x8, 8x2 等,但不能变成 3x5)。
  • -1 的用法 :如果你确定了其他维度,想让 PyTorch 自动计算剩下的那个维度,可以使用 -1
举例说明
python 复制代码
import torch

# 创建一个包含 0 到 5 的一维张量 (总共 6 个元素)
x = torch.tensor([0, 1, 2, 3, 4, 5])
print("原始数据:", x.shape)  # torch.Size([6])

# 1. 将其变为 2行3列 的矩阵
y = x.view(2, 3)
print("\n变为 2x3 矩阵:\n", y)
# 输出:
# tensor([[0, 1, 2],
#         [3, 4, 5]])

# 2. 使用 -1 自动推导维度
# 假设我们想要 3 行,列数让 PyTorch 自己算
z = x.view(3, -1)
print("\n变为 3行x自动计算列:\n", z)
# 输出:
# tensor([[0, 1],
#         [2, 3],
#         [4, 5]]) (自动变成了 3x2)

# 3. 深度学习常见场景:拉平(Flatten)
# 假设有一个 batch 为 2,通道为 1,高宽为 2x2 的图像数据
img_batch = torch.rand(2, 1, 2, 2) # 形状 [2, 1, 2, 2], 总共 8 个元素
# 我们想把它拉平成 [batch_size, features] 的形式,即 [2, 4]
flattened = img_batch.view(2, -1) 
print(f"\n拉平后的形状: {flattened.shape}") # torch.Size([2, 4])

2. matmul 函数

含义

matmul​ 是 Matrix Multiplication 的缩写,执行的是标准的线性代数中的矩阵乘法(行乘以列求和)。

  • 符号 :在 Python 3.5+ 中,可以使用 @ 符号作为 matmul 的简写。
  • 区别 :它不同于普通的 * 号。* 号在 PyTorch 和 NumPy 中通常代表逐元素相乘(Element-wise multiplication)。
  • 维度规则 :如果矩阵 A 是 (N×M)(N \times M)(N×M),矩阵 B 是 (M×K)(M \times K)(M×K),那么 matmul(A, B) 的结果是 (N×K)(N \times K)(N×K)。中间的维度 MMM 必须对齐。
举例说明
python 复制代码
import torch

# 创建两个矩阵
# A: 2行3列
A = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

# B: 3行2列
B = torch.tensor([[7, 8],
                  [9, 10],
                  [11, 12]])

# --- 使用 matmul 进行矩阵乘法 ---
# 运算逻辑: (2x3) 乘以 (3x2) -> 结果应该是 (2x2)
# 例如结果左上角的数计算方式: 1*7 + 2*9 + 3*11 = 7 + 18 + 33 = 58
result = torch.matmul(A, B)
print("矩阵乘法结果 (matmul):\n", result)
# 或者使用 @ 符号 (推荐写法)
result_at = A @ B 
print("使用 @ 符号的结果:\n", result_at)

# 输出:
# tensor([[ 58,  64],
#         [139, 154]])

# --- 对比:使用 * 号 (逐元素相乘) ---
# 逐元素相乘要求形状必须相同,或者满足广播机制
C = torch.tensor([[1, 2], [3, 4]])
D = torch.tensor([[1, 0], [0, 1]])

print("\n逐元素相乘 (*):\n", C * D)
# 输出:
# tensor([[1, 0],
#         [0, 4]])
# (对应位置简单的相乘:1*1, 2*0, 3*0, 4*1)

三维向量的矩阵乘法

在PyTorch中,当对两个三维张量 进行矩阵乘法(matmul​ 或 @​)时,执行的操作被称为 批量矩阵乘法 (Batch Matrix Multiplication)

简单来说,PyTorch 会把第一个维度看作"批次(Batch)" ,不动它,然后对剩下的两个维度进行标准的 2D 矩阵乘法。

1. 计算规则与形状变化

假设有两个三维张量 AAA 和 BBB:

  • 张量 A 的形状 :(B,N,M)(B, N, M)(B,N,M)
  • 张量 B 的形状 :(B,M,K)(B, M, K)(B,M,K)

计算逻辑:

  1. 第一个维度 BBB (Batch Size) :必须相等(或者其中一个是1,触发广播机制)。PyTorch 会在这个维度上逐个取出矩阵。
  2. 后两个维度 :执行标准的矩阵乘法 (N×M)(N \times M)(N×M) 乘以 (M×K)(M \times K)(M×K)。中间的维度 MMM 必须对齐。
  3. 结果形状 :(B,N,K)(B, N, K)(B,N,K)。

可以理解为下面的伪代码循环:

python 复制代码
# 假设 Result 是结果张量
for i in range(B):
    # 取出第 i 个切片(是一个 2D 矩阵)进行乘法
    Result[i] = matmul(A[i], B[i]) 

2. 举例说明

示例场景
  • 我们有 2 个批次的数据。
  • 每个数据是一个 2行3列 的矩阵。
  • 我们要乘以一个 3行2列 的矩阵。

形状变换 :(2,2,3)×(2,3,2)→(2,2,2)(2, 2, 3) \times (2, 3, 2) \rightarrow (2, 2, 2)(2,2,3)×(2,3,2)→(2,2,2)

代码演示
python 复制代码
import torch

# --- 1. 定义数据 ---
# batch_size=2, 形状 (2, 3)
input1 = torch.tensor([
    # 第1个矩阵 (Batch 0)
    [[1, 0, 0],
     [0, 1, 0]],
    
    # 第2个矩阵 (Batch 1)
    [[1, 2, 3],
     [4, 5, 6]]
]) 
print(f"Input1 形状: {input1.shape}") # torch.Size([2, 2, 3])

# batch_size=2, 形状 (3, 2)
input2 = torch.tensor([
    # 第1个矩阵 (Batch 0) - 全1矩阵以便观察
    [[1, 1],
     [1, 1],
     [1, 1]],
    
    # 第2个矩阵 (Batch 1) - 单位矩阵类似的结构
    [[1, 0],
     [0, 1],
     [0, 0]]
])
print(f"Input2 形状: {input2.shape}") # torch.Size([2, 3, 2])

# --- 2. 执行矩阵乘法 ---
# 方式一:使用 matmul
output = torch.matmul(input1, input2)
# 方式二:使用 @ 符号
output_at = input1 @ input2
# 方式三:专门针对3D的 bmm (Batch Matrix Matrix product)
output_bmm = torch.bmm(input1, input2)

print("\n--- 结果 ---")
print(f"结果形状: {output.shape}") # torch.Size([2, 2, 2])
print(output)
手动推导验证结果

Batch 0 的计算 (对应结果的 output[0]):

  • 左边 :[100010]\begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \end{bmatrix}[100100] (看起来像截断的单位矩阵)

  • 右边 :[111111]\begin{bmatrix} 1 & 1 \\ 1 & 1 \\ 1 & 1 \end{bmatrix} 111111

  • 计算

    • 第一行 x 第一列:1∗1+0∗1+0∗1=11*1 + 0*1 + 0*1 = 11∗1+0∗1+0∗1=1
    • ...以此类推
  • Batch 0 结果 :[1111]\begin{bmatrix} 1 & 1 \\ 1 & 1 \end{bmatrix}[1111]

Batch 1 的计算 (对应结果的 output[1]):

  • 左边 :[123456]\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix}[142536]

  • 右边 :[100100]\begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 0 & 0 \end{bmatrix} 100010

  • 计算

    • 左上:1∗1+2∗0+3∗0=11*1 + 2*0 + 3*0 = 11∗1+2∗0+3∗0=1
    • 右上:1∗0+2∗1+3∗0=21*0 + 2*1 + 3*0 = 21∗0+2∗1+3∗0=2
    • 左下:4∗1+5∗0+6∗0=44*1 + 5*0 + 6*0 = 44∗1+5∗0+6∗0=4
    • 右下:4∗0+5∗1+6∗0=54*0 + 5*1 + 6*0 = 54∗0+5∗1+6∗0=5
  • Batch 1 结果 :[1245]\begin{bmatrix} 1 & 2 \\ 4 & 5 \end{bmatrix}[1425]

相关推荐
多恩Stone13 小时前
【3DV 进阶-9】Hunyuan3D2.1 中的 MoE
人工智能·pytorch·python·算法·aigc
shayudiandian13 小时前
TensorFlow vs PyTorch:哪个更适合你?
人工智能·pytorch·tensorflow
Keep_Trying_Go14 小时前
基于Transformer的目标统计方法(CounTR: Transformer-based Generalised Visual Counting)
人工智能·pytorch·python·深度学习·transformer·多模态·目标统计
Aspect of twilight1 天前
PyTorch DDP分布式训练Pytorch代码讲解
人工智能·pytorch·python
tomeasure1 天前
INTERNAL ASSERT FAILED at “/pytorch/c10/cuda/CUDACachingAllocator.cpp“:983
人工智能·pytorch·python·nvidia
一瞬祈望1 天前
【环境配置】Windows 下使用 Anaconda 创建 Python 3.8 环境 + 安装 PyTorch + CUDA(完整教程)
pytorch·windows·python
喜乐boy1 天前
CV系列——Conda + PyTorch + CUDA + cuDNN + Python 环境无脑安装速查笔记[2025.12]
pytorch·python·conda·cuda·cv
过尽漉雪千山1 天前
Anaconda的虚拟环境下使用清华源镜像安装Pytorch
人工智能·pytorch·python·深度学习·机器学习
weixin_457760001 天前
GIOU (Generalized Intersection over Union) 详解
pytorch·python