pytorch中torch.stack()用法虽简单,但不好理解

函数功能

沿一个新维度对输入一系列张量进行连接,序列中所有张量应为相同形状,stack 函数返回的结果会新增一个维度。也即是把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量...以此类推,也就是在增加新的维度上面进行堆叠。

参数列表

tensors :为一系列输入张量,类型为turple和List

dim :新增维度的(下标)位置,当dim = -1时默认最后一个维度;范围必须介于 0 到输入张量的维数之间,默认是dim=0,在第0维进行连接

返回值:输出新增维度后的张量

情况一:输入数据为1维数据

dim = 0 : 在第0维进行连接,相当于在行上进行组合(输入张量为一维,输出张量为两维)

python 复制代码
import torch

a = torch.tensor([1, 2, 3])

b = torch.tensor([11, 22, 33])

#在第0维进行连接,相当于在行上进行组合,取a的一行,b的一行,构成一个新的tensor(输入张量为一维,输出张量为两维)

c = torch.stack([a, b],dim=0)          

print(a)

print(b)

print(c.size())

print(c)

输出:
tensor([1, 2, 3])
tensor([11, 22, 33])
torch.Size([2, 3])
tensor([[ 1,  2,  3],
        [11, 22, 33]])

dim = 1 :在第1维进行连接,相当于在对应行上面对列元素进行组合(输入张量为一维,输出张量为两维)

python 复制代码
import torch

a = torch.tensor([1, 2, 3])

b = torch.tensor([11, 22, 33])

print(a)

print(b)

#在第1维进行连接,相当于在对应行上面对列元素进行组合,取a的一列,b的一列,构成新的tensor的一行(输入张量为一维,输出张量为两维)

c = torch.stack([a, b],dim=1)

print(c.size())

print(c)

输出:
tensor([1, 2, 3])
tensor([11, 22, 33])
torch.Size([3, 2])
tensor([[ 1, 11],
        [ 2, 22],
        [ 3, 33]])

情况二:输入数据为2维数据

dim=0:表示在第0维进行连接,相当于在通道维度上进行组合(输入张量为两维,输出张量为三维),注意:此处输入张量维度为二维,因此dim最大只能为2。

python 复制代码
import torch

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])

print(a)

print(b)
#在第0维进行连接,相当于在通道维度上进行组合
#即取a的所有数据,作为新tensor的一个分量
#取b的所有数据,作为新tensor的另一个分量
#(输入张量为两维,输出张量为三维)

c = torch.stack([a, b],dim=0)

print(c.size())

print(c)

输出:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
torch.Size([2, 3, 3])
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[11, 22, 33],
         [44, 55, 66],
         [77, 88, 99]]])

dim=1:表示在第1维进行连接,相当于对相应通道中每个行进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。

python 复制代码
import torch

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])

print(a)

print(b)

#在第1维(行)进行连接,相当于对相应通道中每个行进行组合
#取a的一行,b的一行,作为新tensor的第1行和第2行
#原来a:3*3,b:3*3,新tensor:3*2*3

c = torch.stack([a, b], 1)

print(c.size())

print(c)

输出:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
torch.Size([3, 2, 3])
tensor([[[ 1,  2,  3],
         [11, 22, 33]],

        [[ 4,  5,  6],
         [44, 55, 66]],

        [[ 7,  8,  9],
         [77, 88, 99]]])

dim=2:表示在第2维进行连接,相当于对相应行中每个列元素进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。

python 复制代码
import torch

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])

print(a)

print(b)

#在第2维进行连接,相当于对相应行中每个列元素进行组合
#针对每行,取a、b的第一列数据,构成tensor的第一行
#针对每行,取a、b的第二列数据,构成tensor的第二行
#,针对每行取a、b的第三列数据,构成tensor的第三行
#原来a:3*3,b:3*3,新tensor:3*3*2
c = torch.stack([a, b], 2)

print(c.size())

print(c)

