一.前言
本章节来介绍一下张量拼接的操作,掌握torch.cat torch.stack使⽤,张量的拼接操作在神经⽹络搭建过程中是⾮常常⽤的⽅法,例如: 在后⾯将要学习到的残差⽹络、注意⼒机 制中都使⽤到了张量拼接。
二.torch.cat 函数的使用
torch.cat 函数可以将两个张量根据指定的维度拼接起来.
python
import torch
def test():
data1 = torch.randint(0, 10, [3, 5, 4])
data2 = torch.randint(0, 10, [3, 5, 4])
print(data1)
print(data2)
print('-' * 50)
# 1. 按0维度拼接
new_data = torch.cat([data1, data2], dim=0)
print(new_data.shape)
print('-' * 50)
# 2. 按1维度拼接
new_data = torch.cat([data1, data2], dim=1)
print(new_data.shape)
print('-' * 50)
# 3. 按2维度拼接
new_data = torch.cat([data1, data2], dim=2)
print(new_data.shape)
if __name__ == '__main__':
test()
结果展示:
tensor(\[\[6, 7, 2, 6,
4, 6, 4, 3,
5, 3, 4, 9,
8, 8, 6, 7,
0, 3, 3, 0],
\[6, 1, 2, 0,
5, 6, 7, 0,
6, 4, 8, 0,
2, 2, 8, 3,
0, 1, 6, 8],
\[3, 5, 0, 8,
6, 2, 1, 7,
8, 9, 9, 8,
3, 8, 8, 0,
5, 8, 4, 4]])
tensor(\[\[7, 2, 2, 1,
8, 0, 6, 6,
9, 0, 6, 5,
1, 3, 7, 7,
7, 0, 5, 1],
\[0, 7, 3, 1,
9, 2, 9, 0,
9, 6, 2, 1,
9, 3, 5, 0,
8, 8, 6, 2],
\[1, 8, 9, 9,
4, 3, 0, 9,
7, 3, 3, 8,
2, 4, 6, 9,
2, 1, 0, 5]])
torch.Size(6, 5, 4)
torch.Size(3, 10, 4)
torch.Size(3, 5, 8)
三.torch.stack 函数的使用
torch.stack 函数可以将两个张量根据指定的维度叠加起来.
python
import torch
def test():
data1 = torch.randint(0, 10, [2, 3])
data2 = torch.randint(0, 10, [2, 3])
print(data1)
print(data2)
print("="*50)
new_data = torch.stack([data1, data2], dim=0)
print(new_data.shape)
print(new_data)
print("=" * 50)
new_data = torch.stack([data1, data2], dim=1)
print(new_data.shape)
print(new_data)
print("=" * 50)
new_data = torch.stack([data1, data2], dim=2)
print(new_data.shape)
print(new_data)
if __name__ == '__main__':
test()
结果展示:
tensor(\[6, 9, 6,
3, 2, 7])
tensor(\[3, 3, 4,
9, 1, 4])
==================================================
torch.Size(2, 2, 3)
tensor(\[\[6, 9, 6,
3, 2, 7],
\[3, 3, 4,
9, 1, 4]])
==================================================
torch.Size(2, 2, 3)
tensor(\[\[6, 9, 6,
3, 3, 4],
\[3, 2, 7,
9, 1, 4]])
==================================================
torch.Size(2, 3, 2)
tensor(\[\[6, 3,
9, 3,
6, 4],
\[3, 9,
2, 1,
7, 4]])
这里十分的不好理解,大家拷贝完代码自己执行理解一下。
四.总结
张量的拼接操作也是在后⾯我们经常使⽤⼀种操作。cat 函数可以将张量按照指定的维度拼接起来,stack 函数可以将张量按照指定的维度叠加起来。