在 PyTorch 中,为何定义一个继承自 nn.Module 的自定义类并实现 forward 方法后,直接调用模型实例时,便会自动调用其 forward 方法?例如使用 output = model(x) 这种形式。
因为自定义的神经网络类所继承的 nn.Module
类对 __call__
方法进行了重写。在 nn.Module
类内部实现的 __call__
方法里,会对用户定义的 forward
方法进行调用。因此,当我们像调用函数一样调用继承自 nn.Module
的自定义神经网络类的实例时,实际上会触发 __call__
方法,进而执行 forward
方法完成前向传播过程。
- 在 Python 中,
__call__
方法允许一个类的实例像函数一样被调用。 - 当你调用一个对象时,Python 会自动查找并调用该对象的
__call__
方法。 - 在
nn.Module
类中,__call__
方法的实现会做一些额外的操作,比如钩子(hook)的处理、梯度计算的设置等,然后调用用户自定义的forward
方法。