pytorch小记(二):pytorch中的连接操作:torch.cat(tensors, dim=0)

pytorch小记(二):pytorch矩阵乘法:torch.cat(tensors, dim=0)

      • 语法
      • 使用规则
      • [示例 1:在第 0 维(行)拼接](#示例 1:在第 0 维(行)拼接)
      • [示例 2:在第 1 维(列)拼接](#示例 2:在第 1 维(列)拼接)
      • [示例 3:在高维张量上拼接](#示例 3:在高维张量上拼接)
        • 初始张量
        • [1. 在 `dim=0` 拼接](#1. 在 dim=0 拼接)
        • [2. 在 `dim=1` 拼接](#2. 在 dim=1 拼接)
        • [3. 在 `dim=2` 拼接](#3. 在 dim=2 拼接)
        • 总结
      • [示例 4:拼接不同形状的张量(错误示范)](#示例 4:拼接不同形状的张量(错误示范))
      • 总结

在 PyTorch 中,torch.cat() 是一种用于在指定维度上连接张量的操作。它能够将多个张量沿某个轴拼接成一个新的张量。


语法

python 复制代码
torch.cat(tensors, dim=0)
  • tensors :一个包含多个待拼接张量的列表或元组。这些张量在指定的 dim 维度以外的所有维度上必须具有相同的形状。
  • dim:指定在哪个维度上进行拼接操作。

使用规则

  1. 在指定维度上,张量的形状可以不同(因为会拼接)。
  2. 在其他维度上,张量的形状必须相同。

示例 1:在第 0 维(行)拼接

python 复制代码
x = torch.tensor([[1, 2], 
				  [3, 4]])  # 形状 (2, 2)
y = torch.tensor([[5, 6], 
				  [7, 8]])  # 形状 (2, 2)

result = torch.cat((x, y), dim=0)  # 在第 0 维拼接
print(result)

输出

复制代码
tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])
  • 原始张量 xy 在第 0 维上(行方向)拼接,因此新张量的形状为 (4, 2)

示例 2:在第 1 维(列)拼接

python 复制代码
x = torch.tensor([[1, 2], 
				  [3, 4]])  # 形状 (2, 2)
y = torch.tensor([[5, 6], 
				  [7, 8]])  # 形状 (2, 2)

result = torch.cat((x, y), dim=1)  # 在第 1 维拼接
print(result)

输出

复制代码
tensor([[1, 2, 5, 6],
        [3, 4, 7, 8]])
  • 原始张量 xy 在第 1 维上(列方向)拼接,因此新张量的形状为 (2, 4)

示例 3:在高维张量上拼接

我们来创建两个高维张量 xy,并分别在不同维度(dim=0, dim=1, dim=2)上使用 torch.cat 进行拼接,展示具体计算结果。


初始张量
python 复制代码
x = torch.tensor([
    [[1, 2, 3], [4, 5, 6]],
    [[7, 8, 9], [10, 11, 12]]
])  # 形状 (2, 2, 3)

y = torch.tensor([
    [[13, 14, 15], [16, 17, 18]],
    [[19, 20, 21], [22, 23, 24]]
])  # 形状 (2, 2, 3)
  • xy 是形状为 (2, 2, 3) 的 3D 张量:

    复制代码
    x:
    [[[ 1,  2,  3], 
      [ 4,  5,  6]],       
                             
     [[ 7,  8,  9],      
      [10, 11, 12]]]       
      
    y:
    [[[13, 14, 15],
      [16, 17, 18]],
    
     [[19, 20, 21],
      [22, 23, 24]]]

1. 在 dim=0 拼接
python 复制代码
result_dim0 = torch.cat((x, y), dim=0)
print(result_dim0.shape)  # torch.Size([4, 2, 3])
print(result_dim0)

拼接逻辑

  • 在第 0 维度(最外层)拼接,结果张量包含 4 个"块",每个"块"的形状仍然是 (2, 3)

结果

复制代码
result_dim0:
[[[  1,   2,   3],
  [  4,   5,   6]],

 [[  7,   8,   9],
  [ 10,  11,  12]],

 [[ 13,  14,  15],
  [ 16,  17,  18]],

 [[ 19,  20,  21],
  [ 22,  23,  24]]]

2. 在 dim=1 拼接
python 复制代码
result_dim1 = torch.cat((x, y), dim=1)
print(result_dim1.shape)  # torch.Size([2, 4, 3])
print(result_dim1)

拼接逻辑

  • 在第 1 维度(每个"块"中的行)拼接,结果张量包含 2 个"块",每个"块"增加了 2 行,形状从 (2, 3) 变为 (4, 3)

结果

复制代码
result_dim1:
[[[  1,   2,   3],
  [  4,   5,   6],
  [ 13,  14,  15],
  [ 16,  17,  18]],

 [[  7,   8,   9],
  [ 10,  11,  12],
  [ 19,  20,  21],
  [ 22,  23,  24]]]

3. 在 dim=2 拼接
python 复制代码
result_dim2 = torch.cat((x, y), dim=2)
print(result_dim2.shape)  # torch.Size([2, 2, 6])
print(result_dim2)

拼接逻辑

  • 在第 2 维度(每行中的列)拼接,结果张量包含 2 个"块",每个"块"有 2 行,但每行的列数增加了一倍,从 3 列变为 6 列。

结果

复制代码
result_dim2:
[[[  1,   2,   3,  13,  14,  15],
  [  4,   5,   6,  16,  17,  18]],

 [[  7,   8,   9,  19,  20,  21],
  [ 10,  11,  12,  22,  23,  24]]]

总结
dim 拼接维度 结果形状 拼接效果
dim=0 最外层 (4, 2, 3) 增加块的数量(纵向堆叠)
dim=1 每块的行数 (2, 4, 3) 增加每块的行数(横向堆叠行)
dim=2 每行的列数 (2, 2, 6) 增加每行的列数(横向堆叠列)

通过改变 dimtorch.cat 可以在不同维度上灵活地拼接张量。


示例 4:拼接不同形状的张量(错误示范)

如果张量在非拼接维度上的形状不同,会抛出错误:

python 复制代码
x = torch.tensor([[1, 2], [3, 4]])  # 形状 (2, 2)
y = torch.tensor([[5, 6, 7]])       # 形状 (1, 3)

result = torch.cat((x, y), dim=0)  # 抛出错误

错误信息

复制代码
RuntimeError: Sizes of tensors must match except in dimension 0. Got 2 and 3 in dimension 1

如果希望在行方向 dim=0 拼接,可以通过 补零裁剪 等方式使列数一致。

补零torch.nn.functional.pad:

python 复制代码
import torch
import torch.nn.functional as F

x = torch.tensor([[1, 2], [3, 4]])  # 形状 (2, 2)
y = torch.tensor([[5, 6, 7]])       # 形状 (1, 3)

# 对 x 补零到列数 3
x_padded = F.pad(x, (0, 1))  # 在列方向右侧补 1 列零
# x_padded 形状: (2, 3)

# 在 dim=0 拼接
result = torch.cat((x_padded, y), dim=0)
print(result)

结果

复制代码
tensor([[1, 2, 0],
        [3, 4, 0],
        [5, 6, 7]])

但是result = torch.cat((x_padded, y), dim=1)则还是错误的!!!


总结

  • torch.cat() 用于连接张量,指定的 dim 决定了在哪个维度上进行拼接。
  • 拼接维度的大小是累加的,其他维度的大小必须一致。
  • 如果不满足上述规则,会抛出错误。

通过这种操作,你可以灵活地调整和组织张量的数据结构。

相关推荐
智启七月13 分钟前
谷歌 Gemini 3.0 正式发布:一键生成 Web OS,编程能力碾压竞品
人工智能·python
Juchecar14 分钟前
物质导光导电的微观原理与半导体
人工智能
2401_8414956414 分钟前
【强化学习】动态规划算法
人工智能·python·算法·动态规划·强化学习·策略迭代·价值迭代
WWZZ202515 分钟前
快速上手大模型:机器学习5(逻辑回归及其代价函数)
人工智能·算法·机器学习·计算机视觉·机器人·slam·具身感知
FreeCode19 分钟前
深度解析Agent Skills:为智能体构建专业特长
人工智能·agent
测试199823 分钟前
自动化测试报告生成(Allure)
自动化测试·软件测试·python·selenium·测试工具·职场和发展·测试用例
_一两风34 分钟前
用 n8n 自动化生成「每日科技热点速递」:从 RSS 到 AI 写作全流程实战(小白必看)
人工智能·rss·deepseek
极昆仑智慧41 分钟前
OpenAI推出了支持人工智能的浏览器ChatGPT Atlas
人工智能·chatgpt
hunteritself43 分钟前
阿里千问上线记忆,Manus 1.5 全栈升级,ChatGPT 将推成人模式!| AI Weekly 10.13-10.19
大数据·人工智能·深度学习·机器学习·chatgpt
姓刘的哦1 小时前
基于线程池的配电房图像检测
人工智能·计算机视觉·目标跟踪