PyTorch 的 F.scaled_dot_product_attention 返回Nan


"为什么 PyTorch 的 scaled_dot_product_attention 会输出 NaN?如何正确构造 Attention Mask"


引言:看似正常的 mask,为什么会引发 NaN?

在使用 F.scaled_dot_product_attention 构建跨模态或多源注意力时,我们常通过 attention_mask 控制每个 query 位置能看到哪些 key。但如果不小心构造出某些 query 对所有 key 都不可见的情况,就会在 softmax 中触发 NaN,进而让模型 loss 崩溃。

这个问题隐蔽却常见,且 PyTorch 不会自动容错,需要我们显式处理。


问题复现:全 -inf 行将导致 NaN

在 PyTorch 的 scaled attention 中:

python 复制代码
output = scaled_dot_product_attention(query, key, value, attn_mask)

其中 attn_maskadditive mask,即:

  • 0.0: 表示该位置可见;
  • -inf: 表示该位置被屏蔽,不可 attend。

当某个 query 行的 mask 全为 -inf 时,softmax 输入类似于:

python 复制代码
softmax([-inf, -inf, ..., -inf]) → [NaN, NaN, ..., NaN]

这将污染整个计算图,最终导致 loss 为 NaN。


产生这种情况的常见原因

这种情况经常发生在任务中存在大量 query(例如图像 patch、token、时间步)本身就不应该 attend 到任何 key,例如背景区域或 padding 区域。

因此,虽然逻辑合理,但仍然在数学上不合法


解决方案:fallback 解锁最后一个 key

为避免 NaN,可在转换 bool mask → float mask 时引入一个 fallback:

python 复制代码
# attention_mask: [B, Q, K],bool 类型,True 表示"可以 attend"
attention_mask_float = torch.full_like(attention_mask, float('-inf'), dtype=query.dtype)
attention_mask_float.masked_fill_(attention_mask, 0.0)

# fallback:避免某些 query 全为 -inf
all_inf_rows = (attention_mask_float == float('-inf')).all(dim=-1, keepdim=True)  # [B, Q, 1]
if all_inf_rows.any():
    last_key_idx = attention_mask_float.size(-1) - 1
    fix_mask = torch.arange(attention_mask_float.size(-1), device=attention_mask.device) == last_key_idx
    fix_mask = fix_mask.view(1, 1, -1)  # reshape for broadcast
    attention_mask_float = attention_mask_float.masked_fill(all_inf_rows & fix_mask, 0.0)

这样即便某个 query 原本完全不可见,也能保证 softmax 至少有一个有效分布。


可视化建议

可以使用 matplotlib.imshow 直接可视化 [Q, K] 的 mask 分布:

python 复制代码
# 黑色:可见(0.0),白色:被 mask(-inf)
vis_mask = (attn_mask == 0.0).astype(np.uint8)
plt.imshow(vis_mask, cmap='Greys', aspect='auto')

可视化能帮助你快速定位全白 query 行,即潜在 NaN 风险点。


总结

条目 建议
是否允许 query 全被屏蔽 语义上允许,数学上不合法(需处理)
PyTorch 是否兜底 否,需用户自己容错
是否应解锁一个 dummy key 是,最安全的 fallback 机制
可否通过可视化排查 是,黑白图可快速识别空行

相关推荐
老胖闲聊1 小时前
Python Copilot【代码辅助工具】 简介
开发语言·python·copilot
Blossom.1181 小时前
使用Python和Scikit-Learn实现机器学习模型调优
开发语言·人工智能·python·深度学习·目标检测·机器学习·scikit-learn
曹勖之2 小时前
基于ROS2,撰写python脚本,根据给定的舵-桨动力学模型实现动力学更新
开发语言·python·机器人·ros2
scdifsn2 小时前
动手学深度学习12.7. 参数服务器-笔记&练习(PyTorch)
pytorch·笔记·深度学习·分布式计算·数据并行·参数服务器
DFminer2 小时前
【LLM】fast-api 流式生成测试
人工智能·机器人
lyaihao3 小时前
使用python实现奔跑的线条效果
python·绘图
郄堃Deep Traffic3 小时前
机器学习+城市规划第十四期:利用半参数地理加权回归来实现区域带宽不同的规划任务
人工智能·机器学习·回归·城市规划
ai大师3 小时前
(附代码及图示)Multi-Query 多查询策略详解
python·langchain·中转api·apikey·中转apikey·免费apikey·claude4
海盗儿3 小时前
Attention Is All You Need (Transformer) 以及Transformer pytorch实现
pytorch·深度学习·transformer
GIS小天3 小时前
AI+预测3D新模型百十个定位预测+胆码预测+去和尾2025年6月7日第101弹
人工智能·算法·机器学习·彩票