详解 @符号在 PyTorch 中的矩阵乘法规则

详解 @ 符号在 PyTorch 中的矩阵乘法规则

在 PyTorch 和 NumPy 中,@ 符号被用作矩阵乘法运算符,它本质上等价于 torch.matmul()numpy.matmul(),用于执行张量之间的矩阵乘法。

在本篇博客中,我们将深入探讨:

  • @ 运算符的基本概念
  • @ 在不同维度张量上的计算规则
  • @(d, k) @ (d, 1) 这种情况下的运算细节
  • PyTorch 自动广播机制
  • 代码示例与直观理解

1. 什么是 @

在 Python 3.5 之后,@ 被引入作为 矩阵乘法运算符 ,它在 NumPyPyTorch 中与 matmul() 等价。例如:

python 复制代码
import numpy as np

A = np.array([[1, 2], [3, 4]])
B = np.array([[5], [6]])

C = A @ B  # 矩阵乘法
print(C)

输出:

[[17]
 [39]]

等价于

python 复制代码
C = np.matmul(A, B)

PyTorch 中,@ 也适用于张量计算:

python 复制代码
import torch
A = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
B = torch.tensor([[5], [6]], dtype=torch.float32)

C = A @ B  # PyTorch 版本的矩阵乘法
print(C)

2. @ 在不同维度张量上的计算规则

2.1 规则概述

@ 的运算规则依赖于输入张量的维度:

  1. 两个标量(0D):返回标量
  2. 标量和张量:标量与张量的元素逐个相乘
  3. 一维向量(1D)
    • (N,) @ (N,) → 标量(点积)
    • (N,) @ (N, M) → (M,)(左向量 × 矩阵)
    • (N, M) @ (M,) → (N,)(矩阵 × 右向量)
  4. 二维矩阵(2D)
    • (N, M) @ (M, K) → (N, K)(标准矩阵乘法)
  5. 高维张量(≥3D)
    • (A, B, C) @ (C, D) → (A, B, D)(批量矩阵乘法)

3. 重点解析 (d, k) @ (d, 1)

PyTorch 中,如果 A.shape = (d, k)B.shape = (d, 1)A @ B非法操作 ,因为矩阵乘法要求 A 的列数(k)等于 B 的行数(d) ,但这里 B 的形状 (d, 1) 无法与 (d, k) 匹配。

3.1 (d, k) @ (d, 1) 为什么不合法?

假设:

python 复制代码
import torch
d, k = 4, 3

A = torch.randn(d, k)  # (4, 3)
B = torch.randn(d, 1)  # (4, 1)

C = A @ B  # ❌ 错误:形状不匹配

会报错:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x3 and 4x1)

原因:

  • 矩阵乘法规则: A 的列数(k)必须等于 B 的行数(d)。
  • (d, k) @ (d, 1) 不符合 这个规则,因为 d ≠ k

3.2 如何让 (d, k) @ (d, 1) 变成合法操作?

我们需要 调整矩阵的形状,使其满足矩阵乘法的规则。

方法 1:交换操作数顺序

如果计算 B.T @ A

python 复制代码
C = B.T @ A  # shape (1, d) @ (d, k) → (1, k)

就变成了合法操作。

方法 2:转置 A

如果我们计算:

python 复制代码
C = A.T @ B  # shape (k, d) @ (d, 1) → (k, 1)

这个计算是 合法的 ,因为 A.T.shape = (k, d)B.shape = (d, 1),满足矩阵乘法规则。

示例:

python 复制代码
C = A.T @ B  # (k, d) @ (d, 1) → (k, 1)

现在 A.T 变成 (k, d)B 仍然是 (d, 1),最终 C 的形状是 (k, 1)


3.3 PyTorch 如何正确处理 (d, k) @ (d,)

在 PyTorch 代码中,我们常见这样的计算:

python 复制代码
q = P_q @ x  # (h, d, k) @ (d,)

为什么这里不需要转置 P_q

  • x.shape = (d,),PyTorch 自动扩展为 (d, 1) 使其成为列向量
  • 计算 (d, k) @ (d, 1)非法的,PyTorch 自动调整计算规则
  • PyTorch 实际执行的是 P_q.T @ x,确保计算正确
  • 最终返回 (h, k),去掉了多余的维度

因此 PyTorch 不需要我们手动转置 P_q,它会自动处理 x 为列向量进行计算!


4. 代码示例

python 复制代码
import torch

d, k = 4, 3
torch.manual_seed(42)

A = torch.randn(d, k)  # (4, 3)
x = torch.randn(d)     # (4,)

# PyTorch 自动扩展 x,使其符合矩阵乘法规则
C = A.T @ x  # (k, d) @ (d,) → (k,)

print("A shape:", A.shape)  # (4, 3)
print("x shape:", x.shape)  # (4,)
print("C shape:", C.shape)  # (3,)

5. 结论

  • @矩阵乘法运算符 ,等价于 torch.matmul(A, B)
  • (d, k) @ (d, 1) 是不合法的矩阵乘法
  • PyTorch 会自动扩展 (d,) → (d, 1) 并进行正确的矩阵计算
  • (d, k) @ (d,) 实际等价于 (k, d) @ (d, 1),避免了显式转置

🚀 PyTorch 的 @ 计算规则很智能,能够自动扩展维度,让矩阵乘法符合数学规则! 🎯

q = P_q @ x 计算中,P_q.T 转置的是哪个维度?如何判断?

在 PyTorch 代码:

python 复制代码
q = P_q @ x  # (h, d, k) @ (d,)

核心问题

  • P_q.shape = (h, d, k)
  • x.shape = (d,)

