中文版
深入理解 PyTorch 中的 torch.stack
函数
在使用 PyTorch 进行深度学习开发时,经常需要对张量进行操作和组合。torch.stack
是一个非常常用且重要的函数,它可以将一组张量沿着新的维度拼接成一个新的张量。本文将深入介绍 torch.stack
的用法,包括其功能、参数、注意事项和实际案例。
什么是 torch.stack
?
torch.stack
的主要作用是沿着新的维度将多个张量堆叠 。与 torch.cat
不同,torch.cat
是在已有维度上进行拼接,而 torch.stack
是创建一个新的维度,将一组张量按照指定的位置堆叠起来。
官方文档定义如下:
python
torch.stack(tensors, dim=0) → Tensor
参数解释:
tensors
:要堆叠的张量序列(可以是列表或元组),所有张量的形状必须相同。dim
:新维度的索引(位置),默认值为0
。
返回值:
返回一个新的张量,包含输入张量序列,堆叠后的维度会比原张量多 1。
使用示例
基本用法
python
import torch
# 创建三个相同形状的张量
t1 = torch.tensor([1, 2, 3])
t2 = torch.tensor([4, 5, 6])
t3 = torch.tensor([7, 8, 9])
# 沿着新的维度堆叠
result = torch.stack([t1, t2, t3], dim=0)
print(result)
输出:
c
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
上述代码中,dim=0
表示在最外层新增一个维度(行堆叠)。结果张量的形状是 [3, 3]
,表示 3 行 3 列。
改变维度位置
如果我们将 dim
设置为 1:
python
result = torch.stack([t1, t2, t3], dim=1)
print(result)
输出:
c
tensor([[1, 4, 7],
[2, 5, 8],
[3, 6, 9]])
此时,dim=1
表示在列的维度堆叠,结果形状是 [3, 3]
,但数据的排列方式发生了变化。
与 torch.cat
的对比
torch.cat
和 torch.stack
都可以用于张量的组合,但它们的功能和结果有显著区别。
区别 1:是否创建新维度
torch.cat
:仅在现有维度上拼接,不会创建新的维度。torch.stack
:会创建一个新的维度。
例如:
python
# 使用 torch.cat
result_cat = torch.cat([t1.unsqueeze(0), t2.unsqueeze(0), t3.unsqueeze(0)], dim=0)
print(result_cat)
# 使用 torch.stack
result_stack = torch.stack([t1, t2, t3], dim=0)
print(result_stack)
输出:
两者的结果相同,但 torch.cat
需要手动添加维度(通过 unsqueeze
),而 torch.stack
会自动处理这一点。
区别 2:维度要求
torch.cat
的输入张量在拼接维度以外的维度上必须完全一致。torch.stack
的输入张量要求形状完全一致。
进阶案例
1. 批量生成新的张量
假设我们有多个二维张量,想要将它们堆叠成一个三维张量:
python
t1 = torch.tensor([[1, 2], [3, 4]])
t2 = torch.tensor([[5, 6], [7, 8]])
t3 = torch.tensor([[9, 10], [11, 12]])
result = torch.stack([t1, t2, t3], dim=0)
print(result)
输出:
c
tensor([[[ 1, 2],
[ 3, 4]],
[[ 5, 6],
[ 7, 8]],
[[ 9, 10],
[11, 12]]])
此时,dim=0
表示在最外层添加一个维度,结果形状为 [3, 2, 2]
。
2. 张量拆分后再堆叠
python
# 创建一个三维张量
x = torch.tensor([[[1, 2], [3, 4]],
[[5, 6], [7, 8]],
[[9, 10], [11, 12]]])
# 沿着 dim=0 拆分
split_tensors = torch.unbind(x, dim=0)
# 再次堆叠
result = torch.stack(split_tensors, dim=1)
print(result)
输出:
c
tensor([[[ 1, 5, 9],
[ 3, 7, 11]],
[[ 2, 6, 10],
[ 4, 8, 12]]])
通过 torch.unbind
将张量拆分后,我们可以用 torch.stack
重新组织张量的结构。
注意事项
-
输入张量的形状必须完全一致 :
如果输入张量的形状不同,将会报错:
pythont1 = torch.tensor([1, 2]) t2 = torch.tensor([3, 4, 5]) torch.stack([t1, t2]) # 会报错
-
dim
的取值范围:dim
的取值范围是[-(d+1), d]
,其中d
是输入张量的维度。- 如果设置超出范围的值,将会报错。
-
效率问题:
torch.stack
本质上是向输入张量添加一个维度,然后调用torch.cat
实现堆叠。因此,在大规模操作中,合理利用torch.cat
和维度操作可能更高效。
总结
torch.stack
是一个强大而灵活的函数,在处理张量时提供了简洁的接口。通过为张量增加新的维度,它使得许多复杂的张量操作变得更加直观。无论是构建批量数据、改变张量维度,还是处理高级的张量变换,torch.stack
都是一个不可或缺的工具。
在实际使用中,了解 torch.stack
的参数含义以及它与 torch.cat
的区别,可以帮助我们写出更加高效和简洁的代码。希望通过这篇博客,大家能够全面掌握 torch.stack
的使用!
英文版
Understanding the torch.stack
Function in PyTorch
In PyTorch, tensor manipulation is a crucial skill for deep learning practitioners. One of the commonly used functions for combining tensors is torch.stack
. This blog will provide a detailed explanation of what torch.stack
does, how to use it, and how it differs from similar functions like torch.cat
.
What is torch.stack
?
torch.stack
is used to combine multiple tensors along a new dimension , creating a higher-dimensional tensor. Unlike torch.cat
, which concatenates tensors along an existing dimension, torch.stack
introduces a new dimension to hold the tensors.
Function Definition
python
torch.stack(tensors, dim=0) → Tensor
Parameters:
tensors
: A sequence (list or tuple) of tensors to stack. All tensors must have the same shape.dim
: The dimension along which the tensors will be stacked. Default is0
.
Returns:
A new tensor with an additional dimension. The shape of the resulting tensor depends on the stacking dimension.
Basic Usage
Example 1: Stacking Along the First Dimension
python
import torch
# Create three tensors of the same shape
t1 = torch.tensor([1, 2, 3])
t2 = torch.tensor([4, 5, 6])
t3 = torch.tensor([7, 8, 9])
# Stack tensors along the first dimension (dim=0)
result = torch.stack([t1, t2, t3], dim=0)
print(result)
Output:
c
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
In this example, dim=0
adds a new dimension at the outermost level (rows). The resulting tensor has the shape [3, 3]
, with 3 rows and 3 columns.
Example 2: Stacking Along a Different Dimension
By changing the dim
parameter, you can stack tensors along other dimensions:
python
result = torch.stack([t1, t2, t3], dim=1)
print(result)
Output:
c
tensor([[1, 4, 7],
[2, 5, 8],
[3, 6, 9]])
Here, dim=1
stacks the tensors along the second dimension (columns). The result has the same shape [3, 3]
, but the data arrangement is different.
Key Differences Between torch.stack
and torch.cat
While both functions are used to combine tensors, there are notable differences:
-
Adding a New Dimension:
torch.stack
adds a new dimension to hold the tensors.torch.cat
does not create a new dimension; it concatenates tensors along an existing dimension.
-
Shape Requirements:
torch.stack
requires all tensors to have the exact same shape.torch.cat
requires tensors to have matching shapes in all dimensions except the concatenation dimension.
Example:
python
# Using torch.stack
stack_result = torch.stack([t1, t2, t3], dim=0)
# Using torch.cat with unsqueeze to match dimensions
cat_result = torch.cat([t1.unsqueeze(0), t2.unsqueeze(0), t3.unsqueeze(0)], dim=0)
print(stack_result)
print(cat_result)
Both produce the same result, but with torch.cat
, you need to explicitly add a dimension using unsqueeze
.
Advanced Examples
1. Stacking Higher-Dimensional Tensors
Consider stacking 2D tensors into a 3D tensor:
python
t1 = torch.tensor([[1, 2], [3, 4]])
t2 = torch.tensor([[5, 6], [7, 8]])
t3 = torch.tensor([[9, 10], [11, 12]])
result = torch.stack([t1, t2, t3], dim=0)
print(result)
Output:
c
tensor([[[ 1, 2],
[ 3, 4]],
[[ 5, 6],
[ 7, 8]],
[[ 9, 10],
[11, 12]]])
Here, dim=0
creates a new outermost dimension, resulting in a 3D tensor of shape [3, 2, 2]
.
2. Rearranging Dimensions
You can manipulate dimensions by combining torch.unbind
and torch.stack
. For example:
python
# Original 3D tensor
x = torch.tensor([[[1, 2], [3, 4]],
[[5, 6], [7, 8]],
[[9, 10], [11, 12]]])
# Split along the first dimension
split_tensors = torch.unbind(x, dim=0)
# Re-stack along a different dimension
result = torch.stack(split_tensors, dim=1)
print(result)
Output:
c
tensor([[[ 1, 5, 9],
[ 3, 7, 11]],
[[ 2, 6, 10],
[ 4, 8, 12]]])
This demonstrates how you can use torch.stack
to rearrange data across dimensions.
Best Practices and Tips
-
Ensure Consistent Shapes :
All tensors must have the same shape. If their shapes differ, PyTorch will raise an error.
pythont1 = torch.tensor([1, 2]) t2 = torch.tensor([3, 4, 5]) torch.stack([t1, t2]) # This will raise a runtime error
-
Dimension Bounds :
The
dim
parameter can take values between-(d+1)
andd
, whered
is the number of dimensions of the input tensors. -
Performance Considerations :
Internally,
torch.stack
adds a dimension and then appliestorch.cat
. If you need full control over the process for performance optimization, consider usingtorch.cat
with manual dimension manipulation.
Conclusion
torch.stack
is a powerful and versatile function for combining tensors in PyTorch. It simplifies the process of adding new dimensions, making it especially useful for creating higher-dimensional tensors, organizing batch data, and performing tensor transformations. By understanding its parameters, usage, and differences from similar functions, you can effectively incorporate it into your PyTorch workflows.
When dealing with tensor operations, a clear understanding of torch.stack
will save you time and effort, allowing you to focus more on building and fine-tuning your models!
后记
2024年12月12日22点03分于上海,在GPT4o大模型辅助下完成。