张量拼接操作

一.前言

本章节来介绍一下张量拼接的操作,掌握torch.cat torch.stack使⽤,张量的拼接操作在神经⽹络搭建过程中是⾮常常⽤的⽅法,例如: 在后⾯将要学习到的残差⽹络、注意⼒机 制中都使⽤到了张量拼接。

二.torch.cat 函数的使用

torch.cat 函数可以将两个张量根据指定的维度拼接起来.

python 复制代码
import torch


def test():
    data1 = torch.randint(0, 10, [3, 5, 4])
    data2 = torch.randint(0, 10, [3, 5, 4])
    print(data1)
    print(data2)
    print('-' * 50)
    # 1. 按0维度拼接
    new_data = torch.cat([data1, data2], dim=0)
    print(new_data.shape)
    print('-' * 50)
    # 2. 按1维度拼接
    new_data = torch.cat([data1, data2], dim=1)
    print(new_data.shape)
    print('-' * 50)
    # 3. 按2维度拼接
    new_data = torch.cat([data1, data2], dim=2)
    print(new_data.shape)

if __name__ == '__main__':
    test()

结果展示:

tensor(\[\[6, 7, 2, 6,

4, 6, 4, 3,

5, 3, 4, 9,

8, 8, 6, 7,

0, 3, 3, 0],

\[6, 1, 2, 0,

5, 6, 7, 0,

6, 4, 8, 0,

2, 2, 8, 3,

0, 1, 6, 8],

\[3, 5, 0, 8,

6, 2, 1, 7,

8, 9, 9, 8,

3, 8, 8, 0,

5, 8, 4, 4]])

tensor(\[\[7, 2, 2, 1,

8, 0, 6, 6,

9, 0, 6, 5,

1, 3, 7, 7,

7, 0, 5, 1],

\[0, 7, 3, 1,

9, 2, 9, 0,

9, 6, 2, 1,

9, 3, 5, 0,

8, 8, 6, 2],

\[1, 8, 9, 9,

4, 3, 0, 9,

7, 3, 3, 8,

2, 4, 6, 9,

2, 1, 0, 5]])


torch.Size(6, 5, 4)


torch.Size(3, 10, 4)


torch.Size(3, 5, 8)

三.torch.stack 函数的使用

torch.stack 函数可以将两个张量根据指定的维度叠加起来.

python 复制代码
import torch

def test():
    data1 = torch.randint(0, 10, [2, 3])
    data2 = torch.randint(0, 10, [2, 3])
    print(data1)
    print(data2)
    print("="*50)
    new_data = torch.stack([data1, data2], dim=0)
    print(new_data.shape)
    print(new_data)
    print("=" * 50)
    new_data = torch.stack([data1, data2], dim=1)
    print(new_data.shape)
    print(new_data)
    print("=" * 50)
    new_data = torch.stack([data1, data2], dim=2)
    print(new_data.shape)
    print(new_data)

if __name__ == '__main__':
    test()

结果展示:

tensor(\[6, 9, 6,

3, 2, 7])

tensor(\[3, 3, 4,

9, 1, 4])

==================================================

torch.Size(2, 2, 3)

tensor(\[\[6, 9, 6,

3, 2, 7],

\[3, 3, 4,

9, 1, 4]])

==================================================

torch.Size(2, 2, 3)

tensor(\[\[6, 9, 6,

3, 3, 4],

\[3, 2, 7,

9, 1, 4]])

==================================================

torch.Size(2, 3, 2)

tensor(\[\[6, 3,

9, 3,

6, 4],

\[3, 9,

2, 1,

7, 4]])

这里十分的不好理解,大家拷贝完代码自己执行理解一下。

四.总结

张量的拼接操作也是在后⾯我们经常使⽤⼀种操作。cat 函数可以将张量按照指定的维度拼接起来,stack 函数可以将张量按照指定的维度叠加起来。