5min看懂torch.einsum()计算方法-torch.einsum()手动推导详解

引言: torch.einsum()的分析和介绍已经有很多博客介绍过了, 但大多数的落脚点都是爱因斯坦求和约定,许多篇幅是用于介绍爱因斯坦求和约定到的各项法则,而实际案例分析方面只是草草给出一笔带过,涉及到的案例也较为简单。而实际我们要用到或者看到torch.einsum()的时候往往是在计算非常复杂的情况下。 因此本文将从实际复杂案例的角度对torch.einsum()的计算过程进行分析,一步一步的推导最终输出的每个元素和输入元素之间的关系。

爱因斯坦求和约定

首先,torch.einsum()的基础原理是爱因斯坦求和约定,此处为了行文的整体性将对其进行简要的介绍,如果只关注计算本身,可以跳到下一节。爱因斯坦求和约定是为了简化计算而诞生的一种"记法",就类似于我们用 <math xmlns="http://www.w3.org/1998/Math/MathML"> × \times </math>×来标记乘法一样,不同之处在于爱因斯坦求和约定可表示的运算更为复杂、灵活性也更高。爱因斯坦求和约定的典型写法为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> i 1 i 2 . . . i N , j 1 j 2 . . . j M → i k 1 i k 2 . . j l 1 j l 1 , k 1 . . . ∈ N , l 1 . . l ∈ M i_1i_2...i_N,j_1j_2...j_M\rightarrow i_{k_1}i_{k_2}..j_{l_1}j_{l_1},k_1...\in N,l_1..l\in M </math>i1i2...iN,j1j2...jM→ik1ik2..jl1jl1,k1...∈N,l1..l∈M

其中左端 <math xmlns="http://www.w3.org/1998/Math/MathML"> i 1 i 2 . . . i N , j 1 j 2 . . . j M i_1i_2...i_N,j_1j_2...j_M </math>i1i2...iN,j1j2...jM就表示了输入两个矩阵元素的坐标索引,右端 <math xmlns="http://www.w3.org/1998/Math/MathML"> i k 1 i k 2 . . j l 1 j l 1 i_{k_1}i_{k_2}..j_{l_1}j_{l_1} </math>ik1ik2..jl1jl1为输出矩阵元素的坐标索引,可以看到输出矩阵元素索引相较于输入端的索引可能会缺少几项,运算就是发生这几个维度上的乘累加操作。

其中同时出现在左端和右端的坐标索引为自由索引 ,只用于标记位置;而仅仅出现在右端的索引为求和索引 ,爱因斯坦求和约定的本质就是沿着求和索引的方向计算两个输入逐元素乘累加和的结果放到输出自由索引的位置上 ,更为细致的介绍参见:一文学会 Pytorch 中的 einsumeinsum:爱因斯坦求和约定 举例而言:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> i j , j k → i k ij,jk\rightarrow ik </math>ij,jk→ik

就表示沿着 <math xmlns="http://www.w3.org/1998/Math/MathML"> j j </math>j这个维度进行乘累加操作:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> O i k = ∑ j A i j B j k O_{ik}=\sum_{j}A_{ij}B_{jk} </math>Oik=j∑AijBjk

输出的第 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( i , k ) (i,k) </math>(i,k)个元素为 <math xmlns="http://www.w3.org/1998/Math/MathML"> A i ⋅ A_{i \cdot} </math>Ai⋅的行向量和 <math xmlns="http://www.w3.org/1998/Math/MathML"> B ⋅ k B_{\cdot k} </math>B⋅k列向量逐元素乘累加,实际上就是矩阵相乘。

复杂案例推导

正如第一节中所介绍的,torch.enisum()的核心计算过程就是沿着只在算式右边出现的轴对输入矩阵元素进行乘累加得到对应位置的输出元素。因此,想要弄清一个复杂的torch.eisum()表达式含义需要做的也只是将这个求和公式写出来再仔细分析。

案例. 四维张量乘三维张量

