详细分析Pytorch中的masked_fill基本知识(附Demo)

目录

  • [1. 基本知识](#1. 基本知识)
  • [2. Demo](#2. Demo)

1. 基本知识

基本的原理知识如下:

  1. 输入张量和掩码

    masked_fill 接受两个主要参数:一个输入张量和一个布尔掩码

    掩码的形状必须与输入张量相同,True 表示需要填充的位置,False 表示保持原值

  2. 掩码操作

    在执行 masked_fill 操作时,函数会检查掩码中每个元素的值

    如果掩码对应的位置为 True,则在输出张量中填充指定的值;

    如果为 False,则保留输入张量中对应位置的值

  3. 输出结果

    最终生成的新张量包含了在掩码位置上被替换的值,其余位置保持原样


在代码逻辑上

  1. 创建掩码
    mask 是一个布尔张量,标识了哪些位置需要填充:
python 复制代码
[[False, True, False],
 [True, False, True],
 [False, False, True]]
  1. 执行 masked_fill
    当调用 tensor.masked_fill(mask, -1) 时,PyTorch 会遍历掩码中的每个元素:对于 mask 中的每个 True 值,tensor 在对应位置的值会被替换为 -1,对于 False 值,保持原值不变

masked_fill 操作是基于 C/C++ 的实现,因此在处理大规模数据时性能较高。常用于深度学习模型中的数据预处理,比如在填充序列、处理缺失值或标记特定条件的数据时

2. Demo

Demo 1: 基本用法

python 复制代码
import torch

# 创建一个 3x3 的张量
tensor = torch.tensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])

# 创建一个掩码,标记要填充的位置
mask = torch.tensor([[False, True, False],
                     [True, False, True],
                     [False, False, True]])

# 使用 masked_fill 填充掩码位置为 -1
result = tensor.masked_fill(mask, -1)

print("原始张量:")
print(tensor)
print("\n填充后的张量:")
print(result)

截图如下:

Demo 2: 与条件结合使用

python 复制代码
import torch

# 创建一个随机张量
tensor = torch.randn(3, 3)

# 创建掩码:标记负值的位置
mask = tensor < 0

# 将负值位置填充为 0
result = tensor.masked_fill(mask, 0)

print("原始张量:")
print(tensor)
print("\n填充后的张量 (负值填充为 0):")
print(result)

截图如下:

Demo 3: 结合计算

python 复制代码
import torch

# 创建一个张量
tensor = torch.tensor([[10, 20, 30],
                       [40, 50, 60],
                       [70, 80, 90]])

# 创建掩码:标记大于 50 的位置
mask = tensor > 50

# 用 999 填充大于 50 的位置
result = tensor.masked_fill(mask, 999)

print("原始张量:")
print(tensor)
print("\n填充后的张量 (大于 50 的位置填充为 999):")
print(result)

截图如下:

相关推荐
Source.Liu几秒前
【Python基础】 19 Rust 与 Python if 语句对比笔记
笔记·python·rust
一颗20217 分钟前
深度解读:PSPNet(Pyramid Scene Parsing Network) — 用金字塔池化把“场景理解”装进分割网络
人工智能·深度学习·计算机视觉
奋进的电子工程师10 分钟前
汽车软件研发智能化:AI在CI/CD中的实践
人工智能·ci/cd·汽车·软件工程·软件构建·代码规范
摘星编程14 分钟前
Cursor Pair Programming:在前端项目里用 AI 快速迭代 UI 组件
前端·人工智能·ui·typescript·前端开发·cursorai
工业互联网专业23 分钟前
基于Spark的新冠肺炎疫情实时监控系统_django+spider
python·spark·django·vue·毕业设计·源码·课程设计
ZHOU_WUYI25 分钟前
门控MLP(Qwen3MLP)与稀疏混合专家(Qwen3MoeSparseMoeBlock)模块解析
人工智能·llm
Yh87020327 分钟前
2025年经济学专业女生必考证书指南:打造差异化竞争力
python
黄焖鸡能干四碗34 分钟前
信息系统安全保护措施文件方案
大数据·开发语言·人工智能·web安全·制造
BYSJMG35 分钟前
大数据毕业设计推荐:基于Spark的零售时尚精品店销售数据分析系统【Hadoop+python+spark】
大数据·hadoop·python·spark·django·课程设计
hallo12836 分钟前
学习机器学习能看哪些书籍
人工智能·深度学习·机器学习