深度学习pytorch——拼接与拆分(持续更新)

cat拼接

使用条件:合并的dim的size可以不同,但是其它的dim的size必须相同。

语法:cat([tensor1,tensor2],dim = n) # 将tensor1和tensor2的第n个维度合并

代码演示:

python 复制代码
# 拼接与拆分
a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
print(torch.cat([a,b],dim=0).shape)     # torch.Size([9, 32, 8])

stack拼接

为什么要使用stack?下面会举个例子阐述一下原因:

A [32, 8] # 一个班,一共有32个同学,每个同学有8门成绩

B [32, 8] # 一个班,一共有32个同学,每个同学有8门成绩

cat:[64, 8] # 一个班,一共有64个同学,每个同学有8门成绩,不符合实际

stack: [2, 32, 8] # 2个班,每个班有32个同学,每个同学有8门成绩,符合实际

使用条件:A.shape = B.shape

代码演示:

python 复制代码
a = torch.rand(32,8)
b = torch.rand(32,8)
print(torch.cat([a,b],dim=0).shape)     # torch.Size([64, 8])
print(torch.stack([a,b],dim=0).shape)   # torch.Size([2, 32, 8])

split------根据长度拆分

语法:split(len, dim = n) # 在第n个维度拆分,每个size=len

代码演示:

python 复制代码
# c.shape = torch.Size([2, 32, 8])
aa, bb = c.split(1,dim=0)
print(aa.shape,bb.shape)                # torch.Size([1, 32, 8]) torch.Size([1, 32, 8])

注意:不要超过第0维的总体长度2,等于也不行,别忘了split进行的是拆分。

chunk------根据数量拆分

语法:chunk(num, dim = n) # 在第n维进行拆分,拆分为num份

代码演示:

python 复制代码
# c.shape = torch.Size([2, 32, 8])
aa, bb = c.chunk(2,dim = 0)
print(aa.shape,bb.shape)                # torch.Size([1, 32, 8]) torch.Size([1, 32, 8])
相关推荐
哈__3 分钟前
实测VLM:昇腾平台上的视觉语言模型测评与优化实践
人工智能·语言模型·自然语言处理·gitcode·sglang
海森大数据9 分钟前
数据筛选新范式:以质胜量,揭开大模型后训练黑箱
人工智能·语言模型
PNP Robotics10 分钟前
PNP机器人受邀参加英业达具身智能活动
大数据·人工智能·python·学习·机器人
祝余Eleanor15 分钟前
Day 51 神经网络调参指南
深度学习·神经网络·机器学习
智算菩萨17 分钟前
【Python进阶】搭建AI工程:Python模块、包与版本控制
开发语言·人工智能·python
算法熔炉22 分钟前
深度学习面试八股文(4)—— transformer专题
深度学习·面试·transformer
大模型真好玩24 分钟前
LangGraph智能体开发设计模式(一)——提示链模式、路由模式、并行化模式
人工智能·langchain·agent
大学生毕业题目26 分钟前
毕业项目推荐:90-基于yolov8/yolov5/yolo11的工程车辆检测识别系统(Python+卷积神经网络)
人工智能·python·yolo·目标检测·cnn·pyqt·工程车辆检测
是店小二呀27 分钟前
解构 Qwen2 在昇腾 Atlas 800T 上的极限性能:基于 SGLang 的深度评测
人工智能·npu
软件算法开发38 分钟前
基于山羚羊优化的LSTM深度学习网络模型(MGO-LSTM)的一维时间序列预测算法matlab仿真
深度学习·matlab·lstm·一维时间序列预测·山羚羊优化·mgo-lstm