Pytorch张量拼接秘籍:cat与stack的深度解析与实战
- 一、核心概念:读懂cat与stack的本质差异
- 二、torch.cat():不改变维度的"平铺式"拼接
-
- [2.1 基础语法](#2.1 基础语法)
- [2.2 实战演示:二维张量的cat拼接](#2.2 实战演示:二维张量的cat拼接)
- [2.3 cat的核心逻辑总结](#2.3 cat的核心逻辑总结)
- 三、torch.stack():新增维度的"堆叠式"拼接
-
- [3.1 基础语法](#3.1 基础语法)
- [3.2 实战演示:二维张量的stack堆叠](#3.2 实战演示:二维张量的stack堆叠)
- [3.3 stack的核心逻辑总结](#3.3 stack的核心逻辑总结)
- 四、cat与stack的拼接逻辑可视化
-
- [4.1 torch.cat(dim=0) 拼接可视化](#4.1 torch.cat(dim=0) 拼接可视化)
- [4.2 torch.stack(dim=0) 堆叠可视化](#4.2 torch.stack(dim=0) 堆叠可视化)
- 五、实战选型:什么时候用cat?什么时候用stack?
-
- [5.1 优先使用torch.cat()的场景](#5.1 优先使用torch.cat()的场景)
- [5.2 优先使用torch.stack()的场景](#5.2 优先使用torch.stack()的场景)
- 六、核心知识点梳理
- 七、写在最后
在Pytorch的张量操作体系中,拼接是数据处理与模型构建里高频出现的核心操作,而torch.cat()与torch.stack()作为实现张量拼接的两大核心函数,常常让初学者陷入混淆。二者虽都服务于张量的组合,但在维度处理、使用要求、应用场景上有着本质区别。今天,我们就透过底层逻辑+实战代码,彻底拆解这两个函数的奥秘,让你轻松掌握张量拼接的正确打开方式✨。
一、核心概念:读懂cat与stack的本质差异
张量拼接的核心矛盾,在于是否改变原张量的维度数 ,这也是cat和stack最根本的区别。为了更直观的对比,我们先通过表格梳理二者的核心特性:
| 函数 | 维度变化 | 形状要求 | 核心逻辑 | 应用场景 |
|---|---|---|---|---|
torch.cat() |
不改变维度数 | 除拼接维度外,其余维度形状必须完全一致 | 沿指定维度"平铺"拼接,仅扩展拼接维度的尺寸 | 同维度数据的合并,如批量整合样本、拼接特征图 |
torch.stack() |
增加新维度 | 所有输入张量的全部维度形状必须完全一致 | 沿新维度"堆叠"张量,生成更高维的新张量 | 构建新的维度维度,如整合多个同形状的特征张量、构建批次维度 |
简单来说,cat是同维度内的拼接 ,只把张量在指定方向拉长;stack是跨维度的堆叠,为张量新增一个维度后再组合,相当于给多个张量套上了一个"新的大括号"。 |
二、torch.cat():不改变维度的"平铺式"拼接
torch.cat()的核心要义是沿指定维度拼接,维度数不变,这就要求待拼接的张量,除了我们指定的拼接维度可以有不同尺寸,其余所有维度的尺寸必须严格一致,否则会直接触发形状不匹配的报错。
2.1 基础语法
python
import torch
# 基础格式
torch.cat(tensors, dim=0, out=None)
-
tensors:待拼接的张量序列(列表/元组形式) -
dim:指定的拼接维度,支持正整数(0,1,2...)和负整数(-1表示最后一个维度) -
out:可选参数,指定输出张量的存储位置
2.2 实战演示:二维张量的cat拼接
我们以2行3列的二维张量为基础,分别演示沿0维度、1维度的拼接效果,同时验证"非拼接维度必须一致"的规则。
步骤1:创建基础张量
python
import torch
# 创建两个2行3列的二维张量,元素范围1~10
T1 = torch.randint(1, 10, (2, 3))
T2 = torch.randint(1, 10, (2, 3))
print("张量T1:\n", T1)
print("张量T2:\n", T2)
print("T1形状:", T1.shape) # 输出:torch.Size([2, 3])
print("T2形状:", T2.shape) # 输出:torch.Size([2, 3])
步骤2:沿0维度拼接(行方向)
0维度是二维张量的行维度,沿0维度拼接,就是将多个张量的行按顺序平铺,最终行维度尺寸相加,列维度尺寸不变。
python
# 沿0维度拼接
T3 = torch.cat((T1, T2), dim=0)
print("沿0维度拼接结果T3:\n", T3)
print("T3形状:", T3.shape) # 输出:torch.Size([4, 3])
效果:2行3列 + 2行3列 → 4行3列,维度数仍为2,仅行维度从2扩展为4。
步骤3:沿1维度拼接(列方向)
1维度是二维张量的列维度,沿1维度拼接,就是将多个张量的列按顺序平铺,最终列维度尺寸相加,行维度尺寸不变。
python
# 沿1维度拼接
T4 = torch.cat((T1, T2), dim=1)
print("沿1维度拼接结果T4:\n", T4)
print("T4形状:", T4.shape) # 输出:torch.Size([2, 6])
效果:2行3列 + 2行3列 → 2行6列,维度数仍为2,仅列维度从3扩展为6。
步骤4:负维度拼接(dim=-1)
dim=-1表示最后一个维度 ,对于二维张量,最后一个维度就是1维度,因此沿dim=-1拼接与dim=1效果完全一致:
python
# 沿-1维度拼接
T5 = torch.cat((T1, T2), dim=-1)
print("沿-1维度拼接结果T5:\n", T5)
print("T5形状:", T5.shape) # 输出:torch.Size([2, 6])
步骤5:规则验证:非拼接维度不一致会报错
若我们将T2改为2行6列,沿1维度拼接时,行维度(非拼接维度)均为2,满足要求;但沿0维度拼接时,列维度(非拼接维度)3≠6,会直接报错:
python
# 重构T2为2行6列
T2_new = torch.randint(1, 10, (2, 6))
# 沿1维度拼接:可行,行维度均为2
T6 = torch.cat((T1, T2_new), dim=1)
print("T1与T2_new沿1维度拼接形状:", T6.shape) # 输出:torch.Size([2, 9])
# 沿0维度拼接:报错,列维度3≠6
try:
T7 = torch.cat((T1, T2_new), dim=0)
except Exception as e:
print("报错信息:", e) # 输出:Size mismatch
步骤6:维度越界报错
cat不会改变原张量的维度数,因此指定的拼接维度不能超过原张量的维度范围 。二维张量的维度只有0和1,若指定dim=2,会直接触发维度越界报错:
python
try:
T8 = torch.cat((T1, T2), dim=2)
except Exception as e:
print("报错信息:", e) # 输出:Dimension out of range
2.3 cat的核心逻辑总结
torch.cat()的拼接逻辑可以用一句话概括:"指定维度自由扩展,其余维度严格对齐"。无论原张量是几维,只要满足"非拼接维度形状一致",就能实现平铺式拼接,且始终保持原有的维度数不变。
三、torch.stack():新增维度的"堆叠式"拼接
torch.stack()是比cat要求更严格的拼接方式,其核心是先新增一个维度,再沿该维度堆叠张量 ,因此要求所有待拼接张量的全部维度形状必须完全一致,哪怕有一个维度尺寸不同,都会触发报错。
stack的拼接过程,就像把多本相同大小的书,放进一个新的书立里------书的大小(张量形状)必须完全一样,而书立就是新增的维度。
3.1 基础语法
python
import torch
# 基础格式
torch.stack(tensors, dim=0, out=None)
参数含义与cat一致,但dim的含义变为新维度的插入位置,而非原张量的拼接维度。
3.2 实战演示:二维张量的stack堆叠
我们仍以2行3列的二维张量T1、T2为基础,分别演示沿0、1、2维度堆叠的效果,理解"新维度插入"的核心逻辑。
步骤1:创建基础张量(设置随机种子保证结果固定)
为了让每次运行的张量值一致,我们设置随机种子,再创建相同形状的张量:
python
import torch
torch.manual_seed(1) # 设置随机种子
T1 = torch.randint(1, 10, (2, 3))
T2 = torch.randint(1, 10, (2, 3))
print("张量T1:\n", T1)
print("张量T2:\n", T2)
print("T1形状:", T1.shape) # 输出:torch.Size([2, 3])
print("T2形状:", T2.shape) # 输出:torch.Size([2, 3])
步骤2:沿0维度堆叠(新维度插入在最外层)
沿0维度堆叠,就是在原张量的最外层插入新维度,将两个2行3列的二维张量,堆叠成一个3维张量,新维度尺寸为2(对应待拼接的张量个数)。
python
# 沿0维度堆叠
T9 = torch.stack((T1, T2), dim=0)
print("沿0维度堆叠结果T9:\n", T9)
print("T9形状:", T9.shape) # 输出:torch.Size([2, 2, 3])
效果:2个2行3列的二维张量 → 形状为[2,2,3]的三维张量,新维度为最外层的0维度,尺寸为2(代表有2个原始张量)。
步骤3:沿1维度堆叠(新维度插入在中间)
沿1维度堆叠,就是在原张量的中间维度插入新维度,最终仍生成[2,2,3]的三维张量,但堆叠逻辑变为"按原张量的行维度对应堆叠"。
python
# 沿1维度堆叠
T10 = torch.stack((T1, T2), dim=1)
print("沿1维度堆叠结果T10:\n", T10)
print("T10形状:", T10.shape) # 输出:torch.Size([2, 2, 3])
效果:原张量的每一行分别对应堆叠,比如T1的第一行与T2的第一行组成新维度的一个元素,最终仍为[2,2,3]的三维张量。
步骤4:沿2维度堆叠(新维度插入在最内层)
沿2维度堆叠,就是在原张量的最内层插入新维度,生成的三维张量形状仍为[2,2,3],堆叠逻辑变为"按原张量的每个元素对应堆叠"。
python
# 沿2维度堆叠
T11 = torch.stack((T1, T2), dim=2)
print("沿2维度堆叠结果T11:\n", T11)
print("T11形状:", T11.shape) # 输出:torch.Size([2, 2, 3])
效果:原张量的每个位置的元素一一对应堆叠,比如T1[0,0]与T2[0,0]组成新维度的一个元素,实现元素级的堆叠。
步骤5:规则验证1:张量形状不一致会报错
stack要求所有维度完全一致,若将T2改为3行3列,哪怕只有一个维度尺寸不同,也会直接报错:
python
# 重构T2为3行3列,与T1形状不一致
T2_error = torch.randint(1, 10, (3, 3))
try:
T12 = torch.stack((T1, T2_error), dim=0)
except Exception as e:
print("报错信息:", e) # 输出:Size mismatch
步骤6:规则验证2:新维度越界会报错
对于二维张量,stack支持的新维度插入位置为0、1、2(对应原维度前、原维度间、原维度后),若指定dim=3,会触发维度越界报错:
python
try:
T13 = torch.stack((T1, T2), dim=3)
except Exception as e:
print("报错信息:", e) # 输出:Dimension out of range
3.3 stack的核心逻辑总结
torch.stack()的核心是**"先插新维,再做堆叠"**,三个关键点需牢记:
-
待拼接张量形状必须完全一致,无任何灵活空间;
-
拼接后维度数会增加1,新维度的尺寸等于待拼接的张量个数;
-
dim参数表示新维度的插入位置,而非原张量的拼接维度。
四、cat与stack的拼接逻辑可视化
为了更直观的理解二者的拼接差异,我们用Mermaid的图形语法,可视化二维张量(2,3)的cat(dim=0)和stack(dim=0)操作过程:
4.1 torch.cat(dim=0) 拼接可视化
行平铺
行平铺
张量T1(2,3)
拼接结果(4,3)
张量T2(2,3)
注:仅扩展行维度,维度数仍为2
说明 :该图展示了cat沿0维度的拼接逻辑,T1和T2的行按顺序直接平铺,最终行维度从2+2=4,列维度保持3不变,整个过程未新增任何维度,张量仍为二维。
4.2 torch.stack(dim=0) 堆叠可视化
套新维度
套新维度
堆叠
张量T1(2,3)
新维度0 尺寸=2
张量T2(2,3)
堆叠结果(2,2,3)
注:插入新维度0,维度数变为3
说明 :该图展示了stack沿0维度的堆叠逻辑,先为T1和T2插入一个新的外层维度(尺寸为2,对应2个张量),再将两个张量放入新维度中完成堆叠,最终张量从二维变为三维,形状为(2,2,3)。
五、实战选型:什么时候用cat?什么时候用stack?
理解了二者的差异,核心问题就变成了场景匹配------根据业务需求选择合适的函数,才能让张量操作更高效、更贴合逻辑。
5.1 优先使用torch.cat()的场景
cat因灵活性更高(仅要求非拼接维度一致),是实际开发中使用频率更高 的拼接方式,适合所有同维度数据合并的需求:
-
批量整合样本:比如有两个批次的图片张量,形状分别为(32, 3, 224, 224)和(16, 3, 224, 224),沿0维度(批次维度)拼接为(48, 3, 224, 224),整合为一个大批次;
-
拼接特征图:模型中不同层的特征图,若除通道维度外其余维度一致,可沿通道维度拼接,扩展特征维度;
-
整合序列数据:自然语言处理中,两个同长度的词向量序列,沿列维度拼接,丰富特征信息。
5.2 优先使用torch.stack()的场景
stack因要求严格,适合需要构建新维度的场景,核心是将多个同形状的张量,组合成一个更高维的张量:
-
构建批次维度:若有10张单独的图片张量,形状均为(3, 224, 224),沿0维度stack后,生成(10, 3, 224, 224)的批次张量,直接输入模型;
-
整合多视角特征:同一样本的多个视角特征,形状均为(128,),stack后生成(8, 128)的张量(8为视角数),构建多视角特征维度;
-
生成序列维度:将多个同形状的时间步特征,stack后新增时间维度,构建时序张量。
六、核心知识点梳理
-
torch.cat():平铺拼接,维度不变,非拼接维度需一致,灵活度高,使用频率高; -
torch.stack():堆叠拼接,新增维度,所有维度需完全一致,要求严格,适合构建新维度; -
dim参数:cat中是原张量的拼接维度 ,stack中是新维度的插入位置; -
负维度
dim=-1:均表示最后一个维度,cat和stack中均适用; -
维度越界:二者指定的dim均不能超过自身支持的维度范围,否则报错。
七、写在最后
cat和stack作为Pytorch张量拼接的双核心,看似简单,却是理解维度操作的关键。很多时候初学者的报错,本质都是对"维度是否变化""形状要求是什么"理解不到位。
记住一个简单的判断法则:如果想把张量"拉长",用cat;如果想给张量"套新维度",用stack。掌握这个核心,再结合实战验证,就能彻底避开二者的使用误区。

在后续的Pytorch学习中,维度操作会贯穿始终,从数据预处理到模型构建,从特征提取到结果整合,都离不开对cat和stack的灵活运用。打好这个基础,后续的高维张量操作会变得事半功倍💪