tensor连接和拆分

文章目录

连接

torch.cat()

函数目的: 在给定维度上对输入的张量序列 进行连接操作。

案例准备
python 复制代码
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float)
b = torch.tensor([[10,10,10,],[10,10,10],[10,10,10,]], dtype=torch.float)
python 复制代码
# dim指的是维度,dim = 0就是行,所以下面的代码就是按行拼接
print("按行拼接:\n",torch.cat((a,b),dim=0))
print("按行拼接:\n",torch.cat((a,b),dim=0).shape) #6行3列
python 复制代码
print("按列拼接:\n",torch.cat((a,b),dim=1))
print("按列拼接:\n",torch.cat((a,b),dim=1).shape)#3行6列

torch.stack()

沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同 形状。

也就是2维拼成3维,3维拼4维,以此类推。

python 复制代码
print("按行拼接:\n",torch.stack((a,b),dim=0))
print("按行拼接:\n",torch.stack((a,b),dim=0).shape) 
python 复制代码
print("按行拼接:\n",torch.stack((a,b),dim=1))
print("按行拼接:\n",torch.stack((a,b),dim=1).shape)
python 复制代码
print("按行拼接:\n",torch.stack((a,b),dim=2))
print("按行拼接:\n",torch.stack((a,b),dim=2).shape)
区别

stack与cat的区别在于,torch.stack()函数要求输入张量的大小完全相同,得到的张量的维度会比输入的张量的大小多1,并且多出的那个维度就是拼接的维度,那个维度的大小就是输入张量的个数。

python 复制代码
c = torch.tensor([[10,20],[30,40],[50,60]], dtype=torch.float)
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float)
torch.cat((a,c),dim=1)
python 复制代码
#但是以下情况就会出错
torch.cat((a,c),dim=0)

如图,按行拼接会缺数据,报错吗,应该的。

python 复制代码
torch.stack((a,c),dim=0)
###运行结果
RuntimeError: stack expects each tensor to be equal size, but got [3, 3] at entry 0 and [3, 2] at entry 1

再次验证stack需要两个大小一样的张量

拆分

torch.split()

def split(

tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0

) -> Tuple[Tensor, ...]:

  • 按块大小拆分张量 除不尽的取余数,返回一个元组
python 复制代码
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float)
print(torch.split(a,2,dim=0))	#按行拆,两行拆成一个
print(torch.split(a,1,dim=0))	#按行拆,一行拆成一个
print(torch.split(a,1,dim=1))	#按列拆,一列拆成一个
print(torch.split(a,2,dim=1)) 	#按列拆,两列拆成一个
  • 按块数拆分张量
python 复制代码
torch.chunk(a,2,dim=0)	#按行拆成两块
torch.split(a,2,dim=1)	#按列拆成两块
相关推荐
Dekesas96955 小时前
【深度学习】基于Faster R-CNN的黄瓜幼苗智能识别与定位系统,农业AI新突破
人工智能·深度学习·r语言
哥布林学者7 小时前
吴恩达深度学习课程四:计算机视觉 第二周:经典网络结构 (三)1×1卷积与Inception网络
深度学习·ai
鼾声鼾语7 小时前
matlab的ros2发布的消息,局域网内其他设备收不到情况吗?但是matlab可以订阅其他局域网的ros2发布的消息(问题总结)
开发语言·人工智能·深度学习·算法·matlab·isaaclab
【建模先锋】9 小时前
特征提取+概率神经网络 PNN 的轴承信号故障诊断模型
人工智能·深度学习·神经网络·信号处理·故障诊断·概率神经网络·特征提取
轲轲019 小时前
Week02 深度学习基本原理
人工智能·深度学习
smile_Iris9 小时前
Day 40 复习日
人工智能·深度学习·机器学习
深度学习实战训练营9 小时前
TransUNet:Transformer 成为医学图像分割的强大编码器,Transformer 编码器 + U-Net 解码器-k学长深度学习专栏
人工智能·深度学习·transformer
火山kim10 小时前
经典论文研读报告:DAGGER (Dataset Aggregation)
人工智能·深度学习·机器学习
Coding茶水间10 小时前
基于深度学习的水果检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·机器学习·计算机视觉
studytosky11 小时前
深度学习理论与实战:反向传播、参数初始化与优化算法全解析
人工智能·python·深度学习·算法·分类·matplotlib