PyTorch中乘法运算详细介绍

文章目录

    • [1. 元素级乘法 (Element-wise Multiplication)](#1. 元素级乘法 (Element-wise Multiplication))
      • [1.1 `torch.mul()` 或 `*` 运算符](#1.1 torch.mul()* 运算符)
    • [2. 矩阵乘法 (Matrix Multiplication)](#2. 矩阵乘法 (Matrix Multiplication))
      • [2.1 `torch.matmul()` 或 `@` 运算符](#2.1 torch.matmul()@ 运算符)
      • [2.2 `torch.mm()` - 专门用于2D矩阵](#2.2 torch.mm() - 专门用于2D矩阵)
      • [2.3 `torch.bmm()` - 批量矩阵乘法](#2.3 torch.bmm() - 批量矩阵乘法)
    • [3. 点积 (Dot Product)](#3. 点积 (Dot Product))
      • [3.1 `torch.dot()` - 向量点积](#3.1 torch.dot() - 向量点积)
    • [4. 其他相关乘法](#4. 其他相关乘法)
      • [4.1 `torch.einsum()` - 爱因斯坦求和约定](#4.1 torch.einsum() - 爱因斯坦求和约定)
      • [4.2 `torch.outer()` - 外积](#4.2 torch.outer() - 外积)
    • [5. 对比总结表](#5. 对比总结表)
    • [6. 重要区别示例](#6. 重要区别示例)
    • [7. 广播机制示例](#7. 广播机制示例)
    • [8. 性能建议](#8. 性能建议)
    • [9. 常见错误避免](#9. 常见错误避免)

在PyTorch中,乘法运算有多种形式,每种都有不同的用途。以下详细介绍各种乘法运算:

1. 元素级乘法 (Element-wise Multiplication)

1.1 torch.mul()* 运算符

python 复制代码
import torch

# 元素级乘法 - 对应位置相乘
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])

# 方法1: * 运算符
result1 = a * b  # tensor([[ 5, 12], [21, 32]])

# 方法2: torch.mul()
result2 = torch.mul(a, b)  # 同上

# 广播机制同样适用
c = torch.tensor([2, 3])
d = torch.tensor([[1, 2], [3, 4]])
result3 = d * c  # 广播: c被广播为[[2,3], [2,3]]

2. 矩阵乘法 (Matrix Multiplication)

2.1 torch.matmul()@ 运算符

python 复制代码
# 矩阵乘法 (二维)
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])

# 方法1: @ 运算符 (Python 3.5+)
result1 = A @ B  # 标准的矩阵乘法

# 方法2: torch.matmul()
result2 = torch.matmul(A, B)  # tensor([[19, 22], [43, 50]])

# 方法3: torch.mm() (仅限2D)
result3 = torch.mm(A, B)  # 同上,但只支持2D

2.2 torch.mm() - 专门用于2D矩阵

python 复制代码
# 只能用于2D张量
x = torch.randn(2, 3)
y = torch.randn(3, 4)
result = torch.mm(x, y)  # 输出形状: (2, 4)

2.3 torch.bmm() - 批量矩阵乘法

python 复制代码
# 批量矩阵乘法 (3D张量)
batch_size = 10
x = torch.randn(batch_size, 3, 4)
y = torch.randn(batch_size, 4, 5)
result = torch.bmm(x, y)  # 输出形状: (10, 3, 5)

3. 点积 (Dot Product)

3.1 torch.dot() - 向量点积

python 复制代码
# 一维向量的点积
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])
result = torch.dot(x, y)  # 1*4 + 2*5 + 3*6 = 32.0

4. 其他相关乘法

4.1 torch.einsum() - 爱因斯坦求和约定

python 复制代码
# 灵活的乘法表示
A = torch.randn(2, 3)
B = torch.randn(3, 4)

# 矩阵乘法
result = torch.einsum('ij,jk->ik', A, B)

# 向量点积
x = torch.randn(3)
y = torch.randn(3)
dot_product = torch.einsum('i,i->', x, y)

# 批量矩阵乘法
batch_A = torch.randn(10, 2, 3)
batch_B = torch.randn(10, 3, 4)
batch_result = torch.einsum('bij,bjk->bik', batch_A, batch_B)

4.2 torch.outer() - 外积

python 复制代码
# 向量外积
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
result = torch.outer(x, y)  # 3x3矩阵

5. 对比总结表

函数/运算符 维度要求 说明 示例
*torch.mul() 任意,可广播 元素级乘法 a * b
@torch.matmul() 支持多种维度 灵活的矩阵乘法 a @ b
torch.mm() 必须2D 2D矩阵乘法 torch.mm(a, b)
torch.bmm() 必须3D 批量矩阵乘法 torch.bmm(a, b)
torch.dot() 必须1D 向量点积 torch.dot(a, b)
torch.einsum() 任意 灵活的多维运算 见上例

6. 重要区别示例

python 复制代码
import torch

# 创建测试张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])

print("元素级乘法 (a * b):")
print(a * b)
# tensor([[ 5, 12],
#         [21, 32]])

print("\n矩阵乘法 (a @ b):")
print(a @ b)
# tensor([[19, 22],
#         [43, 50]])

# 计算过程说明:
# 元素级: [1,2] * [5,6] = [5,12]
# 矩阵乘: [1,2] · [5,7] = 1*5 + 2*7 = 19 (第一行第一列)

7. 广播机制示例

python 复制代码
# 广播在元素级乘法中的应用
a = torch.randn(3, 4, 5)
b = torch.randn(5)  # 会被广播到(1,1,5)然后到(3,4,5)

# 有效的广播乘法
result = a * b  # 形状: (3, 4, 5)

# 矩阵乘法中的广播
x = torch.randn(10, 3, 4)
y = torch.randn(4, 5)  # 会被广播到(10, 4, 5)
result2 = torch.matmul(x, y)  # 形状: (10, 3, 5)

8. 性能建议

  1. 元素级乘法 :使用 * 运算符(最简洁)
  2. 矩阵乘法
    • 2D矩阵:@torch.mm()
    • 批量3D:torch.bmm()
    • 更复杂情况:torch.matmul()
  3. 复杂运算 :考虑使用 torch.einsum() 可读性更好

9. 常见错误避免

python 复制代码
# 错误示例
a = torch.randn(3, 4)
b = torch.randn(3, 4)

# 错误: 想用*做矩阵乘法
# wrong = a * b  # 这是元素级乘法,不是矩阵乘法

# 正确: 转置后进行矩阵乘法
correct = a @ b.T  # 或者 torch.mm(a, b.T)

# 检查维度
print(f"a shape: {a.shape}, b.T shape: {b.T.shape}")

关键区别:* 是元素级乘法,@ 是矩阵乘法

相关推荐
小超同学你好7 分钟前
面向 LLM 的程序设计 6:Tool Calling 的完整生命周期——从定义、决策、执行到观测回注
人工智能·语言模型
明日清晨13 分钟前
python扫码登录dy
开发语言·python
智星云算力26 分钟前
本地GPU与租用GPU混合部署:混合算力架构搭建指南
人工智能·架构·gpu算力·智星云·gpu租用
bazhange26 分钟前
python如何像matlab一样使用向量化替代for循环
开发语言·python·matlab
jinanwuhuaguo27 分钟前
截止到4月8日,OpenClaw 2026年4月更新深度解读剖析:从“能力回归”到“信任内建”的范式跃迁
android·开发语言·人工智能·深度学习·kotlin
xiaozhazha_31 分钟前
效率提升80%:2026年AI CRM与ERP深度集成的架构设计与实现
人工智能
枫叶林FYL32 分钟前
【自然语言处理 NLP】7.2.2 安全性评估与Constitutional AI
人工智能·自然语言处理
AI人工智能+39 分钟前
基于高精度身份证OCR识别、炫彩活体检测及人脸比对技术的人脸核身系统,为通信行业数字化转型提供了坚实的安全底座
人工智能·计算机视觉·人脸识别·ocr·人脸核身
人工干智能40 分钟前
科普:python中你写的模块找不到了——`ModuleNotFoundError`
服务器·python
unicrom_深圳市由你创科技1 小时前
做虚拟示波器这种实时波形显示的上位机,用什么语言?
c++·python·c#