文章目录
-
- [1. 元素级乘法 (Element-wise Multiplication)](#1. 元素级乘法 (Element-wise Multiplication))
-
- [1.1 `torch.mul()` 或 `*` 运算符](#1.1
torch.mul()或*运算符)
- [1.1 `torch.mul()` 或 `*` 运算符](#1.1
- [2. 矩阵乘法 (Matrix Multiplication)](#2. 矩阵乘法 (Matrix Multiplication))
-
- [2.1 `torch.matmul()` 或 `@` 运算符](#2.1
torch.matmul()或@运算符) - [2.2 `torch.mm()` - 专门用于2D矩阵](#2.2
torch.mm()- 专门用于2D矩阵) - [2.3 `torch.bmm()` - 批量矩阵乘法](#2.3
torch.bmm()- 批量矩阵乘法)
- [2.1 `torch.matmul()` 或 `@` 运算符](#2.1
- [3. 点积 (Dot Product)](#3. 点积 (Dot Product))
-
- [3.1 `torch.dot()` - 向量点积](#3.1
torch.dot()- 向量点积)
- [3.1 `torch.dot()` - 向量点积](#3.1
- [4. 其他相关乘法](#4. 其他相关乘法)
-
- [4.1 `torch.einsum()` - 爱因斯坦求和约定](#4.1
torch.einsum()- 爱因斯坦求和约定) - [4.2 `torch.outer()` - 外积](#4.2
torch.outer()- 外积)
- [4.1 `torch.einsum()` - 爱因斯坦求和约定](#4.1
- [5. 对比总结表](#5. 对比总结表)
- [6. 重要区别示例](#6. 重要区别示例)
- [7. 广播机制示例](#7. 广播机制示例)
- [8. 性能建议](#8. 性能建议)
- [9. 常见错误避免](#9. 常见错误避免)
在PyTorch中,乘法运算有多种形式,每种都有不同的用途。以下详细介绍各种乘法运算:
1. 元素级乘法 (Element-wise Multiplication)
1.1 torch.mul() 或 * 运算符
python
import torch
# 元素级乘法 - 对应位置相乘
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
# 方法1: * 运算符
result1 = a * b # tensor([[ 5, 12], [21, 32]])
# 方法2: torch.mul()
result2 = torch.mul(a, b) # 同上
# 广播机制同样适用
c = torch.tensor([2, 3])
d = torch.tensor([[1, 2], [3, 4]])
result3 = d * c # 广播: c被广播为[[2,3], [2,3]]
2. 矩阵乘法 (Matrix Multiplication)
2.1 torch.matmul() 或 @ 运算符
python
# 矩阵乘法 (二维)
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
# 方法1: @ 运算符 (Python 3.5+)
result1 = A @ B # 标准的矩阵乘法
# 方法2: torch.matmul()
result2 = torch.matmul(A, B) # tensor([[19, 22], [43, 50]])
# 方法3: torch.mm() (仅限2D)
result3 = torch.mm(A, B) # 同上,但只支持2D
2.2 torch.mm() - 专门用于2D矩阵
python
# 只能用于2D张量
x = torch.randn(2, 3)
y = torch.randn(3, 4)
result = torch.mm(x, y) # 输出形状: (2, 4)
2.3 torch.bmm() - 批量矩阵乘法
python
# 批量矩阵乘法 (3D张量)
batch_size = 10
x = torch.randn(batch_size, 3, 4)
y = torch.randn(batch_size, 4, 5)
result = torch.bmm(x, y) # 输出形状: (10, 3, 5)
3. 点积 (Dot Product)
3.1 torch.dot() - 向量点积
python
# 一维向量的点积
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])
result = torch.dot(x, y) # 1*4 + 2*5 + 3*6 = 32.0
4. 其他相关乘法
4.1 torch.einsum() - 爱因斯坦求和约定
python
# 灵活的乘法表示
A = torch.randn(2, 3)
B = torch.randn(3, 4)
# 矩阵乘法
result = torch.einsum('ij,jk->ik', A, B)
# 向量点积
x = torch.randn(3)
y = torch.randn(3)
dot_product = torch.einsum('i,i->', x, y)
# 批量矩阵乘法
batch_A = torch.randn(10, 2, 3)
batch_B = torch.randn(10, 3, 4)
batch_result = torch.einsum('bij,bjk->bik', batch_A, batch_B)
4.2 torch.outer() - 外积
python
# 向量外积
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
result = torch.outer(x, y) # 3x3矩阵
5. 对比总结表
| 函数/运算符 | 维度要求 | 说明 | 示例 |
|---|---|---|---|
* 或 torch.mul() |
任意,可广播 | 元素级乘法 | a * b |
@ 或 torch.matmul() |
支持多种维度 | 灵活的矩阵乘法 | a @ b |
torch.mm() |
必须2D | 2D矩阵乘法 | torch.mm(a, b) |
torch.bmm() |
必须3D | 批量矩阵乘法 | torch.bmm(a, b) |
torch.dot() |
必须1D | 向量点积 | torch.dot(a, b) |
torch.einsum() |
任意 | 灵活的多维运算 | 见上例 |
6. 重要区别示例
python
import torch
# 创建测试张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
print("元素级乘法 (a * b):")
print(a * b)
# tensor([[ 5, 12],
# [21, 32]])
print("\n矩阵乘法 (a @ b):")
print(a @ b)
# tensor([[19, 22],
# [43, 50]])
# 计算过程说明:
# 元素级: [1,2] * [5,6] = [5,12]
# 矩阵乘: [1,2] · [5,7] = 1*5 + 2*7 = 19 (第一行第一列)
7. 广播机制示例
python
# 广播在元素级乘法中的应用
a = torch.randn(3, 4, 5)
b = torch.randn(5) # 会被广播到(1,1,5)然后到(3,4,5)
# 有效的广播乘法
result = a * b # 形状: (3, 4, 5)
# 矩阵乘法中的广播
x = torch.randn(10, 3, 4)
y = torch.randn(4, 5) # 会被广播到(10, 4, 5)
result2 = torch.matmul(x, y) # 形状: (10, 3, 5)
8. 性能建议
- 元素级乘法 :使用
*运算符(最简洁) - 矩阵乘法 :
- 2D矩阵:
@或torch.mm() - 批量3D:
torch.bmm() - 更复杂情况:
torch.matmul()
- 2D矩阵:
- 复杂运算 :考虑使用
torch.einsum()可读性更好
9. 常见错误避免
python
# 错误示例
a = torch.randn(3, 4)
b = torch.randn(3, 4)
# 错误: 想用*做矩阵乘法
# wrong = a * b # 这是元素级乘法,不是矩阵乘法
# 正确: 转置后进行矩阵乘法
correct = a @ b.T # 或者 torch.mm(a, b.T)
# 检查维度
print(f"a shape: {a.shape}, b.T shape: {b.T.shape}")
关键区别:* 是元素级乘法,@ 是矩阵乘法。