einops 库和 PyTorch 的 einsum 的语法

einsum,基于爱因斯坦求和约定,主要用于指定张量的乘法操作。

einops,更高级、更直观的语法,专注于张量重塑和重新排列,容易理解。

两者使用了,模式字符串的形式,'字符串'的每个字付相当于代数式,代表每一个维度数。主要语法在于,维度之间,不能使用逗号间隔('b,c,h,w'),只能使用空格('b c h w')。变量之间,使用逗号分隔,变量内部可以不加空格。

下面有一个典型的例子结果展示用法,用于理解功能和操作原理:

python 复制代码
import torch
def test_einsum():
    #基于爱因斯坦求和约定,用于指定张量操作
    a = torch.randn(3,4)
    b = torch.randn(4, 5)
    print("-------a")
    print(a)
    print("-------b")
    print(b)
    c = torch.einsum("ij, jk -> ik",a,b)  # 等价于 A @ B,相乘
    print("-------c")
    print(c)
    d = torch.einsum("mn, nj -> mj",a,b)# 矩阵乘法
    print("-------d")
    print(d)

    # 批量矩阵乘法
    batch_A = torch.randn(2, 3, 4)
    batch_B = torch.randn(2, 4, 5)
    batch_C = torch.einsum('bij,bjk->bik', batch_A, batch_B)#'字符串'相当于代数式,代表每一个维度
    print("-------batch_C")
    print(batch_C.shape)

    # 向量点积
    v1 = torch.randn(5)# 5*1
    v2 = torch.randn(5)# 5*1
    dot = torch.einsum('i,i->', v1, v2)#相当于先换成1*5的形状,再点乘
    print("-------dot")
    print(v1)
    print(v2.shape)
    print(dot)

    # 外积
    v1 = torch.randn(3)# 3*1
    v2 = torch.randn(4)# 4*1
    outer = torch.einsum('i,j->ij', v1, v2)# 相当于先换成[3*1]*[1*4]的形状
    print("-------outer")
    print(v1)
    print(v2.shape)
    print(outer)

    # 逐元素乘法求和(张量缩并),逐元素乘法
    A = torch.randn(3, 4, 5)
    B = torch.randn(5, 3, 4)
    C = torch.einsum('ijk,kij->', A, B)# 等价于 torch.sum(A * B),[4,5]*[3,4]
    print("-------C")
    print(C)
    # 张量缩并,再解释一下,其实是Frobenius 内积:衡量两个矩阵在所有元素上的"相似度",类似于向量的点积。
    A = torch.tensor([[1, 2], [3, 4]])
    B = torch.tensor([[5, 6], [7, 8]])
    C_einsum = torch.einsum('ij,ij->', A, B)#对应位置相乘,然后求和
    print("einsum结果:", C_einsum.item())      # 输出: 70 (1*5 + 2*6 + 3*7 + 4*8)



    # 转置
    A = torch.randn(2, 3, 4)
    B = torch.einsum('i j k->k j i', A)  # 维度顺序反转
    print("-------B")
    print(B.shape)
    # 也可以不加空格
    B = torch.einsum('ijk->kji', A)  
    print("-------B")
    print(B.shape)


def test_einops():
    ############
    # einops
    # 更高级、更直观的语法,专注于张量重塑和重新排列。
    ############
    import einops
    import torch
    # 重塑和重新排列
    a = torch.randn(3,4,5,6)        #假设3个batch,4张,5*6高宽的图片
    print(a.shape)
    # 展平
    flat = einops.rearrange(a, 'b c h w-> b(c h w)')#不能使用逗号间隔,只能使用空格 b,c,h,w
    print(flat.shape)#torch.Size([3, 4*5*6])
    # 空间展平
    spatial_flat = einops.rearrange(a, 'b c h w->b c (h w)')#()代表维度合并,结果是维度的相乘
    print(spatial_flat.shape)#torch.Size([3, 4, 5*6])
    # 改变维度顺序
    reordered = einops.rearrange(a, 'b c h w->b h w c')#b h w c
    print(reordered.shape)#torch.Size([3, 5, 6, 4])

    # 分割维度,相当于维度上数值的因式分解,相当于将5*6 重新分解为15*2;确保 h1*h2=5 且 w1*w2=6(均为整数)
    reshaped = einops.rearrange(a, 'b c (h1 h2) (w1 w2)->b (h1 w1) (h2 w2) c', h1=5, w1=3)
    # (3, 5*3=15, 1*2=2, 4) → (3, 15, 2, 4)
    print(reshaped.shape)#torch.Size([3, 15, 2, 4])

    # 合并维度
    a = torch.randn(3, 4, 5, 6)
    batch_merged = einops.rearrange(a, ' b  c h w -> (b c)  h w')
    print(batch_merged.shape) #torch.Size([12, 5, 6])
    # 重复模式
    repeated = einops.repeat(a[0], 'h w c -> (tile h) w c', tile=3)  # 沿高度重复3次
    print(repeated.shape)  # torch.Size([15, 6, 4])



if __name__ == "__main__":
    test_einsum()
    test_einops()
相关推荐
小陈phd1 小时前
多模态大模型学习笔记(七)——多模态数据的表征与对齐
人工智能·算法·机器学习
摆烂小白敲代码1 小时前
腾讯云智能结构化OCR在物流行业的应用
大数据·人工智能·经验分享·ocr·腾讯云
CoderJia程序员甲1 小时前
GitHub 热榜项目 - 日榜(2026-02-24)
人工智能·ai·大模型·github·ai教程
nimadan121 小时前
**AI漫剧软件2025推荐,解锁高性价比创意制作新体验**
人工智能·python
前网易架构师-高司机1 小时前
带标注的安全带和车牌识别数据集,识别率在88.8%,可识别挡风玻璃,是否系安全带,车牌区域,支持yolo,coco json,pascal voc xml格式
人工智能·数据集·交通违法·违法拍摄·安全带
Bal炎魔2 小时前
AI 学习专题一,AI 实现的原理
人工智能·学习
kjmkq2 小时前
办公智能体落地:九科信息让AI深度融入企业日常运营
人工智能
NAGNIP2 小时前
一文搞懂神经元模型是什么!
人工智能·算法
Ro Jace2 小时前
分岔机制学习
人工智能·学习·机器学习