PyTorch 的 F.scaled_dot_product_attention 返回Nan


"为什么 PyTorch 的 scaled_dot_product_attention 会输出 NaN?如何正确构造 Attention Mask"


引言:看似正常的 mask,为什么会引发 NaN?

在使用 F.scaled_dot_product_attention 构建跨模态或多源注意力时,我们常通过 attention_mask 控制每个 query 位置能看到哪些 key。但如果不小心构造出某些 query 对所有 key 都不可见的情况,就会在 softmax 中触发 NaN,进而让模型 loss 崩溃。

这个问题隐蔽却常见,且 PyTorch 不会自动容错,需要我们显式处理。


问题复现:全 -inf 行将导致 NaN

在 PyTorch 的 scaled attention 中:

python 复制代码
output = scaled_dot_product_attention(query, key, value, attn_mask)

其中 attn_maskadditive mask,即:

  • 0.0: 表示该位置可见;
  • -inf: 表示该位置被屏蔽,不可 attend。

当某个 query 行的 mask 全为 -inf 时,softmax 输入类似于:

python 复制代码
softmax([-inf, -inf, ..., -inf]) → [NaN, NaN, ..., NaN]

这将污染整个计算图,最终导致 loss 为 NaN。


产生这种情况的常见原因

这种情况经常发生在任务中存在大量 query(例如图像 patch、token、时间步)本身就不应该 attend 到任何 key,例如背景区域或 padding 区域。

因此,虽然逻辑合理,但仍然在数学上不合法


解决方案:fallback 解锁最后一个 key

为避免 NaN,可在转换 bool mask → float mask 时引入一个 fallback:

python 复制代码
# attention_mask: [B, Q, K],bool 类型,True 表示"可以 attend"
attention_mask_float = torch.full_like(attention_mask, float('-inf'), dtype=query.dtype)
attention_mask_float.masked_fill_(attention_mask, 0.0)

# fallback:避免某些 query 全为 -inf
all_inf_rows = (attention_mask_float == float('-inf')).all(dim=-1, keepdim=True)  # [B, Q, 1]
if all_inf_rows.any():
    last_key_idx = attention_mask_float.size(-1) - 1
    fix_mask = torch.arange(attention_mask_float.size(-1), device=attention_mask.device) == last_key_idx
    fix_mask = fix_mask.view(1, 1, -1)  # reshape for broadcast
    attention_mask_float = attention_mask_float.masked_fill(all_inf_rows & fix_mask, 0.0)

这样即便某个 query 原本完全不可见,也能保证 softmax 至少有一个有效分布。


可视化建议

可以使用 matplotlib.imshow 直接可视化 [Q, K] 的 mask 分布:

python 复制代码
# 黑色:可见(0.0),白色:被 mask(-inf)
vis_mask = (attn_mask == 0.0).astype(np.uint8)
plt.imshow(vis_mask, cmap='Greys', aspect='auto')

可视化能帮助你快速定位全白 query 行,即潜在 NaN 风险点。


总结

条目 建议
是否允许 query 全被屏蔽 语义上允许,数学上不合法(需处理)
PyTorch 是否兜底 否,需用户自己容错
是否应解锁一个 dummy key 是,最安全的 fallback 机制
可否通过可视化排查 是,黑白图可快速识别空行

相关推荐
Python×CATIA工业智造12 小时前
Python函数包装技术详解:从基础装饰器到高级应用
python·pycharm
落羽的落羽12 小时前
【Linux系统】从零掌握make与Makefile:高效自动化构建项目的工具
linux·服务器·开发语言·c++·人工智能·机器学习·1024程序员节
应用市场12 小时前
VSCode + AI Agent实现直接编译调试:告别Visual Studio的原理与实践
人工智能·vscode·visual studio
GIS数据转换器12 小时前
城市基础设施安全运行监管平台
大数据·运维·人工智能·物联网·安全·无人机·1024程序员节
快秃头的码农13 小时前
LazyLLM,(万象应用开发平台 AppStudio)商汤大装置
python
遇雪长安13 小时前
深度学习YOLO实战:4、模型的三要素:任务、类别与规模
人工智能·深度学习·yolo
搞科研的小刘选手13 小时前
【云计算专题会议】第二届云计算与大数据国际学术会议(ICCBD 2025)
大数据·人工智能·物联网·5g·云计算·6g·智能通信
电商软件开发 小银13 小时前
微信生态新机遇:视频号推客模式助力商家突围
大数据·人工智能·twitter·系统开发·实体店转型·数字化经济·视频号推客模式
综合热讯13 小时前
湖南粒界教育科技有限公司:专注影视职业教育,AI辅助教学提升学习实效
人工智能·科技·学习
深兰科技13 小时前
深兰科技法务大模型亮相,推动律所文书处理智能化
人工智能·scrapy·beautifulsoup·scikit-learn·pyqt·fastapi·深兰科技