昇思MindSpore进阶教程--黑塞矩阵

大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。

技术上主攻前端开发、鸿蒙开发和AI算法研究。

努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧

黑塞矩阵

在介绍MindSpore提供的计算黑塞矩阵的方法之前,首先对黑塞矩阵进行介绍。

黑塞矩阵可以由梯度操作 ∇ \nabla ∇和广度梯度操作 ∂ \partial ∂的复合得到,即
∇ ∘ ∂ : F n 1 ⟶ F n n ⟶ F n × n n \nabla \circ \partial: F_{n}^{1} \longrightarrow F_{n}^{n} \longrightarrow F_{n \times n}^{n} ∇∘∂:Fn1⟶Fnn⟶Fn×nn

将该复合操作用于f,得到,
f ⟼ ∇ f ⟼ J ∇ f f \longmapsto \nabla f \longmapsto J_{\nabla f} f⟼∇f⟼J∇f

可以得到黑塞矩阵,
H f = [ ∂ ( ∇ 1 f ) ∂ x 1 ∂ ( ∇ 1 f ) ∂ x 2 ... ∂ ( ∇ 1 f ) ∂ x n ∂ ( ∇ 2 f ) ∂ x 1 ∂ ( ∇ 2 f ) ∂ x 2 ... ∂ ( ∇ 2 f ) ∂ x n ⋮ ⋮ ⋱ ⋮ ∂ ( ∇ n f ) ∂ x 1 ∂ ( ∇ n f ) ∂ x 2 ... ∂ ( ∇ n f ) ∂ x n ] = [ ∂ 2 f ∂ x 1 2 ∂ 2 f ∂ x 2 ∂ x 1 ... ∂ 2 f ∂ x n ∂ x 1 ∂ 2 f ∂ x 1 ∂ x 2 ∂ 2 f ∂ x 2 2 ... ∂ 2 f ∂ x n ∂ x 2 ⋮ ⋮ ⋱ ⋮ ∂ 2 f ∂ x 1 ∂ x n ∂ 2 f ∂ x 2 ∂ x n ... ∂ 2 f ∂ x n 2 ] \begin{split}H_{f} = \begin{bmatrix} \frac{\partial (\nabla {1}f)}{\partial x{1}} &\frac{\partial (\nabla {1}f)}{\partial x{2}} &\dots &\frac{\partial (\nabla {1}f)}{\partial x{n}} \\ \frac{\partial (\nabla {2}f)}{\partial x{1}} &\frac{\partial (\nabla {2}f)}{\partial x{2}} &\dots &\frac{\partial (\nabla {2}f)}{\partial x{n}} \\ \vdots &\vdots &\ddots &\vdots \\ \frac{\partial (\nabla {n}f)}{\partial x{1}} &\frac{\partial (\nabla {n}f)}{\partial x{2}} &\dots &\frac{\partial (\nabla {n}f)}{\partial x{n}} \end{bmatrix} = \begin{bmatrix} \frac{\partial ^2 f}{\partial x_{1}^{2}} &\frac{\partial ^2 f}{\partial x_{2} \partial x_{1}} &\dots &\frac{\partial ^2 f}{\partial x_{n} \partial x_{1}} \\ \frac{\partial ^2 f}{\partial x_{1} \partial x_{2}} &\frac{\partial ^2 f}{\partial x_{2}^{2}} &\dots &\frac{\partial ^2 f}{\partial x_{n} \partial x_{2}} \\ \vdots &\vdots &\ddots &\vdots \\ \frac{\partial ^2 f}{\partial x_{1} \partial x_{n}} &\frac{\partial ^2 f}{\partial x_{2} \partial x_{n}} &\dots &\frac{\partial ^2 f}{\partial x_{n}^{2}} \end{bmatrix}\end{split} Hf= ∂x1∂(∇1f)∂x1∂(∇2f)⋮∂x1∂(∇nf)∂x2∂(∇1f)∂x2∂(∇2f)⋮∂x2∂(∇nf)......⋱...∂xn∂(∇1f)∂xn∂(∇2f)⋮∂xn∂(∇nf) = ∂x12∂2f∂x1∂x2∂2f⋮∂x1∂xn∂2f∂x2∂x1∂2f∂x22∂2f⋮∂x2∂xn∂2f......⋱...∂xn∂x1∂2f∂xn∂x2∂2f⋮∂xn2∂2f

