5min看懂torch.einsum()计算方法-torch.einsum()手动推导详解

引言: torch.einsum()的分析和介绍已经有很多博客介绍过了, 但大多数的落脚点都是爱因斯坦求和约定,许多篇幅是用于介绍爱因斯坦求和约定到的各项法则,而实际案例分析方面只是草草给出一笔带过,涉及到的案例也较为简单。而实际我们要用到或者看到torch.einsum()的时候往往是在计算非常复杂的情况下。 因此本文将从实际复杂案例的角度对torch.einsum()的计算过程进行分析,一步一步的推导最终输出的每个元素和输入元素之间的关系。

爱因斯坦求和约定

首先,torch.einsum()的基础原理是爱因斯坦求和约定,此处为了行文的整体性将对其进行简要的介绍,如果只关注计算本身,可以跳到下一节。爱因斯坦求和约定是为了简化计算而诞生的一种"记法",就类似于我们用 <math xmlns="http://www.w3.org/1998/Math/MathML"> × \times </math>×来标记乘法一样,不同之处在于爱因斯坦求和约定可表示的运算更为复杂、灵活性也更高。爱因斯坦求和约定的典型写法为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> i 1 i 2 . . . i N , j 1 j 2 . . . j M → i k 1 i k 2 . . j l 1 j l 1 , k 1 . . . ∈ N , l 1 . . l ∈ M i_1i_2...i_N,j_1j_2...j_M\rightarrow i_{k_1}i_{k_2}..j_{l_1}j_{l_1},k_1...\in N,l_1..l\in M </math>i1i2...iN,j1j2...jM→ik1ik2..jl1jl1,k1...∈N,l1..l∈M

其中左端 <math xmlns="http://www.w3.org/1998/Math/MathML"> i 1 i 2 . . . i N , j 1 j 2 . . . j M i_1i_2...i_N,j_1j_2...j_M </math>i1i2...iN,j1j2...jM就表示了输入两个矩阵元素的坐标索引,右端 <math xmlns="http://www.w3.org/1998/Math/MathML"> i k 1 i k 2 . . j l 1 j l 1 i_{k_1}i_{k_2}..j_{l_1}j_{l_1} </math>ik1ik2..jl1jl1为输出矩阵元素的坐标索引,可以看到输出矩阵元素索引相较于输入端的索引可能会缺少几项,运算就是发生这几个维度上的乘累加操作。

其中同时出现在左端和右端的坐标索引为自由索引 ,只用于标记位置;而仅仅出现在右端的索引为求和索引 ,爱因斯坦求和约定的本质就是沿着求和索引的方向计算两个输入逐元素乘累加和的结果放到输出自由索引的位置上 ,更为细致的介绍参见:一文学会 Pytorch 中的 einsumeinsum:爱因斯坦求和约定 举例而言:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> i j , j k → i k ij,jk\rightarrow ik </math>ij,jk→ik

就表示沿着 <math xmlns="http://www.w3.org/1998/Math/MathML"> j j </math>j这个维度进行乘累加操作:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> O i k = ∑ j A i j B j k O_{ik}=\sum_{j}A_{ij}B_{jk} </math>Oik=j∑AijBjk

输出的第 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( i , k ) (i,k) </math>(i,k)个元素为 <math xmlns="http://www.w3.org/1998/Math/MathML"> A i ⋅ A_{i \cdot} </math>Ai⋅的行向量和 <math xmlns="http://www.w3.org/1998/Math/MathML"> B ⋅ k B_{\cdot k} </math>B⋅k列向量逐元素乘累加,实际上就是矩阵相乘。

复杂案例推导

正如第一节中所介绍的,torch.enisum()的核心计算过程就是沿着只在算式右边出现的轴对输入矩阵元素进行乘累加得到对应位置的输出元素。因此,想要弄清一个复杂的torch.eisum()表达式含义需要做的也只是将这个求和公式写出来再仔细分析。

案例. 四维张量乘三维张量

给出一个复杂案例:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> n c j t , n p j − > n c p t ncjt,npj->ncpt </math>ncjt,npj−>ncpt

