深入理解 PyTorch 中的torch.stack函数:中英双语

中文版

深入理解 PyTorch 中的 torch.stack 函数

在使用 PyTorch 进行深度学习开发时,经常需要对张量进行操作和组合。torch.stack 是一个非常常用且重要的函数,它可以将一组张量沿着新的维度拼接成一个新的张量。本文将深入介绍 torch.stack 的用法,包括其功能、参数、注意事项和实际案例。


什么是 torch.stack

torch.stack 的主要作用是沿着新的维度将多个张量堆叠 。与 torch.cat 不同,torch.cat 是在已有维度上进行拼接,而 torch.stack创建一个新的维度,将一组张量按照指定的位置堆叠起来。

官方文档定义如下:

python 复制代码
torch.stack(tensors, dim=0) → Tensor
参数解释:
  • tensors:要堆叠的张量序列(可以是列表或元组),所有张量的形状必须相同。
  • dim:新维度的索引(位置),默认值为 0
返回值:

返回一个新的张量,包含输入张量序列,堆叠后的维度会比原张量多 1。


使用示例

基本用法
python 复制代码
import torch

# 创建三个相同形状的张量
t1 = torch.tensor([1, 2, 3])
t2 = torch.tensor([4, 5, 6])
t3 = torch.tensor([7, 8, 9])

# 沿着新的维度堆叠
result = torch.stack([t1, t2, t3], dim=0)
print(result)

输出:

c 复制代码
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

上述代码中,dim=0 表示在最外层新增一个维度(行堆叠)。结果张量的形状是 [3, 3],表示 3 行 3 列。

改变维度位置

如果我们将 dim 设置为 1:

python 复制代码
result = torch.stack([t1, t2, t3], dim=1)
print(result)

输出:

c 复制代码
tensor([[1, 4, 7],
        [2, 5, 8],
        [3, 6, 9]])

此时,dim=1 表示在列的维度堆叠,结果形状是 [3, 3],但数据的排列方式发生了变化。


torch.cat 的对比

torch.cattorch.stack 都可以用于张量的组合,但它们的功能和结果有显著区别。

区别 1:是否创建新维度
  • torch.cat:仅在现有维度上拼接,不会创建新的维度。
  • torch.stack:会创建一个新的维度。

例如:

python 复制代码
# 使用 torch.cat
result_cat = torch.cat([t1.unsqueeze(0), t2.unsqueeze(0), t3.unsqueeze(0)], dim=0)
print(result_cat)

# 使用 torch.stack
result_stack = torch.stack([t1, t2, t3], dim=0)
print(result_stack)

输出:

两者的结果相同,但 torch.cat 需要手动添加维度(通过 unsqueeze),而 torch.stack 会自动处理这一点。

区别 2:维度要求
  • torch.cat 的输入张量在拼接维度以外的维度上必须完全一致。
  • torch.stack 的输入张量要求形状完全一致

进阶案例

1. 批量生成新的张量

假设我们有多个二维张量,想要将它们堆叠成一个三维张量:

python 复制代码
t1 = torch.tensor([[1, 2], [3, 4]])
t2 = torch.tensor([[5, 6], [7, 8]])
t3 = torch.tensor([[9, 10], [11, 12]])

result = torch.stack([t1, t2, t3], dim=0)
print(result)

输出:

c 复制代码
tensor([[[ 1,  2],
         [ 3,  4]],

        [[ 5,  6],
         [ 7,  8]],

        [[ 9, 10],
         [11, 12]]])

此时,dim=0 表示在最外层添加一个维度,结果形状为 [3, 2, 2]

2. 张量拆分后再堆叠
python 复制代码
# 创建一个三维张量
x = torch.tensor([[[1, 2], [3, 4]],
                  [[5, 6], [7, 8]],
                  [[9, 10], [11, 12]]])

# 沿着 dim=0 拆分
split_tensors = torch.unbind(x, dim=0)

# 再次堆叠
result = torch.stack(split_tensors, dim=1)
print(result)

输出:

c 复制代码
tensor([[[ 1,  5,  9],
         [ 3,  7, 11]],

        [[ 2,  6, 10],
         [ 4,  8, 12]]])

