PyTorch 的 F.scaled_dot_product_attention 返回Nan


"为什么 PyTorch 的 scaled_dot_product_attention 会输出 NaN?如何正确构造 Attention Mask"


引言:看似正常的 mask,为什么会引发 NaN?

在使用 F.scaled_dot_product_attention 构建跨模态或多源注意力时,我们常通过 attention_mask 控制每个 query 位置能看到哪些 key。但如果不小心构造出某些 query 对所有 key 都不可见的情况,就会在 softmax 中触发 NaN,进而让模型 loss 崩溃。

这个问题隐蔽却常见,且 PyTorch 不会自动容错,需要我们显式处理。


问题复现:全 -inf 行将导致 NaN

在 PyTorch 的 scaled attention 中:

python 复制代码
output = scaled_dot_product_attention(query, key, value, attn_mask)

其中 attn_maskadditive mask,即:

  • 0.0: 表示该位置可见;
  • -inf: 表示该位置被屏蔽,不可 attend。

当某个 query 行的 mask 全为 -inf 时,softmax 输入类似于:

python 复制代码
softmax([-inf, -inf, ..., -inf]) → [NaN, NaN, ..., NaN]

这将污染整个计算图,最终导致 loss 为 NaN。


产生这种情况的常见原因

这种情况经常发生在任务中存在大量 query(例如图像 patch、token、时间步)本身就不应该 attend 到任何 key,例如背景区域或 padding 区域。

因此,虽然逻辑合理,但仍然在数学上不合法


解决方案:fallback 解锁最后一个 key

为避免 NaN,可在转换 bool mask → float mask 时引入一个 fallback:

python 复制代码
# attention_mask: [B, Q, K],bool 类型,True 表示"可以 attend"
attention_mask_float = torch.full_like(attention_mask, float('-inf'), dtype=query.dtype)
attention_mask_float.masked_fill_(attention_mask, 0.0)

# fallback:避免某些 query 全为 -inf
all_inf_rows = (attention_mask_float == float('-inf')).all(dim=-1, keepdim=True)  # [B, Q, 1]
if all_inf_rows.any():
    last_key_idx = attention_mask_float.size(-1) - 1
    fix_mask = torch.arange(attention_mask_float.size(-1), device=attention_mask.device) == last_key_idx
    fix_mask = fix_mask.view(1, 1, -1)  # reshape for broadcast
    attention_mask_float = attention_mask_float.masked_fill(all_inf_rows & fix_mask, 0.0)

这样即便某个 query 原本完全不可见,也能保证 softmax 至少有一个有效分布。


可视化建议

可以使用 matplotlib.imshow 直接可视化 [Q, K] 的 mask 分布:

python 复制代码
# 黑色:可见(0.0),白色:被 mask(-inf)
vis_mask = (attn_mask == 0.0).astype(np.uint8)
plt.imshow(vis_mask, cmap='Greys', aspect='auto')

可视化能帮助你快速定位全白 query 行,即潜在 NaN 风险点。


总结

条目 建议
是否允许 query 全被屏蔽 语义上允许,数学上不合法(需处理)
PyTorch 是否兜底 否,需用户自己容错
是否应解锁一个 dummy key 是,最安全的 fallback 机制
可否通过可视化排查 是,黑白图可快速识别空行

相关推荐
hrrrrb18 小时前
【Python】文件处理(二)
开发语言·python
catchadmin19 小时前
PHP 快速集成 ChatGPT 用 AI 让你的应用更聪明
人工智能·后端·chatgpt·php
万粉变现经纪人21 小时前
如何解决 pip install 安装报错 ModuleNotFoundError: No module named ‘tokenizers’ 问题
python·selenium·测试工具·scrapy·beautifulsoup·fastapi·pip
编程武士1 天前
从50ms到30ms:YOLOv10部署中图像预处理的性能优化实践
人工智能·python·yolo·性能优化
我的xiaodoujiao1 天前
Windows系统Web UI自动化测试学习系列2--环境搭建--Python-PyCharm-Selenium
开发语言·python·测试工具
max5006001 天前
基于Meta Llama的二语习得学习者行为预测计算模型
人工智能·算法·机器学习·分类·数据挖掘·llama
月疯1 天前
OPENCV摄像头读取视频
人工智能·opencv·音视频
极客天成ScaleFlash1 天前
极客天成让统一存储从云原生‘进化’到 AI 原生: 不是版本升级,而是基因重组
人工智能·云原生
王哥儿聊AI1 天前
Lynx:新一代个性化视频生成模型,单图即可生成视频,重新定义身份一致性与视觉质量
人工智能·算法·安全·机器学习·音视频·软件工程
_pinnacle_1 天前
打开神经网络的黑箱(三) 卷积神经网络(CNN)的模型逻辑
人工智能·神经网络·cnn·黑箱·卷积网络