torch.stack 张量维度的变化

torch.stack 是 PyTorch 中用于将一系列张量沿一个新的维度堆叠的函数。与 torch.cat 不同的是,torch.stack会在指定的维度上增加一个新的维度,而不是将张量直接拼接。

基本用法

语法:

复制代码
torch.stack(tensors, dim=0)
  • tensors: 一个张量列表,包含多个形状相同的张量(shape 必须相同)。
  • dim: 新增维度的位置,默认是 0

举例说明

假设有三个形状为 (2, 3) 的张量:

复制代码
import torch

a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[7, 8, 9], [10, 11, 12]])
c = torch.tensor([[13, 14, 15], [16, 17, 18]])

沿 dim=0 堆叠

复制代码
stacked = torch.stack([a, b, c], dim=0)
print(stacked.shape)  # torch.Size([3, 2, 3])
  • 在维度 0 上增加一个新的维度,原始的 (2, 3) 形状变成 (3, 2, 3)
  • stacked 的第 0 维度有 3 个元素,对应原来的 a, b, c 张量。

沿 dim=1 堆叠

复制代码
stacked = torch.stack([a, b, c], dim=1)
print(stacked.shape)  # torch.Size([2, 3, 3])
  • 新的维度插入到原第 1 维的位置。
  • stacked 的第 1 维度有 3 个元素,对应原来的 a, b, c 张量。

沿 dim=2 堆叠

复制代码
stacked = torch.stack([a, b, c], dim=2)
print(stacked.shape)  # torch.Size([2, 3, 3])
  • 新的维度插入到原第 2 维的位置,形状变为 (2, 3, 3)

torch.stack 的形状变化总结

假设堆叠前的每个张量形状是 (A, B, C),在 dim=0dim=1dim=2 堆叠后的形状分别为:

  • dim=0: (N, A, B, C)
  • dim=1: (A, N, B, C)
  • dim=2: (A, B, N, C)

其中 N 是堆叠的张量数量。

和torch.cat函数的区别:

cat:在指定维度拼接多个张量。不增加维度。

复制代码
c1 = torch.tensor([[1, 2], [3, 4]])
c2 = torch.tensor([[5, 6], [7, 8]])
c_cat = torch.cat([c1, c2], dim=0)  # shape (4, 2)
相关推荐
AKAMAI26 分钟前
运维逆袭志·第1期 | 数据黑洞吞噬一切 :自建系统的美丽陷阱
运维·人工智能·云计算
飞哥数智坊1 小时前
AI编程实战:AI要独立开发了?TRAE SOLO 后端生成能力深度实测
人工智能·trae
SamtecChina20231 小时前
应用科普 | 漫谈6G通信的未来
大数据·网络·人工智能·科技
Java与Android技术栈2 小时前
LLM + 图像处理的第一步:用自然语言驱动调色逻辑
图像处理·人工智能
F_D_Z2 小时前
计算机视觉的四项基本任务辨析
人工智能·计算机视觉
LetsonH2 小时前
⭐CVPR2025 MatAnyone:稳定且精细的视频抠图新框架
人工智能·python·深度学习·计算机视觉·音视频
格林威2 小时前
Baumer相机如何通过YoloV8深度学习模型实现工厂自动化产线牛奶瓶盖实时装配的检测识别(C#代码UI界面版)
人工智能·深度学习·数码相机·yolo·机器学习·计算机视觉·c#
Xyz_Overlord2 小时前
NLP——BERT模型全面解析:从基础架构到优化演进
人工智能·自然语言处理·bert·transformer·迁移学习
星期天要睡觉2 小时前
机器学习——K 折交叉验证(K-Fold Cross Validation),案例:逻辑回归 交叉寻找最佳惩罚因子C
人工智能·机器学习
Sunhen_Qiletian2 小时前
机器学习实战:逻辑回归核心技术全面解析与银行风控深度应用(一)
人工智能·机器学习·逻辑回归