通过 torch.unbind 将张量拆分后,我们可以用 torch.stack 重新组织张量的结构。


注意事项

  1. 输入张量的形状必须完全一致

    如果输入张量的形状不同,将会报错:

    python 复制代码
    t1 = torch.tensor([1, 2])
    t2 = torch.tensor([3, 4, 5])
    torch.stack([t1, t2])  # 会报错
  2. dim 的取值范围

    • dim 的取值范围是 [-(d+1), d],其中 d 是输入张量的维度。
    • 如果设置超出范围的值,将会报错。
  3. 效率问题

    • torch.stack 本质上是向输入张量添加一个维度,然后调用 torch.cat 实现堆叠。因此,在大规模操作中,合理利用 torch.cat 和维度操作可能更高效。

总结

torch.stack 是一个强大而灵活的函数,在处理张量时提供了简洁的接口。通过为张量增加新的维度,它使得许多复杂的张量操作变得更加直观。无论是构建批量数据、改变张量维度,还是处理高级的张量变换,torch.stack 都是一个不可或缺的工具。

在实际使用中,了解 torch.stack 的参数含义以及它与 torch.cat 的区别,可以帮助我们写出更加高效和简洁的代码。希望通过这篇博客,大家能够全面掌握 torch.stack 的使用!

英文版

Understanding the torch.stack Function in PyTorch

In PyTorch, tensor manipulation is a crucial skill for deep learning practitioners. One of the commonly used functions for combining tensors is torch.stack. This blog will provide a detailed explanation of what torch.stack does, how to use it, and how it differs from similar functions like torch.cat.


What is torch.stack?

torch.stack is used to combine multiple tensors along a new dimension , creating a higher-dimensional tensor. Unlike torch.cat, which concatenates tensors along an existing dimension, torch.stack introduces a new dimension to hold the tensors.

Function Definition
python 复制代码
torch.stack(tensors, dim=0) → Tensor
Parameters:
  • tensors: A sequence (list or tuple) of tensors to stack. All tensors must have the same shape.
  • dim : The dimension along which the tensors will be stacked. Default is 0.
Returns:

A new tensor with an additional dimension. The shape of the resulting tensor depends on the stacking dimension.


Basic Usage

Example 1: Stacking Along the First Dimension
python 复制代码
import torch

# Create three tensors of the same shape
t1 = torch.tensor([1, 2, 3])
t2 = torch.tensor([4, 5, 6])
t3 = torch.tensor([7, 8, 9])

# Stack tensors along the first dimension (dim=0)
result = torch.stack([t1, t2, t3], dim=0)
print(result)

Output:

c 复制代码
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

In this example, dim=0 adds a new dimension at the outermost level (rows). The resulting tensor has the shape [3, 3], with 3 rows and 3 columns.

Example 2: Stacking Along a Different Dimension

By changing the dim parameter, you can stack tensors along other dimensions:

python 复制代码
result = torch.stack([t1, t2, t3], dim=1)
print(result)

Output:

c 复制代码
tensor([[1, 4, 7],
        [2, 5, 8],
        [3, 6, 9]])

Here, dim=1 stacks the tensors along the second dimension (columns). The result has the same shape [3, 3], but the data arrangement is different.


Key Differences Between torch.stack and torch.cat

While both functions are used to combine tensors, there are notable differences:

  1. Adding a New Dimension:

    • torch.stack adds a new dimension to hold the tensors.
    • torch.cat does not create a new dimension; it concatenates tensors along an existing dimension.
  2. Shape Requirements:

    • torch.stack requires all tensors to have the exact same shape.
    • torch.cat requires tensors to have matching shapes in all dimensions except the concatenation dimension.
Example:
python 复制代码
# Using torch.stack
stack_result = torch.stack([t1, t2, t3], dim=0)

# Using torch.cat with unsqueeze to match dimensions
cat_result = torch.cat([t1.unsqueeze(0), t2.unsqueeze(0), t3.unsqueeze(0)], dim=0)

print(stack_result)
print(cat_result)

Both produce the same result, but with torch.cat, you need to explicitly add a dimension using unsqueeze.


Advanced Examples

1. Stacking Higher-Dimensional Tensors

