张量拼接操作

一.前言

本章节来介绍一下张量拼接的操作,掌握torch.cat torch.stack使⽤,张量的拼接操作在神经⽹络搭建过程中是⾮常常⽤的⽅法,例如: 在后⾯将要学习到的残差⽹络、注意⼒机 制中都使⽤到了张量拼接。

二.torch.cat 函数的使用

torch.cat 函数可以将两个张量根据指定的维度拼接起来.

python 复制代码
import torch


def test():
    data1 = torch.randint(0, 10, [3, 5, 4])
    data2 = torch.randint(0, 10, [3, 5, 4])
    print(data1)
    print(data2)
    print('-' * 50)
    # 1. 按0维度拼接
    new_data = torch.cat([data1, data2], dim=0)
    print(new_data.shape)
    print('-' * 50)
    # 2. 按1维度拼接
    new_data = torch.cat([data1, data2], dim=1)
    print(new_data.shape)
    print('-' * 50)
    # 3. 按2维度拼接
    new_data = torch.cat([data1, data2], dim=2)
    print(new_data.shape)

if __name__ == '__main__':
    test()

结果展示:

tensor([[[6, 7, 2, 6],

4, 6, 4, 3\], \[5, 3, 4, 9\], \[8, 8, 6, 7\], \[0, 3, 3, 0\]\], \[\[6, 1, 2, 0\], \[5, 6, 7, 0\], \[6, 4, 8, 0\], \[2, 2, 8, 3\], \[0, 1, 6, 8\]\], \[\[3, 5, 0, 8\], \[6, 2, 1, 7\], \[8, 9, 9, 8\], \[3, 8, 8, 0\], \[5, 8, 4, 4\]\]\]) tensor(\[\[\[7, 2, 2, 1\], \[8, 0, 6, 6\], \[9, 0, 6, 5\], \[1, 3, 7, 7\], \[7, 0, 5, 1\]\], \[\[0, 7, 3, 1\], \[9, 2, 9, 0\], \[9, 6, 2, 1\], \[9, 3, 5, 0\], \[8, 8, 6, 2\]\], \[\[1, 8, 9, 9\], \[4, 3, 0, 9\], \[7, 3, 3, 8\], \[2, 4, 6, 9\], \[2, 1, 0, 5\]\]\]) -------------------------------------------------- torch.Size(\[6, 5, 4\]) -------------------------------------------------- torch.Size(\[3, 10, 4\]) -------------------------------------------------- torch.Size(\[3, 5, 8\])

torch.stack 函数可以将两个张量根据指定的维度叠加起来.

python 复制代码
import torch

def test():
    data1 = torch.randint(0, 10, [2, 3])
    data2 = torch.randint(0, 10, [2, 3])
    print(data1)
    print(data2)
    print("="*50)
    new_data = torch.stack([data1, data2], dim=0)
    print(new_data.shape)
    print(new_data)
    print("=" * 50)
    new_data = torch.stack([data1, data2], dim=1)
    print(new_data.shape)
    print(new_data)
    print("=" * 50)
    new_data = torch.stack([data1, data2], dim=2)
    print(new_data.shape)
    print(new_data)

if __name__ == '__main__':
    test()

结果展示:

tensor([[6, 9, 6],

3, 2, 7\]\]) tensor(\[\[3, 3, 4\], \[9, 1, 4\]\]) ================================================== torch.Size(\[2, 2, 3\]) tensor(\[\[\[6, 9, 6\], \[3, 2, 7\]\], \[\[3, 3, 4\], \[9, 1, 4\]\]\]) ================================================== torch.Size(\[2, 2, 3\]) tensor(\[\[\[6, 9, 6\], \[3, 3, 4\]\], \[\[3, 2, 7\], \[9, 1, 4\]\]\]) ================================================== torch.Size(\[2, 3, 2\]) tensor(\[\[\[6, 3\], \[9, 3\], \[6, 4\]\], \[\[3, 9\], \[2, 1\], \[7, 4\]\]\])

这里十分的不好理解,大家拷贝完代码自己执行理解一下。

四.总结

张量的拼接操作也是在后⾯我们经常使⽤⼀种操作。cat 函数可以将张量按照指定的维度拼接起来,stack 函数可以将张量按照指定的维度叠加起来。