Pytorch张量拼接秘籍:cat与stack的深度解析与实战

Pytorch张量拼接秘籍:cat与stack的深度解析与实战

在Pytorch的张量操作体系中,拼接是数据处理与模型构建里高频出现的核心操作,而torch.cat()torch.stack()作为实现张量拼接的两大核心函数,常常让初学者陷入混淆。二者虽都服务于张量的组合,但在维度处理、使用要求、应用场景上有着本质区别。今天,我们就透过底层逻辑+实战代码,彻底拆解这两个函数的奥秘,让你轻松掌握张量拼接的正确打开方式✨。

一、核心概念:读懂cat与stack的本质差异

张量拼接的核心矛盾,在于是否改变原张量的维度数 ,这也是catstack最根本的区别。为了更直观的对比,我们先通过表格梳理二者的核心特性:

函数 维度变化 形状要求 核心逻辑 应用场景
torch.cat() 不改变维度数 拼接维度外,其余维度形状必须完全一致 沿指定维度"平铺"拼接,仅扩展拼接维度的尺寸 同维度数据的合并,如批量整合样本、拼接特征图
torch.stack() 增加新维度 所有输入张量的全部维度形状必须完全一致 沿新维度"堆叠"张量,生成更高维的新张量 构建新的维度维度,如整合多个同形状的特征张量、构建批次维度
简单来说,cat同维度内的拼接 ,只把张量在指定方向拉长;stack跨维度的堆叠,为张量新增一个维度后再组合,相当于给多个张量套上了一个"新的大括号"。

二、torch.cat():不改变维度的"平铺式"拼接

torch.cat()的核心要义是沿指定维度拼接,维度数不变,这就要求待拼接的张量,除了我们指定的拼接维度可以有不同尺寸,其余所有维度的尺寸必须严格一致,否则会直接触发形状不匹配的报错。

2.1 基础语法

python 复制代码
import torch
# 基础格式
torch.cat(tensors, dim=0, out=None)
  • tensors:待拼接的张量序列(列表/元组形式)

  • dim:指定的拼接维度,支持正整数(0,1,2...)和负整数(-1表示最后一个维度)

  • out:可选参数,指定输出张量的存储位置

2.2 实战演示:二维张量的cat拼接

我们以2行3列的二维张量为基础,分别演示沿0维度、1维度的拼接效果,同时验证"非拼接维度必须一致"的规则。

步骤1:创建基础张量

python 复制代码
import torch

# 创建两个2行3列的二维张量,元素范围1~10
T1 = torch.randint(1, 10, (2, 3))
T2 = torch.randint(1, 10, (2, 3))
print("张量T1:\n", T1)
print("张量T2:\n", T2)
print("T1形状:", T1.shape)  # 输出:torch.Size([2, 3])
print("T2形状:", T2.shape)  # 输出:torch.Size([2, 3])

步骤2:沿0维度拼接(行方向)

0维度是二维张量的行维度,沿0维度拼接,就是将多个张量的行按顺序平铺,最终行维度尺寸相加,列维度尺寸不变。

python 复制代码
# 沿0维度拼接
T3 = torch.cat((T1, T2), dim=0)
print("沿0维度拼接结果T3:\n", T3)
print("T3形状:", T3.shape)  # 输出:torch.Size([4, 3])

效果:2行3列 + 2行3列 → 4行3列,维度数仍为2,仅行维度从2扩展为4。

步骤3:沿1维度拼接(列方向)

1维度是二维张量的列维度,沿1维度拼接,就是将多个张量的列按顺序平铺,最终列维度尺寸相加,行维度尺寸不变。

python 复制代码
# 沿1维度拼接
T4 = torch.cat((T1, T2), dim=1)
print("沿1维度拼接结果T4:\n", T4)
print("T4形状:", T4.shape)  # 输出:torch.Size([2, 6])

效果:2行3列 + 2行3列 → 2行6列,维度数仍为2,仅列维度从3扩展为6。

