torch.stack 是 PyTorch 中用于将一系列张量沿一个新的维度堆叠的函数。与 torch.cat 不同的是,torch.stack会在指定的维度上增加一个新的维度,而不是将张量直接拼接。
基本用法
语法:
torch.stack(tensors, dim=0)
tensors: 一个张量列表,包含多个形状相同的张量(shape 必须相同)。dim: 新增维度的位置,默认是0。
举例说明
假设有三个形状为 (2, 3) 的张量:
import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[7, 8, 9], [10, 11, 12]])
c = torch.tensor([[13, 14, 15], [16, 17, 18]])
沿 dim=0 堆叠
stacked = torch.stack([a, b, c], dim=0)
print(stacked.shape) # torch.Size([3, 2, 3])
- 在维度
0上增加一个新的维度,原始的(2, 3)形状变成(3, 2, 3)。 stacked的第0维度有3个元素,对应原来的a,b,c张量。
沿 dim=1 堆叠
stacked = torch.stack([a, b, c], dim=1)
print(stacked.shape) # torch.Size([2, 3, 3])
- 新的维度插入到原第
1维的位置。 stacked的第1维度有3个元素,对应原来的a,b,c张量。
沿 dim=2 堆叠
stacked = torch.stack([a, b, c], dim=2)
print(stacked.shape) # torch.Size([2, 3, 3])
- 新的维度插入到原第
2维的位置,形状变为(2, 3, 3)。
torch.stack 的形状变化总结
假设堆叠前的每个张量形状是 (A, B, C),在 dim=0、dim=1 和 dim=2 堆叠后的形状分别为:
dim=0:(N, A, B, C)dim=1:(A, N, B, C)dim=2:(A, B, N, C)
其中 N 是堆叠的张量数量。
和torch.cat函数的区别:
cat:在指定维度拼接多个张量。不增加维度。
c1 = torch.tensor([[1, 2], [3, 4]])
c2 = torch.tensor([[5, 6], [7, 8]])
c_cat = torch.cat([c1, c2], dim=0) # shape (4, 2)