在使用tensorboard可视化,经常会将模型通过save_graph方法保存下来,方便查看结构。在使用save_graph经常会遇到错误(至少我经常遇到),对于我,最常见的一个错误为
Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
Graph diff:
.....
First diverging operator:
Node diff:
...
我是在模型中用了 pytorch 自带的 nn.MultiheadAttention 发生了这个错误,一个简单的解决方法是将原本的
python
self.attn = nn.MultiheadAttention(128, 8, 0.1, batch_first=True)
中的 batch_first = True 删去,修改之后为
python
self.attn = nn.MultiheadAttention(128, 8, 0.1)
注意删除 batch_first = True 后, 输入格式需要改为 (seq, batch, feature)。