PyTorch中乘法运算详细介绍

文章目录

    • [1. 元素级乘法 (Element-wise Multiplication)](#1. 元素级乘法 (Element-wise Multiplication))
      • [1.1 `torch.mul()` 或 `*` 运算符](#1.1 torch.mul()* 运算符)
    • [2. 矩阵乘法 (Matrix Multiplication)](#2. 矩阵乘法 (Matrix Multiplication))
      • [2.1 `torch.matmul()` 或 `@` 运算符](#2.1 torch.matmul()@ 运算符)
      • [2.2 `torch.mm()` - 专门用于2D矩阵](#2.2 torch.mm() - 专门用于2D矩阵)
      • [2.3 `torch.bmm()` - 批量矩阵乘法](#2.3 torch.bmm() - 批量矩阵乘法)
    • [3. 点积 (Dot Product)](#3. 点积 (Dot Product))
      • [3.1 `torch.dot()` - 向量点积](#3.1 torch.dot() - 向量点积)
    • [4. 其他相关乘法](#4. 其他相关乘法)
      • [4.1 `torch.einsum()` - 爱因斯坦求和约定](#4.1 torch.einsum() - 爱因斯坦求和约定)
      • [4.2 `torch.outer()` - 外积](#4.2 torch.outer() - 外积)
    • [5. 对比总结表](#5. 对比总结表)
    • [6. 重要区别示例](#6. 重要区别示例)
    • [7. 广播机制示例](#7. 广播机制示例)
    • [8. 性能建议](#8. 性能建议)
    • [9. 常见错误避免](#9. 常见错误避免)

在PyTorch中,乘法运算有多种形式,每种都有不同的用途。以下详细介绍各种乘法运算:

1. 元素级乘法 (Element-wise Multiplication)

1.1 torch.mul()* 运算符

python 复制代码
import torch

# 元素级乘法 - 对应位置相乘
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])

# 方法1: * 运算符
result1 = a * b  # tensor([[ 5, 12], [21, 32]])

# 方法2: torch.mul()
result2 = torch.mul(a, b)  # 同上

# 广播机制同样适用
c = torch.tensor([2, 3])
d = torch.tensor([[1, 2], [3, 4]])
result3 = d * c  # 广播: c被广播为[[2,3], [2,3]]

2. 矩阵乘法 (Matrix Multiplication)

2.1 torch.matmul()@ 运算符

python 复制代码
# 矩阵乘法 (二维)
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])

# 方法1: @ 运算符 (Python 3.5+)
result1 = A @ B  # 标准的矩阵乘法

# 方法2: torch.matmul()
result2 = torch.matmul(A, B)  # tensor([[19, 22], [43, 50]])

# 方法3: torch.mm() (仅限2D)
result3 = torch.mm(A, B)  # 同上,但只支持2D

2.2 torch.mm() - 专门用于2D矩阵

python 复制代码
# 只能用于2D张量
x = torch.randn(2, 3)
y = torch.randn(3, 4)
result = torch.mm(x, y)  # 输出形状: (2, 4)

2.3 torch.bmm() - 批量矩阵乘法

python 复制代码
# 批量矩阵乘法 (3D张量)
batch_size = 10
x = torch.randn(batch_size, 3, 4)
y = torch.randn(batch_size, 4, 5)
result = torch.bmm(x, y)  # 输出形状: (10, 3, 5)

3. 点积 (Dot Product)

3.1 torch.dot() - 向量点积

python 复制代码
# 一维向量的点积
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])
result = torch.dot(x, y)  # 1*4 + 2*5 + 3*6 = 32.0

4. 其他相关乘法

4.1 torch.einsum() - 爱因斯坦求和约定

python 复制代码
# 灵活的乘法表示
A = torch.randn(2, 3)
B = torch.randn(3, 4)

# 矩阵乘法
result = torch.einsum('ij,jk->ik', A, B)

# 向量点积
x = torch.randn(3)
y = torch.randn(3)
dot_product = torch.einsum('i,i->', x, y)

# 批量矩阵乘法
batch_A = torch.randn(10, 2, 3)
batch_B = torch.randn(10, 3, 4)
batch_result = torch.einsum('bij,bjk->bik', batch_A, batch_B)

4.2 torch.outer() - 外积

python 复制代码
# 向量外积
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
result = torch.outer(x, y)  # 3x3矩阵

5. 对比总结表

函数/运算符 维度要求 说明 示例
*torch.mul() 任意,可广播 元素级乘法 a * b
@torch.matmul() 支持多种维度 灵活的矩阵乘法 a @ b
torch.mm() 必须2D 2D矩阵乘法 torch.mm(a, b)
torch.bmm() 必须3D 批量矩阵乘法 torch.bmm(a, b)
torch.dot() 必须1D 向量点积 torch.dot(a, b)
torch.einsum() 任意 灵活的多维运算 见上例

6. 重要区别示例

python 复制代码
import torch

# 创建测试张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])

print("元素级乘法 (a * b):")
print(a * b)
# tensor([[ 5, 12],
#         [21, 32]])

print("\n矩阵乘法 (a @ b):")
print(a @ b)
# tensor([[19, 22],
#         [43, 50]])

# 计算过程说明:
# 元素级: [1,2] * [5,6] = [5,12]
# 矩阵乘: [1,2] · [5,7] = 1*5 + 2*7 = 19 (第一行第一列)

7. 广播机制示例

python 复制代码
# 广播在元素级乘法中的应用
a = torch.randn(3, 4, 5)
b = torch.randn(5)  # 会被广播到(1,1,5)然后到(3,4,5)

# 有效的广播乘法
result = a * b  # 形状: (3, 4, 5)

# 矩阵乘法中的广播
x = torch.randn(10, 3, 4)
y = torch.randn(4, 5)  # 会被广播到(10, 4, 5)
result2 = torch.matmul(x, y)  # 形状: (10, 3, 5)

8. 性能建议

  1. 元素级乘法 :使用 * 运算符(最简洁)
  2. 矩阵乘法
    • 2D矩阵:@torch.mm()
    • 批量3D:torch.bmm()
    • 更复杂情况:torch.matmul()
  3. 复杂运算 :考虑使用 torch.einsum() 可读性更好

9. 常见错误避免

python 复制代码
# 错误示例
a = torch.randn(3, 4)
b = torch.randn(3, 4)

# 错误: 想用*做矩阵乘法
# wrong = a * b  # 这是元素级乘法,不是矩阵乘法

# 正确: 转置后进行矩阵乘法
correct = a @ b.T  # 或者 torch.mm(a, b.T)

# 检查维度
print(f"a shape: {a.shape}, b.T shape: {b.T.shape}")

关键区别:* 是元素级乘法,@ 是矩阵乘法

相关推荐
高工智能汽车4 小时前
爱芯元智通过港交所聆讯,智能汽车芯片市场格局加速重构
人工智能·重构·汽车
大力财经4 小时前
悬架、底盘、制动被同时重构,星空计划想把“驾驶”变成一种系统能力
人工智能
喵手5 小时前
Python爬虫零基础入门【第九章:实战项目教学·第15节】搜索页采集:关键词队列 + 结果去重 + 反爬友好策略!
爬虫·python·爬虫实战·python爬虫工程化实战·零基础python爬虫教学·搜索页采集·关键词队列
梁下轻语的秋缘5 小时前
Prompt工程核心指南:从入门到精通,让AI精准响应你的需求
大数据·人工智能·prompt
FreeBuf_5 小时前
ChatGPT引用马斯克AI生成的Grokipedia是否陷入“内容陷阱“?
人工智能·chatgpt
Suchadar5 小时前
if判断语句——Python
开发语言·python
ʚB҉L҉A҉C҉K҉.҉基҉德҉^҉大5 小时前
自动化机器学习(AutoML)库TPOT使用指南
jvm·数据库·python
福客AI智能客服5 小时前
工单智转:电商智能客服与客服AI系统重构售后服务效率
大数据·人工智能
柳鲲鹏5 小时前
OpenCV:超分辨率、超采样及测试性能
人工智能·opencv·计算机视觉
喵手5 小时前
Python爬虫零基础入门【第九章:实战项目教学·第14节】表格型页面采集:多列、多行、跨页(通用表格解析)!
爬虫·python·python爬虫实战·python爬虫工程化实战·python爬虫零基础入门·表格型页面采集·通用表格解析