步骤4:负维度拼接(dim=-1)

dim=-1表示最后一个维度 ,对于二维张量,最后一个维度就是1维度,因此沿dim=-1拼接与dim=1效果完全一致:

python 复制代码
# 沿-1维度拼接
T5 = torch.cat((T1, T2), dim=-1)
print("沿-1维度拼接结果T5:\n", T5)
print("T5形状:", T5.shape)  # 输出:torch.Size([2, 6])

步骤5:规则验证:非拼接维度不一致会报错

若我们将T2改为2行6列,沿1维度拼接时,行维度(非拼接维度)均为2,满足要求;但沿0维度拼接时,列维度(非拼接维度)3≠6,会直接报错:

python 复制代码
# 重构T2为2行6列
T2_new = torch.randint(1, 10, (2, 6))
# 沿1维度拼接:可行,行维度均为2
T6 = torch.cat((T1, T2_new), dim=1)
print("T1与T2_new沿1维度拼接形状:", T6.shape)  # 输出:torch.Size([2, 9])

# 沿0维度拼接:报错,列维度3≠6
try:
    T7 = torch.cat((T1, T2_new), dim=0)
except Exception as e:
    print("报错信息:", e)  # 输出:Size mismatch

步骤6:维度越界报错

cat不会改变原张量的维度数,因此指定的拼接维度不能超过原张量的维度范围 。二维张量的维度只有0和1,若指定dim=2,会直接触发维度越界报错:

python 复制代码
try:
    T8 = torch.cat((T1, T2), dim=2)
except Exception as e:
    print("报错信息:", e)  # 输出:Dimension out of range

2.3 cat的核心逻辑总结

torch.cat()的拼接逻辑可以用一句话概括:"指定维度自由扩展,其余维度严格对齐"。无论原张量是几维,只要满足"非拼接维度形状一致",就能实现平铺式拼接,且始终保持原有的维度数不变。

三、torch.stack():新增维度的"堆叠式"拼接

torch.stack()是比cat要求更严格的拼接方式,其核心是先新增一个维度,再沿该维度堆叠张量 ,因此要求所有待拼接张量的全部维度形状必须完全一致,哪怕有一个维度尺寸不同,都会触发报错。

stack的拼接过程,就像把多本相同大小的书,放进一个新的书立里------书的大小(张量形状)必须完全一样,而书立就是新增的维度。

3.1 基础语法

python 复制代码
import torch
# 基础格式
torch.stack(tensors, dim=0, out=None)

参数含义与cat一致,但dim的含义变为新维度的插入位置,而非原张量的拼接维度。

3.2 实战演示:二维张量的stack堆叠

我们仍以2行3列的二维张量T1、T2为基础,分别演示沿0、1、2维度堆叠的效果,理解"新维度插入"的核心逻辑。

步骤1:创建基础张量(设置随机种子保证结果固定)

为了让每次运行的张量值一致,我们设置随机种子,再创建相同形状的张量:

python 复制代码
import torch
torch.manual_seed(1)  # 设置随机种子
T1 = torch.randint(1, 10, (2, 3))
T2 = torch.randint(1, 10, (2, 3))
print("张量T1:\n", T1)
print("张量T2:\n", T2)
print("T1形状:", T1.shape)  # 输出:torch.Size([2, 3])
print("T2形状:", T2.shape)  # 输出:torch.Size([2, 3])

步骤2:沿0维度堆叠(新维度插入在最外层)

沿0维度堆叠,就是在原张量的最外层插入新维度,将两个2行3列的二维张量,堆叠成一个3维张量,新维度尺寸为2(对应待拼接的张量个数)。

python 复制代码
# 沿0维度堆叠
T9 = torch.stack((T1, T2), dim=0)
print("沿0维度堆叠结果T9:\n", T9)
print("T9形状:", T9.shape)  # 输出:torch.Size([2, 2, 3])

