深度学习02-pytorch-02-张量的拼接操作

在 PyTorch 中,张量的拼接操作可以通过以下几种主要方法来实现,最常用的包括 torch.cat(), torch.stack(), 以及 torch.chunk()。这些操作可以将多个张量沿某个维度拼接在一起或拆分张量。下面将详细介绍如何使用这些操作。

1. torch.cat()

torch.cat() 是最常用的拼接函数,它沿着指定维度将张量拼接在一起。需要确保拼接时除拼接维度外的其他维度大小相同。

python 复制代码
import torch
​
# 定义两个形状为 (2, 3) 的张量
tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)
​
# 沿着维度 0 进行拼接 (行方向)
concat_tensor = torch.cat((tensor1, tensor2), dim=0)
print(concat_tensor.size())  # 输出: torch.Size([4, 3])
​
# 沿着维度 1 进行拼接 (列方向)
concat_tensor = torch.cat((tensor1, tensor2), dim=1)
print(concat_tensor.size())  # 输出: torch.Size([2, 6])

注意:在使用 torch.cat() 时,拼接的张量在除拼接维度外的其他维度必须相同。

2. torch.stack()

torch.stack() 会在新的维度上将多个张量堆叠在一起,返回的张量维度会比输入张量多一维。

python 复制代码
# 定义两个形状为 (2, 3) 的张量
tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)
​
# 在维度 0 进行堆叠
stacked_tensor = torch.stack((tensor1, tensor2), dim=0)
print(stacked_tensor.size())  # 输出: torch.Size([2, 2, 3])
​
# 在维度 1 进行堆叠
stacked_tensor = torch.stack((tensor1, tensor2), dim=1)
print(stacked_tensor.size())  # 输出: torch.Size([2, 2, 3])

cat() 不同,stack() 会增加一个新的维度。

3. torch.chunk()

torch.chunk() 将张量按指定的数量切分成多个张量。

python 复制代码
tensor = torch.randn(4, 6)  # 形状为 (4, 6)
​
# 将张量沿着第 1 维切分成 2 份
chunks = torch.chunk(tensor, 2, dim=1)
for chunk in chunks:
   print(chunk.size())  # 输出两个 (4, 3) 的张量

4. torch.split()

torch.split() 按照指定的大小切分张量。

python 复制代码
tensor = torch.randn(4, 6)
​
# 将张量沿着第 1 维,按照每块大小为 2 切分
splits = torch.split(tensor, 2, dim=1)
for split in splits:
   print(split.size())  # 输出三个 (4, 2) 的张量

5. torch.hstack()torch.vstack()

这些函数分别是 torch.cat() 在水平方向(列方向)和垂直方向(行方向)的简便形式。

python 复制代码
# 水平堆叠
hstacked_tensor = torch.hstack((tensor1, tensor2))
print(hstacked_tensor.size())  # 输出: torch.Size([2, 6])
​
# 垂直堆叠
vstacked_tensor = torch.vstack((tensor1, tensor2))
print(vstacked_tensor.size())  # 输出: torch.Size([4, 3])

总结:

  • torch.cat():沿某一维度拼接张量。

  • torch.stack():在新维度上堆叠张量。

  • torch.chunk():按指定份数拆分张量。

  • torch.split():按指定大小拆分张量。

  • torch.hstack() / torch.vstack():简便的横向和纵向拼接方法。

根据不同的场景选择合适的拼接或拆分方式。

torch.cat(), torch.stack(), torch.hstack(), torch.vstack() 等函数都可以用于拼接多个张量。它们的功能稍有不同,主要是在拼接的方式和增加新维度上有所差别:

  1. torch.cat(): 沿着现有维度拼接多个张量,不会增加新的维度。

    • 适合需要在某个维度上连接多个张量的情况。
  2. torch.stack(): 在指定的维度上增加一个新维度,然后堆叠张量。

    • 适合需要在新维度上堆叠多个张量的情况。
  3. torch.hstack() : 沿水平方向(列方向)拼接多个张量,本质上是 torch.cat()dim=1 的简写。

    • 适合将张量沿列方向拼接。
  4. torch.vstack() : 沿垂直方向(行方向)拼接多个张量,本质上是 torch.cat()dim=0 的简写。

    • 适合将张量沿行方向拼接。

这些函数都可以用于拼接多个张量,不过需要注意的是,拼接时要确保除拼接维度以外,其他维度大小相同。例如,使用 torch.cat() 沿维度 0 拼接时,所有张量的列数需要相同。

相关推荐
井底哇哇22 分钟前
ChatGPT是强人工智能吗?
人工智能·chatgpt
Coovally AI模型快速验证26 分钟前
MMYOLO:打破单一模式限制,多模态目标检测的革命性突破!
人工智能·算法·yolo·目标检测·机器学习·计算机视觉·目标跟踪
AI浩1 小时前
【面试总结】FFN(前馈神经网络)在Transformer模型中先升维再降维的原因
人工智能·深度学习·计算机视觉·transformer
可为测控1 小时前
图像处理基础(4):高斯滤波器详解
人工智能·算法·计算机视觉
ℳ₯㎕ddzོꦿ࿐2 小时前
解决Python 在 Flask 开发模式下定时任务启动两次的问题
开发语言·python·flask
CodeClimb2 小时前
【华为OD-E卷 - 第k个排列 100分(python、java、c++、js、c)】
java·javascript·c++·python·华为od
一水鉴天2 小时前
为AI聊天工具添加一个知识系统 之63 详细设计 之4:AI操作系统 之2 智能合约
开发语言·人工智能·python
Channing Lewis2 小时前
什么是 Flask 的蓝图(Blueprint)
后端·python·flask
倔强的石头1062 小时前
解锁辅助驾驶新境界:基于昇腾 AI 异构计算架构 CANN 的应用探秘
人工智能·架构
B站计算机毕业设计超人2 小时前
计算机毕业设计hadoop+spark股票基金推荐系统 股票基金预测系统 股票基金可视化系统 股票基金数据分析 股票基金大数据 股票基金爬虫
大数据·hadoop·python·spark·课程设计·数据可视化·推荐算法