torch.cat和torch.stack的区别

torch.cattorch.stack 是 PyTorch 中用于组合张量的两个常用函数,它们的核心区别在于输入张量的维度和输出张量的维度变化。以下是详细对比:

1. torch.cat (Concatenate)

  • 作用 :沿现有维度 拼接多个张量,不创建新维度

  • 输入要求 :所有张量的形状必须除拼接维度外完全相同

  • 语法

    python 复制代码
    torch.cat(tensors, dim=0)  # dim 指定拼接的维度
  • 示例

    python 复制代码
    a = torch.tensor([[1, 2], [3, 4]])  # shape (2, 2)
    b = torch.tensor([[5, 6]])           # shape (1, 2)
    
    # 沿 dim=0 拼接(行方向)
    c = torch.cat([a, b], dim=0)
    print(c)
    # tensor([[1, 2],
    #         [3, 4],
    #         [5, 6]])  # shape (3, 2)
  • 特点

    • 拼接后的张量在指定维度上的大小是输入张量该维度大小的总和。

    • 其他维度必须完全一致。

2. torch.stack

  • 作用 :沿新维度 堆叠多个张量,创建新维度

  • 输入要求 :所有张量的形状必须完全相同

  • 语法

    python 复制代码
    torch.stack(tensors, dim=0)  # dim 指定新维度的位置
  • 示例

    python 复制代码
    a = torch.tensor([1, 2])  # shape (2,)
    b = torch.tensor([3, 4])  # shape (2,)
    
    # 沿新维度 dim=0 堆叠
    c = torch.stack([a, b], dim=0)
    print(c)
    # tensor([[1, 2],
    #         [3, 4]])  # shape (2, 2)
    
    # 沿新维度 dim=1 堆叠
    d = torch.stack([a, b], dim=1)
    print(d)
    # tensor([[1, 3],
    #         [2, 4]])  # shape (2, 2)
  • 特点

    • 输出张量比输入张量多一个维度

    • 适用于将多个相同形状的张量合并为批次(如 batch_size 维度)。

3. 关键区别总结

4. 直观对比示例

假设有两个张量:

python 复制代码
x = torch.tensor([1, 2])  # shape (2,)
y = torch.tensor([3, 4])  # shape (2,)

torch.cat 结果

python 复制代码
torch.cat([x, y], dim=0)  # tensor([1, 2, 3, 4]), shape (4,)

torch.stack 结果

python 复制代码
torch.stack([x, y], dim=0)  # tensor([[1, 2], [3, 4]]), shape (2, 2)

5. 如何选择?

  • torch.cat 当需要扩展现有维度(如拼接多个特征图)。

  • torch.stack 当需要创建新维度(如构建批次数据或堆叠不同模型的输出)

通过理解两者的维度变化逻辑,可以避免常见的形状错误(如 size mismatch)。

相关推荐
csdn_aspnet6 小时前
如何用 C# 和 Gemma 3 在本地构建一个真正能完成工作的 AI 代理的
人工智能·ai·c#·gemma
啊哈哈哈哈哈啊哈哈6 小时前
边缘计算与轮廓检测
人工智能·opencv·计算机视觉
cskywit6 小时前
从DFL到无NMS推理:一文拆解YOLO26背后的工程取舍与数学原理
人工智能·机器学习
PPHT-H6 小时前
【人工智能笔记】第四十四节:OpenClaw封神工具 openclaw-free-openai-proxy 免费AI模型批量调用,零token费+稳到不翻车!
人工智能·深度学习·openclaw·免费openai·ai服务代理
yiyu07167 小时前
3分钟搞懂深度学习AI:实操篇:RNN
人工智能·深度学习
uzong7 小时前
CoPaw是什么?-- 2026年开源的国产个人AI助手
人工智能·后端
海盗儿7 小时前
TensorRT-LLM 框架与源码分析
人工智能
无心水7 小时前
【任务调度:框架】11、分布式任务调度进阶:高可用、幂等性、性能优化三板斧
人工智能·分布式·后端·性能优化·架构·2025博客之星·分布式调度框架
码森林7 小时前
小龙虾居然比你更健忘?OpenClaw 记忆系统指南,让它永远记住你
人工智能·ai编程·全栈
ghie90908 小时前
维纳滤波器语音增强MATLAB实现
人工智能·matlab·语音识别