PyTorch 提供了一系列强大的索引操作函数,用于高效地操作张量数据。以下是常用的 PyTorch 索引操作函数及其用途:
1. 基础索引操作
这些操作类似于 NumPy 的基本索引方式。
1.1 方括号索引 ([]
)
-
支持 切片 、布尔索引 、高级索引。
-
示例:
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
print(x[0, 1]) # 取第 0 行第 1 列,输出 20
print(x[:, 1]) # 取所有行的第 1 列,输出 tensor([20, 50])
print(x[1]) # 取第 1 行,输出 tensor([40, 50, 60])
2. 索引操作函数
2.1 torch.index_select
按指定维度的索引提取元素。
-
用法:
torch.index_select(input, dim, index, *, out=None)
-
参数 :
input
:输入张量。dim
:指定的维度。index
:索引张量,必须是一维张量。
-
示例:
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
index = torch.tensor([0, 2])
result = torch.index_select(x, dim=1, index=index)
print(result) # 输出 tensor([[10, 30], [40, 60]])
2.2 torch.gather
根据索引张量,从输入张量中收集数据。
-
用法:
torch.gather(input, dim, index, *, sparse_grad=False, out=None)
-
参数 :
input
:输入张量。dim
:指定的维度。index
:索引张量,必须与输入张量形状兼容。
-
示例:
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
index = torch.tensor([[0, 2], [1, 0]])
result = torch.gather(x, dim=1, index=index)
print(result) # 输出 tensor([[10, 30], [50, 40]])
2.3 torch.scatter
将源张量的值按索引写入目标张量。
-
用法:
torch.scatter(input, dim, index, src, *, reduce=None)
示例:
x = torch.zeros(2, 3)
index = torch.tensor([[0, 1, 2], [0, 1, 2]])
src = torch.tensor([[10., 20., 30.], [40., 50., 60.]])
result = torch.scatter(x, dim=1, index=index, src=src)
print(result) # 输出 tensor([[10., 20., 30.], [40., 50., 60.]])
2.4 torch.take
按照扁平化索引从张量中提取元素。
-
用法:
torch.take(input, index)
示例:
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
index = torch.tensor([0, 2, 5])
result = torch.take(x, index)
print(result) # 输出 tensor([10, 30, 60])
torch.gather和 torch.scatter区别
特性 | torch.gather |
torch.scatter |
---|---|---|
作用 | 从指定位置提取数据 | 向指定位置写入数据 |
索引意义 | 定义要从 input 中提取数据的位置 |
定义要向 input 中写入数据的位置 |
输出形状 | 与 index 形状相同 |
与 input 形状相同 |
操作方向 | 从 input 提取值 |
将 src 的值写入到 input |
使用场景 | 数据提取(如选择性采样) | 数据写入(如更新某些特定位置的值) |
2.5 torch.masked_select
根据布尔掩码从张量中提取元素。
-
用法:
torch.masked_select(input, mask, *, out=None)
示例:
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
mask = x > 30
result = torch.masked_select(x, mask)
print(result) # 输出 tensor([40, 50, 60])
2.6 torch.nonzero
返回所有非零元素的索引。
-
用法:
torch.nonzero(input, *, as_tuple=False)
示例:
x = torch.tensor([[0, 1, 0], [2, 0, 3]])
result = torch.nonzero(x)
print(result) # 输出 tensor([[0, 1], [1, 0], [1, 2]])
2.7 torch.where
根据条件选择值。
-
用法:
torch.where(condition, x, y)
示例:
x = torch.tensor([1, 2, 3])
y = torch.tensor([10, 20, 30])
condition = x > 1
result = torch.where(condition, x, y)
print(result) # 输出 tensor([10, 2, 3])
2.8 torch.advanced_indexing
-
PyTorch 支持 布尔张量索引 和 整形张量索引。
-
示例:
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
index = torch.tensor([[0, 1], [1, 0]])
result = x[index]
print(result)
输出:
tensor([[[10, 20, 30],
[40, 50, 60]],
[[40, 50, 60],
[10, 20, 30]]])
注:第0维度操作
3. 总结
函数 | 功能 |
---|---|
torch.index_select |
按指定维度和索引提取元素 |
torch.gather |
根据索引张量从输入中收集元素 |
torch.scatter |
按索引将源张量的值写入目标张量 |
torch.take |
按扁平化索引从张量中提取元素 |
torch.masked_select |
根据布尔掩码从张量中提取元素 |
torch.nonzero |
找出非零元素的索引 |
torch.where |
根据条件选择元素 |
这些索引操作函数可以帮助简化张量操作,提高代码效率。