pytorch小记(十二):pytorch中 masked_fill_() vs. masked_fill() 详解

pytorch小记(十二):pytorch中 masked_fill_() vs. masked_fill()详解

  • [PyTorch `masked_fill_()` vs. `masked_fill()` 详解](#PyTorch masked_fill_() vs. masked_fill() 详解)
    • [1️⃣ `masked_fill()` 和 `masked_fill_()` 的作用](#1️⃣ masked_fill()masked_fill_() 的作用)
    • [2️⃣ `masked_fill()` vs. `masked_fill_()` 示例](#2️⃣ masked_fill() vs. masked_fill_() 示例)
    • [3️⃣ 输出结果](#3️⃣ 输出结果)
    • [4️⃣ `masked_fill()` vs. `masked_fill_()` 区别](#4️⃣ masked_fill() vs. masked_fill_() 区别)
    • [5️⃣ `masked_fill()` 和 `masked_fill_()` 的实际应用](#5️⃣ masked_fill()masked_fill_() 的实际应用)
    • [6️⃣ `masked_fill()` 在 Transformer 自注意力中的应用](#6️⃣ masked_fill() 在 Transformer 自注意力中的应用)
    • [7️⃣ `masked_fill_()` 在梯度计算中的应用](#7️⃣ masked_fill_() 在梯度计算中的应用)
    • [8️⃣ 总结](#8️⃣ 总结)
      • [💡 实际应用](#💡 实际应用)

PyTorch masked_fill_() vs. masked_fill() 详解

在 PyTorch 中,masked_fill_()masked_fill() 主要用于 根据掩码(mask)填充张量(tensor)中的特定元素 ,但它们的关键区别在于 是否修改原张量(in-place 操作)


1️⃣ masked_fill()masked_fill_() 的作用

  • masked_fill(mask, value)
    • 不会修改原张量,而是返回一个新的张量。
  • masked_fill_(mask, value)
    • 会直接修改原张量(in-place 操作),节省内存。

两者的作用

  • mask 是一个 布尔张量True 代表需要填充的元素)。
  • value 是要填充的数值。

2️⃣ masked_fill() vs. masked_fill_() 示例

python 复制代码
import torch

# 创建一个 3×3 的张量
tensor = torch.tensor([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
])

# 创建一个掩码:True 代表要被替换的元素
mask = torch.tensor([
    [False, True, False],
    [True, False, False],
    [False, False, True]
])

print("原张量 tensor:\n", tensor)

# 使用 masked_fill()(不会修改原张量)
new_tensor = tensor.masked_fill(mask, -1)
print("\n新张量 new_tensor(使用 masked_fill()):\n", new_tensor)
print("\n原张量 tensor(未修改):\n", tensor)

# 使用 masked_fill_()(会修改原张量)
tensor.masked_fill_(mask, -1)
print("\n原张量 tensor(被修改):\n", tensor)

3️⃣ 输出结果

复制代码
原张量 tensor:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

新张量 new_tensor(使用 masked_fill()):
tensor([[ 1, -1,  3],
        [-1,  5,  6],
        [ 7,  8, -1]])

原张量 tensor(未修改):
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

原张量 tensor(被修改):
tensor([[ 1, -1,  3],
        [-1,  5,  6],
        [ 7,  8, -1]])

4️⃣ masked_fill() vs. masked_fill_() 区别

函数 是否修改原张量? 返回值
masked_fill(mask, value) ❌ 不修改 返回新的张量
masked_fill_(mask, value) ✅ 直接修改 修改后的原张量

总结

  • 如果你希望创建一个新张量,而不修改原数据 ,用 masked_fill()
  • 如果你希望节省内存并直接修改原张量 ,用 masked_fill_()

5️⃣ masked_fill()masked_fill_() 的实际应用

自然语言处理(NLP)深度学习模型 中,这两个函数经常用于 掩码(masking)操作,例如:

  • 屏蔽填充(Padding Mask) :防止模型处理填充的 PAD 词(如 Transformer)。
  • 屏蔽未来信息(Future Mask):用于自回归模型(如 GPT),确保预测不会使用未来的信息。

6️⃣ masked_fill() 在 Transformer 自注意力中的应用

python 复制代码
import torch

# 假设有一个 4×4 的注意力得分矩阵
attn_scores = torch.tensor([
    [0.5, 0.7, 0.8, 0.9],
    [0.6, 0.5, 0.4, 0.8],
    [0.2, 0.4, 0.5, 0.7],
    [0.3, 0.5, 0.6, 0.8]
])

# 创建一个掩码(模拟未来时间步的屏蔽)
mask = torch.tensor([
    [False, False, False, True],
    [False, False, True, True],
    [False, True, True, True],
    [True, True, True, True]
])

# 用 -inf 屏蔽掩码位置
masked_scores = attn_scores.masked_fill(mask, float('-inf'))
print("\n注意力得分(masked_fill()):\n", masked_scores)

示例输出:

复制代码
注意力得分(masked_fill()):
tensor([[ 0.5000,  0.7000,  0.8000,    -inf],
        [ 0.6000,  0.5000,    -inf,    -inf],
        [ 0.2000,    -inf,    -inf,    -inf],
        [   -inf,    -inf,    -inf,    -inf]])

📌 解释

  • False 的位置保留原始数值。
  • True 的位置填充 -inf,在 softmax 计算时会被归零,不影响其他数值。

7️⃣ masked_fill_() 在梯度计算中的应用

在 PyTorch 训练过程中,如果你想直接修改梯度计算中的变量 ,可以使用 masked_fill_() 进行 in-place 操作

python 复制代码
import torch

# 创建一个需要计算梯度的张量
x = torch.tensor([0.1, 0.2, 0.3, 0.4], requires_grad=True)

# 创建掩码
mask = torch.tensor([False, True, False, True])

# 直接修改 x
x.masked_fill_(mask, 0.0)

print("\n被修改后的 x(masked_fill_()):\n", x)

示例输出:

复制代码
被修改后的 x(masked_fill_()):
tensor([0.1000, 0.0000, 0.3000, 0.0000], requires_grad=True)

📌 解释

  • 通过 masked_fill_() 直接在计算图中修改 x,避免创建新张量。

8️⃣ 总结

  • masked_fill()masked_fill_() 都用于按掩码填充张量中的特定元素
  • 主要区别
    • masked_fill() 不修改原张量,返回新的张量。
    • masked_fill_() 直接修改原张量(in-place 操作)。
  • 适用场景
    • masked_fill() :适用于需要 保持原张量不变 的情况,如 Transformer 掩码处理
    • masked_fill_() :适用于需要 节省内存直接修改张量 的情况,如 梯度计算

💡 实际应用

场景 推荐使用
创建新张量,不修改原数据 masked_fill()
直接修改原数据,减少内存占用 masked_fill_()
Transformer 自注意力掩码 masked_fill()
梯度计算,避免额外的计算图创建 masked_fill_()

🚀 合理使用 masked_fill()masked_fill_(),可以优化你的 PyTorch 代码,提高计算效率! 🎯

相关推荐
大数据魔法师12 分钟前
豆瓣图书数据采集与可视化分析
python·数据分析·数据可视化
批量小王子12 分钟前
第1个小脚本:英语单语按字母个数进行升序排序
python
AmazingKO32 分钟前
制作像素风《饥荒》类游戏的整体蓝图和流程
人工智能·python·游戏·docker·visual studio code·竹相左边
CV-杨帆34 分钟前
trl的安装与单GPU多GPU测试
人工智能
_一条咸鱼_1 小时前
AI 大模型的 Prompt Engineering 原理
人工智能·深度学习·面试
趣谈AI1 小时前
使用Trae编辑器开发Python Api (FastApi 框架)
python·编辑器·fastapi
carpell1 小时前
二叉树实战篇2
python·二叉树·数据结构与算法
python_chai1 小时前
Python网络编程从入门到精通:Socket核心技术+TCP/UDP实战详解
网络·python·tcp/ip·udp·socket
一只名叫Me的猫1 小时前
conda 创建、激活、退出、删除环境命令
python·conda
huang_xiaoen1 小时前
试一下阿里云新出的mcp服务
人工智能·阿里云·ai·云计算·mcp