einsum,基于爱因斯坦求和约定,主要用于指定张量的乘法操作。
einops,更高级、更直观的语法,专注于张量重塑和重新排列,容易理解。
两者使用了,模式字符串的形式,'字符串'的每个字付相当于代数式,代表每一个维度数。主要语法在于,维度之间,不能使用逗号间隔('b,c,h,w'),只能使用空格('b c h w')。变量之间,使用逗号分隔,变量内部可以不加空格。
下面有一个典型的例子结果展示用法,用于理解功能和操作原理:
python
import torch
def test_einsum():
#基于爱因斯坦求和约定,用于指定张量操作
a = torch.randn(3,4)
b = torch.randn(4, 5)
print("-------a")
print(a)
print("-------b")
print(b)
c = torch.einsum("ij, jk -> ik",a,b) # 等价于 A @ B,相乘
print("-------c")
print(c)
d = torch.einsum("mn, nj -> mj",a,b)# 矩阵乘法
print("-------d")
print(d)
# 批量矩阵乘法
batch_A = torch.randn(2, 3, 4)
batch_B = torch.randn(2, 4, 5)
batch_C = torch.einsum('bij,bjk->bik', batch_A, batch_B)#'字符串'相当于代数式,代表每一个维度
print("-------batch_C")
print(batch_C.shape)
# 向量点积
v1 = torch.randn(5)# 5*1
v2 = torch.randn(5)# 5*1
dot = torch.einsum('i,i->', v1, v2)#相当于先换成1*5的形状,再点乘
print("-------dot")
print(v1)
print(v2.shape)
print(dot)
# 外积
v1 = torch.randn(3)# 3*1
v2 = torch.randn(4)# 4*1
outer = torch.einsum('i,j->ij', v1, v2)# 相当于先换成[3*1]*[1*4]的形状
print("-------outer")
print(v1)
print(v2.shape)
print(outer)
# 逐元素乘法求和(张量缩并),逐元素乘法
A = torch.randn(3, 4, 5)
B = torch.randn(5, 3, 4)
C = torch.einsum('ijk,kij->', A, B)# 等价于 torch.sum(A * B),[4,5]*[3,4]
print("-------C")
print(C)
# 张量缩并,再解释一下,其实是Frobenius 内积:衡量两个矩阵在所有元素上的"相似度",类似于向量的点积。
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
C_einsum = torch.einsum('ij,ij->', A, B)#对应位置相乘,然后求和
print("einsum结果:", C_einsum.item()) # 输出: 70 (1*5 + 2*6 + 3*7 + 4*8)
# 转置
A = torch.randn(2, 3, 4)
B = torch.einsum('i j k->k j i', A) # 维度顺序反转
print("-------B")
print(B.shape)
# 也可以不加空格
B = torch.einsum('ijk->kji', A)
print("-------B")
print(B.shape)
def test_einops():
############
# einops
# 更高级、更直观的语法,专注于张量重塑和重新排列。
############
import einops
import torch
# 重塑和重新排列
a = torch.randn(3,4,5,6) #假设3个batch,4张,5*6高宽的图片
print(a.shape)
# 展平
flat = einops.rearrange(a, 'b c h w-> b(c h w)')#不能使用逗号间隔,只能使用空格 b,c,h,w
print(flat.shape)#torch.Size([3, 4*5*6])
# 空间展平
spatial_flat = einops.rearrange(a, 'b c h w->b c (h w)')#()代表维度合并,结果是维度的相乘
print(spatial_flat.shape)#torch.Size([3, 4, 5*6])
# 改变维度顺序
reordered = einops.rearrange(a, 'b c h w->b h w c')#b h w c
print(reordered.shape)#torch.Size([3, 5, 6, 4])
# 分割维度,相当于维度上数值的因式分解,相当于将5*6 重新分解为15*2;确保 h1*h2=5 且 w1*w2=6(均为整数)
reshaped = einops.rearrange(a, 'b c (h1 h2) (w1 w2)->b (h1 w1) (h2 w2) c', h1=5, w1=3)
# (3, 5*3=15, 1*2=2, 4) → (3, 15, 2, 4)
print(reshaped.shape)#torch.Size([3, 15, 2, 4])
# 合并维度
a = torch.randn(3, 4, 5, 6)
batch_merged = einops.rearrange(a, ' b c h w -> (b c) h w')
print(batch_merged.shape) #torch.Size([12, 5, 6])
# 重复模式
repeated = einops.repeat(a[0], 'h w c -> (tile h) w c', tile=3) # 沿高度重复3次
print(repeated.shape) # torch.Size([15, 6, 4])
if __name__ == "__main__":
test_einsum()
test_einops()