Consider stacking 2D tensors into a 3D tensor:

python 复制代码
t1 = torch.tensor([[1, 2], [3, 4]])
t2 = torch.tensor([[5, 6], [7, 8]])
t3 = torch.tensor([[9, 10], [11, 12]])

result = torch.stack([t1, t2, t3], dim=0)
print(result)

Output:

c 复制代码
tensor([[[ 1,  2],
         [ 3,  4]],

        [[ 5,  6],
         [ 7,  8]],

        [[ 9, 10],
         [11, 12]]])

Here, dim=0 creates a new outermost dimension, resulting in a 3D tensor of shape [3, 2, 2].

2. Rearranging Dimensions

You can manipulate dimensions by combining torch.unbind and torch.stack. For example:

python 复制代码
# Original 3D tensor
x = torch.tensor([[[1, 2], [3, 4]],
                  [[5, 6], [7, 8]],
                  [[9, 10], [11, 12]]])

# Split along the first dimension
split_tensors = torch.unbind(x, dim=0)

# Re-stack along a different dimension
result = torch.stack(split_tensors, dim=1)
print(result)

Output:

c 复制代码
tensor([[[ 1,  5,  9],
         [ 3,  7, 11]],

        [[ 2,  6, 10],
         [ 4,  8, 12]]])

This demonstrates how you can use torch.stack to rearrange data across dimensions.


Best Practices and Tips

  1. Ensure Consistent Shapes :

    All tensors must have the same shape. If their shapes differ, PyTorch will raise an error.

    python 复制代码
    t1 = torch.tensor([1, 2])
    t2 = torch.tensor([3, 4, 5])
    torch.stack([t1, t2])  # This will raise a runtime error
  2. Dimension Bounds :

    The dim parameter can take values between -(d+1) and d, where d is the number of dimensions of the input tensors.

  3. Performance Considerations :

    Internally, torch.stack adds a dimension and then applies torch.cat. If you need full control over the process for performance optimization, consider using torch.cat with manual dimension manipulation.


Conclusion

torch.stack is a powerful and versatile function for combining tensors in PyTorch. It simplifies the process of adding new dimensions, making it especially useful for creating higher-dimensional tensors, organizing batch data, and performing tensor transformations. By understanding its parameters, usage, and differences from similar functions, you can effectively incorporate it into your PyTorch workflows.

When dealing with tensor operations, a clear understanding of torch.stack will save you time and effort, allowing you to focus more on building and fine-tuning your models!

后记

2024年12月12日22点03分于上海,在GPT4o大模型辅助下完成。

相关推荐
天天代码码天天3 分钟前
C# OpenCvSharp 部署表格检测
人工智能·目标检测·表格检测
姓学名生4 分钟前
李沐vscode配置+github管理+FFmpeg视频搬运+百度API添加翻译字幕
vscode·python·深度学习·ffmpeg·github·视频
斯多葛的信徒8 分钟前
看看你的电脑可以跑 AI 模型吗?
人工智能·语言模型·电脑·llama
正在走向自律8 分钟前
AI 写作(六):核心技术与多元应用(6/10)
人工智能·aigc·ai写作
AI科技大本营8 分钟前
Anthropic四大专家“会诊”:实现深度思考不一定需要多智能体,AI完美对齐比失控更可怕!...
人工智能·深度学习
Cc不爱吃洋葱8 分钟前
如何本地部署AI智能体平台,带你手搓一个AI Agent
人工智能·大语言模型·agent·ai大模型·ai agent·智能体·ai智能体
网安打工仔9 分钟前
斯坦福李飞飞最新巨著《AI Agent综述》
人工智能·自然语言处理·大模型·llm·agent·ai大模型·大模型入门
AGI学习社9 分钟前
2024中国排名前十AI大模型进展、应用案例与发展趋势
linux·服务器·人工智能·华为·llama
AI_Tool9 分钟前
纳米AI搜索官网 - 新一代智能答案引擎
人工智能·搜索引擎
Damon小智10 分钟前
合合信息DocFlow产品解析与体验:人人可搭建的AI自动化单据处理工作流
图像处理·人工智能·深度学习·机器学习·ai·自动化·docflow