效果:2个2行3列的二维张量 → 形状为[2,2,3]的三维张量,新维度为最外层的0维度,尺寸为2(代表有2个原始张量)。

步骤3:沿1维度堆叠(新维度插入在中间)

沿1维度堆叠,就是在原张量的中间维度插入新维度,最终仍生成[2,2,3]的三维张量,但堆叠逻辑变为"按原张量的行维度对应堆叠"。

python 复制代码
# 沿1维度堆叠
T10 = torch.stack((T1, T2), dim=1)
print("沿1维度堆叠结果T10:\n", T10)
print("T10形状:", T10.shape)  # 输出:torch.Size([2, 2, 3])

效果:原张量的每一行分别对应堆叠,比如T1的第一行与T2的第一行组成新维度的一个元素,最终仍为[2,2,3]的三维张量。

步骤4:沿2维度堆叠(新维度插入在最内层)

沿2维度堆叠,就是在原张量的最内层插入新维度,生成的三维张量形状仍为[2,2,3],堆叠逻辑变为"按原张量的每个元素对应堆叠"。

python 复制代码
# 沿2维度堆叠
T11 = torch.stack((T1, T2), dim=2)
print("沿2维度堆叠结果T11:\n", T11)
print("T11形状:", T11.shape)  # 输出:torch.Size([2, 2, 3])

效果:原张量的每个位置的元素一一对应堆叠,比如T1[0,0]与T2[0,0]组成新维度的一个元素,实现元素级的堆叠。

步骤5:规则验证1:张量形状不一致会报错

stack要求所有维度完全一致,若将T2改为3行3列,哪怕只有一个维度尺寸不同,也会直接报错:

python 复制代码
# 重构T2为3行3列,与T1形状不一致
T2_error = torch.randint(1, 10, (3, 3))
try:
    T12 = torch.stack((T1, T2_error), dim=0)
except Exception as e:
    print("报错信息:", e)  # 输出:Size mismatch

步骤6:规则验证2:新维度越界会报错

对于二维张量,stack支持的新维度插入位置为0、1、2(对应原维度前、原维度间、原维度后),若指定dim=3,会触发维度越界报错:

python 复制代码
try:
    T13 = torch.stack((T1, T2), dim=3)
except Exception as e:
    print("报错信息:", e)  # 输出:Dimension out of range

3.3 stack的核心逻辑总结

torch.stack()的核心是**"先插新维,再做堆叠"**,三个关键点需牢记:

  1. 待拼接张量形状必须完全一致,无任何灵活空间;

  2. 拼接后维度数会增加1,新维度的尺寸等于待拼接的张量个数;

  3. dim参数表示新维度的插入位置,而非原张量的拼接维度。

四、cat与stack的拼接逻辑可视化

为了更直观的理解二者的拼接差异,我们用Mermaid的图形语法,可视化二维张量(2,3)的cat(dim=0)和stack(dim=0)操作过程:

4.1 torch.cat(dim=0) 拼接可视化

行平铺
行平铺
张量T1(2,3)
拼接结果(4,3)
张量T2(2,3)
注:仅扩展行维度,维度数仍为2

说明 :该图展示了cat沿0维度的拼接逻辑,T1和T2的行按顺序直接平铺,最终行维度从2+2=4,列维度保持3不变,整个过程未新增任何维度,张量仍为二维。

4.2 torch.stack(dim=0) 堆叠可视化

套新维度
套新维度
堆叠
张量T1(2,3)
新维度0 尺寸=2
张量T2(2,3)
堆叠结果(2,2,3)
注:插入新维度0,维度数变为3

说明 :该图展示了stack沿0维度的堆叠逻辑,先为T1和T2插入一个新的外层维度(尺寸为2,对应2个张量),再将两个张量放入新维度中完成堆叠,最终张量从二维变为三维,形状为(2,2,3)。

五、实战选型:什么时候用cat?什么时候用stack?

