PyTorch乘法全解析

在PyTorch里做乘法,可不是随便选个符号就行------不同的乘法对应着完全不同的运算逻辑,用错了轻则结果不对,重则直接报错。今天就把PyTorch里常用的几种乘法掰开揉碎了讲,从原理到代码例子,保证你看完能分清什么时候该用哪种。

一、逐元素乘法:* 运算符与torch.mul()

原理

这种乘法最直观,就是两个张量对应位置的元素相乘,要求两个张量的形状要么完全相同 ,要么满足广播机制(Broadcast)------简单说就是形状可以通过自动扩展维度来匹配,比如一个(3,)的向量和一个(1,3)的矩阵,会自动把向量扩展成(1,3)再逐元素相乘。

代码例子

python 复制代码
import torch

# 形状完全相同的情况
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
result1 = a * b
result2 = torch.mul(a, b)
print("逐元素乘法结果(*):\n", result1)
print("逐元素乘法结果(torch.mul):\n", result2)

# 广播机制的情况
c = torch.tensor([10, 20])  # 形状(2,)
d = torch.tensor([[1], [2]])  # 形状(2,1)
result3 = c * d
print("广播逐元素乘法结果:\n", result3)

输出结果

复制代码
逐元素乘法结果(*):
 tensor([[ 5, 12],
        [21, 32]])
逐元素乘法结果(torch.mul):
 tensor([[ 5, 12],
        [21, 32]])
广播逐元素乘法结果:
 tensor([[10, 20],
        [20, 40]])

可以看到*torch.mul()的效果完全一致,前者是后者的语法糖,日常写代码用*更简洁。

二、矩阵乘法:@ 运算符与torch.matmul()

原理

这就是线性代数里标准的矩阵乘法,要求第一个张量的最后一个维度大小 等于第二个张量的倒数第二个维度大小,比如形状为(m, n)的矩阵和(n, p)的矩阵相乘,结果是(m, p)的矩阵。

它还支持更高维的张量运算,比如处理批量矩阵:如果输入是(batch_size, m, n)和(batch_size, n, p),结果就是(batch_size, m, p),会自动对每个batch单独做矩阵乘法。

代码例子

python 复制代码
import torch

# 二维矩阵乘法
a = torch.tensor([[1, 2], [3, 4]])  # 形状(2,2)
b = torch.tensor([[5, 6], [7, 8]])  # 形状(2,2)
result1 = a @ b
result2 = torch.matmul(a, b)
print("矩阵乘法结果(@):\n", result1)
print("矩阵乘法结果(torch.matmul):\n", result2)

# 批量矩阵乘法
batch_a = torch.randn(3, 2, 3)  # 3个(2,3)的矩阵
batch_b = torch.randn(3, 3, 4)  # 3个(3,4)的矩阵
result3 = batch_a @ batch_b
print("批量矩阵乘法结果形状:", result3.shape)

输出结果

复制代码
矩阵乘法结果(@):
 tensor([[19, 22],
        [43, 50]])
矩阵乘法结果(torch.matmul):
 tensor([[19, 22],
        [43, 50]])
批量矩阵乘法结果形状: torch.Size([3, 2, 4])

这里@同样是torch.matmul()的语法糖,对于二维矩阵来说,两者完全等价;高维张量运算时,torch.matmul()会自动识别批量维度,非常适合深度学习里的批量数据处理。

三、向量点积:torch.dot()

原理

专门用于两个一维张量(向量)的点积,即对应元素相乘后求和,要求两个向量的长度必须相同。注意它只能处理一维张量,输入高维张量会报错。

代码例子

python 复制代码
import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.dot(a, b)
print("向量点积结果:", result)

输出结果

复制代码
向量点积结果: tensor(32)

计算过程是14 + 25 + 3*6 = 4+10+18=32,和数学上的点积定义完全一致。

四、批量矩阵乘法专用:torch.bmm()

原理

torch.matmul()的批量矩阵乘法类似,但它要求输入必须是三维张量,且第一个维度是batch_size,后面两个维度是矩阵的行和列,即输入形状为(batch_size, m, n)和(batch_size, n, p),输出是(batch_size, m, p)。

它的限制比torch.matmul()更严格,不支持广播,必须保证两个输入的batch_size一致,且中间维度匹配。

代码例子

python 复制代码
import torch

batch_a = torch.randn(2, 3, 4)  # 2个(3,4)的矩阵
batch_b = torch.randn(2, 4, 5)  # 2个(4,5)的矩阵
result = torch.bmm(batch_a, batch_b)
print("torch.bmm结果形状:", result.shape)

输出结果

复制代码
torch.bmm结果形状: torch.Size([2, 3, 5])

如果尝试输入非三维张量,比如二维矩阵,torch.bmm()会直接报错,适合明确知道是批量矩阵乘法的场景,避免不小心触发广播导致错误。

五、总结对比

乘法方式 运算逻辑 形状要求 适用场景
* / torch.mul() 逐元素相乘 形状相同或满足广播 对应元素的乘积运算
@ / torch.matmul() 矩阵/批量矩阵乘法 第一个最后一维=第二个倒数第二维,支持广播 线性变换、网络层运算
torch.dot() 一维向量点积 必须是一维张量,长度相同 向量相似度计算等
torch.bmm() 三维批量矩阵乘法 必须是三维张量,batch_size一致,中间维度匹配 明确的批量矩阵运算,避免广播

人能力有限,有问题随时联系。

相关推荐
程序员cxuan4 小时前
一句话,让你用上 GPT-5.6
人工智能·后端·程序员
机器之心4 小时前
AI圈刚开始谈Loop Engineering,两位95后博士已经盯上了人类闭环数据
人工智能·openai
澄旭4 小时前
一文讲清 MCP:AI 应用连接外部世界的标准协议
人工智能
机器之心4 小时前
不只DeepSeek,阶跃等开源JetSpec:大模型解码提速近10倍
人工智能·openai
moMo5 小时前
当LLM学会"递纸条",AI是如何调用工具的
人工智能
拾年2755 小时前
大模型的"聪明"从哪来?聊聊 AI 数据集的那些事儿
人工智能·深度学习·机器学习
拾年2755 小时前
从 Prompt 到 Context 再到 Harness:AI 工程化的三年三级跳
人工智能
小九九的爸爸5 小时前
前端想要入门Agent开发,要具备哪些Python基础?
python·agent·ai编程
用户3090463613945 小时前
Claude 不会直接执行你的函数,它只会生成一段结构化的工具调用请求。真正执行函数、访问数据库、请求外部 API 的动作,必须由你的后端完成。
人工智能