使用torch普通算子组合替代torch.einsum爱因斯坦求和

1. torch.einsum('bnd, bmd->bnm', x, y)

torch.einsum('bnd, bmd->bnm', x, y) 表示的是对张量 x 和 y 进行特定的求和和维度变换。

具体来说,这个操作的输入是两个形状为 [b, n, d] 和 [b, m, d] 的张量 x 和 y,输出是一个形状为 [b, n, m] 的张量 z。其计算过程可以理解为:对于每个 b,z[b, n, m] 等于 x[b, n, :] 和 y[b, m, :] 之间的点积。

为了用普通的 torch 操作符来替代 einsum,我们可以通过 torch.matmul 函数实现。这个函数可以用来执行批量矩阵乘法,并且能够很好地替代这个 einsum 操作。

具体实现如下:

python 复制代码
import torch

# 假设 x 和 y 的形状分别为 (b, n, d) 和 (b, m, d)
x = torch.randn(10, 20, 30)  # 举例
y = torch.randn(10, 15, 30)  # 举例

# einsum: z = torch.einsum('bnd, bmd->bnm', x, y)
# 可以转换为以下操作:
z = torch.matmul(x, y.transpose(-1, -2))  # z 的形状为 (b, n, m)

# 检查 z 的形状是否正确
print(z.shape)

2. torch.einsum('ij,jk->ik', A, B)

可以用普通的矩阵乘法 torch.matmul 替代

具体实现如下:

python 复制代码
import torch

A = torch.rand(3, 4)
B = torch.rand(4, 5)

# 使用 einsum
result_einsum = torch.einsum('ij,jk->ik', A, B)

# 使用 matmul
result_matmul = torch.matmul(A, B)

# 验证结果相同
print(torch.allclose(result_einsum, result_matmul))

3. torch.einsum('bij,bjk->bik', A, B)

可以用 torch.bmm 来替代

具体实现如下:

python 复制代码
import torch

A = torch.rand(10, 3, 4)
B = torch.rand(10, 4, 5)

# 使用 einsum
result_einsum = torch.einsum('bij,bjk->bik', A, B)

# 使用 bmm
result_bmm = torch.bmm(A, B)

# 验证结果相同
print(torch.allclose(result_einsum, result_bmm))

4. torch.einsum('i,i->', A, B)

向量内积,可以用 torch.dot 来替代

具体实现如下:

python 复制代码
import torch

A = torch.rand(4)
B = torch.rand(4)

# 使用 einsum
result_einsum = torch.einsum('i,i->', A, B)

# 使用 dot
result_dot = torch.dot(A, B)

# 验证结果相同
print(torch.allclose(result_einsum, result_dot))

5. torch.einsum('i,j->ij', A, B)

向量外积,可以用 torch.outer 来替代

具体实现如下:

python 复制代码
import torch

A = torch.rand(4)
B = torch.rand(5)

# 使用 einsum
result_einsum = torch.einsum('i,j->ij', A, B)

# 使用 outer
result_outer = torch.outer(A, B)

# 验证结果相同
print(torch.allclose(result_einsum, result_outer))

不同的 einsum 表达式会对应不同的替代操作,有时可能需要组合多个普通操作来达到相同的效果。如果某些 einsum 表达式太复杂,使用普通算子替代时会比较繁琐,此时建议继续使用 einsum,因为它不仅更简洁,而且通常性能优化得很好。
后续遇到其余需替换的 op 再进行更新

相关推荐
qq_5710993537 分钟前
学习周报三十九
人工智能·深度学习·机器学习
Ulyanov42 分钟前
卡尔曼滤波技术博客系列:第四篇:多目标跟踪:数据关联与航迹管理
python·目标跟踪·系统仿真·雷达电子战·仿真引擎
Three~stone1 小时前
MATLAB vs Python 两者区别和安装教程
开发语言·python·matlab
soragui1 小时前
【Python】第 1 章:Python 解释器原理
开发语言·python
Ulyanov1 小时前
卡尔曼滤波技术博客系列:第三篇 雷达目标跟踪:运动模型与坐标转换
python·目标跟踪·系统仿真·雷达电子战
AI医影跨模态组学1 小时前
Radiology子刊(IF=6.3)复旦大学附属金山医院强金伟教授等团队:基于多参数MRI的深度学习和影像组学评估早期宫颈癌淋巴结转移
人工智能·深度学习·论文·医学·医学影像
nimadan121 小时前
生成剧本杀软件2025推荐,创新剧情设计工具引领潮流
人工智能·python
极光代码工作室1 小时前
基于深度学习的智能垃圾分类系统
python·深度学习·神经网络·机器学习·ai
MediaTea2 小时前
Pandas 操作指南(二):数据选取与条件筛选
人工智能·python·机器学习·数据挖掘·pandas
小陈工2 小时前
Python Web开发入门(十二):使用Flask-RESTful构建API——让后端开发更优雅
开发语言·前端·python·安全·oracle·flask·restful