理解了二者的差异,核心问题就变成了场景匹配------根据业务需求选择合适的函数,才能让张量操作更高效、更贴合逻辑。

5.1 优先使用torch.cat()的场景

cat因灵活性更高(仅要求非拼接维度一致),是实际开发中使用频率更高 的拼接方式,适合所有同维度数据合并的需求:

  1. 批量整合样本:比如有两个批次的图片张量,形状分别为(32, 3, 224, 224)和(16, 3, 224, 224),沿0维度(批次维度)拼接为(48, 3, 224, 224),整合为一个大批次;

  2. 拼接特征图:模型中不同层的特征图,若除通道维度外其余维度一致,可沿通道维度拼接,扩展特征维度;

  3. 整合序列数据:自然语言处理中,两个同长度的词向量序列,沿列维度拼接,丰富特征信息。

5.2 优先使用torch.stack()的场景

stack因要求严格,适合需要构建新维度的场景,核心是将多个同形状的张量,组合成一个更高维的张量:

  1. 构建批次维度:若有10张单独的图片张量,形状均为(3, 224, 224),沿0维度stack后,生成(10, 3, 224, 224)的批次张量,直接输入模型;

  2. 整合多视角特征:同一样本的多个视角特征,形状均为(128,),stack后生成(8, 128)的张量(8为视角数),构建多视角特征维度;

  3. 生成序列维度:将多个同形状的时间步特征,stack后新增时间维度,构建时序张量。

六、核心知识点梳理

  1. torch.cat()平铺拼接,维度不变,非拼接维度需一致,灵活度高,使用频率高;

  2. torch.stack()堆叠拼接,新增维度,所有维度需完全一致,要求严格,适合构建新维度;

  3. dim参数:cat中是原张量的拼接维度 ,stack中是新维度的插入位置

  4. 负维度dim=-1:均表示最后一个维度,cat和stack中均适用;

  5. 维度越界:二者指定的dim均不能超过自身支持的维度范围,否则报错。

七、写在最后

catstack作为Pytorch张量拼接的双核心,看似简单,却是理解维度操作的关键。很多时候初学者的报错,本质都是对"维度是否变化""形状要求是什么"理解不到位。

记住一个简单的判断法则:如果想把张量"拉长",用cat;如果想给张量"套新维度",用stack。掌握这个核心,再结合实战验证,就能彻底避开二者的使用误区。

在后续的Pytorch学习中,维度操作会贯穿始终,从数据预处理到模型构建,从特征提取到结果整合,都离不开对catstack的灵活运用。打好这个基础,后续的高维张量操作会变得事半功倍💪

相关推荐
java1234_小锋1 分钟前
Spring AI 2.0 开发Java Agent智能体 - Spring AI 2.0简介
java·人工智能·spring·spring ai
Jun6262 分钟前
【树莓派】opencv水滴接触角测量
人工智能·opencv·计算机视觉
2401_882273725 分钟前
golang如何处理zip压缩包_golang zip压缩包处理思路
jvm·数据库·python
tankeven5 分钟前
贪心算法(Greedy Algorithm)详解:从理论到C++实践
c++·算法
Hesionberger6 分钟前
LeetCode72.编辑距离(多维动态规划)
java·开发语言·c++·python·算法
zhangfeng11337 分钟前
No space left on device (28) llamafactory微调训练的时候 报错,需要调节 dataloader_num_workers
人工智能·语言模型·llama
lwf0061648 分钟前
逻辑回归学习笔记-梯度下降求解回归方程
算法·机器学习·逻辑回归
流年似水~8 分钟前
iOS 开发进阶之路:从能跑到能维护
人工智能·程序人生·ios·语言模型
QuestLab10 分钟前
【第23期】2026年4月26日 AI日报
人工智能
tjc1990100511 分钟前
Golang怎么实现分布式定时任务_Golang如何保证集群中定时任务不重复执行【进阶】
jvm·数据库·python