李沐《动手学深度学习》torch.cat() 和 torch.stack()的区别及思考

一、问题引出

好久没更新啦!最近在学习沐神《动手学深度学习》6.5节池化层的时候,发现沐神在两处相似的地方使用了两种Python拼接函数torch.cat()和torch.stack():


百思不得其解,于是查阅相关文档之后终于弄清楚了两者之间的区别,遂做总结如下。

二、问题解决

1.torch.cat()

torch.cat()函数可以将多个张量拼接成一个张量。torch.cat()有两个参数,第一个是要拼接的张量的列表或是元组;第二个参数是拼接的维度

python 复制代码
# 假设是时间步T1的输出
T1 = torch.tensor([[1, 2, 3],
          [4, 5, 6],
          [7, 8, 9]])
# 假设是时间步T2的输出
T2 = torch.tensor([[10, 20, 30],
          [40, 50, 60],
          [70, 80, 90]])
print("T1.shape: ", T1.shape, "T2.shape: ", T2.shape)
print(torch.cat((T1,T2),dim=0).shape)
print(torch.cat((T1,T2),dim=1).shape)

输出为:

2.torch.stack()

torch.stack()函数同样有张量列表和维度两个参数。stack与cat的区别在于,torch.stack()函数要求输入张量的大小完全相同,得到的张量的维度会比输入的张量的大小多1,并且多出的那个维度就是拼接的维度,那个维度的大小就是输入张量的个数。

python 复制代码
print("T1.shape: ", T1.shape, "T2.shape: ", T2.shape)
print(torch.stack((T1,T2),dim=0).shape)
print(torch.stack((T1,T2),dim=1).shape)
print(torch.stack((T1,T2),dim=2).shape)

输出为:

三、总结

总的来说,cat 和 stack的区别在于 cat会增加现有维度的值,可以理解为续接,stack会新加增加一个维度,可以理解为叠加。

使用stack可以保留两个信息:1. 序列2. 张量矩阵 信息,属于【扩张再拼接】的函数。形象的理解:假如数据都是二维矩阵(平面),它可以把这些一个个平面(矩阵)按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度。该函数常出现在自然语言处理(NLP)和图像卷积神经网络(CV)中。

欢迎大家一起跟着学习沐神的《动手学深度学习》,我建了一个github网站,发布了我的日常学习笔记,欢迎大家star!,网址为https://github.com/BugMaker2002/DeepLearningAction-LiMu

相关推荐
金融小师妹2 分钟前
AI因子共振模型显示:金银比突破区间上沿,白银定价逻辑进入再校准阶段
人工智能·算法·均值算法·线性回归
奶油话梅糖3 分钟前
IMA 知识库体验(内有资源分享):把资料变成可以提问的 AI 知识助手
人工智能·ai·aigc·知识图谱·知识库·学习工具·ima
Orchestrator_me6 分钟前
Python pip install报SSL错误
python·ssl·pip
老金带你玩AI6 分钟前
用ChatGPT管项目,让Codex只做Ticket
人工智能
开源量化GO7 分钟前
期货 K 线算信号 tick 级止损:天勤双序列 wait_update 触发规则
linux·运维·服务器·python
聆春烟雨簌簌15 分钟前
LangChain4j使用文档
开发语言·python
前端不太难15 分钟前
从模型部署到智能运营:企业AI的新挑战
人工智能
ZFSS23 分钟前
VS Code + Luma MCP 使用教程
人工智能·ai·ai作画·copilot·ai编程·ai写作
某林21223 分钟前
ROS2 语音机器人实战:从 KCF 跟随失效到 RTAB-Map 建图闭环的完整排障
人工智能·机器人·语音识别·ros2·架构重构·技术复盘·c++底层排错