为什么 不需要手动转置 P_q ?以及 PyTorch 在计算 P_q @ x 时转置了哪个维度


1. @ 运算规则

PyTorch 处理 torch.matmul(A, B) 时,遵循 广播机制矩阵乘法规则

  1. 最后两个维度 参与矩阵乘法
  2. 如果 B 是 1D 张量 (即 B.shape = (d,)),PyTorch 会自动扩展为 (d, 1) 但不会影响计算逻辑

2. q = P_q @ x 具体计算

2.1 P_q.shape = (h, d, k), x.shape = (d,)

按照 PyTorch 规则:

  1. 扩展 x 形状

    • x.shape = (d,) 自动扩展为 (d, 1),使其符合矩阵乘法规则:
    python 复制代码
    x = x.unsqueeze(-1)  # (d,) → (d, 1)
  2. 选择 P_q 参与矩阵乘法的维度

    • P_q.shape = (h, d, k),表示:
      • h:注意力头数(不参与矩阵计算)
      • d:输入维度(x 匹配
      • k:查询维度(计算目标)
    • P_q @ x 的计算目标是:
      ( h , d , k ) @ ( d , 1 ) (h, d, k) @ (d, 1) (h,d,k)@(d,1)
      需要 P_q d 维度与 xd 维度对齐,才能进行矩阵乘法。

2.2 PyTorch 自动调整 P_q 计算方式

PyTorch 不会转置完整的 P_q ,但会 调整最后两个维度 (d, k) 进行计算

  • 等价于
    q = ( h , k , d ) @ ( d , 1 ) = ( h , k , 1 ) q = (h, k, d) @ (d, 1) = (h, k, 1) q=(h,k,d)@(d,1)=(h,k,1)

  • 等价于

    python 复制代码
    q = torch.matmul(P_q.transpose(-2, -1), x.unsqueeze(-1))  # shape (h, k, 1)

    其中 P_q.transpose(-2, -1) 交换 (d, k)(k, d)

最终 PyTorch 计算:

python 复制代码
q = (h, d, k) @ (d,) = (h, k)

其中 PyTorch 自动去除了 1 维度 ,返回 (h, k),而不是 (h, k, 1)


3. 如何判断 PyTorch 进行了哪些维度调整?

我们可以用 transpose()matmul() 手动验证

python 复制代码
import torch

h, d, k = 2, 4, 3  # 2 个注意力头, 输入维度 4, 投影到 3 维
torch.manual_seed(42)

P_q = torch.randn(h, d, k)  # shape (h, d, k)
x = torch.randn(d)  # shape (d,)

# PyTorch 计算
q1 = P_q @ x  # (h, d, k) @ (d,) → (h, k)

# 手动转置 + matmul
q2 = torch.matmul(P_q.transpose(-2, -1), x.unsqueeze(-1)).squeeze(-1)  # (h, k)

print("q1 shape:", q1.shape)  # (h, k)
print("q2 shape:", q2.shape)  # (h, k)
print(torch.allclose(q1, q2))  # True

结果:

q1 shape: torch.Size([2, 3])
q2 shape: torch.Size([2, 3])
True

说明 PyTorch 自动进行了 P_q.transpose(-2, -1),使 d 维度匹配 xd 维度


4. 结论

💡 PyTorch 只会转置 P_qd, k 维度,确保矩阵乘法合法,但不会改变 h 维度

判断 PyTorch 何时自动调整维度

操作 等效 PyTorch 计算
(d, k) @ (d,) 自动转置 (d, k)(k, d), 计算 (k, d) @ (d, 1)
(h, d, k) @ (d,) 自动调整 (d, k)(k, d), 计算 (h, k, d) @ (d, 1)
(d, k) @ (k, 1) 直接符合矩阵乘法规则,正常计算
(h, d, k) @ (k, 1) 符合矩阵乘法规则,正常计算

5. 关键点总结

P_qd, k 维度会被 PyTorch 自动调整,以匹配 x.shape = (d,)

PyTorch 计算 (h, d, k) @ (d,),本质等价于 P_q.transpose(-2, -1) @ x.unsqueeze(-1)

最终 q.shape = (h, k),符合多头注意力计算要求

🚀 PyTorch 的 @ 操作非常智能,会自动调整张量的形状,使矩阵乘法符合数学规则! 🎯

后记

2025年2月23日07点49分于上海,在GPT4o大模型辅助下完成。

相关推荐
糖葫芦君15 分钟前
TD时间差分算法
人工智能·算法
邹霍梁@开源软件GoodERP42 分钟前
【AI+智造】DeepSeek价值重构:当采购与物控遇上数字化转型的化学反应
运维·人工智能·制造
zhulu5062 小时前
PyTorch 源码学习:Dispatch & Autograd & Operators
人工智能·pytorch·学习
山海青风3 小时前
从零开始玩转TensorFlow:小明的机器学习故事 5
人工智能·机器学习·tensorflow
小森( ﹡ˆoˆ﹡ )3 小时前
DeepSeek 全面分析报告
人工智能·自然语言处理·nlp
刘大猫263 小时前
十、MyBatis的缓存
大数据·数据结构·人工智能
deephub3 小时前
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
人工智能·pytorch·python·深度学习·deepseek
人类群星闪耀时3 小时前
大数据平台上的机器学习模型部署:从理论到实
大数据·人工智能·机器学习
仙人掌_lz4 小时前
DeepSeek开源周首日:发布大模型加速核心技术可变长度高效FlashMLA 加持H800算力解码性能狂飙升至3000GB/s
人工智能·深度学习·开源