李沐《动手学深度学习》torch.cat() 和 torch.stack()的区别及思考

一、问题引出

好久没更新啦!最近在学习沐神《动手学深度学习》6.5节池化层的时候,发现沐神在两处相似的地方使用了两种Python拼接函数torch.cat()和torch.stack():


百思不得其解,于是查阅相关文档之后终于弄清楚了两者之间的区别,遂做总结如下。

二、问题解决

1.torch.cat()

torch.cat()函数可以将多个张量拼接成一个张量。torch.cat()有两个参数,第一个是要拼接的张量的列表或是元组;第二个参数是拼接的维度

python 复制代码
# 假设是时间步T1的输出
T1 = torch.tensor([[1, 2, 3],
          [4, 5, 6],
          [7, 8, 9]])
# 假设是时间步T2的输出
T2 = torch.tensor([[10, 20, 30],
          [40, 50, 60],
          [70, 80, 90]])
print("T1.shape: ", T1.shape, "T2.shape: ", T2.shape)
print(torch.cat((T1,T2),dim=0).shape)
print(torch.cat((T1,T2),dim=1).shape)

输出为:

2.torch.stack()

torch.stack()函数同样有张量列表和维度两个参数。stack与cat的区别在于,torch.stack()函数要求输入张量的大小完全相同,得到的张量的维度会比输入的张量的大小多1,并且多出的那个维度就是拼接的维度,那个维度的大小就是输入张量的个数。

python 复制代码
print("T1.shape: ", T1.shape, "T2.shape: ", T2.shape)
print(torch.stack((T1,T2),dim=0).shape)
print(torch.stack((T1,T2),dim=1).shape)
print(torch.stack((T1,T2),dim=2).shape)

输出为:

三、总结

总的来说,cat 和 stack的区别在于 cat会增加现有维度的值,可以理解为续接,stack会新加增加一个维度,可以理解为叠加。

使用stack可以保留两个信息:[1. 序列] 和 [2. 张量矩阵] 信息,属于【扩张再拼接】的函数。形象的理解:假如数据都是二维矩阵(平面),它可以把这些一个个平面(矩阵)按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度。该函数常出现在自然语言处理(NLP)和图像卷积神经网络(CV)中。

欢迎大家一起跟着学习沐神的《动手学深度学习》,我建了一个github网站,发布了我的日常学习笔记,欢迎大家star!,网址为https://github.com/BugMaker2002/DeepLearningAction-LiMu

相关推荐
AAD555888992 小时前
数字仪表LCD显示识别与读数:数字0-9、小数点及单位kwh检测识别实战
python
开源技术4 小时前
Python Pillow 优化,打开和保存速度最快提高14倍
开发语言·python·pillow
Niuguangshuo4 小时前
深入解析Stable Diffusion基石——潜在扩散模型(LDMs)
人工智能·计算机视觉·stable diffusion
迈火4 小时前
SD - Latent - Interposer:解锁Stable Diffusion潜在空间的创意工具
人工智能·gpt·计算机视觉·stable diffusion·aigc·语音识别·midjourney
wfeqhfxz25887824 小时前
YOLO13-C3k2-GhostDynamicConv烟雾检测算法实现与优化
人工智能·算法·计算机视觉
芝士爱知识a4 小时前
2026年AI面试软件推荐
人工智能·面试·职场和发展·大模型·ai教育·考公·智蛙面试
Li emily5 小时前
解决港股实时行情数据 API 接入难题
人工智能·python·fastapi
Aaron15885 小时前
基于RFSOC的数字射频存储技术应用分析
c语言·人工智能·驱动开发·算法·fpga开发·硬件工程·信号处理
J_Xiong01175 小时前
【Agents篇】04:Agent 的推理能力——思维链与自我反思
人工智能·ai agent·推理
wfeqhfxz25887825 小时前
农田杂草检测与识别系统基于YOLO11实现六种杂草自动识别_1
python