易见,黑塞矩阵是一个实对称矩阵。

黑塞矩阵的应用:利用黑塞矩阵,我们可以探索神经网络在某点处的曲率,为训练是否收敛提供数值依据。

计算黑塞矩阵

在MindSpore中,我们可以通过jacfwd和jacrev的任意组合来计算黑塞矩阵。

python 复制代码
Din = 32
Dout = 16
weight = ops.randn(Dout, Din)
bias = ops.randn(Dout)
x = ops.randn(Din)

hess1 = jacfwd(jacfwd(forecast, grad_position=2), grad_position=2)(weight, bias, x)
hess2 = jacfwd(jacrev(forecast, grad_position=2), grad_position=2)(weight, bias, x)
hess3 = jacrev(jacfwd(forecast, grad_position=2), grad_position=2)(weight, bias, x)
hess4 = jacrev(jacrev(forecast, grad_position=2), grad_position=2)(weight, bias, x)

np.allclose(hess1.asnumpy(), hess2.asnumpy())
np.allclose(hess2.asnumpy(), hess3.asnumpy())
np.allclose(hess3.asnumpy(), hess4.asnumpy())

计算黑塞-向量积

计算黑塞-向量积(Hessian-vector product, hvp)的最直接的方法计算一个完整的黑塞矩阵,并将其与向量进行点积运算。但MindSpore提供了更好的方法,使得不需要计算一个完整的黑塞矩阵,便可以计算黑塞-向量积。下面我们介绍计算黑塞-向量积的两种方法。

  • 将反向模式自动微分与反向模式自动微分组合。

  • 将反向模式自动微分与前向模式自动微分组合。

下面先介绍,在MindSpore中,如何使用反向模式自动微分与前向模式自动微分组合的方式计算黑塞-向量积,

python 复制代码
def hvp_revfwd(f, inputs, vector):
    return jvp(grad(f), inputs, vector)[1]

def f(x):
    return x.sin().sum()

inputs = ops.randn(128)
vector = ops.randn(128)

result_hvp_revfwd = hvp_revfwd(f, inputs, vector)
print(result_hvp_revfwd.shape)

如果前向自动微分不能满足要求,我们可以使用反向模式自动微分与反向模式自动微分组合的方式来计算黑塞-向量积,

python 复制代码
def hvp_revrev(f, inputs, vector):
    _, vjp_fn = vjp(grad(f), *inputs)
    return vjp_fn(*vector)

result_hvp_revrev = hvp_revrev(f, (inputs,), (vector,))
print(result_hvp_revrev[0].shape)
相关推荐
云樱梦海6 分钟前
OpenAI 推出 Canvas 工具,助力用户与 ChatGPT 协作写作和编程
人工智能·chatgpt·canvas
小白熊_XBX19 分钟前
机器学习可视化教程——混淆矩阵与回归图
人工智能·python·机器学习·矩阵·回归·sklearn
_.Switch38 分钟前
自动机器学习(AutoML):实战项目中的应用与实现
人工智能·python·机器学习·自然语言处理·架构·scikit-learn
肖遥Janic1 小时前
Stable Diffusion绘画 | 插件-Deforum:动态视频生成(终篇)
人工智能·ai·ai作画·stable diffusion
念啊啊啊啊丶1 小时前
【AIGC】2021-arXiv-LoRA:大型语言模型的低秩自适应
人工智能·深度学习·神经网络·机器学习·自然语言处理
Mr_Happy_Li1 小时前
利用GPU进行训练
python·深度学习·神经网络·机器学习·计算机视觉
柠檬少少开发1 小时前
基于matlab的语音信号处理
人工智能·语音识别
要养家的程序猿2 小时前
上海AI Lab视频生成大模型书生.筑梦环境搭建&推理测试
人工智能·音视频
凭栏落花侧2 小时前
数据挖掘中的常见误区与注意事项
人工智能·数据挖掘
ζั͡ޓއއއ坏尐絯2 小时前
深度学习(7):RNN实战之人名的国籍预测
人工智能·rnn·深度学习