5min看懂torch.einsum()计算方法-torch.einsum()手动推导详解

引言: torch.einsum()的分析和介绍已经有很多博客介绍过了, 但大多数的落脚点都是爱因斯坦求和约定,许多篇幅是用于介绍爱因斯坦求和约定到的各项法则,而实际案例分析方面只是草草给出一笔带过,涉及到的案例也较为简单。而实际我们要用到或者看到torch.einsum()的时候往往是在计算非常复杂的情况下。 因此本文将从实际复杂案例的角度对torch.einsum()的计算过程进行分析,一步一步的推导最终输出的每个元素和输入元素之间的关系。

爱因斯坦求和约定

首先,torch.einsum()的基础原理是爱因斯坦求和约定,此处为了行文的整体性将对其进行简要的介绍,如果只关注计算本身,可以跳到下一节。爱因斯坦求和约定是为了简化计算而诞生的一种"记法",就类似于我们用 <math xmlns="http://www.w3.org/1998/Math/MathML"> × \times </math>×来标记乘法一样,不同之处在于爱因斯坦求和约定可表示的运算更为复杂、灵活性也更高。爱因斯坦求和约定的典型写法为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> i 1 i 2 . . . i N , j 1 j 2 . . . j M → i k 1 i k 2 . . j l 1 j l 1 , k 1 . . . ∈ N , l 1 . . l ∈ M i_1i_2...i_N,j_1j_2...j_M\rightarrow i_{k_1}i_{k_2}..j_{l_1}j_{l_1},k_1...\in N,l_1..l\in M </math>i1i2...iN,j1j2...jM→ik1ik2..jl1jl1,k1...∈N,l1..l∈M

其中左端 <math xmlns="http://www.w3.org/1998/Math/MathML"> i 1 i 2 . . . i N , j 1 j 2 . . . j M i_1i_2...i_N,j_1j_2...j_M </math>i1i2...iN,j1j2...jM就表示了输入两个矩阵元素的坐标索引,右端 <math xmlns="http://www.w3.org/1998/Math/MathML"> i k 1 i k 2 . . j l 1 j l 1 i_{k_1}i_{k_2}..j_{l_1}j_{l_1} </math>ik1ik2..jl1jl1为输出矩阵元素的坐标索引,可以看到输出矩阵元素索引相较于输入端的索引可能会缺少几项,运算就是发生这几个维度上的乘累加操作。

其中同时出现在左端和右端的坐标索引为自由索引 ,只用于标记位置;而仅仅出现在右端的索引为求和索引 ,爱因斯坦求和约定的本质就是沿着求和索引的方向计算两个输入逐元素乘累加和的结果放到输出自由索引的位置上 ,更为细致的介绍参见:一文学会 Pytorch 中的 einsumeinsum:爱因斯坦求和约定 举例而言:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> i j , j k → i k ij,jk\rightarrow ik </math>ij,jk→ik

就表示沿着 <math xmlns="http://www.w3.org/1998/Math/MathML"> j j </math>j这个维度进行乘累加操作:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> O i k = ∑ j A i j B j k O_{ik}=\sum_{j}A_{ij}B_{jk} </math>Oik=j∑AijBjk

输出的第 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( i , k ) (i,k) </math>(i,k)个元素为 <math xmlns="http://www.w3.org/1998/Math/MathML"> A i ⋅ A_{i \cdot} </math>Ai⋅的行向量和 <math xmlns="http://www.w3.org/1998/Math/MathML"> B ⋅ k B_{\cdot k} </math>B⋅k列向量逐元素乘累加,实际上就是矩阵相乘。

复杂案例推导

正如第一节中所介绍的,torch.enisum()的核心计算过程就是沿着只在算式右边出现的轴对输入矩阵元素进行乘累加得到对应位置的输出元素。因此,想要弄清一个复杂的torch.eisum()表达式含义需要做的也只是将这个求和公式写出来再仔细分析。

案例. 四维张量乘三维张量

给出一个复杂案例:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> n c j t , n p j − > n c p t ncjt,npj->ncpt </math>ncjt,npj−>ncpt

则其输出元素可以表示为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> C n c p t = ∑ j A n c j t B n p j C_{ncpt}=\sum_j A_{ncjt}B_{npj} </math>Cncpt=j∑AncjtBnpj

首先我们可以注意到对于C的第一维 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n而言,它同时出现在A和B的首位,也就是对于这一维的每个元素,都是会逐元素的执行A和B剩余维度的计算再在当前维度上排布,用深度学习中的描述来说就是对BATCH中的每个元素都独立的执行后续子操作,子操作可以记为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> C c p t = ∑ j A c j t B p j C_{cpt}=\sum_j A_{cjt}B_{pj} </math>Ccpt=j∑AcjtBpj

紧接着,对当前算式的第一维 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c来说它只出现在 <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A中,每沿着 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c计算一个不同的元素都要和"相同"的B计算,也就出现了广播机制,B有了个隐藏的、元素重复的维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c,计算变为 <math xmlns="http://www.w3.org/1998/Math/MathML"> C c p t = ∑ j A c j t B c p j C_{cpt}=\sum_j A_{cjt}B_{cpj} </math>Ccpt=∑jAcjtBcpj,同第一步计算的原理,这里又可以化简成逐元素的子操作:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> C p t = ∑ j A j t B p j = ∑ j B p j A j t C_{pt}=\sum_jA_{jt}B_{pj}=\sum_jB_{pj}A_{jt} </math>Cpt=j∑AjtBpj=j∑BpjAjt

此时易看出 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( p , t ) (p,t) </math>(p,t)元素就是B的第 <math xmlns="http://www.w3.org/1998/Math/MathML"> p p </math>p行向量和A的第 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t列向量求内积。  从而我们可以得出结论,这一表达式的意思是,对于BATCH内的每个元素(A'三维,B'二维),对B在第一维度进行广播(A''三维,B''三维),最后沿着第二维和第三维计算矩阵相乘B'''A'''(A'''二维,B'''二维)。

而整个的推导过程可以总结为以下几要点

  1. 沿着维数较高输入的第一维开始,判断是否存在于B中,如果在的话就可认为是逐元素操作,暂时忽略该维度;
  2. 如果该维度指示不在B中,则进行广播操作,重新回到1,否则3.
  3. 判断当前最简表达式的意义。
相关推荐
开发者每周简报5 分钟前
微软的AI转型故事
人工智能·microsoft
古希腊掌管学习的神9 分钟前
[机器学习]sklearn入门指南(1)
人工智能·python·算法·机器学习·sklearn
普密斯科技37 分钟前
手机外观边框缺陷视觉检测智慧方案
人工智能·计算机视觉·智能手机·自动化·视觉检测·集成测试
四口鲸鱼爱吃盐1 小时前
Pytorch | 利用AI-FGTM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
lishanlu1361 小时前
Pytorch分布式训练
人工智能·ddp·pytorch并行训练
日出等日落1 小时前
从零开始使用MaxKB打造本地大语言模型智能问答系统与远程交互
人工智能·语言模型·自然语言处理
三木吧1 小时前
开发微信小程序的过程与心得
人工智能·微信小程序·小程序
whaosoft-1431 小时前
w~视觉~3D~合集5
人工智能
猫头虎1 小时前
新纪天工 开物焕彩:重大科技成就发布会参会感
人工智能·开源·aigc·开放原子·开源软件·gpu算力·agi
正在走向自律2 小时前
京东物流营销 Agent:智能驱动,物流新篇(13/30)
人工智能·ai agent·ai智能体·京东物流agent