引言:看似双生,实则不同
在 PyTorch 的张量操作中,torch.cat (concatenate) 和 torch.stack 是两个最高频出现的"兄弟"函数。初学者往往觉得它们都能合并张量,但在构建复杂的神经网络(如 SKNet、特征融合模块)时,选错函数往往会导致维度冲突或逻辑错误。
本质上,这两者的区别在于:你是要在"原有的地盘"扩张,还是"开辟新的维度"?
1. 概念解剖:操作的本质
1.1 torch.cat (拼接/级联)
torch.cat 是在现有的维度上将一系列张量连接起来。
- 物理类比: 像是在一列火车的末尾再挂上几节车厢。火车的"轨道"(维度)没变,只是"长度"增加了。
- 要求: 除了拼接的那个维度外,其他所有维度的尺寸必须完全一致。
1.2 torch.stack (堆叠)
torch.stack 会创建一个新的维度,并将张量沿着这个新维度排列。
- 物理类比: 像是把几张平面的照片叠在一起,构成一本相册。相册比照片多了一个"厚度"的维度。
- 要求: 所有参与堆叠的张量,其形状(Shape)必须完全相同。
2. 深度对比:维度的几何演变
假设我们有两个形状为 的张量 和 :
维度计算演示
-
执行
torch.cat([A, B], dim=0: -
在第 0 维拼接:
-
结果仍然是 2 维。
-
执行
torch.stack([A, B], dim=0): -
在最前面新增一维:
-
结果变成了 3 维。
数学等价性
这是理解两者的核心深度所在:stack 实际上是 unsqueeze + cat 的组合技。
通过这个公式可以发现,stack 的本质是先给每个张量"升位",再进行拼接。
3. 性能与内存布局
在底层实现中,张量是存储在连续的内存块中的。
cat的效率: 如果拼接后的张量依然能保持内存连续,其效率极高。stack的考量: 由于引入了新维度,stack往往涉及到更多的数据重排。在处理海量数据(如大型视频序列)时,频繁的stack可能会带来细微的开销。
4. 实战场景:我该选哪一个?
4.1 使用 torch.cat 的场景
- 特征融合 (Feature Fusion): 在 U-Net 等网络中,将 Encoder 的浅层特征和 Decoder 的深层特征在通道维度拼接。
- 数据增强: 将原图和翻转后的图拼接在一起形成一个更大的 Batch。
4.2 使用 torch.stack 的场景
- 多分支决策 (SKAttention/Attention): 如我们在讨论 SKNet 时看到的,将不同卷积核的结果堆叠起来,以便后续通过
softmax算出每个分支的权重。 - 序列处理: 将一帧一帧的 2D 图像叠成 3D 的视频片段。
5. 总结对照表
| 特性 | torch.cat | torch.stack |
|---|---|---|
| 维度数量 | 保持不变 | 增加 1 维 |
| 输入约束 | 仅拼接维可不等,其余必相等 | 所有维度必须完全相等 |
| 底层逻辑 | 在现有轴上延伸 | 先 unsqueeze 再 cat |
| 典型应用 | 通道拼接、特征融合 | 多分支加权、序列建模 |
6. 代码深度实战:从维度看本质
6.1 torch.cat:原位扩张的"胶水"
cat 操作不会增加维度的数量,它只是在已有的某个维度上增加"长度"。
python
import torch
# 构造两个形状为 (2, 3) 的张量 (2行3列)
A = torch.ones((2, 3))
B = torch.zeros((2, 3))
# 情况 1:在第 0 维(行)拼接
res_cat_0 = torch.cat([A, B], dim=0)
print(f"cat (dim=0) 形状: {res_cat_0.shape}")
# 输出: torch.Size([4, 3]) -> 行数增加了
# 情况 2:在第 1 维(列)拼接
res_cat_1 = torch.cat([A, B], dim=1)
print(f"cat (dim=1) 形状: {res_cat_1.shape}")
# 输出: torch.Size([2, 6]) -> 列数增加了
6.2 torch.stack:开辟新维度的"层叠"
stack 之后,你会发现张量的维度从 2 维变成了 3 维。这多出来的一维,就是用来区分 A 和 B 的"层"。
python
# 执行 stack 操作
res_stack = torch.stack([A, B], dim=0)
print(f"stack (dim=0) 形状: {res_stack.shape}")
# 输出: torch.Size([2, 2, 3]) -> 变成 3 维了!
# 查看内容分布
print(f"第 0 层内容 (A):\n{res_stack[0]}")
print(f"第 1 层内容 (B):\n{res_stack[1]}")
6.3 深度证明:stack 的等价逻辑
为了深刻理解 stack 的底层逻辑,我们可以用代码验证那个经典的公式:stack = unsqueeze + cat。
python
# 1. 先给 A 和 B 手动升维,从 (2, 3) 变成 (1, 2, 3)
A_un = A.unsqueeze(0)
B_un = B.unsqueeze(0)
# 2. 在新开辟的第 0 维上进行 cat 拼接
res_manual = torch.cat([A_un, B_un], dim=0)
# 3. 验证是否与直接 stack 的结果一致
print(f"两者是否相等: {torch.equal(res_stack, res_manual)}")
# 输出: True
7. 常见坑点与报错分析
错误 1:非拼接维度的尺寸不一致
- 场景 :尝试
cat两个形状分别为 (2,3)(2, 3)(2,3) 和 (2,4)(2, 4)(2,4) 的张量,但dim=0。 - 后果 :报错
RuntimeError。 - 结论 :
cat要求除了拼接维度外,其他维度必须完全相等。
错误 2:stack 的形状不严谨
- 场景 :尝试
stack两个形状分别为 (2,3)(2, 3)(2,3) 和 (2,4)(2, 4)(2,4) 的张量。 - 后果:直接报错。
- 结论 :
stack要求所有输入张量的形状必须一模一样,因为它要整齐地叠在一起。
8. 总结
在深度学习的代码实现中:
- 如果你想把多个特征图合并成一个更厚的特征图 (通道数增加),请找
cat。 - 如果你想把多个独立的分支结果排好队 ,以便后续做多分支加权(如 SKNet、Video 处理),请找
stack。
理解 stack 与 cat 的区别,是通往深度学习高手之路的必经门槛。cat 是广度 的积累,而 stack 是深度的跨越。在设计你的网络结构时,请根据你是否需要"独立的分支维度"来做出明智的选择。