einops 库和 PyTorch 的 einsum 的语法

einsum,基于爱因斯坦求和约定,主要用于指定张量的乘法操作。

einops,更高级、更直观的语法,专注于张量重塑和重新排列,容易理解。

两者使用了,模式字符串的形式,'字符串'的每个字付相当于代数式,代表每一个维度数。主要语法在于,维度之间,不能使用逗号间隔('b,c,h,w'),只能使用空格('b c h w')。变量之间,使用逗号分隔,变量内部可以不加空格。

下面有一个典型的例子结果展示用法,用于理解功能和操作原理:

python 复制代码
import torch
def test_einsum():
    #基于爱因斯坦求和约定,用于指定张量操作
    a = torch.randn(3,4)
    b = torch.randn(4, 5)
    print("-------a")
    print(a)
    print("-------b")
    print(b)
    c = torch.einsum("ij, jk -> ik",a,b)  # 等价于 A @ B,相乘
    print("-------c")
    print(c)
    d = torch.einsum("mn, nj -> mj",a,b)# 矩阵乘法
    print("-------d")
    print(d)

    # 批量矩阵乘法
    batch_A = torch.randn(2, 3, 4)
    batch_B = torch.randn(2, 4, 5)
    batch_C = torch.einsum('bij,bjk->bik', batch_A, batch_B)#'字符串'相当于代数式,代表每一个维度
    print("-------batch_C")
    print(batch_C.shape)

    # 向量点积
    v1 = torch.randn(5)# 5*1
    v2 = torch.randn(5)# 5*1
    dot = torch.einsum('i,i->', v1, v2)#相当于先换成1*5的形状,再点乘
    print("-------dot")
    print(v1)
    print(v2.shape)
    print(dot)

    # 外积
    v1 = torch.randn(3)# 3*1
    v2 = torch.randn(4)# 4*1
    outer = torch.einsum('i,j->ij', v1, v2)# 相当于先换成[3*1]*[1*4]的形状
    print("-------outer")
    print(v1)
    print(v2.shape)
    print(outer)

    # 逐元素乘法求和(张量缩并),逐元素乘法
    A = torch.randn(3, 4, 5)
    B = torch.randn(5, 3, 4)
    C = torch.einsum('ijk,kij->', A, B)# 等价于 torch.sum(A * B),[4,5]*[3,4]
    print("-------C")
    print(C)
    # 张量缩并,再解释一下,其实是Frobenius 内积:衡量两个矩阵在所有元素上的"相似度",类似于向量的点积。
    A = torch.tensor([[1, 2], [3, 4]])
    B = torch.tensor([[5, 6], [7, 8]])
    C_einsum = torch.einsum('ij,ij->', A, B)#对应位置相乘,然后求和
    print("einsum结果:", C_einsum.item())      # 输出: 70 (1*5 + 2*6 + 3*7 + 4*8)



    # 转置
    A = torch.randn(2, 3, 4)
    B = torch.einsum('i j k->k j i', A)  # 维度顺序反转
    print("-------B")
    print(B.shape)
    # 也可以不加空格
    B = torch.einsum('ijk->kji', A)  
    print("-------B")
    print(B.shape)


def test_einops():
    ############
    # einops
    # 更高级、更直观的语法,专注于张量重塑和重新排列。
    ############
    import einops
    import torch
    # 重塑和重新排列
    a = torch.randn(3,4,5,6)        #假设3个batch,4张,5*6高宽的图片
    print(a.shape)
    # 展平
    flat = einops.rearrange(a, 'b c h w-> b(c h w)')#不能使用逗号间隔,只能使用空格 b,c,h,w
    print(flat.shape)#torch.Size([3, 4*5*6])
    # 空间展平
    spatial_flat = einops.rearrange(a, 'b c h w->b c (h w)')#()代表维度合并,结果是维度的相乘
    print(spatial_flat.shape)#torch.Size([3, 4, 5*6])
    # 改变维度顺序
    reordered = einops.rearrange(a, 'b c h w->b h w c')#b h w c
    print(reordered.shape)#torch.Size([3, 5, 6, 4])

    # 分割维度,相当于维度上数值的因式分解,相当于将5*6 重新分解为15*2;确保 h1*h2=5 且 w1*w2=6(均为整数)
    reshaped = einops.rearrange(a, 'b c (h1 h2) (w1 w2)->b (h1 w1) (h2 w2) c', h1=5, w1=3)
    # (3, 5*3=15, 1*2=2, 4) → (3, 15, 2, 4)
    print(reshaped.shape)#torch.Size([3, 15, 2, 4])

    # 合并维度
    a = torch.randn(3, 4, 5, 6)
    batch_merged = einops.rearrange(a, ' b  c h w -> (b c)  h w')
    print(batch_merged.shape) #torch.Size([12, 5, 6])
    # 重复模式
    repeated = einops.repeat(a[0], 'h w c -> (tile h) w c', tile=3)  # 沿高度重复3次
    print(repeated.shape)  # torch.Size([15, 6, 4])



if __name__ == "__main__":
    test_einsum()
    test_einops()
相关推荐
低调小一2 小时前
Google AI Agent 白皮书拆解(1):从《Introduction to Agents》看清 Agent 的工程底座
人工智能
feasibility.2 小时前
混元3D-dit-v2-mv-turbo生成3D模型初体验(ComfyUI)
人工智能·3d·aigc·三维建模·comfyui
极智-9962 小时前
GitHub 热榜项目-日榜精选(2026-02-02)| AI智能体、终端工具、视频生成等 | openclaw、99、Maestro等
人工智能·github·视频生成·终端工具·ai智能体·电子书管理·rust工具
悟纤2 小时前
AI 音乐创作中的音乐织体(Texture)完整指南 | Suno高级篇 | 第30篇
人工智能·suno·suno ai·suno api·ai music
编码者卢布2 小时前
【Azure Storage Account】Azure Table Storage 跨区批量迁移方案
后端·python·flask
可触的未来,发芽的智生2 小时前
狂想:为AGI代称造字ta,《第三类智慧存在,神的赐名》
javascript·人工智能·python·神经网络·程序人生
莱茶荼菜2 小时前
yolo26 阅读笔记
人工智能·笔记·深度学习·ai·yolo26
Dingdangcat863 小时前
【YOLOv8改进实战】使用Ghost模块优化P2结构提升涂胶缺陷检测精度_1
人工智能·yolo·目标跟踪
吴维炜3 小时前
「Python算法」计费引擎系统SKILL.md
python·算法·agent·skill.md·vb coding