则其输出元素可以表示为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> C n c p t = ∑ j A n c j t B n p j C_{ncpt}=\sum_j A_{ncjt}B_{npj} </math>Cncpt=j∑AncjtBnpj

首先我们可以注意到对于C的第一维 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n而言,它同时出现在A和B的首位,也就是对于这一维的每个元素,都是会逐元素的执行A和B剩余维度的计算再在当前维度上排布,用深度学习中的描述来说就是对BATCH中的每个元素都独立的执行后续子操作,子操作可以记为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> C c p t = ∑ j A c j t B p j C_{cpt}=\sum_j A_{cjt}B_{pj} </math>Ccpt=j∑AcjtBpj

紧接着,对当前算式的第一维 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c来说它只出现在 <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A中,每沿着 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c计算一个不同的元素都要和"相同"的B计算,也就出现了广播机制,B有了个隐藏的、元素重复的维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c,计算变为 <math xmlns="http://www.w3.org/1998/Math/MathML"> C c p t = ∑ j A c j t B c p j C_{cpt}=\sum_j A_{cjt}B_{cpj} </math>Ccpt=∑jAcjtBcpj,同第一步计算的原理,这里又可以化简成逐元素的子操作:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> C p t = ∑ j A j t B p j = ∑ j B p j A j t C_{pt}=\sum_jA_{jt}B_{pj}=\sum_jB_{pj}A_{jt} </math>Cpt=j∑AjtBpj=j∑BpjAjt

此时易看出 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( p , t ) (p,t) </math>(p,t)元素就是B的第 <math xmlns="http://www.w3.org/1998/Math/MathML"> p p </math>p行向量和A的第 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t列向量求内积。  从而我们可以得出结论,这一表达式的意思是,对于BATCH内的每个元素(A'三维,B'二维),对B在第一维度进行广播(A''三维,B''三维),最后沿着第二维和第三维计算矩阵相乘B'''A'''(A'''二维,B'''二维)。

而整个的推导过程可以总结为以下几要点

  1. 沿着维数较高输入的第一维开始,判断是否存在于B中,如果在的话就可认为是逐元素操作,暂时忽略该维度;
  2. 如果该维度指示不在B中,则进行广播操作,重新回到1,否则3.
  3. 判断当前最简表达式的意义。
相关推荐
春末的南方城市22 分钟前
FLUX的ID保持项目也来了! 字节开源PuLID-FLUX-v0.9.0,开启一致性风格写真新纪元!
人工智能·计算机视觉·stable diffusion·aigc·图像生成
zmjia11124 分钟前
AI大语言模型进阶应用及模型优化、本地化部署、从0-1搭建、智能体构建技术
人工智能·语言模型·自然语言处理
jndingxin38 分钟前
OpenCV视频I/O(14)创建和写入视频文件的类:VideoWriter介绍
人工智能·opencv·音视频
AI完全体1 小时前
【AI知识点】偏差-方差权衡(Bias-Variance Tradeoff)
人工智能·深度学习·神经网络·机器学习·过拟合·模型复杂度·偏差-方差
GZ_TOGOGO1 小时前
【2024最新】华为HCIE认证考试流程
大数据·人工智能·网络协议·网络安全·华为
sp_fyf_20241 小时前
计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-02
人工智能·神经网络·算法·计算机视觉·语言模型·自然语言处理·数据挖掘
新缸中之脑1 小时前
Ollama 运行视觉语言模型LLaVA
人工智能·语言模型·自然语言处理
卷心菜小温2 小时前
【BUG】P-tuningv2微调ChatGLM2-6B时所踩的坑
python·深度学习·语言模型·nlp·bug
胡耀超2 小时前
知识图谱入门——3:工具分类与对比(知识建模工具:Protégé、 知识抽取工具:DeepDive、知识存储工具:Neo4j)
人工智能·知识图谱
陈苏同学2 小时前
4. 将pycharm本地项目同步到(Linux)服务器上——深度学习·科研实践·从0到1
linux·服务器·ide·人工智能·python·深度学习·pycharm