给出一个复杂案例:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> n c j t , n p j − > n c p t ncjt,npj->ncpt </math>ncjt,npj−>ncpt

则其输出元素可以表示为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> C n c p t = ∑ j A n c j t B n p j C_{ncpt}=\sum_j A_{ncjt}B_{npj} </math>Cncpt=j∑AncjtBnpj

首先我们可以注意到对于C的第一维 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n而言,它同时出现在A和B的首位,也就是对于这一维的每个元素,都是会逐元素的执行A和B剩余维度的计算再在当前维度上排布,用深度学习中的描述来说就是对BATCH中的每个元素都独立的执行后续子操作,子操作可以记为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> C c p t = ∑ j A c j t B p j C_{cpt}=\sum_j A_{cjt}B_{pj} </math>Ccpt=j∑AcjtBpj

紧接着,对当前算式的第一维 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c来说它只出现在 <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A中,每沿着 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c计算一个不同的元素都要和"相同"的B计算,也就出现了广播机制,B有了个隐藏的、元素重复的维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c,计算变为 <math xmlns="http://www.w3.org/1998/Math/MathML"> C c p t = ∑ j A c j t B c p j C_{cpt}=\sum_j A_{cjt}B_{cpj} </math>Ccpt=∑jAcjtBcpj,同第一步计算的原理,这里又可以化简成逐元素的子操作:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> C p t = ∑ j A j t B p j = ∑ j B p j A j t C_{pt}=\sum_jA_{jt}B_{pj}=\sum_jB_{pj}A_{jt} </math>Cpt=j∑AjtBpj=j∑BpjAjt

此时易看出 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( p , t ) (p,t) </math>(p,t)元素就是B的第 <math xmlns="http://www.w3.org/1998/Math/MathML"> p p </math>p行向量和A的第 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t列向量求内积。  从而我们可以得出结论,这一表达式的意思是,对于BATCH内的每个元素(A'三维,B'二维),对B在第一维度进行广播(A''三维,B''三维),最后沿着第二维和第三维计算矩阵相乘B'''A'''(A'''二维,B'''二维)。

而整个的推导过程可以总结为以下几要点

  1. 沿着维数较高输入的第一维开始,判断是否存在于B中,如果在的话就可认为是逐元素操作,暂时忽略该维度;
  2. 如果该维度指示不在B中,则进行广播操作,重新回到1,否则3.
  3. 判断当前最简表达式的意义。
相关推荐
墨染天姬2 分钟前
【AI】Hermes的GEPA算法
人工智能·算法
小超同学你好4 分钟前
OpenClaw 深度解析系列 · 第8篇:Learning & Adaptation(学习与自适应)
人工智能·语言模型·chatgpt
紫微AI13 分钟前
前端文本测量成了卡死一切创新的最后瓶颈,pretext实现突破了
前端·人工智能·typescript
码途漫谈22 分钟前
Easy-Vibe开发篇阅读笔记(四)——前端开发之结合 Agent Skills 美化界面
人工智能·笔记·ai·开源·ai编程
易连EDI—EasyLink28 分钟前
易连EDI–EasyLink实现OCR智能数据采集
网络·人工智能·安全·汽车·ocr·edi
冬奇Lab39 分钟前
RAG 系列(二):用 LangChain 搭建你的第一个 RAG Pipeline
人工智能·langchain·llm
学习论之费曼学习法1 小时前
多模态大模型实战:用 GPT-4o API 打造 AI 助手,能看、能听、能说!
人工智能
昨夜见军贴06161 小时前
IACheck与AI报告审核,开启供应商资质核验报告审核新篇章
人工智能
m0_726365831 小时前
Ai漫剧系统 几分钟,让AI 把一篇小说变成了一部漫剧成片:从剧本到视频的全流程系统实现
人工智能·语言模型·ai作画·音视频
AIwenIPgeolocation1 小时前
出海应用合规与风控平衡术:可信ID的全球安全实践
人工智能·安全