JVP (Jacobian-vector product) and VJP (vector-Jacobian product)

JVP {Jacobian-vector product} and VJP {vector-Jacobian product}

  • [1. The Jacobian Matrix](#1. The Jacobian Matrix)
  • [2. Vector-Jacobian product (VJP)](#2. Vector-Jacobian product (VJP))
  • [3. Jacobian-vector product (JVP)](#3. Jacobian-vector product (JVP))
  • [4. JVP (Jacobian-vector product) and VJP (vector-Jacobian product)](#4. JVP (Jacobian-vector product) and VJP (vector-Jacobian product))
  • [5. A Concrete Example](#5. A Concrete Example)
    • [5.1. Vector Jacobian Product](#5.1. Vector Jacobian Product)
    • [5.2. Jacobian Vector Product](#5.2. Jacobian Vector Product)
  • References

Of VJPs and JVPs
https://maximerobeyns.com/of_vjps_and_jvps

1. The Jacobian Matrix

Suppose that we have a function f : R n → R m f: \R^{n} \to \R^{m} f:Rn→Rm, which maps an n n n-dimensional input x ∈ R n \mathbf{x} \in \R^{n} x∈Rn to an m m m-dimensional output y ∈ R m \mathbf{y} \in \R^{m} y∈Rm.

One way to view this function is as a column vector of m m m scalar-valued functions stacked one on top of each other:

f ( x ) = [ y 1 ( x 1 , ... , x n ) y 2 ( x 1 , ... , x n ) ⋮ y m ( x 1 , ... , x n ) ] m × 1 \begin{aligned} {f}(\mathbf{x}) = \begin{bmatrix} y_{1}(x_1,\ldots,x_n) \\[1.2ex] y_{2}(x_1,\ldots,x_n) \\[1.2ex] \vdots \\[1.2ex] y_{m}(x_1,\ldots,x_n) \\[1.2ex] \end{bmatrix}_{m \times 1} \end{aligned} f(x)= y1(x1,...,xn)y2(x1,...,xn)⋮ym(x1,...,xn) m×1

The Jacobian matrix J f ( x ) \mathbf {J}{f}(\mathbf{x}) Jf(x) of f ( x ) {f}(\mathbf{x}) f(x) is an m × n m \times n m×n matrix, where each row contains the gradient of the i i ith "scalar function" wrt the inputs x \mathbf{x} x, ∇ x y i ( x ) \nabla{\mathbf{x}}y_{i}(\mathbf{x}) ∇xyi(x):

J f ( x ) = [ ∇ x y 1 ( x ) ∇ x y 2 ( x ) ⋮ ∇ x y m ( x ) ] = [ ∂ y 1 ∂ x 1 ∂ y 1 ∂ x 2 ⋯ ∂ y 1 ∂ x n ∂ y 2 ∂ x 1 ∂ y 2 ∂ x 2 ⋯ ∂ y 2 ∂ x n ⋮ ⋮ ⋱ ⋮ ∂ y m ∂ x 1 ∂ y m ∂ x 2 ⋯ ∂ y m ∂ x n ] m × n \begin{aligned} \mathbf {J}{f}(\mathbf{x}) &= \begin{bmatrix} \nabla{\mathbf{x}}y_{1}(\mathbf{x}) \\[1.2ex] \nabla_{\mathbf{x}}y_{2}(\mathbf{x}) \\[1.2ex] \vdots \\[1.2ex] \nabla_{\mathbf{x}}y_{m}(\mathbf{x}) \\[1.2ex] \end{bmatrix} \\ &= \begin{bmatrix} \frac{\partial y_{1}}{\partial x_{1}} & \frac{\partial y_{1}}{\partial x_{2}} & \cdots & \frac{\partial y_{1}}{\partial x_{n}} \\[1.2ex] \frac{\partial y_{2}}{\partial x_{1}} & \frac{\partial y_{2}}{\partial x_{2}} & \cdots & \frac{\partial y_{2}}{\partial x_{n}} \\[1.2ex] \vdots & \vdots & \ddots & \vdots \\[1.2ex] \frac{\partial y_{m}}{\partial x_{1}} & \frac{\partial y_{m}}{\partial x_{2}} & \cdots & \frac{\partial y_{m}}{\partial x_{n}} \\[1.2ex] \end{bmatrix}_{m \times n} \end{aligned} Jf(x)= ∇xy1(x)∇xy2(x)⋮∇xym(x) = ∂x1∂y1∂x1∂y2⋮∂x1∂ym∂x2∂y1∂x2∂y2⋮∂x2∂ym⋯⋯⋱⋯∂xn∂y1∂xn∂y2⋮∂xn∂ym m×n

More generally, the i , j i,j i,j-th entry of J f ( x ) \mathbf {J}_{f}(\mathbf{x}) Jf(x) contains the partial derivative of the i i ith output with respect to the j j jth input.

At risk of stating the obvious, the vector-Jacobian product (VJP) is the left-multiplication of the Jacobian by some vector, while the Jacobian-vector product (JVP) is the right-multiplication of the Jacobian by an appropriately shaped vector.

复制代码
obvious [ˈɒbviəs]
adj. 明显的;显然的;易理解的;公认的

2. Vector-Jacobian product (VJP)

Given a vector v ∈ R m \mathbf{v} \in \R^{m} v∈Rm, the VJP is the following row vector:

( v ⊤ ) 1 × m @ J m × n ∈ R 1 × n \begin{equation} (\mathbf{v}^{\top}){1 \times m} \text{@} \mathbf {J}{m \times n} \in \R^{1 \times n} \end{equation} (v⊤)1×m@Jm×n∈R1×n

Being n n n-dimensional, we have one VJP element for each of the function inputs, x \mathbf{x} x (it is an "input space" concept).

It tells us "in what direction should each of the inputs change, in order to get (as close as possible to) a change of v \mathbf{v} v in the outputs?"

We might think of it as a "sensitivity map" over the inputs: if I want to increase the first output element y 1 y_1 y1 by 0.5 (by setting v 1 = 0.5 v_1 = 0.5 v1=0.5, and v i ≠ 1 = 0 v_{i \ne 1} = 0 vi=1=0), then the resulting n n n-dimensional VJP will tell me how I ought to perturb x \mathbf{x} x.

This is what we do during reverse-mode automatic differentiation.

3. Jacobian-vector product (JVP)

Given a vector v ∈ R n \mathbf{v} \in \R^{n} v∈Rn, the JVP is the following column vector:

J m × n @ v n × 1 ∈ R m × 1 \begin{equation} \mathbf {J}{m \times n} \text{@} \mathbf{v}{n \times 1} \in \R^{m \times 1} \end{equation} Jm×n@vn×1∈Rm×1

Being m m m-dimensional, we have one JVP element for each of the function outputs, y \mathbf{y} y (it is an "output space" concept).

It tells us "in what direction do the outputs of f ( x ) {f}(\mathbf{x}) f(x) change if I make a perturbation v \mathbf{v} v to the inputs?"

We might think of it as a directional derivative of f ( x ) {f}(\mathbf{x}) f(x) in the direction of v \mathbf{v} v: if I perturb the first input element x 1 x_{1} x1 by 0.5 (by setting v 1 = 0.5 v_{1} = 0.5 v1=0.5, and v i ≠ 1 = 0 v_{i \ne 1} = 0 vi=1=0), then the resulting m m m-dimensional JVP will tell me how much the output y \mathbf{y} y will change.

This is what we do during forward-mode automatic differentiation.

Note that the JVP really just corresponds to a first-order Taylor approximation to the function f ( x ) {f}(\mathbf{x}) f(x). If f ( x ) {f}(\mathbf{x}) f(x) is differentiable at some point x 0 \mathbf{x}_{0} x0, then we can approximate it as:

f ( x ) = f ( x 0 ) + J ( x − x 0 ) + o ( ∥ x − x 0 ∥ ) , \begin{equation} f(\mathbf{x}) = f(\mathbf{x}{0}) + \mathbf {J}(\mathbf{x} - \mathbf{x}{0}) + o(\Vert \mathbf{x} - \mathbf{x}_{0}\Vert), \end{equation} f(x)=f(x0)+J(x−x0)+o(∥x−x0∥),

where we use the little- o o o notation to denote that the remainder goes to 0 0 0 faster than ∥ x − x 0 ∥ \Vert \mathbf{x} - \mathbf{x}{0}\Vert ∥x−x0∥ as x → x 0 \mathbf{x} \to \mathbf{x}{0} x→x0. We often refer to this as the linearization of f ( x ) {f}(\mathbf{x}) f(x).

4. JVP (Jacobian-vector product) and VJP (vector-Jacobian product)

复制代码
mnemonic [nɪˈmɒnɪk]
adj. 记忆的;记忆术的;增进记忆的
n. 帮助记忆的词句或诗歌等;助记符号

One simple way to remember which is which is that for a simple affine transformation f : R n → R m f: \R^{n} \to \R^{m} f:Rn→Rm defined as

f ( x ) = W m × n @ x , \begin{equation} f(\mathbf{x}) = \mathbf{W}_{m \times n} \text{@} \mathbf{x}, \end{equation} f(x)=Wm×n@x,

the Jacobian is simply J = W ∈ R m × n \mathbf {J} = \mathbf{W} \in \R^{m\times n} J=W∈Rm×n. Hence,

  • VJP: left-multiplying J m × n \mathbf {J}{m \times n} Jm×n or W m × n \mathbf{W}{m \times n} Wm×n by a vector must correspond to an output-space perturbation, and hence returns the "sensitivty map" over the inputs.
  • JVP: right-multiplying J m × n \mathbf {J}{m \times n} Jm×n or W m × n \mathbf{W}{m \times n} Wm×n by a vector must correspond to an input-space perturbation, and gives us the directional derivative of f ( x ) {f}(\mathbf{x}) f(x).

Here is a rote memonic you can apply in times of confusion: just look at whether the V in the acronym comes before (input) or after (output) the Jacobian's J .

这里有一个死记硬背的记忆方法,可以在你感到困惑时使用:只需看看缩写词中的字母 V 是在雅可比矩阵的 J 之前 (输入) 还是之后 (输出)。

V JP: This is the n n n-dimensional, input-space object: the sensitivity map telling us how the input affects the output.

JV P: This is the m m m-dimensional, output-space object: the directional derivative telling us how the function output changes when the inputs are perturbed.

复制代码
rote [rəʊt]
n. 死记硬背

5. A Concrete Example

Let's consider a very simple function with n = 3 n=3 n=3 inputs and m = 2 m=2 m=2 outputs to illustrate what's going on:

f ( x ) = [ y 1 ( x ) y 2 ( x ) ] = [ x 1 2 + x 2 + x 3 x 1 − x 2 + x 3 ] 2 × 1 \begin{aligned} f(\mathbf{x}) &= \begin{bmatrix} y_{1}(\mathbf{x}) \\ y_{2}(\mathbf{x}) \\ \end{bmatrix} \\ &= \begin{bmatrix} x_{1}^{2} + x_{2} + x_{3} \\ x_{1} - x_{2} + x_{3} \\ \end{bmatrix}_{2 \times 1} \end{aligned} f(x)=[y1(x)y2(x)]=[x12+x2+x3x1−x2+x3]2×1

The Jacobian ∇ x f ( x ) ∣ x = [ x 1 = 2 , x 2 = 3 , x 3 = 1 ] T \nabla_{\mathbf{x}}f(\mathbf{x})\vert_{\mathbf{x} =[{x_{1}=2, \ x_{2}=3, \ x_{3}=1}]^\text{T}} ∇xf(x)∣x=[x1=2, x2=3, x3=1]T, evaluated at some example input point x = [ x 1 = 2 , x 2 = 3 , x 3 = 1 ] T {\mathbf{x} =[{x_{1}=2, \ x_{2}=3, \ x_{3}=1}]^\text{T}} x=[x1=2, x2=3, x3=1]T, is the following 2 × 3 2\times 3 2×3 matrix:

J f ( x ) = [ ∂ y 1 ∂ x 1 ∂ y 1 ∂ x 2 ∂ y 1 ∂ x 3 ∂ y 2 ∂ x 1 ∂ y 2 ∂ x 2 ∂ y 2 ∂ x 3 ] 2 × 3 = [ 2 x 1 1 1 1 − 1 1 ] x 1 = 2 , x 2 = 3 , x 3 = 1 = [ 4 1 1 1 − 1 1 ] x 1 = 2 , x 2 = 3 , x 3 = 1 \begin{aligned} \mathbf {J}{f}(\mathbf{x}) &= \begin{bmatrix} \frac{\partial y_1}{\partial x_1} &\frac{\partial y_1}{\partial x_2} & \frac{\partial y_1}{\partial x_3} \\[1.2ex] \frac{\partial y_2}{\partial x_1} & \frac{\partial y_2}{\partial x_2} & \frac{\partial y_2}{\partial x_3} \\[1.2ex] \end{bmatrix}{2 \times 3} \\[1.2ex] &= \begin{bmatrix} 2x_{1} & 1 & 1 \\[1.2ex] 1 & -1 & 1 \\[1.2ex] \end{bmatrix}{x{1}=2, \ x_{2}=3, \ x_{3}=1} \\[1.2ex] &= \begin{bmatrix} 4 & 1 & 1 \\[1.2ex] 1 & -1 & 1 \\[1.2ex] \end{bmatrix}{x{1}=2, \ x_{2}=3, \ x_{3}=1} \\[1.2ex] \end{aligned} Jf(x)= ∂x1∂y1∂x1∂y2∂x2∂y1∂x2∂y2∂x3∂y1∂x3∂y2 2×3= 2x111−111 x1=2, x2=3, x3=1= 411−111 x1=2, x2=3, x3=1

5.1. Vector Jacobian Product

For a VJP ( v ⊤ ) 1 × m @ J m × n (\mathbf{v}^{\top}){1 \times m} \text{@} \mathbf {J}{m \times n} (v⊤)1×m@Jm×n, we will require an m m m-dimensional vector v \mathbf{v} v. Consider, to start, what happens if v \mathbf{v} v is a one-hot vector; we get

v ⊤ @ J = [ 1 0 ] @ [ 4 1 1 1 − 1 1 ] = [ 4 1 1 ] \begin{aligned} \mathbf{v}^{\top} \text{@} \mathbf {J} = \begin{bmatrix}1 & 0\end{bmatrix} \text{@} \begin{bmatrix} 4 & 1 & 1 \\ 1 & -1 & 1 \end{bmatrix} = \begin{bmatrix}4 & 1 & 1 \end{bmatrix} \end{aligned} v⊤@J=[10]@[411−111]=[411]

In other words, with some element i ∈ { 1 , 2 , ... , m } i \in \{1, 2, \ldots, m\} i∈{1,2,...,m} of v \mathbf{v} v set to 1, and all others set to 0, we have selected the i i ith row of the Jacobian, which is exactly the gradient of the "scalar function" y i ( x ) y_{i}(\mathbf{x}) yi(x). Recall, this is an n n n-dimensional vector, which tells us how the i i ith output component of y i ( x ) y_{i}(\mathbf{x}) yi(x) depends on each of the n n n input components: it describes the sensitivity of this i i ith output component to infintessimal changes in each input component.

复制代码
infinitesimal [ˌɪnfɪnɪ'tesɪml]
adj. 极微小
adv. 限小地

Now, if each of the m m m elements of v \mathbf{v} v were set to 1 1 1, we would pay equal attention to the sensitivity of all the output components to changes in the input - doing an element-wise sum over the rows of the Jacobian:

v ⊤ @ J = [ 1 1 ] @ [ 4 1 1 1 − 1 1 ] = [ 5 0 2 ] \begin{aligned} \mathbf{v}^{\top} \text{@} \mathbf {J} = \begin{bmatrix}1 & 1 \end{bmatrix} \text{@} \begin{bmatrix} 4 & 1 & 1 \\ 1 & -1 & 1 \end{bmatrix} = \begin{bmatrix}5 & 0 & 2 \end{bmatrix} \end{aligned} v⊤@J=[11]@[411−111]=[502]

Depending on the importance of each output element, we could also re-weight each row (i.e. gradient vector) differently, giving us a weighted "sensitivity map" over the inputs.

v ⊤ @ J = [ 0.5 1 ] @ [ 4 1 1 1 − 1 1 ] = [ 3 − 0.5 1.5 ] \begin{aligned} \mathbf{v}^{\top} \text{@} \mathbf {J} = \begin{bmatrix}0.5 & 1 \end{bmatrix} \text{@} \begin{bmatrix} 4 & 1 & 1 \\ 1 & -1 & 1 \end{bmatrix} = \begin{bmatrix}3 & -0.5 & 1.5 \end{bmatrix} \end{aligned} v⊤@J=[0.51]@[411−111]=[3−0.51.5]

If the elements of v \mathbf{v} v don't sum to 1, then this simply scales the VJP magnitude up or down, and it is trivial to normalise this again.

复制代码
trivial ['trɪviəl]
adj. 不重要的;琐碎的;微不足道的
normalise ['nɔ:məlaɪz]
v. 常规化

5.2. Jacobian Vector Product

For the JVP J m × n @ v n × 1 \mathbf {J}{m \times n} \text{@} \mathbf{v}{n \times 1} Jm×n@vn×1, we need an n n n-dimensional perturbation vector which gives the direction along which we wish to calculate the derivative of f ( x ) f(\mathbf{x}) f(x).

复制代码
perturbation [ˌpɜ:tə'beɪʃn]
n. 摄动;扰动;微扰;忧虑

First, let's once again consider one-hot vector. Setting the first dimension to 1, we get

J m × n @ v n × 1 = [ 4 1 1 1 − 1 1 ] [ 1 0 0 ] = [ 4 1 ] . \begin{aligned} \mathbf {J}{m \times n} \text{@} \mathbf{v}{n \times 1} = \begin{bmatrix} 4 & 1 & 1 \\ 1 & -1 & 1 \end{bmatrix}\begin{bmatrix}1 \\ 0 \\ 0\end{bmatrix} = \begin{bmatrix}4 \\ 1\end{bmatrix}. \end{aligned} Jm×n@vn×1=[411−111] 100 =[41].

Generalising: by setting element j ∈ { 1 , 2 , ... , n } j \in \{1, 2, \ldots, n\} j∈{1,2,...,n} of v \mathbf{v} v to 1, we select the j j jth column of the Jacobian. This m m m-dimensional vector tells us the derivative of f ( x ) f(\mathbf{x}) f(x) along the direction of the first basis vector e 1 \mathbf{e}{1} e1. Put otherwise, if we had perturbed the first input value x 1 x{1} x1 by 1, the JVP tells us how the function output would change in response.

References

1\] Yongqiang Cheng (程永强), \[2\] Of VJPs and JVPs,

相关推荐
lhxcc_fly2 天前
手撕简易版的vector
c++·vector
linweidong4 天前
虎牙C++面试题及参考答案(上)
stl·vector·线程·内存管理·c++20·c++面试·c++调用
燃于AC之乐8 天前
【C++手撕STL】Vector模拟实现:从零到一的容器设计艺术
开发语言·c++·容器·stl·vector·底层·模板编程
燃于AC之乐12 天前
深入解剖STL Vector:从底层原理到核心接口的灵活运用
开发语言·c++·迭代器·stl·vector·源码分析·底层原理
hellokandy15 天前
C++ 如何知道程序最多可以申请多少内存
c++·vector·cin·cout
tod11316 天前
从零手写一个面试级 C++ vector:内存模型、拷贝语义与扩容策略全解析
c++·面试·职场和发展·stl·vector
Yongqiang Cheng23 天前
Hadamard product (阿达玛乘积)
hadamard·product·阿达玛乘积
夜莺云原生监控25 天前
Vector 日志采集实战:采集夜莺日志推送 VictoriaLogs 完整教程
vector·nightingale·夜莺监控·victorialogs
脏脏a25 天前
手撕 vector:从 0 到 1 模拟实现 STL 容器
开发语言·c++·vector