在 PyTorch 中,masked_fill
、masked_select
和 masked_scatter
是三种常用的掩码(mask)操作方法,它们通过布尔类型的掩码张量(mask
)对原始张量进行条件筛选或修改。以下是它们的详细解释和对比:
1. masked_fill
作用 :将原始张量中 mask
为 True
的位置用指定值填充,其余位置保持不变。
参数 :
• mask
(BoolTensor):与原始张量形状相同的布尔掩码。
• value
(标量):用于填充的值。
特点 :
• 原地操作 :会直接修改原始张量(除非使用 masked_fill_
的 in-place 版本)。
• 保持形状:输出张量形状与输入张量一致。
示例:
python
import torch
x = torch.tensor([[1, 2], [3, 4]])
mask = torch.tensor([[False, True], [True, False]])
y = x.masked_fill(mask, -1)
print(y)
# 输出:
# tensor([[ 1, -1],
# [-1, 4]])
典型应用 :
• 在 Transformer 的注意力机制中,用 -inf
填充 padding 或未来的位置,使 softmax 后概率为 0。
• 数据清洗时屏蔽无效值(如 NaN)。
2. masked_select
作用 :从原始张量中提取 mask
为 True
的元素,返回一维张量。
参数 :
• mask
(BoolTensor):与原始张量形状相同的布尔掩码。
特点 :
• 返回一维张量 :输出会丢失原始张量的维度信息。
• 非原地操作:生成新的张量。
示例:
python
x = torch.tensor([[1, 2], [3, 4]])
mask = torch.tensor([[False, True], [True, False]])
y = x.masked_select(mask)
print(y) # tensor([2, 3])
典型应用 :
• 提取满足条件的元素(如分类任务中筛选正样本)。
• 统计掩码区域的值(如计算非零元素均值)。
3. masked_scatter
作用 :将另一个张量(source
)中的值按顺序填充到原始张量中 mask
为 True
的位置。
参数 :
• mask
(BoolTensor):与原始张量形状相同的布尔掩码。
• source
(Tensor):提供填充值的源张量。
特点 :
• 按顺序填充 :source
中的值按行优先顺序填充到 mask
为 True
的位置。
• source
的长度必须 ≥ mask
中 True
的数量。
示例:
python
x = torch.tensor([[1, 2], [3, 4]])
mask = torch.tensor([[False, True], [True, False]])
source = torch.tensor([10, 20])
y = x.masked_scatter(mask, source)
print(y)
# 输出:
# tensor([[ 1, 10],
# [20, 4]])
典型应用 :
• 动态替换张量中的部分值(如用随机噪声替换特定区域)。
• 批量更新参数时选择性地替换某些位置。
对比总结
方法 | 输入张量形状 | 输出形状 | 是否修改原张量 | 核心功能 |
---|---|---|---|---|
masked_fill |
保留原形状 | 与原张量相同 | 是(可选) | 用标量填充掩码区域 |
masked_select |
保留原形状 | 一维张量 | 否 | 提取掩码区域的元素 |
masked_scatter |
保留原形状 | 与原张量相同 | 是(可选) | 用另一张量填充掩码区域 |
关键注意事项
- 掩码形状匹配 :
mask
必须与原始张量形状严格一致,否则会报错。 - 数据类型 :
mask
必须是布尔类型(BoolTensor
)。 - 梯度传播 :所有操作均支持自动求导,但填充的值(如
masked_fill
中的value
)需是浮点数以避免类型错误。 - 性能:对大规模张量频繁使用这些操作可能影响性能,建议优先使用向量化操作。
选择方法指南
• 需要保持形状并填充标量 → masked_fill
• 需要提取元素并丢弃形状 → masked_select
• 需要按顺序替换为另一张量的值 → masked_scatter
通过合理使用这些方法,可以高效实现条件筛选、数据清洗、动态修改等任务。