一.前言
本章节来介绍一下张量拼接的操作,掌握torch.cat torch.stack使⽤,张量的拼接操作在神经⽹络搭建过程中是⾮常常⽤的⽅法,例如: 在后⾯将要学习到的残差⽹络、注意⼒机 制中都使⽤到了张量拼接。
二.torch.cat 函数的使用
torch.cat 函数可以将两个张量根据指定的维度拼接起来.
python
import torch
def test():
data1 = torch.randint(0, 10, [3, 5, 4])
data2 = torch.randint(0, 10, [3, 5, 4])
print(data1)
print(data2)
print('-' * 50)
# 1. 按0维度拼接
new_data = torch.cat([data1, data2], dim=0)
print(new_data.shape)
print('-' * 50)
# 2. 按1维度拼接
new_data = torch.cat([data1, data2], dim=1)
print(new_data.shape)
print('-' * 50)
# 3. 按2维度拼接
new_data = torch.cat([data1, data2], dim=2)
print(new_data.shape)
if __name__ == '__main__':
test()
结果展示:
tensor([[[6, 7, 2, 6],
4, 6, 4, 3\], \[5, 3, 4, 9\], \[8, 8, 6, 7\], \[0, 3, 3, 0\]\], \[\[6, 1, 2, 0\], \[5, 6, 7, 0\], \[6, 4, 8, 0\], \[2, 2, 8, 3\], \[0, 1, 6, 8\]\], \[\[3, 5, 0, 8\], \[6, 2, 1, 7\], \[8, 9, 9, 8\], \[3, 8, 8, 0\], \[5, 8, 4, 4\]\]\]) tensor(\[\[\[7, 2, 2, 1\], \[8, 0, 6, 6\], \[9, 0, 6, 5\], \[1, 3, 7, 7\], \[7, 0, 5, 1\]\], \[\[0, 7, 3, 1\], \[9, 2, 9, 0\], \[9, 6, 2, 1\], \[9, 3, 5, 0\], \[8, 8, 6, 2\]\], \[\[1, 8, 9, 9\], \[4, 3, 0, 9\], \[7, 3, 3, 8\], \[2, 4, 6, 9\], \[2, 1, 0, 5\]\]\]) -------------------------------------------------- torch.Size(\[6, 5, 4\]) -------------------------------------------------- torch.Size(\[3, 10, 4\]) -------------------------------------------------- torch.Size(\[3, 5, 8\])
torch.stack 函数可以将两个张量根据指定的维度叠加起来.
python
import torch
def test():
data1 = torch.randint(0, 10, [2, 3])
data2 = torch.randint(0, 10, [2, 3])
print(data1)
print(data2)
print("="*50)
new_data = torch.stack([data1, data2], dim=0)
print(new_data.shape)
print(new_data)
print("=" * 50)
new_data = torch.stack([data1, data2], dim=1)
print(new_data.shape)
print(new_data)
print("=" * 50)
new_data = torch.stack([data1, data2], dim=2)
print(new_data.shape)
print(new_data)
if __name__ == '__main__':
test()
结果展示:
tensor([[6, 9, 6],
3, 2, 7\]\]) tensor(\[\[3, 3, 4\], \[9, 1, 4\]\]) ================================================== torch.Size(\[2, 2, 3\]) tensor(\[\[\[6, 9, 6\], \[3, 2, 7\]\], \[\[3, 3, 4\], \[9, 1, 4\]\]\]) ================================================== torch.Size(\[2, 2, 3\]) tensor(\[\[\[6, 9, 6\], \[3, 3, 4\]\], \[\[3, 2, 7\], \[9, 1, 4\]\]\]) ================================================== torch.Size(\[2, 3, 2\]) tensor(\[\[\[6, 3\], \[9, 3\], \[6, 4\]\], \[\[3, 9\], \[2, 1\], \[7, 4\]\]\])
这里十分的不好理解,大家拷贝完代码自己执行理解一下。
四.总结
张量的拼接操作也是在后⾯我们经常使⽤⼀种操作。cat 函数可以将张量按照指定的维度拼接起来,stack 函数可以将张量按照指定的维度叠加起来。