JVP {Jacobian-vector product} and VJP {vector-Jacobian product}
- [1. The Jacobian Matrix](#1. The Jacobian Matrix)
- [2. Vector-Jacobian product (VJP)](#2. Vector-Jacobian product (VJP))
- [3. Jacobian-vector product (JVP)](#3. Jacobian-vector product (JVP))
- [4. JVP (Jacobian-vector product) and VJP (vector-Jacobian product)](#4. JVP (Jacobian-vector product) and VJP (vector-Jacobian product))
- [5. A Concrete Example](#5. A Concrete Example)
-
- [5.1. Vector Jacobian Product](#5.1. Vector Jacobian Product)
- [5.2. Jacobian Vector Product](#5.2. Jacobian Vector Product)
- References
Of VJPs and JVPs
https://maximerobeyns.com/of_vjps_and_jvps
1. The Jacobian Matrix
Suppose that we have a function f : R n → R m f: \R^{n} \to \R^{m} f:Rn→Rm, which maps an n n n-dimensional input x ∈ R n \mathbf{x} \in \R^{n} x∈Rn to an m m m-dimensional output y ∈ R m \mathbf{y} \in \R^{m} y∈Rm.
One way to view this function is as a column vector of m m m scalar-valued functions stacked one on top of each other:
f ( x ) = [ y 1 ( x 1 , ... , x n ) y 2 ( x 1 , ... , x n ) ⋮ y m ( x 1 , ... , x n ) ] m × 1 \begin{aligned} {f}(\mathbf{x}) = \begin{bmatrix} y_{1}(x_1,\ldots,x_n) \\[1.2ex] y_{2}(x_1,\ldots,x_n) \\[1.2ex] \vdots \\[1.2ex] y_{m}(x_1,\ldots,x_n) \\[1.2ex] \end{bmatrix}_{m \times 1} \end{aligned} f(x)= y1(x1,...,xn)y2(x1,...,xn)⋮ym(x1,...,xn) m×1
The Jacobian matrix J f ( x ) \mathbf {J}{f}(\mathbf{x}) Jf(x) of f ( x ) {f}(\mathbf{x}) f(x) is an m × n m \times n m×n matrix, where each row contains the gradient of the i i ith "scalar function" wrt the inputs x \mathbf{x} x, ∇ x y i ( x ) \nabla{\mathbf{x}}y_{i}(\mathbf{x}) ∇xyi(x):
J f ( x ) = [ ∇ x y 1 ( x ) ∇ x y 2 ( x ) ⋮ ∇ x y m ( x ) ] = [ ∂ y 1 ∂ x 1 ∂ y 1 ∂ x 2 ⋯ ∂ y 1 ∂ x n ∂ y 2 ∂ x 1 ∂ y 2 ∂ x 2 ⋯ ∂ y 2 ∂ x n ⋮ ⋮ ⋱ ⋮ ∂ y m ∂ x 1 ∂ y m ∂ x 2 ⋯ ∂ y m ∂ x n ] m × n \begin{aligned} \mathbf {J}{f}(\mathbf{x}) &= \begin{bmatrix} \nabla{\mathbf{x}}y_{1}(\mathbf{x}) \\[1.2ex] \nabla_{\mathbf{x}}y_{2}(\mathbf{x}) \\[1.2ex] \vdots \\[1.2ex] \nabla_{\mathbf{x}}y_{m}(\mathbf{x}) \\[1.2ex] \end{bmatrix} \\ &= \begin{bmatrix} \frac{\partial y_{1}}{\partial x_{1}} & \frac{\partial y_{1}}{\partial x_{2}} & \cdots & \frac{\partial y_{1}}{\partial x_{n}} \\[1.2ex] \frac{\partial y_{2}}{\partial x_{1}} & \frac{\partial y_{2}}{\partial x_{2}} & \cdots & \frac{\partial y_{2}}{\partial x_{n}} \\[1.2ex] \vdots & \vdots & \ddots & \vdots \\[1.2ex] \frac{\partial y_{m}}{\partial x_{1}} & \frac{\partial y_{m}}{\partial x_{2}} & \cdots & \frac{\partial y_{m}}{\partial x_{n}} \\[1.2ex] \end{bmatrix}_{m \times n} \end{aligned} Jf(x)= ∇xy1(x)∇xy2(x)⋮∇xym(x) = ∂x1∂y1∂x1∂y2⋮∂x1∂ym∂x2∂y1∂x2∂y2⋮∂x2∂ym⋯⋯⋱⋯∂xn∂y1∂xn∂y2⋮∂xn∂ym m×n
More generally, the i , j i,j i,j-th entry of J f ( x ) \mathbf {J}_{f}(\mathbf{x}) Jf(x) contains the partial derivative of the i i ith output with respect to the j j jth input.
At risk of stating the obvious, the vector-Jacobian product (VJP) is the left-multiplication of the Jacobian by some vector, while the Jacobian-vector product (JVP) is the right-multiplication of the Jacobian by an appropriately shaped vector.
obvious [ˈɒbviəs]
adj. 明显的;显然的;易理解的;公认的
2. Vector-Jacobian product (VJP)
Given a vector v ∈ R m \mathbf{v} \in \R^{m} v∈Rm, the VJP is the following row vector:
( v ⊤ ) 1 × m @ J m × n ∈ R 1 × n \begin{equation} (\mathbf{v}^{\top}){1 \times m} \text{@} \mathbf {J}{m \times n} \in \R^{1 \times n} \end{equation} (v⊤)1×m@Jm×n∈R1×n
Being n n n-dimensional, we have one VJP element for each of the function inputs, x \mathbf{x} x (it is an "input space" concept).
It tells us "in what direction should each of the inputs change, in order to get (as close as possible to) a change of v \mathbf{v} v in the outputs?"
We might think of it as a "sensitivity map" over the inputs: if I want to increase the first output element y 1 y_1 y1 by 0.5 (by setting v 1 = 0.5 v_1 = 0.5 v1=0.5, and v i ≠ 1 = 0 v_{i \ne 1} = 0 vi=1=0), then the resulting n n n-dimensional VJP will tell me how I ought to perturb x \mathbf{x} x.
This is what we do during reverse-mode automatic differentiation.
3. Jacobian-vector product (JVP)
Given a vector v ∈ R n \mathbf{v} \in \R^{n} v∈Rn, the JVP is the following column vector:
J m × n @ v n × 1 ∈ R m × 1 \begin{equation} \mathbf {J}{m \times n} \text{@} \mathbf{v}{n \times 1} \in \R^{m \times 1} \end{equation} Jm×n@vn×1∈Rm×1
Being m m m-dimensional, we have one JVP element for each of the function outputs, y \mathbf{y} y (it is an "output space" concept).
It tells us "in what direction do the outputs of f ( x ) {f}(\mathbf{x}) f(x) change if I make a perturbation v \mathbf{v} v to the inputs?"
We might think of it as a directional derivative of f ( x ) {f}(\mathbf{x}) f(x) in the direction of v \mathbf{v} v: if I perturb the first input element x 1 x_{1} x1 by 0.5 (by setting v 1 = 0.5 v_{1} = 0.5 v1=0.5, and v i ≠ 1 = 0 v_{i \ne 1} = 0 vi=1=0), then the resulting m m m-dimensional JVP will tell me how much the output y \mathbf{y} y will change.
This is what we do during forward-mode automatic differentiation.
Note that the JVP really just corresponds to a first-order Taylor approximation to the function f ( x ) {f}(\mathbf{x}) f(x). If f ( x ) {f}(\mathbf{x}) f(x) is differentiable at some point x 0 \mathbf{x}_{0} x0, then we can approximate it as:
f ( x ) = f ( x 0 ) + J ( x − x 0 ) + o ( ∥ x − x 0 ∥ ) , \begin{equation} f(\mathbf{x}) = f(\mathbf{x}{0}) + \mathbf {J}(\mathbf{x} - \mathbf{x}{0}) + o(\Vert \mathbf{x} - \mathbf{x}_{0}\Vert), \end{equation} f(x)=f(x0)+J(x−x0)+o(∥x−x0∥),
where we use the little- o o o notation to denote that the remainder goes to 0 0 0 faster than ∥ x − x 0 ∥ \Vert \mathbf{x} - \mathbf{x}{0}\Vert ∥x−x0∥ as x → x 0 \mathbf{x} \to \mathbf{x}{0} x→x0. We often refer to this as the linearization of f ( x ) {f}(\mathbf{x}) f(x).
4. JVP (Jacobian-vector product) and VJP (vector-Jacobian product)
mnemonic [nɪˈmɒnɪk]
adj. 记忆的;记忆术的;增进记忆的
n. 帮助记忆的词句或诗歌等;助记符号
One simple way to remember which is which is that for a simple affine transformation f : R n → R m f: \R^{n} \to \R^{m} f:Rn→Rm defined as
f ( x ) = W m × n @ x , \begin{equation} f(\mathbf{x}) = \mathbf{W}_{m \times n} \text{@} \mathbf{x}, \end{equation} f(x)=Wm×n@x,
the Jacobian is simply J = W ∈ R m × n \mathbf {J} = \mathbf{W} \in \R^{m\times n} J=W∈Rm×n. Hence,
- VJP: left-multiplying J m × n \mathbf {J}{m \times n} Jm×n or W m × n \mathbf{W}{m \times n} Wm×n by a vector must correspond to an output-space perturbation, and hence returns the "sensitivty map" over the inputs.
- JVP: right-multiplying J m × n \mathbf {J}{m \times n} Jm×n or W m × n \mathbf{W}{m \times n} Wm×n by a vector must correspond to an input-space perturbation, and gives us the directional derivative of f ( x ) {f}(\mathbf{x}) f(x).
Here is a rote memonic you can apply in times of confusion: just look at whether the V in the acronym comes before (input) or after (output) the Jacobian's J .
这里有一个死记硬背的记忆方法,可以在你感到困惑时使用:只需看看缩写词中的字母 V 是在雅可比矩阵的 J 之前 (输入) 还是之后 (输出)。
V JP: This is the n n n-dimensional, input-space object: the sensitivity map telling us how the input affects the output.
JV P: This is the m m m-dimensional, output-space object: the directional derivative telling us how the function output changes when the inputs are perturbed.
rote [rəʊt]
n. 死记硬背
5. A Concrete Example
Let's consider a very simple function with n = 3 n=3 n=3 inputs and m = 2 m=2 m=2 outputs to illustrate what's going on:
f ( x ) = [ y 1 ( x ) y 2 ( x ) ] = [ x 1 2 + x 2 + x 3 x 1 − x 2 + x 3 ] 2 × 1 \begin{aligned} f(\mathbf{x}) &= \begin{bmatrix} y_{1}(\mathbf{x}) \\ y_{2}(\mathbf{x}) \\ \end{bmatrix} \\ &= \begin{bmatrix} x_{1}^{2} + x_{2} + x_{3} \\ x_{1} - x_{2} + x_{3} \\ \end{bmatrix}_{2 \times 1} \end{aligned} f(x)=[y1(x)y2(x)]=[x12+x2+x3x1−x2+x3]2×1
The Jacobian ∇ x f ( x ) ∣ x = [ x 1 = 2 , x 2 = 3 , x 3 = 1 ] T \nabla_{\mathbf{x}}f(\mathbf{x})\vert_{\mathbf{x} =[{x_{1}=2, \ x_{2}=3, \ x_{3}=1}]^\text{T}} ∇xf(x)∣x=[x1=2, x2=3, x3=1]T, evaluated at some example input point x = [ x 1 = 2 , x 2 = 3 , x 3 = 1 ] T {\mathbf{x} =[{x_{1}=2, \ x_{2}=3, \ x_{3}=1}]^\text{T}} x=[x1=2, x2=3, x3=1]T, is the following 2 × 3 2\times 3 2×3 matrix:
J f ( x ) = [ ∂ y 1 ∂ x 1 ∂ y 1 ∂ x 2 ∂ y 1 ∂ x 3 ∂ y 2 ∂ x 1 ∂ y 2 ∂ x 2 ∂ y 2 ∂ x 3 ] 2 × 3 = [ 2 x 1 1 1 1 − 1 1 ] x 1 = 2 , x 2 = 3 , x 3 = 1 = [ 4 1 1 1 − 1 1 ] x 1 = 2 , x 2 = 3 , x 3 = 1 \begin{aligned} \mathbf {J}{f}(\mathbf{x}) &= \begin{bmatrix} \frac{\partial y_1}{\partial x_1} &\frac{\partial y_1}{\partial x_2} & \frac{\partial y_1}{\partial x_3} \\[1.2ex] \frac{\partial y_2}{\partial x_1} & \frac{\partial y_2}{\partial x_2} & \frac{\partial y_2}{\partial x_3} \\[1.2ex] \end{bmatrix}{2 \times 3} \\[1.2ex] &= \begin{bmatrix} 2x_{1} & 1 & 1 \\[1.2ex] 1 & -1 & 1 \\[1.2ex] \end{bmatrix}{x{1}=2, \ x_{2}=3, \ x_{3}=1} \\[1.2ex] &= \begin{bmatrix} 4 & 1 & 1 \\[1.2ex] 1 & -1 & 1 \\[1.2ex] \end{bmatrix}{x{1}=2, \ x_{2}=3, \ x_{3}=1} \\[1.2ex] \end{aligned} Jf(x)= ∂x1∂y1∂x1∂y2∂x2∂y1∂x2∂y2∂x3∂y1∂x3∂y2 2×3= 2x111−111 x1=2, x2=3, x3=1= 411−111 x1=2, x2=3, x3=1
5.1. Vector Jacobian Product
For a VJP ( v ⊤ ) 1 × m @ J m × n (\mathbf{v}^{\top}){1 \times m} \text{@} \mathbf {J}{m \times n} (v⊤)1×m@Jm×n, we will require an m m m-dimensional vector v \mathbf{v} v. Consider, to start, what happens if v \mathbf{v} v is a one-hot vector; we get
v ⊤ @ J = [ 1 0 ] @ [ 4 1 1 1 − 1 1 ] = [ 4 1 1 ] \begin{aligned} \mathbf{v}^{\top} \text{@} \mathbf {J} = \begin{bmatrix}1 & 0\end{bmatrix} \text{@} \begin{bmatrix} 4 & 1 & 1 \\ 1 & -1 & 1 \end{bmatrix} = \begin{bmatrix}4 & 1 & 1 \end{bmatrix} \end{aligned} v⊤@J=[10]@[411−111]=[411]
In other words, with some element i ∈ { 1 , 2 , ... , m } i \in \{1, 2, \ldots, m\} i∈{1,2,...,m} of v \mathbf{v} v set to 1, and all others set to 0, we have selected the i i ith row of the Jacobian, which is exactly the gradient of the "scalar function" y i ( x ) y_{i}(\mathbf{x}) yi(x). Recall, this is an n n n-dimensional vector, which tells us how the i i ith output component of y i ( x ) y_{i}(\mathbf{x}) yi(x) depends on each of the n n n input components: it describes the sensitivity of this i i ith output component to infintessimal changes in each input component.
infinitesimal [ˌɪnfɪnɪ'tesɪml]
adj. 极微小
adv. 限小地
Now, if each of the m m m elements of v \mathbf{v} v were set to 1 1 1, we would pay equal attention to the sensitivity of all the output components to changes in the input - doing an element-wise sum over the rows of the Jacobian:
v ⊤ @ J = [ 1 1 ] @ [ 4 1 1 1 − 1 1 ] = [ 5 0 2 ] \begin{aligned} \mathbf{v}^{\top} \text{@} \mathbf {J} = \begin{bmatrix}1 & 1 \end{bmatrix} \text{@} \begin{bmatrix} 4 & 1 & 1 \\ 1 & -1 & 1 \end{bmatrix} = \begin{bmatrix}5 & 0 & 2 \end{bmatrix} \end{aligned} v⊤@J=[11]@[411−111]=[502]
Depending on the importance of each output element, we could also re-weight each row (i.e. gradient vector) differently, giving us a weighted "sensitivity map" over the inputs.
v ⊤ @ J = [ 0.5 1 ] @ [ 4 1 1 1 − 1 1 ] = [ 3 − 0.5 1.5 ] \begin{aligned} \mathbf{v}^{\top} \text{@} \mathbf {J} = \begin{bmatrix}0.5 & 1 \end{bmatrix} \text{@} \begin{bmatrix} 4 & 1 & 1 \\ 1 & -1 & 1 \end{bmatrix} = \begin{bmatrix}3 & -0.5 & 1.5 \end{bmatrix} \end{aligned} v⊤@J=[0.51]@[411−111]=[3−0.51.5]
If the elements of v \mathbf{v} v don't sum to 1, then this simply scales the VJP magnitude up or down, and it is trivial to normalise this again.
trivial ['trɪviəl]
adj. 不重要的;琐碎的;微不足道的
normalise ['nɔ:məlaɪz]
v. 常规化
5.2. Jacobian Vector Product
For the JVP J m × n @ v n × 1 \mathbf {J}{m \times n} \text{@} \mathbf{v}{n \times 1} Jm×n@vn×1, we need an n n n-dimensional perturbation vector which gives the direction along which we wish to calculate the derivative of f ( x ) f(\mathbf{x}) f(x).
perturbation [ˌpɜ:tə'beɪʃn]
n. 摄动;扰动;微扰;忧虑
First, let's once again consider one-hot vector. Setting the first dimension to 1, we get
J m × n @ v n × 1 = [ 4 1 1 1 − 1 1 ] [ 1 0 0 ] = [ 4 1 ] . \begin{aligned} \mathbf {J}{m \times n} \text{@} \mathbf{v}{n \times 1} = \begin{bmatrix} 4 & 1 & 1 \\ 1 & -1 & 1 \end{bmatrix}\begin{bmatrix}1 \\ 0 \\ 0\end{bmatrix} = \begin{bmatrix}4 \\ 1\end{bmatrix}. \end{aligned} Jm×n@vn×1=[411−111] 100 =[41].
Generalising: by setting element j ∈ { 1 , 2 , ... , n } j \in \{1, 2, \ldots, n\} j∈{1,2,...,n} of v \mathbf{v} v to 1, we select the j j jth column of the Jacobian. This m m m-dimensional vector tells us the derivative of f ( x ) f(\mathbf{x}) f(x) along the direction of the first basis vector e 1 \mathbf{e}{1} e1. Put otherwise, if we had perturbed the first input value x 1 x{1} x1 by 1, the JVP tells us how the function output would change in response.
References
1\] Yongqiang Cheng (程永强),