pytorch小记(十二):pytorch中 masked_fill_() vs. masked_fill() 详解

pytorch小记(十二):pytorch中 masked_fill_() vs. masked_fill()详解

  • [PyTorch `masked_fill_()` vs. `masked_fill()` 详解](#PyTorch masked_fill_() vs. masked_fill() 详解)
    • [1️⃣ `masked_fill()` 和 `masked_fill_()` 的作用](#1️⃣ masked_fill()masked_fill_() 的作用)
    • [2️⃣ `masked_fill()` vs. `masked_fill_()` 示例](#2️⃣ masked_fill() vs. masked_fill_() 示例)
    • [3️⃣ 输出结果](#3️⃣ 输出结果)
    • [4️⃣ `masked_fill()` vs. `masked_fill_()` 区别](#4️⃣ masked_fill() vs. masked_fill_() 区别)
    • [5️⃣ `masked_fill()` 和 `masked_fill_()` 的实际应用](#5️⃣ masked_fill()masked_fill_() 的实际应用)
    • [6️⃣ `masked_fill()` 在 Transformer 自注意力中的应用](#6️⃣ masked_fill() 在 Transformer 自注意力中的应用)
    • [7️⃣ `masked_fill_()` 在梯度计算中的应用](#7️⃣ masked_fill_() 在梯度计算中的应用)
    • [8️⃣ 总结](#8️⃣ 总结)
      • [💡 实际应用](#💡 实际应用)

PyTorch masked_fill_() vs. masked_fill() 详解

在 PyTorch 中,masked_fill_()masked_fill() 主要用于 根据掩码(mask)填充张量(tensor)中的特定元素 ,但它们的关键区别在于 是否修改原张量(in-place 操作)


1️⃣ masked_fill()masked_fill_() 的作用

  • masked_fill(mask, value)
    • 不会修改原张量,而是返回一个新的张量。
  • masked_fill_(mask, value)
    • 会直接修改原张量(in-place 操作),节省内存。

两者的作用

  • mask 是一个 布尔张量True 代表需要填充的元素)。
  • value 是要填充的数值。

2️⃣ masked_fill() vs. masked_fill_() 示例

python 复制代码
import torch

# 创建一个 3×3 的张量
tensor = torch.tensor([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
])

# 创建一个掩码:True 代表要被替换的元素
mask = torch.tensor([
    [False, True, False],
    [True, False, False],
    [False, False, True]
])

print("原张量 tensor:\n", tensor)

# 使用 masked_fill()(不会修改原张量)
new_tensor = tensor.masked_fill(mask, -1)
print("\n新张量 new_tensor(使用 masked_fill()):\n", new_tensor)
print("\n原张量 tensor(未修改):\n", tensor)

# 使用 masked_fill_()(会修改原张量)
tensor.masked_fill_(mask, -1)
print("\n原张量 tensor(被修改):\n", tensor)

3️⃣ 输出结果

复制代码
原张量 tensor:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

新张量 new_tensor(使用 masked_fill()):
tensor([[ 1, -1,  3],
        [-1,  5,  6],
        [ 7,  8, -1]])

原张量 tensor(未修改):
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

原张量 tensor(被修改):
tensor([[ 1, -1,  3],
        [-1,  5,  6],
        [ 7,  8, -1]])

4️⃣ masked_fill() vs. masked_fill_() 区别

函数 是否修改原张量? 返回值
masked_fill(mask, value) ❌ 不修改 返回新的张量
masked_fill_(mask, value) ✅ 直接修改 修改后的原张量

总结

  • 如果你希望创建一个新张量,而不修改原数据 ,用 masked_fill()
  • 如果你希望节省内存并直接修改原张量 ,用 masked_fill_()

5️⃣ masked_fill()masked_fill_() 的实际应用

自然语言处理(NLP)深度学习模型 中,这两个函数经常用于 掩码(masking)操作,例如:

  • 屏蔽填充(Padding Mask) :防止模型处理填充的 PAD 词(如 Transformer)。
  • 屏蔽未来信息(Future Mask):用于自回归模型(如 GPT),确保预测不会使用未来的信息。

6️⃣ masked_fill() 在 Transformer 自注意力中的应用

python 复制代码
import torch

# 假设有一个 4×4 的注意力得分矩阵
attn_scores = torch.tensor([
    [0.5, 0.7, 0.8, 0.9],
    [0.6, 0.5, 0.4, 0.8],
    [0.2, 0.4, 0.5, 0.7],
    [0.3, 0.5, 0.6, 0.8]
])

# 创建一个掩码(模拟未来时间步的屏蔽)
mask = torch.tensor([
    [False, False, False, True],
    [False, False, True, True],
    [False, True, True, True],
    [True, True, True, True]
])

# 用 -inf 屏蔽掩码位置
masked_scores = attn_scores.masked_fill(mask, float('-inf'))
print("\n注意力得分(masked_fill()):\n", masked_scores)

示例输出:

复制代码
注意力得分(masked_fill()):
tensor([[ 0.5000,  0.7000,  0.8000,    -inf],
        [ 0.6000,  0.5000,    -inf,    -inf],
        [ 0.2000,    -inf,    -inf,    -inf],
        [   -inf,    -inf,    -inf,    -inf]])

📌 解释

  • False 的位置保留原始数值。
  • True 的位置填充 -inf,在 softmax 计算时会被归零,不影响其他数值。

7️⃣ masked_fill_() 在梯度计算中的应用

在 PyTorch 训练过程中,如果你想直接修改梯度计算中的变量 ,可以使用 masked_fill_() 进行 in-place 操作

python 复制代码
import torch

# 创建一个需要计算梯度的张量
x = torch.tensor([0.1, 0.2, 0.3, 0.4], requires_grad=True)

# 创建掩码
mask = torch.tensor([False, True, False, True])

# 直接修改 x
x.masked_fill_(mask, 0.0)

print("\n被修改后的 x(masked_fill_()):\n", x)

示例输出:

复制代码
被修改后的 x(masked_fill_()):
tensor([0.1000, 0.0000, 0.3000, 0.0000], requires_grad=True)

📌 解释

  • 通过 masked_fill_() 直接在计算图中修改 x,避免创建新张量。

8️⃣ 总结

  • masked_fill()masked_fill_() 都用于按掩码填充张量中的特定元素
  • 主要区别
    • masked_fill() 不修改原张量,返回新的张量。
    • masked_fill_() 直接修改原张量(in-place 操作)。
  • 适用场景
    • masked_fill() :适用于需要 保持原张量不变 的情况,如 Transformer 掩码处理
    • masked_fill_() :适用于需要 节省内存直接修改张量 的情况,如 梯度计算

💡 实际应用

场景 推荐使用
创建新张量,不修改原数据 masked_fill()
直接修改原数据,减少内存占用 masked_fill_()
Transformer 自注意力掩码 masked_fill()
梯度计算,避免额外的计算图创建 masked_fill_()

🚀 合理使用 masked_fill()masked_fill_(),可以优化你的 PyTorch 代码,提高计算效率! 🎯

相关推荐
liufangshun1 小时前
【DeepSeekR1】怎样清除mssql的日志文件?
数据库·人工智能·sqlserver
深圳市快瞳科技有限公司1 小时前
AI鸟类识别技术革新生态监测:快瞳科技如何用“智慧之眼”守护自然?
人工智能·科技
ModelWhale1 小时前
和鲸科技受邀赴中国气象局气象干部培训学院湖南分院开展 DeepSeek 趋势下的人工智能技术应用专题培训
人工智能·科技
Fansv5871 小时前
深度学习框架PyTorch——从入门到精通(3)数据集和数据加载器
人工智能·pytorch·深度学习
Sunday_ding4 小时前
NLP 与常见的nlp应用
人工智能·自然语言处理
一ge科研小菜鸡4 小时前
当下主流 AI 模型对比:ChatGPT、DeepSeek、Grok 及其他前沿技术
人工智能
ai产品老杨5 小时前
全流程数字化管理的智慧物流开源了。
前端·javascript·vue.js·人工智能·安全
mzgong5 小时前
图像分割的mask有空洞怎么修补
人工智能·opencv·计算机视觉
背着代码的蜗牛5 小时前
Pycharm远程开发注意事项
ide·python·pycharm·ssh·远程工作