nn.TransformerEncoder的输出为NaN值的原因及解决方法

问题描述:

当使用nn.TransformerEncoder时,即如下情况:

复制代码
实例化一个TransformerEncoder 
 self.encoder_layer = nn.TransformerEncoderLayer(d_model=encoder_in_dim, nhead=encoder_head,
                                                           dim_feedforward=encoder_ffnn_dim,
                                                           batch_first=batch_first)
 self.model = nn.TransformerEncoder(self.pre_encoder_layer, num_layers=pre_encoder_layer_num)
调用:
transformer_features =  self.model(embeddings, src_key_padding_mask=src_padding_mask)

transformer_features的值为NaN

原因在于src_padding_mask的传入出现均为0/False的情况!即attention---mask出现了全1/True行

由于我们在使用MultiheadAttention做self-attention时因为batch内序列长度不一致,难免需要使用mask

以pytorch自带的torch.nn.TransformerEncoder方法为例,其forward函数如下

复制代码
forward(src, mask=None, src_key_padding_mask=None)

这里的mask会送到torch.nn.TransformerEncoderLayer的forward函数:

python 复制代码
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        src2 = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0]

之后送到MultiheadAttention 的forward函数的attn_mask参数,而这里做的是一个self attention。

此时若是attn_mask出现一整行都是True的情况,通过如下源码中的实现mask的方法可以看到:

python 复制代码
    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_output_weights.masked_fill_(attn_mask, float('-inf'))
        else:
            attn_output_weights += attn_mask

把权重矩阵中需要mask的位置置为负无穷,然后再按行做softmax,问题就在这里,把一个元素全是是负无穷的tensor送给softmax,就会得到一个全是NaN的tensor。然后NaN和任何数运算都是NaN,NaN会传染,再经过一轮self attention,输出的tensor就全成NaN了。

解决方法:避免attention mask中存在全1/True的行

相关推荐
青瓷程序设计7 小时前
植物识别系统【最新版】Python+TensorFlow+Vue3+Django+人工智能+深度学习+卷积神经网络算法
人工智能·python·深度学习
AI即插即用7 小时前
即插即用系列 | CVPR 2025 WPFormer:用于表面缺陷检测的查询式Transformer
人工智能·深度学习·yolo·目标检测·cnn·视觉检测·transformer
T0uken8 小时前
【Python】UV:境内的深度学习环境搭建
人工智能·深度学习·uv
AI即插即用8 小时前
即插即用系列 | 2025 MambaNeXt-YOLO 炸裂登场!YOLO 激吻 Mamba,打造实时检测新霸主
人工智能·pytorch·深度学习·yolo·目标检测·计算机视觉·视觉检测
studytosky12 小时前
深度学习理论与实战:MNIST 手写数字分类实战
人工智能·pytorch·python·深度学习·机器学习·分类·matplotlib
哥布林学者12 小时前
吴恩达深度学习课程三: 结构化机器学习项目 第一周:机器学习策略(二)数据集设置
深度学习·ai
【建模先锋】14 小时前
精品数据分享 | 锂电池数据集(四)PINN+锂离子电池退化稳定性建模和预测
深度学习·预测模型·pinn·锂电池剩余寿命预测·锂电池数据集·剩余寿命
九年义务漏网鲨鱼14 小时前
【大模型学习】现代大模型架构(二):旋转位置编码和SwiGLU
深度学习·学习·大模型·智能体
CoovallyAIHub14 小时前
破局红外小目标检测:异常感知Anomaly-Aware YOLO以“俭”驭“繁”
深度学习·算法·计算机视觉
云雾J视界14 小时前
AI芯片设计实战:用Verilog高级综合技术优化神经网络加速器功耗与性能
深度学习·神经网络·verilog·nvidia·ai芯片·卷积加速器