引言: 深度学习中的显式层
现代深度学习方法的核心是层的概念。深度学习模型传统上通过堆叠许多这样的层来构建,以创建为解决特定任务而设计的架构。例如,卷积网络由卷积层组成,通常后面跟着逐元素的非线性激活如 ReLU,以及归一化或 dropout 等额外操作,并且可能通过多种不同方式连接在一起,形成诸如残差层之类的结构。同样,Transformer 网络等架构由所谓的自注意力层和全连接层的组合构成,同样以产生模型最终形式的方式堆叠在一起。
一个常见的共同特征(目前已经非常标准以至于从业者常常注意不到)是,现代深度学习中的绝大多数层都是显式 定义的。也就是说,它们由从输入到输出的精确操作序列指定。以缩放自注意力层为例 [Vaswani et al.](https://arxiv.org/abs/1706.03762)。该层是从三个矩阵 K , Q , V ∈ R T × n K,Q,V \in \mathbb{R}^{T \times n} K,Q,V∈RT×n 到输出 Z ∈ R T × n Z \in \mathbb{R}^{T \times n} Z∈RT×n 的映射,由以下操作定义
Z = S e l f A t t e n t i o n ( K , Q , V ) ≡ s o f t m a x ( K Q T n ) V Z = \mathrm{SelfAttention}(K,Q,V) \equiv \mathrm{softmax}\left (\frac{K Q^T}{\sqrt{n}} \right) V Z=SelfAttention(K,Q,V)≡softmax(n KQT)V
(这是一个简化版本的自注意力,仅用于说明,没有掩码或多头结构)。我们可以将这个层写成一个简单的 Python 函数(同样,仅用于说明,并不是实际编写 softmax 操作的方式,更不用说你可能更想使用自动微分库而不是纯 numpy 来编写这些函数)。
python
import numpy as np
def self_attention(K,Q,V):
A = np.exp(K @ Q.T) / np.sqrt(K.shape[1])
return (A / np.sum(A,1)) @ V
K, Q, V = np.random.randn(3, 5, 4)
print(self_attention(K, Q, V))
[[ 0.81840935 -0.60257186 -0.47946598 -0.176331 ]
[-0.13068278 0.13096843 -0.60390036 1.69345066]
[ 0.99015942 -0.8193751 -0.35065811 -1.26840087]
[ 4.97929178 -3.82622622 -1.54894005 -5.19986435]
[ 0.98142343 -0.44840791 -0.26860421 -0.53614219]]
自然地,随着我们向层本身添加更多功能,在自动微分库中实现它们等,事情开始变得稍微复杂一些,但典型层的这种显式形式贯穿始终:层在很大程度上像典型的计算机程序一样构建,我们直接编写代码来生成作为输入函数的层输出。这可能已经如此根深蒂固,以至于很难想象存在一种完全不同的方式来定义层,即通过隐式层。
隐式层
隐式层的核心(正如我们在本文档中所使用的术语)在于,我们不指定如何从输入计算层的输出,而是指定我们希望层输出满足的条件 。也就是说,如果我们将显式层(输入 x ∈ X x \in \mathcal{X} x∈X,输出 z ∈ Z z \in \mathcal{Z} z∈Z)写为某个显式函数 f : X → Z f : \mathcal{X} \rightarrow \mathcal{Z} f:X→Z 的应用
z = f ( x ) z = f(x) z=f(x)
那么隐式层则通过一个函数 g : X × Z → R n g : \mathcal{X} \times \mathcal{Z} \rightarrow \mathbb{R}^n g:X×Z→Rn 来定义,该函数是 x x x 和 z z z 的联合函数,其中层的输出 z z z 需要满足某个约束 ,例如求方程的根:
求 z 使得 g ( x , z ) = 0. \text{求 z 使得 } g(x,z) = 0. 求 z 使得 g(x,z)=0.
这里的符号可能表明 g ( x , z ) g(x,z) g(x,z) 是一个简单的代数方程,但在实践中,同样的形式可以涵盖代数方程和不动点(引出循环反向传播模型或深度平衡模型)、微分方程(引出神经常微分方程),或优化问题的最优性条件(引出可微优化方法)。
在进入具体示例之前,我们应该强调,最初采用这种隐式公式可能看起来像是一个微不足道的点。毕竟,为了实际实现这样的层,我们需要指定某种实际计算 方程 g g g 的根的方法。但正如我们很快将看到的,考虑层的隐式形式有许多实际优势。
最基本的是,隐式形式的层将层的求解过程与层的定义本身分离开来 。这种模块化程度在许多领域已被证明极其有用。例如,微分方程求解器试图找到常微分方程的数值解,可以实现各种自适应步长、对所谓"刚性"方程的校正等,所有这些都旨在为微分方程找到低误差的解。再比如,优化求解器通常涉及非常复杂的启发式方法来求解某些类型的问题,但它们的共同目标是找到优化任务的最小目标解。事实上,由于我们很少找到代数或微分方程等的精确解,不同的求解方法可以根据它们满足层试图满足的条件的程度来客观地相互评估。
这种将层的目标与其求解方法分离本身就是够理想的,但隐式层的第二个优势特别出现在深度学习和自动微分的背景下。机器学习中自动微分(AD)的传统方法是在自动微分框架(如 PyTorch、Tensorflow 或 JAX)中实现所有层,这立即让我们可以将这些层包含在需要梯度来拟合数据的深度模型中。然而,直接在 AD 库中实现求解过程,特别是涉及迭代更新的求解过程(如标准微分方程或优化求解器),意味着我们需要存储完整求解过程的计算图,以及在此求解过程中创建的临时迭代值。这需要在内存中存储大量信息,这通常可能成为训练大型深度学习模型时的瓶颈。幸运的是,如下文所述,并在本教程中多次强调,隐式层具有一个显著优势,即我们可以使用隐函数定理直接在这些方程的解点处计算梯度,而无需沿途存储任何中间变量。这极大地改善了内存消耗,通常也提高了这些方法的数值精度,为深度学习中的隐式模型提供了另一个显著的好处。
应用与示例
由于本章的剩余部分将专注于一个极其简单的演示,旨在作为教学说明(而非最新技术的展示),我们想简要强调一下已经使用隐式层解决的广泛应用。以下只是少量示例(作者恰好最熟悉的几个实例),但它们希望展示当前隐式层研究所涉及方法的广度。我们将在文本中深入探讨其中一些内容,但大多数情况下,你需要关注该领域当前的研究以跟上这些方法所涉及的所有应用领域。
隐式层已被用于:
- 以可微方式求解任意结构化凸问题(使用
cvxpy库)。 - 求解组合优化问题的平滑松弛,如图割、可满足性等。
- 将微分方程作为深度网络中的层进行积分(其本身就有众多应用,例如整合连续时间观测,或近似传统残差网络的连续版本)。
- 创建用于高效表示平滑密度的架构,用于生成建模等领域。
- 在语言建模任务上实现与最先进 Transformer 模型相当的性能(在相同参数数量下),并在分类和语义分割等任务上与最先进的计算机视觉架构相当。
你的第一个隐式模型:不动点迭代
在深入探讨数学细节和多种不同形式的隐式模型之前,让我们从一个特别简单的例子开始:一个由不动点迭代定义的网络层。如上所述,这种类型的层可以追溯到循环反向传播的一些原始公式,也是我们稍后将讨论的深度平衡模型的基础。
不动点迭代层
尽管稍后我们将采用将该层视为某个方程的根的观点,但为了介绍该层,假设我们有相同维度的输入和输出 x , z ∈ R n x,z \in \mathbb{R}^n x,z∈Rn,并考虑以下计算输出 z z z 作为 x x x 函数的方法
z : = 0 重复直到收敛: z : = tanh ( W z + x ) \begin{aligned} & z := 0 \\ & \text{重复直到收敛:} \\ & \quad z := \tanh(Wz + x) \end{aligned} z:=0重复直到收敛:z:=tanh(Wz+x)
其中网络参数 W ∈ R n × n W \in \mathbb{R}^{n \times n} W∈Rn×n。这是一个不动点迭代 的实例:在特定条件下,该过程将收敛到某个固定输出 z ⋆ z^\star z⋆,当然它具有性质
z ⋆ = tanh ( W z ⋆ + x ) . z^\star = \tanh(W z^\star + x). z⋆=tanh(Wz⋆+x).
我们暂时推迟讨论 为什么 这对于层来说可能是一种特别好的形式,但简要地说,这种类型的层可以被解释为一个简单的循环网络,其中 z z z 是隐藏层,我们重复将网络应用于相同的 输入 x x x。这样的层也可以获得"深度"神经网络的一些好处(因为它涉及重复应用非线性),同时只拥有"单层"的参数 W W W。但我们稍后将讨论这些优势,现在先专注于使用这样的层。
注意,这当然可以写成上述形式的隐式层,即 z ⋆ z^\star z⋆ 是寻根方程的解
求 z 使得 g ( x , z ) = 0 , 其中 g ( x , z ) ≡ z − tanh ( W z + x ) . \text{求 z 使得 } g(x,z) = 0, \quad \text{其中 } g(x,z) \equiv z - \tanh(W z + x). 求 z 使得 g(x,z)=0,其中 g(x,z)≡z−tanh(Wz+x).
注意这个迭代不一定收敛:虽然 tanh \tanh tanh 激活函数会强制 z z z 的值永远不会离开 − 1 , + 1 -1,+1 −1,+1 范围,但取决于 W W W 的值,可能会发生值无限循环而永远达不到不动点的情况。另一方面,如果 W = 0 W=0 W=0,则迭代在一次迭代后达到"不动点" z ⋆ = tanh ( x ) z^\star = \tanh(x) z⋆=tanh(x)。我们这里要说明的是,对于"典型"的 W W W 值(即大多数深度学习库使用的线性层的默认值,加上它们在优化过程中达到的值),这个迭代确实会收敛,我们将在后面讨论不动点存在性和唯一性的问题。
实现不动点迭代层
在像 PyTorch 或 JAX 这样的自动微分库中实现隐式层肯定比传统层需要更多的努力。但实际的核心实现仍然相当直接,并且通过这些工具变得更加容易。
首先,让我们考虑实现这样一个层的最简单方式,它简单地重复不动点迭代直到收敛,全部通过库的正常 autograd 功能实现(即我们只是"展开"不动点计算)。由于这是通过正常的 autograd 机制发生的,每个中间迭代值都必须存储在内存中,反向传播必须以相反的顺序在相同的迭代上类似地进行。现在,我们将创建一个单独的层,实现上述 tanh \tanh tanh 和线性层的组合(加上其他简化技巧,如存储最近的迭代次数和误差),但在后续章节中,我们将使其更加模块化,以便我们可以找到使用相同库实现的通用层的类似不动点。
python
import torch
import torch.nn as nn
class TanhFixedPointLayer(nn.Module):
def __init__(self, out_features, tol = 1e-4, max_iter=50):
super().__init__()
self.linear = nn.Linear(out_features, out_features, bias=False)
self.tol = tol
self.max_iter = max_iter
def forward(self, x):
# 将输出 z 初始化为零
z = torch.zeros_like(x)
self.iterations = 0
# 迭代直到收敛
while self.iterations < self.max_iter:
z_next = torch.tanh(self.linear(z) + x)
self.err = torch.norm(z - z_next)
z = z_next
self.iterations += 1
if self.err < self.tol:
break
return z
我们可以在随机输出上运行这个层,以确认它确实达到了不动点。
python
layer = TanhFixedPointLayer(50)
X = torch.randn(10,50)
Z = layer(X)
print(f"Terminated after {layer.iterations} iterations with error {layer.err}")
Terminated after 14 iterations with error 5.921083356952295e-05
虽然在随机数据上运行该层可能并没有提供那么多信息,但如果我们能看到该层在实际模型中使用,就会更有趣一些。因此,我们将在下面展示一个在 MNIST 数据集上训练的简单模型,使用一个单独的不动点层(在不动点层之前有一个额外的线性输入层,在不动点层之后有一个线性层)。这个模型并不打算打破任何记录,但它提供了一个比单独运行该层更有用的实验基础。
python
# 导入 MNIST 数据集和数据加载器
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
mnist_train = datasets.MNIST(".", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST(".", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
python
# 构建带有不动点层的简单模型
import torch.optim as optim
torch.manual_seed(0)
model = nn.Sequential(nn.Flatten(),
nn.Linear(784, 100),
TanhFixedPointLayer(100, max_iter=200),
nn.Linear(100, 10)
).to(device)
opt = optim.SGD(model.parameters(), lr=1e-1)
python
# 运行单个 epoch 的通用函数(训练或评估)
def epoch(loader, model, opt=None, monitor=None):
total_loss, total_err, total_monitor = 0.,0.,0.
model.eval() if opt is None else model.train()
for X,y in loader:
X,y = X.to(device), y.to(device)
yp = model(X)
loss = nn.CrossEntropyLoss()(yp,y)
if opt:
opt.zero_grad()
loss.backward()
if sum(torch.sum(torch.isnan(p.grad)) for p in model.parameters()) == 0:
opt.step()
total_err += (yp.max(dim=1)[1] != y).sum().item()
total_loss += loss.item() * X.shape[0]
if monitor is not None:
total_monitor += monitor(model)
return total_err / len(loader.dataset), total_loss / len(loader.dataset), total_monitor / len(loader)
最后让我们训练模型 10 个 epoch。除了训练/测试误差和损失之外,我们还将打印层收敛到不动点所需的平均不动点迭代次数。
python
for i in range(10):
if i == 5:
opt.param_groups[0]["lr"] = 1e-2
train_err, train_loss, train_fpiter = epoch(train_loader, model, opt, lambda x : x[2].iterations)
test_err, test_loss, test_fpiter = epoch(test_loader, model, monitor = lambda x : x[2].iterations)
print(f"Train Error: {train_err:.4f}, Loss: {train_loss:.4f}, FP Iters: {train_fpiter:.2f} | " +
f"Test Error: {test_err:.4f}, Loss: {test_loss:.4f}, FP Iters: {test_fpiter:.2f}")
再次强调,这里并没有打破任何记录(单隐藏层网络以更快的执行/训练速度达到了相同的性能),但至少网络使用这个层进行了训练。然而,有几点需要注意。第一点是,为了收敛到 10 − 4 10^{-4} 10−4 精度的不动点,我们最终运行了相当多的不动点迭代次数。如果你查看每个小批量所需的单独迭代次数,你会发现其中一些甚至在 200 步后也没有达到这个容差,而是在较低的精度水平退出(在训练过程中某些时刻,不动点迭代甚至可能变得不稳定,如果没有适当的错误处理,这通常会显著降低模型性能)。这似乎是一个相当大的缺点。我们实际上在实践中运行了一个 50-80 层的网络,但没有看到比标准 MLP 更多的优势(注意,因为我们在每次迭代 z : = tanh ( W z + x ) z := \tanh(Wz + x) z:=tanh(Wz+x) 中重新添加了输入,这与传统相同深度的 MLP 不同,后者会遭受梯度消失/爆炸的问题)。
要真正看到这些层的潜在优势,我们需要引入更多的想法。
替代寻根技术
回想一下,隐式层的一个好处是它们提供了层的计算内容与计算方式之间的分离。在上面的例子中,不动点迭代的目标是找到某个 z z z 使得
z = tanh ( W z + x ) . z = \tanh(W z + x). z=tanh(Wz+x).
实现这一点的一种方法是简单地迭代这个方程,但这绝不是唯一的方法。或者,我们可以使用更快的寻根方法,如牛顿法,来更高效地找到这个解。
牛顿法是一种通用的求根技术。对于某个函数 g : R n → R n g : \mathbb{R}^n \rightarrow \mathbb{R}^n g:Rn→Rn,如果我们希望找到根 g ( z ) = 0 g(z) = 0 g(z)=0,那么牛顿法重复更新
z : = z − ( ∂ g ∂ z ) − 1 g ( z ) z := z - \left ( \frac{\partial g}{\partial z} \right ) ^{-1} g(z) z:=z−(∂z∂g)−1g(z)
其中 ∂ g ∂ z \frac{\partial g}{\partial z} ∂z∂g 表示 g g g 对 z z z 的 Jacobian(实践中通常需要"受保护"的更新,采取更小的步长以确保残差 ∥ g ( z ) ∥ \|g(z)\| ∥g(z)∥ 充分下降,但我们这里不考虑这一点)。虽然我们可以求助于自动微分来计算 Jacobian(在后面的章节中,当我们在不动点迭代中使用更通用的层时,我们需要这样做),但对于我们的 tanh \tanh tanh 加线性层的情况,很容易以闭形式计算 Jacobian。具体来说,我们试图找到方程 g ( x , z ) = 0 g(x,z) = 0 g(x,z)=0 的根(回到本节前面的符号,我们显式地依赖于层输入 x x x),其中
g ( x , z ) = z − tanh ( W z + x ) . g(x,z) = z - \tanh(Wz + x). g(x,z)=z−tanh(Wz+x).
那么我们的 Jacobian 为
∂ g ∂ z = I − d i a g ( tanh ′ ( W z + x ) ) W \frac{\partial g}{\partial z} = I - \mathrm{diag}(\tanh'(Wz + x)) W ∂z∂g=I−diag(tanh′(Wz+x))W
其中 tanh ′ \tanh' tanh′ 表示 tanh \tanh tanh 函数的导数,由下式给出
tanh ′ ( x ) = s e c h 2 ( x ) . \tanh'(x) = \mathrm{sech}^2(x). tanh′(x)=sech2(x).
让我们看看牛顿法在代码中的实现。该实现比简单的不动点迭代稍微复杂一些,因为需要计算牛顿步,但这只需要额外的几行代码。
python
class TanhNewtonLayer(nn.Module):
def __init__(self, out_features, tol = 1e-4, max_iter=50):
super().__init__()
self.linear = nn.Linear(out_features, out_features, bias=False)
self.tol = tol
self.max_iter = max_iter
def forward(self, x):
# 将输出 z 初始化为零
z = torch.tanh(x)
self.iterations = 0
# 迭代直到收敛
while self.iterations < self.max_iter:
z_linear = self.linear(z) + x
g = z - torch.tanh(z_linear)
self.err = torch.norm(g)
if self.err < self.tol:
break
# 牛顿步
J = torch.eye(z.shape[1])[None,:,:] - (1 / torch.cosh(z_linear)**2)[:,:,None]*self.linear.weight[None,:,:]
z = z - torch.solve(g[:,:,None], J)[0][:,:,0]
self.iterations += 1
g = z - torch.tanh(self.linear(z) + x)
z[torch.norm(g,dim=1) > self.tol,:] = 0
return z
python
layer = TanhNewtonLayer(50)
X = torch.randn(10,50)
Z = layer(X)
print(f"Terminated after {layer.iterations} iterations with error {layer.err}")
Terminated after 3 iterations with error 1.2266763178558904e-06
该方法能够比固定点迭代更快地收敛,但存在一个(主要的)注意事项:我们必须在每次迭代中求解一个线性系统。同样地,由于我们使用自动微分来实现整个过程,这意味着在反向传播过程中也需要通过求解器进行梯度回传。不过我们可以像之前一样将其直接嵌入相同的训练流程中。
python
torch.manual_seed(0)
model = nn.Sequential(nn.Flatten(),
nn.Linear(784, 100),
TanhNewtonLayer(100, max_iter=40),
nn.Linear(100, 10)
).to(device)
opt = optim.SGD(model.parameters(), lr=1e-1)
for i in range(8):
if i == 5:
opt.param_groups[0]["lr"] = 1e-2
train_err, train_loss, train_fpiter = epoch(train_loader, model, opt, lambda x : x[2].iterations)
test_err, test_loss, test_fpiter = epoch(test_loader, model, monitor = lambda x : x[2].iterations)
print(f"Train Error: {train_err:.4f}, Loss: {train_loss:.4f}, Newton Iters: {train_fpiter:.2f} | " +
f"Test Error: {test_err:.4f}, Loss: {test_loss:.4f}, Newton Iters: {test_fpiter:.2f}")
Train Error: 0.1156, Loss: 0.4131, FP Iters: 6.46 | Test Error: 0.0633, Loss: 0.2117, FP Iters: 7.03
Train Error: 0.0610, Loss: 0.2065, FP Iters: 6.98 | Test Error: 0.0490, Loss: 0.1681, FP Iters: 6.78
Train Error: 0.0463, Loss: 0.1555, FP Iters: 6.73 | Test Error: 0.0438, Loss: 0.1483, FP Iters: 6.04
Train Error: 0.0380, Loss: 0.1280, FP Iters: 7.01 | Test Error: 0.0373, Loss: 0.1276, FP Iters: 6.28
Train Error: 0.0328, Loss: 0.1087, FP Iters: 7.29 | Test Error: 0.0332, Loss: 0.1167, FP Iters: 7.06
Train Error: 0.0238, Loss: 0.0817, FP Iters: 7.41 | Test Error: 0.0323, Loss: 0.1134, FP Iters: 7.63
Train Error: 0.0217, Loss: 0.0764, FP Iters: 8.07 | Test Error: 0.0311, Loss: 0.1064, FP Iters: 7.74
Train Error: 0.0201, Loss: 0.0719, FP Iters: 8.69 | Test Error: 0.0298, Loss: 0.1057, FP Iters: 8.28
同样,该方法工作得相当好。然而,这个实现存在一些显著的问题。首先,如果你运行代码会立即注意到,该方法明显慢于 上面更简单的不动点迭代方法。尽管所需的迭代次数远少于不动点迭代,但每次单独迭代也慢得多,因为它涉及为小批量中的每个样本构建和求逆一个单独的(在这里是 100 × 100 100 \times 100 100×100 的)Jacobian 矩阵。对于更大的隐藏单元规模(特别是对于卷积网络),求逆甚至存储这些矩阵将很快变得不可行。实际上,精确的牛顿法很少使用,而是采用拟牛顿法来改善标准不动点迭代的收敛性,同时改善实际运行时间。
这种方法的第二个问题更微妙,但实际上是一个更大的问题。因为我们在自动微分工具包中直接实现了牛顿法,所以这种方法存在几个大的缺点。首先,与不动点迭代一样,自动微分工具需要保存隐藏单元的中间迭代值;但在这里,这意味着我们还 需要在内存中存储 Jacobian 项的中间迭代值,这极大地增加了内存消耗,即使在我们能够 存储和求逆完整 Jacobian 的情况下也是如此。此外,通过重复求逆进行反向传播可能是一个数值不稳定的过程:如果逆矩阵接近奇异,那么即使前向传播正确收敛,反向传播仍然可能在梯度中产生数值误差。事实上,你会注意到我们在 epoch() 中包含了一个"NaN 检查"。如果我们不 这样做,那么对于牛顿法,该方法会立即失败:如果你检查,你会发现大约 5% 的更新在梯度中有 NaN 值,这是由于 Jacobian 的条件数导致的,这也是导致该方法实际上比不动点迭代版本收敛更慢的原因。
这为求解隐式模型的"高效"方法描绘了一幅相当暗淡的图景。幸运的是,有一种更好的方式来实现这些层,这要归功于隐函数定理。
隐式层中的微分
到目前为止,我们以与实现任何其他层完全相同的方式实现了隐式层的求解器,并让自动微分库处理反向传播。然而,有一种更好的方式来计算关于隐藏层不动点的导数。为了理解如何做到这一点,让我们考虑隐式层的通用形式,即给定 x x x,找到某个 z z z 使得
g ( x , z ) = 0. g(x,z) = 0. g(x,z)=0.
记 z ⋆ ( x ) z^\star(x) z⋆(x) 为求解该不动点的值,这样写是为了强调隐式层的输出当然仍然是输入的(隐式)函数。
现在考虑如何计算这个输出关于输入的 Jacobian
∂ z ⋆ ( x ) ∂ x . \frac{\partial z^\star(x)}{\partial x}. ∂x∂z⋆(x).
与你熟悉的传统函数(我们有一个显式形式来计算从输入到输出的结果)不同,如何确定这样的 Jacobian 可能并不明显。但实际上使用隐式微分 来计算这个项是非常直接的,这是一种可追溯至几个世纪前的微积分技术。具体来说,为了推导这个 Jacobian 的表达式,我们从我们知道对 z ⋆ ( x ) z^\star(x) z⋆(x) 成立的不动点条件开始,并对两边关于 x x x 求导:
∂ g ( x , z ⋆ ( x ) ) ∂ x = 0. \frac{\partial g(x,z^\star(x))}{\partial x} = 0. ∂x∂g(x,z⋆(x))=0.
现在我们使用链式法则展开这个偏导数:由于 g g g 是两个变量的函数,将有一个涉及关于每个变量的导数的项:
∂ g ( x , z ⋆ ) ∂ x + ∂ g ( x , z ⋆ ) ∂ z ⋆ ∂ z ⋆ ( x ) ∂ x = 0 \frac{\partial g(x,z^\star)}{\partial x} + \frac{\partial g(x,z^\star)}{\partial z^\star}\frac{\partial z^\star(x)}{\partial x} = 0 ∂x∂g(x,z⋆)+∂z⋆∂g(x,z⋆)∂x∂z⋆(x)=0
其中符号 z ⋆ z^\star z⋆(未表示为 x x x 的函数)仅表示我们在这里将 z ⋆ z^\star z⋆ 视为一个固定 值(即 Jacobian ∂ g ( x , z ⋆ ) ∂ x \frac{\partial g(x,z^\star)}{\partial x} ∂x∂g(x,z⋆) 只是 g g g 对 x x x 的 Jacobian,在点 ( x , z ⋆ ) (x,z^\star) (x,z⋆) 处求值)。因此,这个项以及 ∂ g ( x , z ⋆ ) ∂ z ⋆ \frac{\partial g(x,z^\star)}{\partial z^\star} ∂z⋆∂g(x,z⋆) 项本身可以使用普通的自动微分库计算。最后,我们重写这个方程,用我们已知的表达式给出我们想要的表达式:
∂ z ⋆ ( x ) ∂ x = − ( ∂ g ( x , z ⋆ ) ∂ z ⋆ ) − 1 ∂ g ( x , z ⋆ ) ∂ x . \frac{\partial z^\star(x)}{\partial x} = - \left ( \frac{\partial g(x,z^\star)}{\partial z^\star} \right )^{-1} \frac{\partial g(x,z^\star)}{\partial x}. ∂x∂z⋆(x)=−(∂z⋆∂g(x,z⋆))−1∂x∂g(x,z⋆).
从技术上讲,为了确保我们能实际应用这个定理,需要满足某些条件,以便隐函数 z ⋆ ( x ) z^\star(x) z⋆(x) 保证存在:这些条件反映在所谓的隐函数定理中,这将在下一章讨论。此外,就像牛顿法与拟牛顿法的情况一样,在实践中通常无法直接计算这个逆矩阵,而是需要迭代过程。我们将在下一章中更多地讨论数学细节和形式化内容,但对于我们实际需要推导的大多数目的来说,这种"非正式"推导几乎就是你所需要的全部。最后,尽管我们上面写的是关于 x x x 的 Jacobian 公式,但当 g g g 也是某个参数 θ \theta θ(如权重和偏置)的函数时,完全相同的推导也适用于求关于这些参数的 Jacobian。
从对这个公式的详细推导中回过头来看,隐函数定理带来了一个非常实际的后果。即,该公式给出了必要 Jacobian 的形式,而无需反向传播通过用于获得不动点的方法 。换句话说,我们如何计算函数的零点(无论是通过不动点迭代、牛顿法还是拟牛顿法)完全不重要。重要的只是找到不动点(使用你想要的任何技术),然后我们可以直接使用这个解析形式计算必要的 Jacobian(或者更准确地说,计算反向传播,在实践中通常不需要显式计算 Jacobian)。不需要在内存中存储用于计算不动点的迭代方法的中间项(使方法更节省内存),也不需要在自动微分层中展开前向计算。
实现隐式微分
让我们看看在实践中如何实现隐式微分。首先,再次考虑我们的 tanh 加线性层,其中 g ( x , z ) g(x,z) g(x,z) 函数为
g ( x , z ) = z − tanh ( W z + x ) g(x,z) = z - \tanh(Wz + x) g(x,z)=z−tanh(Wz+x)
在这种情况下,隐式微分所需的 Jacobian ∂ g ∂ z ⋆ \frac{\partial g}{\partial z^\star} ∂z⋆∂g 为
∂ g ∂ z ⋆ = I − d i a g ( tanh ′ ( W z ⋆ + x ) ) W . \frac{\partial g}{\partial z^\star} = I - \mathrm{diag}(\tanh'(Wz^\star+x)) W. ∂z⋆∂g=I−diag(tanh′(Wz⋆+x))W.
你可能注意到这与我们使用牛顿法求解不动点时形成的 Jacobian 完全相同 。这并非偶然:事实上,在牛顿法中求根所需的 Jacobian 项与通过隐式微分计算反向传播所需的 Jacobian 项完全相同。这产生了一个非常好的性质:当我们通过牛顿法(或任何计算并求逆 Jacobian 的方法)找到方程的根时,通过牛顿法计算反向传播实际上是"免费"的(至少相对于求解不动点的复杂度而言):我们可以简单地重用在前向传播中计算的 Jacobian(及其逆矩阵)。当然,由于在实践中我们通常使用拟牛顿法或一阶方法来寻找隐式层的不动点,这并不像看起来那么大的优势。但是,无论如何,在我们确实在前向传播中计算了 Jacobian 的近似值的情况下,利用这个计算来进行反向传播也是有益的。
在我们继续实现之前,应该强调实际的隐式微分过程在反向传播(即反向模式自动微分)中是如何工作的。在反向传播中,我们实际上不需要计算网络中中间层的完整 Jacobian。相反,反向传播的目标是计算关于某个标量 损失函数的梯度。如果我们用上面的梯度将其写出来,它看起来像
∂ ℓ ∂ x = ∂ ℓ ∂ z ⋆ ∂ z ⋆ ∂ x = − ∂ ℓ ∂ z ⋆ ( ∂ g ∂ z ⋆ ) − 1 ∂ g ∂ x \frac{\partial \ell}{\partial x} = \frac{\partial \ell}{\partial z^\star} \frac{\partial z^\star}{\partial x} = - \frac{\partial \ell}{\partial z^\star} \left (\frac{\partial g}{\partial z^\star} \right )^{-1} \frac{\partial g}{\partial x} ∂x∂ℓ=∂z⋆∂ℓ∂x∂z⋆=−∂z⋆∂ℓ(∂z⋆∂g)−1∂x∂g
其中我们在最后一个等式中应用了上面的隐式微分公式。在反向传播中,这个项是从左到右计算的,这意味着我们实际上不需要计算完整的 Jacobian ∂ z ⋆ ∂ x \frac{\partial z^\star}{\partial x} ∂x∂z⋆,只需要计算上面所示的向量-Jacobian 积 。按照惯例,大多数自动微分框架用对梯度 的操作来表述(标量值函数的 Jacobian 的转置)
∇ z ⋆ ℓ = ( ∂ ℓ ∂ z ⋆ ) T \nabla_{z^\star} \ell = \left ( \frac{\partial \ell}{\partial z^\star} \right )^T ∇z⋆ℓ=(∂z⋆∂ℓ)T
所以我们需要乘以 Jacobian 的转置
∇ x ℓ = ( ∂ g ∂ x ) T ( ∂ g ∂ z ⋆ ) − T ∇ z ⋆ ℓ . \nabla_x \ell = \left (\frac{\partial g}{\partial x} \right )^T \left (\frac{\partial g}{\partial z^\star} \right )^{-T} \nabla_{z^\star} \ell. ∇xℓ=(∂x∂g)T(∂z⋆∂g)−T∇z⋆ℓ.
再次强调,我们实际上不需要存储和计算实际的逆矩阵 ( ∂ g ∂ z ⋆ ) − T \left (\frac{\partial g}{\partial z^\star} \right )^{-T} (∂z⋆∂g)−T,只需要能够求解该公式中出现的(线性)方程。
最后,让我们讨论如何在自动微分工具包中实现这样的公式。具体细节当然因框架而异,但由于我们最终是在实现一种新类型 的函数(即在前向传播中在任何自动微分之外计算不动点,然后计算"自定义"反向传播),你可能倾向于使用像 autograd.Function 接口这样的功能(如果你在 PyTorch 中实现的话),它允许你在库的正常自动微分过程之外指定前向和反向传播。但这在实践中会有点繁琐:毕竟,自动微分的好处之一是我们可能能够在同一个自动微分库中实现函数 g g g(无论它使用卷积、自注意力还是任何其他功能),并且我们希望自动包含所有这些梯度,而无需为我们想要实现的每个特定函数 g g g 编写新函数。幸运的是,有一种相当直接但微妙的方式来处理这个问题。我们将在后续章节中回到高效隐式微分的几个示例,它们各自有自己的实现技巧,但对于像这样的简单示例,一个常用的范式包括以下三个步骤:
- 在自动微分带之外 ,求解隐式层 g ( x , z ⋆ ) = 0 g(x,z^\star) = 0 g(x,z⋆)=0 的根。
- 通过在自动微分带内 运行以下赋值来"重新接合"自动微分:
z : = z ⋆ − g ( x , z ⋆ ) . z := z^\star - g(x,z^\star). z:=z⋆−g(x,z⋆).
这具有将偏导数 − ∂ g ∂ x -\frac{\partial g}{\partial x} −∂x∂g "重新插入"到 autograd tape 的效果(并且在 z z z 的值方面是空操作,因为 g ( x , z ⋆ ) = 0 g(x,z^\star) = 0 g(x,z⋆)=0)。 - 在反向传播中添加一个"后向钩子",用于乘以 ( ∂ g ∂ z ⋆ ) − T (\frac{\partial g}{\partial z^\star})^{-T} (∂z⋆∂g)−T。这将修正反向传播,使其根据隐函数定理正确实现梯度。
对于之前的 tanh 加线性层,这产生了如下实现。注意该层与我们之前实现的版本基本相同,只是牛顿法在 torch.no_grad(): 块内运行,并且我们通过 register_hook 函数添加了反向传播钩子。对于上面的第二步,鉴于前面强调的 g g g 函数,赋值就是
z : = z ⋆ − g ( x , z ⋆ ) = z ⋆ − z ⋆ + tanh ( W z + x ) = tanh ( W z ⋆ + x ) z := z^\star - g(x,z^\star) = z^\star - z^\star + \tanh(Wz + x) = \tanh(Wz^\star + x) z:=z⋆−g(x,z⋆)=z⋆−z⋆+tanh(Wz+x)=tanh(Wz⋆+x)
即,在找到不动点后,我们在自动微分带内运行一次不动点迭代。
python
class TanhNewtonImplicitLayer(nn.Module):
def __init__(self, out_features, tol = 1e-4, max_iter=50):
super().__init__()
self.linear = nn.Linear(out_features, out_features, bias=False)
self.tol = tol
self.max_iter = max_iter
def forward(self, x):
# 在 autograd 框架外运行牛顿法
with torch.no_grad():
z = torch.tanh(x)
self.iterations = 0
while self.iterations < self.max_iter:
z_linear = self.linear(z) + x
g = z - torch.tanh(z_linear)
self.err = torch.norm(g)
if self.err < self.tol:
break
# 牛顿步
J = torch.eye(z.shape[1])[None,:,:] - (1 / torch.cosh(z_linear)**2)[:,:,None]*self.linear.weight[None,:,:]
z = z - torch.solve(g[:,:,None], J)[0][:,:,0]
self.iterations += 1
# 重新接合 autograd 并添加梯度钩子
z = torch.tanh(self.linear(z) + x)
z.register_hook(lambda grad : torch.solve(grad[:,:,None], J.transpose(1,2))[0][:,:,0])
return z
注意这是一个相当非标准的实现:我们在正常的自动微分带之外 实现了前向传播的一个元素,然后添加一个后向钩子来"修正"梯度。我们可以使用内置的 gradcheck 命令验证这个层的正确性。注意这个实现不适用于双重反向传播(即 gradgradcheck 不会工作),但这可以通过稍微复杂的方法来解决,并且实践中通常不需要,所以我们暂时忽略它。
python
from torch.autograd import gradcheck
layer = TanhNewtonImplicitLayer(5, tol=1e-10).double()
gradcheck(layer, torch.randn(3, 5, requires_grad=True, dtype=torch.double), check_undefined_grad=False)
最后,再次为了演示,我们将使用这种新的隐式层变体训练我们的 MNIST 网络。正如所希望的,该方法确实比以前实现的牛顿法更快 且更稳定。虽然我们再次强调,在这种情况下使用精确牛顿法通常不是一个合理的方法,但类似的方法在后续章节讨论可微优化时实际上会非常有用。
python
torch.manual_seed(0)
model = nn.Sequential(nn.Flatten(),
nn.Linear(784, 100),
TanhNewtonImplicitLayer(100, max_iter=40),
nn.Linear(100, 10)
).to(device)
opt = optim.SGD(model.parameters(), lr=1e-1)
for i in range(10):
if i == 5:
opt.param_groups[0]["lr"] = 1e-2
train_err, train_loss, train_fpiter = epoch(train_loader, model, opt, lambda x : x[2].iterations)
test_err, test_loss, test_fpiter = epoch(test_loader, model, monitor = lambda x : x[2].iterations)
print(f"Train Error: {train_err:.4f}, Loss: {train_loss:.4f}, Newton Iters: {train_fpiter:.2f} | " +
f"Test Error: {test_err:.4f}, Loss: {test_loss:.4f}, Newton Iters: {test_fpiter:.2f}")
Train Error: 0.1130, Loss: 0.4061, Newton Iters: 6.48 | Test Error: 0.0605, Loss: 0.2040, Newton Iters: 6.92
Train Error: 0.0577, Loss: 0.1949, Newton Iters: 6.95 | Test Error: 0.0487, Loss: 0.1660, Newton Iters: 6.67
Train Error: 0.0449, Loss: 0.1500, Newton Iters: 6.76 | Test Error: 0.0412, Loss: 0.1398, Newton Iters: 5.79
Train Error: 0.0369, Loss: 0.1236, Newton Iters: 6.68 | Test Error: 0.0374, Loss: 0.1238, Newton Iters: 6.63
Train Error: 0.0316, Loss: 0.1056, Newton Iters: 7.25 | Test Error: 0.0346, Loss: 0.1145, Newton Iters: 7.12
Train Error: 0.0214, Loss: 0.0736, Newton Iters: 7.43 | Test Error: 0.0308, Loss: 0.1018, Newton Iters: 7.30
Train Error: 0.0191, Loss: 0.0682, Newton Iters: 7.84 | Test Error: 0.0285, Loss: 0.0989, Newton Iters: 7.79
Train Error: 0.0182, Loss: 0.0651, Newton Iters: 8.32 | Test Error: 0.0288, Loss: 0.0992, Newton Iters: 8.12
Train Error: 0.0179, Loss: 0.0626, Newton Iters: 9.07 | Test Error: 0.0288, Loss: 0.1001, Newton Iters: 8.64
Train Error: 0.0171, Loss: 0.0607, Newton Iters: 9.76 | Test Error: 0.0277, Loss: 0.0990, Newton Iters: 9.29
章节结束语
在深入研究更实际和多样化的实际隐式模型世界之前,我们想强调一下到目前为止我们所取得的成就。使用比"传统"深度模型非常少的额外代码(并且肯定不比传统循环模型多多少),我们能够编写一个层,它能够 1) 通过牛顿法求解非线性寻根问题,等价于找到无限深度网络的不动点,并且 2) 轻松集成到自动微分工具中。一旦你克服了隐式微分的一些数学符号,这些方法的相对简便性确实是在深度学习中整体使用隐式层更引人注目的因素之一。
在本教程的其余部分,我们将为你提供所需工具和背景知识,以将隐式层应用于各种问题和场景,并附有代码示例。我们希望这将使读者能够快速整合并在这一新方向上取得进展。