输出:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
torch.Size([3, 3, 2])
tensor([[[ 1, 11],
         [ 2, 22],
         [ 3, 33]],

        [[ 4, 44],
         [ 5, 55],
         [ 6, 66]],

        [[ 7, 77],
         [ 8, 88],
         [ 9, 99]]])

情况三:输入数据为3维数据

dim=0:表示在第0维进行连接,相当于在通道维进行拼接。注意:此处输入张量维度为三维,因此dim最大只能为3。

python 复制代码
import torch

a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])

b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])

print(a)

print(b)

#表示在第0维进行连接,取整个a作为新tensor的一个分量,取整个b作为新tensor的一个分量
c = torch.stack([a, b], 0)

print(c)

输出:
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]])
torch.Size([2, 3, 3])
tensor([[[ 11,  22,  33],
         [ 44,  55,  66],
         [ 77,  88,  99]],

        [[110, 220, 330],
         [440, 550, 660],
         [770, 880, 990]]])
torch.Size([2, 3, 3])
torch.Size([2, 2, 3, 3])
tensor([[[[  1,   2,   3],
          [  4,   5,   6],
          [  7,   8,   9]],

         [[ 10,  20,  30],
          [ 40,  50,  60],
          [ 70,  80,  90]]],


        [[[ 11,  22,  33],
          [ 44,  55,  66],
          [ 77,  88,  99]],

         [[110, 220, 330],
          [440, 550, 660],
          [770, 880, 990]]]])

dim=1:表示在第1维进行连接,取各自的第1维度数据,进行拼接。注意:此处输入张量维度为三维,因此dim最大只能为3。

python 复制代码
import torch

a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])

b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])

print(a)
print(a.size())

print(b)
print(b.size())

#表示在第1维进行连接,取a的第一维数据[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
#取b的第一维数据[[11, 22, 33], [44, 55, 66], [77, 88, 99]]作为新tensor的一个分量

#取a的第一维数据[[10, 20, 30], [40, 50, 60], [70, 80, 90]]
#取b的第一维数据[[110, 220, 330], [440, 550, 660], [770, 880, 990]]作为新tensor的另一个分量
c = torch.stack([a, b], 1)

print(c.size())

print(c)

输出:
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]])
torch.Size([2, 3, 3])
tensor([[[ 11,  22,  33],
         [ 44,  55,  66],
         [ 77,  88,  99]],

        [[110, 220, 330],
         [440, 550, 660],
         [770, 880, 990]]])
torch.Size([2, 3, 3])
torch.Size([2, 2, 3, 3])
tensor([[[[  1,   2,   3],
          [  4,   5,   6],
          [  7,   8,   9]],

         [[ 11,  22,  33],
          [ 44,  55,  66],
          [ 77,  88,  99]]],


        [[[ 10,  20,  30],
          [ 40,  50,  60],
          [ 70,  80,  90]],

         [[110, 220, 330],
          [440, 550, 660],
          [770, 880, 990]]]])

dim=2:表示在第2维进行连接,取各自的第2维度数据,进行拼接。注意:此处输入张量维度为三维,因此dim最大只能为3。

python 复制代码
import torch

a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])

b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])

print(a)
print(a.size())

print(b)
print(b.size())

#表示在第1维进行连接,取a的第2维数据[1, 2, 3]
#取b的第2维数据[11, 22, 33]作为新tensor的一个分量

#取a的第2维数据[4, 5, 6]
#取b的第2维数据[44, 55, 66]作为新tensor的一个分量

#取a的第2维数据[4, 5, 6]
#取b的第2维数据[44, 55, 66]作为新tensor的一个分量

#取a的第2维数据[7, 8, 9]
#取b的第2维数据[77, 88, 99]作为新tensor的一个分量

#取a的第2维数据[10, 20, 30]
#取b的第2维数据[110, 220, 330]作为新tensor的一个分量

#取a的第2维数据[40, 50, 60]
#取b的第2维数据[440, 550, 660]作为新tensor的一个分量

#取a的第2维数据[70, 80, 90]
#取b的第2维数据[770, 880, 990]作为新tensor的一个分量
c = torch.stack([a, b], 2)
print(c.size())

