PyTorch深度学习总结
第四章 PyTorch中张量(Tensor)拼接和拆分操作
文章目录
前言
上文介绍了PyTorch中张量(Tensor)的切片操作,本文主要介绍张量的拆分和拼接操作。
一、张量拼接
| 函数 | 描述 | 
|---|---|
| torch.cat() | 将张量按照 指定维度关系进行拼接 | 
| torch.stack() | 将张量按照 指定维度关系进行拼接(用法同cat相同) | 
python# 引入库 import torch # 创建张量 A = torch.arange(9).reshape(1, 3, 3) print(A)输出结果为:
tensor(
\[\[0, 1, 2\], \[3, 4, 5\], \[6, 7, 8\]\]\]) *** ** * ** *** **1、按照维度1进行拼接:** ```python B0 = torch.cat((A, A), dim=0) print(B0) ``` 输出结果为: tensor(\[\[\[0, 1, 2\], \[3, 4, 5\], \[6, 7, 8\]\], \[\[0, 1, 2\], \[3, 4, 5\], \[6, 7, 8\]\]\]) *** ** * ** *** **1、按照维度2(`行`)进行拼接:** ```python B1 = torch.cat((A, A), dim=2) print(B1) ``` 输出结果为: tensor(\[\[\[0, 1, 2\], \[3, 4, 5\], \[6, 7, 8\], \[0, 1, 2\], \[3, 4, 5\], \[6, 7, 8\]\]\]) *** ** * ** *** **1、按照维度3(`列`)进行拼接:** ```python B2 = torch.cat((A, A), dim=2) print(B2) ``` 输出结果为: tensor(\[\[\[0, 1, 2, 0, 1, 2\], \[3, 4, 5, 3, 4, 5\], \[6, 7, 8, 6, 7, 8\]\]\])
二、张量拆分
| 函数 | 描述 | 
|---|---|
| torch.chunk() | 将张量分割为特定数量的块(当张量对应维度元素数量不足以拆分时会按照可以拆分数量进行拆分,且会出现不均等拆分情况) | 
| torch.split() | 将张量分割为特定数量的块,可以指定块的大小 | 
注意:
torch.chunk():当张量对应维度元素数量不足以拆分时,会按照可以拆分的最大数量进行拆分,且会出现不均等拆分情况,且最后一个块最小
下文使用B0进行示例
B0 = tensor([[[0, 1, 2],
         [3, 4, 5],
         [6, 7, 8]],
        [[0, 1, 2],
         [3, 4, 5],
         [6, 7, 8]]])1、
torch.chunk()按照维度1进行拆分:
pythonC1, C2 = torch.chunk(B0, 2, dim=1) # 维度1只有三组元素,所以会按照2:1的比例进行拆分 print(C1, C2)输出结果为:
tensor([[[0, 1, 2],
3, 4, 5\], \[6, 7, 8\]\]\]) tensor(\[\[\[0, 1, 2\], \[3, 4, 5\], \[6, 7, 8\]\]\]) *** ** * ** *** **1、`torch.chunk()`按照维度2进行拆分:** ```python D1, D2 = torch.chunk(B0, 2, dim=1) # 3表示指定拆分数,但由于不足以拆分,所以只会拆分两组 print(D1, D2) ``` 输出结果为: tensor(\[\[\[0, 1, 2\], \[3, 4, 5\]\], \[\[0, 1, 2\], \[3, 4, 5\]\]\]) tensor(\[\[\[6, 7, 8\]\], \[\[6, 7, 8\]\]\])