print(c)

输出:
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]])
torch.Size([2, 3, 3])
tensor([[[ 11,  22,  33],
         [ 44,  55,  66],
         [ 77,  88,  99]],

        [[110, 220, 330],
         [440, 550, 660],
         [770, 880, 990]]])
torch.Size([2, 3, 3])
torch.Size([2, 3, 2, 3])
tensor([[[[  1,   2,   3],
          [ 11,  22,  33]],

         [[  4,   5,   6],
          [ 44,  55,  66]],

         [[  7,   8,   9],
          [ 77,  88,  99]]],


        [[[ 10,  20,  30],
          [110, 220, 330]],

         [[ 40,  50,  60],
          [440, 550, 660]],

         [[ 70,  80,  90],
          [770, 880, 990]]]])

dim=3:表示在第3维进行连接,取各自的第3维度数据,进行拼接。注意:此处输入张量维度为三维,因此dim最大只能为3。

python 复制代码
import torch

a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])

b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])

print(a)
print(a.size())

print(b)
print(b.size())

#针对第二维数据,在每个第二维度相同的情况下,取各自的列数据,构成新tensor的一行
c = torch.stack([a, b], 3)
print(c.size())

print(c)

输出:
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]])
torch.Size([2, 3, 3])
tensor([[[ 11,  22,  33],
         [ 44,  55,  66],
         [ 77,  88,  99]],

        [[110, 220, 330],
         [440, 550, 660],
         [770, 880, 990]]])
torch.Size([2, 3, 3])
torch.Size([2, 3, 3, 2])
tensor([[[[  1,  11],
          [  2,  22],
          [  3,  33]],

         [[  4,  44],
          [  5,  55],
          [  6,  66]],

         [[  7,  77],
          [  8,  88],
          [  9,  99]]],


        [[[ 10, 110],
          [ 20, 220],
          [ 30, 330]],

         [[ 40, 440],
          [ 50, 550],
          [ 60, 660]],

         [[ 70, 770],
          [ 80, 880],
          [ 90, 990]]]])

总结:m个序列数据,在某个维度k进行拼接,该维度大小为n,则拼接后形成了*n*m*大小,具体拼接过程是取m个序列数据,k-1维(设k-1维大小为x,从x=1开始取)相同情况下的第1个数据,构成新tensor的一个行;第二个数据...,第三个数据...构成tensor的新行;然后从x=2开始执行同样的操作

相关推荐
沈浩(种子思维作者)3 分钟前
真的能精准医疗吗?癌症能提前发现吗?
人工智能·python·网络安全·健康医疗·量子计算
minhuan5 分钟前
大模型应用:大模型越大越好?模型参数量与效果的边际效益分析.51
人工智能·大模型参数评估·边际效益分析·大模型参数选择
Cherry的跨界思维10 分钟前
28、AI测试环境搭建与全栈工具实战:从本地到云平台的完整指南
java·人工智能·vue3·ai测试·ai全栈·测试全栈·ai测试全栈
MM_MS13 分钟前
Halcon变量控制类型、数据类型转换、字符串格式化、元组操作
开发语言·人工智能·深度学习·算法·目标检测·计算机视觉·视觉检测
ASF1231415sd26 分钟前
【基于YOLOv10n-CSP-PTB的大豆花朵检测与识别系统详解】
人工智能·yolo·目标跟踪
水如烟1 小时前
孤能子视角:“意识“的阶段性回顾,“感质“假说
人工智能
Carl_奕然1 小时前
【数据挖掘】数据挖掘必会技能之:A/B测试
人工智能·python·数据挖掘·数据分析
旅途中的宽~1 小时前
《European Radiology》:2024血管瘤分割—基于MRI T1序列的分割算法
人工智能·计算机视觉·mri·sci一区top·血管瘤·t1
岁月宁静1 小时前
当 AI 越来越“聪明”,人类真正的护城河是什么:智商、意识与认知主权
人工智能
CareyWYR1 小时前
每周AI论文速递(260105-260109)
人工智能