Gradients of Matrix-Matrix Multiplication in Deep Learning

Gradients of Matrix-Matrix Multiplication in Deep Learning

  • [1. Matrix multiplication](#1. Matrix multiplication)
  • [2. Derivation of the gradients](#2. Derivation of the gradients)
    • [2.1. Dimensions of the gradients](#2.1. Dimensions of the gradients)
    • [2.2. The chain rule](#2.2. The chain rule)
    • [2.3. Derivation of the gradient ∂ L ∂ A \frac{ {\partial L} }{ {\partial \boldsymbol {\boldsymbol {A} } } } ∂A∂L](#2.3. Derivation of the gradient ∂ L ∂ A \frac{ {\partial L} }{ {\partial \boldsymbol {\boldsymbol {A} } } } ∂A∂L)
    • [2.4. Derivation of the gradient ∂ L ∂ B \frac{ {\partial L} }{ {\partial \boldsymbol {\boldsymbol {B} } } } ∂B∂L](#2.4. Derivation of the gradient ∂ L ∂ B \frac{ {\partial L} }{ {\partial \boldsymbol {\boldsymbol {B} } } } ∂B∂L)
  • [3. Custom implementations and validation](#3. Custom implementations and validation)
  • [4. Summary](#4. Summary)
  • References

Understanding Artificial Neural Networks with Hands-on Experience - Part 1. Matrix Multiplication, Its Gradients and Custom Implementations
https://coolgpu.github.io/coolgpu_blog/github/pages/2020/09/22/matrixmultiplication.html

coolgpu_blog
https://github.com/coolgpu/coolgpu_blog/tree/master/_posts

We talk about matrix multiplication because it can be used in Convolution and ConvTranspose operations to make things simpler. In this post, we will focus on derivation of the gradients of matrix multiplication. While the derivation process may seem complex, the final results will be in a pretty simple form and are easy to remember.

1. Matrix multiplication

The definition of matrix /ˈmeɪtrɪks/ multiplication /ˌmʌltɪplɪˈkeɪʃn/ can be found in every linear algebra /ˈældʒɪbrə/ book. Let's use the definition from Wikipedia. Given a m × k m \times k m×k matrix A \boldsymbol {A} A and a k × n k \times n k×n matrix B \boldsymbol {B} B

A = a 11 a 12 ⋯ a 1 k a 21 a 22 ⋯ a 2 k ⋮ ⋮ ⋱ ⋮ a m 1 a m 2 ⋯ a m k (1) \begin{split}\boldsymbol{A}=\begin{bmatrix} a_{11} & a_{12} & \cdots & a_{1k} \\ a_{21} & a_{22} & \cdots & a_{2k} \\ \vdots & \vdots & \ddots & \vdots \\ a_{m1} & a_{m2} & \cdots & a_{mk} \\ \end{bmatrix}\end{split} \tag{1} A= a11a21⋮am1a12a22⋮am2⋯⋯⋱⋯a1ka2k⋮amk (1)

and

B = b 11 b 12 ⋯ b 1 n b 21 b 22 ⋯ b 2 n ⋮ ⋮ ⋱ ⋮ b k 1 b k 2 ⋯ b k n (1) \begin{split}\boldsymbol{B}=\begin{bmatrix} b_{11} & b_{12} & \cdots & b_{1n} \\ b_{21} & b_{22} & \cdots & b_{2n} \\ \vdots & \vdots & \ddots & \vdots \\ b_{k1} & b_{k2} & \cdots & b_{kn} \\ \end{bmatrix}\end{split} \tag{1} B= b11b21⋮bk1b12b22⋮bk2⋯⋯⋱⋯b1nb2n⋮bkn (1)

their matrix product C = A B \boldsymbol {C} = \boldsymbol{A}\boldsymbol{B} C=AB is defined as

C = c 11 c 12 ⋯ c 1 n c 21 c 22 ⋯ c 2 n ⋮ ⋮ c i j ⋮ c m 1 c m 2 ⋯ c m n (2) \begin{split}\boldsymbol{C}=\begin{bmatrix} c_{11} & c_{12} & \cdots & c_{1n} \\ c_{21} & c_{22} & \cdots & c_{2n} \\ \vdots & \vdots &c_{ij} & \vdots \\ c_{m1} & c_{m2} & \cdots & c_{mn} \\ \end{bmatrix}\end{split} \tag{2} C= c11c21⋮cm1c12c22⋮cm2⋯⋯cij⋯c1nc2n⋮cmn (2)

where its element c i j {c_{ij}} cij is given by

c i j = ∑ t = 1 k a i t b t j (3) {c_{ij}} = \sum_{t = 1}^k {a_{it} }{b_{tj}} \tag{3} cij=t=1∑kaitbtj(3)

for i = 1 , ... , m i = 1, \ldots ,m i=1,...,m and j = 1 , ... , n j = 1, \ldots ,n j=1,...,n. In other words, c i j {c_{ij}} cij is the dot product of the i i ith row of A \boldsymbol {A} A and the j j jth column of B \boldsymbol {B} B.

2. Derivation of the gradients

2.1. Dimensions of the gradients

If we are considering an isolated matrix multiplication, the partial derivative matrix C \boldsymbol {C} C with respect to either matrix A \boldsymbol {A} A and matrix B \boldsymbol {B} B would be a 4-D hyper-space relationship, referred to as Jacobian Matrix. You will also find that there will be many zeros in the 4-D Jacobian Matrix because, as shown in Equation (3), c i j {c_{ij} } cij is a function of only the i i ith row of A \boldsymbol {A} A and the j j jth column of B \boldsymbol {B} B, and independent of other rows of A \boldsymbol {A} A and other columns of B \boldsymbol {B} B.

Jacobian matrix and determinant
https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant

复制代码
isolate /ˈaɪsəleɪt/
vt. 使孤立;使绝缘;使隔离
n. 隔离种群
vi. 孤立;隔离
adj. 隔离的;孤立的

What we are considering here is not an isolated matrix multiplication. Instead, we are talking about matrix multiplication inside a neural network that will have a scalar loss function. For example, consider a simple case where the loss L {L} L is the mean of matrix C \boldsymbol {C} C:

L = 1 m × n ∑ i = 1 m ∑ j = 1 n c i j (4) L = \frac{1}{ {m \times n} } \sum \limits_{i = 1}^m \sum \limits_{j = 1}^n {c_{ij} } \tag{4} L=m×n1i=1∑mj=1∑ncij(4)

our focus is the partial derivatives of scalar L {L} L w.r.t. the input matrix A \boldsymbol {A} A and B \boldsymbol {B} B, ∂ L ∂ A \frac{ {\partial L} }{ {\partial \boldsymbol {A} } } ∂A∂L and ∂ L ∂ B \frac{ {\partial L} }{ {\partial \boldsymbol {B} } } ∂B∂L, respectively. Therefore, ∂ L ∂ A \frac{ {\partial L} }{ {\partial \boldsymbol {A} } } ∂A∂L has the same dimension as A \boldsymbol {A} A, which is another m × k m \times k m×k matrix, and ∂ L ∂ B \frac{ {\partial L} }{ {\partial \boldsymbol {B} } } ∂B∂L has the same dimension as B \boldsymbol {B} B, which is another k × n k \times n k×n matrix.

2.2. The chain rule

We will use the chain rule to do backpropagation of gradients. For such an important tool in neural networks, it doesn't hurt to briefly summarize the chain rule just like in the previous post for one more time. Given a function L ( x 1 , x 2 , ... x N ) L\left( { {x_1},{x_2}, \ldots {x_N} } \right) L(x1,x2,...xN) as

L ( x 1 , ... x N ) = L ( f 1 ( x 1 , ... x N ) , f 2 ( x 1 , ... x N ) , ... , f M ( x 1 , ... x N ) ) (5) L\left( { {x_1}, \ldots {x_N} } \right) = L\left( { {f_1}\left( { {x_1}, \ldots {x_N} } \right),{f_2}\left( { {x_1}, \ldots {x_N} } \right), \ldots ,{f_M}\left( { {x_1}, \ldots {x_N} } \right)} \right) \tag{5} L(x1,...xN)=L(f1(x1,...xN),f2(x1,...xN),...,fM(x1,...xN))(5)

Then the gradient of L L L w.r.t x i {x_i} xi can be computed as

∂ L ∂ x i = ∂ L ∂ f 1 ∂ f 1 ∂ x i + ∂ L ∂ f 2 ∂ f 2 ∂ x i + ⋯ + ∂ L ∂ f M ∂ f M ∂ x i = ∑ m = 1 M ∂ L ∂ f m ∂ f m ∂ x i (6) \frac{ {\partial L} }{ {\partial {x_i} } } = \frac{ {\partial L} }{ {\partial {f_1} } }\frac{ {\partial {f_1} } }{ {\partial {x_i} } } + \frac{ {\partial L} }{ {\partial {f_2} } }\frac{ {\partial {f_2} } }{ {\partial {x_i} } } +\cdots + \frac{ {\partial L} }{ {\partial {f_M} } }\frac{ {\partial {f_M} } }{ {\partial {x_i} } } = \sum \limits_{m = 1}^M \frac{ {\partial L} }{ {\partial {f_m} } }\frac{ {\partial {f_m} } }{ {\partial {x_i} } } \tag{6} ∂xi∂L=∂f1∂L∂xi∂f1+∂f2∂L∂xi∂f2+⋯+∂fM∂L∂xi∂fM=m=1∑M∂fm∂L∂xi∂fm(6)

Equation (6) can be understood from two perspectives:

  • Summation means that all possible paths through which x i {x_i} xi contributes to L L L should be included
  • Product means that, along each path m m m, the output gradient equals the upstream passed in, ∂ L ∂ f m \frac{ {\partial L} }{ {\partial {f_m} } } ∂fm∂L, times the local gradient, ∂ f m ∂ x i \frac{ {\partial {f_m} } }{ {\partial {x_i} } } ∂xi∂fm.

2.3. Derivation of the gradient ∂ L ∂ A \frac{ {\partial L} }{ {\partial \boldsymbol {\boldsymbol {A} } } } ∂A∂L

In this section, we will use a 2 × 4 2 \times 4 2×4 matrix A \boldsymbol {A} A and a 4 × 3 4 \times 3 4×3 matrix B \boldsymbol {B} B as an example to step-by-step derive the partial derivative of ∂ L ∂ A \frac{ {\partial L} }{ {\partial \boldsymbol {A} } } ∂A∂L. Please note that the same derivation can be performed on a general m × k m \times k m×k matrix A \boldsymbol {A} A and k × n k \times n k×n matrix B \boldsymbol {B} B. A specific example is used here purely for the purpose of making it more straightforward.

Let's start with writing the matrix A \boldsymbol {A} A, B \boldsymbol {B} B and their matrix product C = A B \boldsymbol {C} = AB C=AB in expanded format.

复制代码
expand /ɪkˈspænd/
vt. 扩张;使膨胀;详述
vi. 张开,展开;发展

A = a 11 a 12 a 13 a 14 a 21 a 22 a 23 a 24 (7) \boldsymbol {A} = \left {\\begin{array}{}{ {a_{11} } }\&{ {a_{12} } }\&{ {a_{13} } }\&{ {a_{14} } }\\\\{ {a_{21} } }\&{ {a_{22} } }\&{ { \\color{red} a_{23 } } }\&{ {a_{24} } }\\end{array} } \\right \tag{7} A=a11a21a12a22a13a23a14a24(7)

and

B = b 11 b 12 b 13 b 21 b 22 b 23 b 31 b 32 b 33 b 41 b 42 b 43 (7) \boldsymbol {B} = \left {\\begin{array}{}{ {b_{11} } }\&{ {b_{12} } }\&{ {b_{13} } }\\\\{ {b_{21} } }\&{ {b_{22} } }\&{ {b_{23} } }\\\\{ {b_{31} } }\&{ {b_{32} } }\&{ {b_{33} } }\\\\{ {b_{41} } }\&{ {b_{42} } }\&{ {b_{43} } }\\end{array} } \\right \tag{7} B= b11b21b31b41b12b22b32b42b13b23b33b43 (7)

C = c 11 c 12 c 13 c 21 c 22 c 23 = a 11 a 12 a 13 a 14 a 21 a 22 a 23 a 24 b 11 b 12 b 13 b 21 b 22 b 23 b 31 b 32 b 33 b 41 b 42 b 43 = a 11 b 11 + a 12 b 21 + a 13 b 31 + a 14 b 41 a 11 b 12 + a 12 b 22 + a 13 b 32 + a 14 b 42 a 11 b 13 + a 12 b 23 + a 13 b 33 + a 14 b 43 a 21 b 11 + a 22 b 21 + a 23 b 31 + a 24 b 41 a 21 b 12 + a 22 b 22 + a 23 b 32 + a 24 b 42 a 21 b 13 + a 22 b 23 + a 23 b 33 + a 24 b 43 (8) \begin{aligned} \boldsymbol {C} &= \left {\\begin{array}{}{ {c_{11} } }\&{ {c_{12} } }\&{ {c_{13} } }\\\\{ {c_{21} } }\&{ {c_{22} } }\&{ {c_{23} } }\\end{array} } \\right = \left {\\begin{array}{}{ {a_{11} } }\&{ {a_{12} } }\&{ {a_{13} } }\&{ {a_{14} } }\\\\{ {a_{21} } }\&{ {a_{22} } }\&{ { \\color{red} a_{23 } } }\&{ {a_{24} } }\\end{array} } \\right\left {\\begin{array}{}{ {b_{11} } }\&{ {b_{12} } }\&{ {b_{13} } }\\\\{ {b_{21} } }\&{ {b_{22} } }\&{ {b_{23} } }\\\\{ {b_{31} } }\&{ {b_{32} } }\&{ {b_{33} } }\\\\{ {b_{41} } }\&{ {b_{42} } }\&{ {b_{43} } }\\end{array} } \\right \\ &= \left {\\begin{array}{}{ { {a_{11} }{b_{11} } + {a_{12} }{b_{21} } + {a_{13} }{b_{31} } + {a_{14} }{b_{41} } } }\&{ { {a_{11} }{b_{12} } + {a_{12} }{b_{22} } + {a_{13} }{b_{32} } + {a_{14} }{b_{42} } } }\&{ { {a_{11} }{b_{13} } + {a_{12} }{b_{23} } + {a_{13} }{b_{33} } + {a_{14} }{b_{43} } } }\\\\{ { {a_{21} }{b_{11} } + {a_{22} }{b_{21} } + { \\color{red} a_{23 } }{b_{31} } + {a_{24} }{b_{41} } } }\&{ { {a_{21} }{b_{12} } + {a_{22} }{b_{22} } + { \\color{red} a_{23 } }{b_{32} } + {a_{24} }{b_{42} } } }\&{ { {a_{21} }{b_{13} } + {a_{22} }{b_{23} } + { \\color{red} a_{23 } }{b_{33} } + {a_{24} }{b_{43} } } }\\end{array} } \\right \tag{8} \end{aligned} C=c11c21c12c22c13c23=a11a21a12a22a13a23a14a24 b11b21b31b41b12b22b32b42b13b23b33b43 =a11b11+a12b21+a13b31+a14b41a21b11+a22b21+a23b31+a24b41a11b12+a12b22+a13b32+a14b42a21b12+a22b22+a23b32+a24b42a11b13+a12b23+a13b33+a14b43a21b13+a22b23+a23b33+a24b43(8)

Consider an arbitrary element of A \boldsymbol {A} A, for example a 23 { \color{red} a_{23 } } a23, we have the local partial derivative of C \boldsymbol {C} C w.r.t. a 23 { \color{red} a_{23 } } a23 based on Equation (8).

∂ L ∂ A = ∂ L ∂ C ∂ C ∂ A \frac{ {\partial L} }{ {\partial \boldsymbol {A} } } = \frac{ {\partial L} }{ {\partial \boldsymbol {C} } }\frac{ {\partial \boldsymbol {C}} }{ {\partial \boldsymbol {A} } } ∂A∂L=∂C∂L∂A∂C

∂ c 11 ∂ a 23 = 0 ∂ c 12 ∂ a 23 = 0 ∂ c 13 ∂ a 23 = 0 ∂ c 21 ∂ a 23 = ∂ ∂ a 23 ( a 21 b 11 + a 22 b 21 + a 23 b 31 + a 24 b 41 ) = 0 + 0 + ∂ ∂ a 23 ( a 23 b 31 ) + 0 = b 31 ∂ c 22 ∂ a 23 = ∂ ∂ a 23 ( a 21 b 12 + a 22 b 22 + a 23 b 32 + a 24 b 42 ) = 0 + 0 + ∂ ∂ a 23 ( a 23 b 32 ) + 0 = b 32 ∂ c 23 ∂ a 23 = ∂ ∂ a 23 ( a 21 b 13 + a 22 b 23 + a 23 b 33 + a 24 b 43 ) = 0 + 0 + ∂ ∂ a 23 ( a 23 b 33 ) + 0 = b 33 (9) \begin{aligned} \frac{ {\partial {c_{11} } } }{ {\partial { \color{red} a_{23 } } } } &= 0 \\ \frac{ {\partial {c_{12} } } }{ {\partial { \color{red} a_{23 } } } } &= 0 \\ \frac{ {\partial {c_{13} } } }{ {\partial { \color{red} a_{23 } } } } &= 0 \\ \frac{ {\partial {c_{21} } } }{ {\partial { \color{red} a_{23 } } } } &= \frac{\partial }{ {\partial { \color{red} a_{23 } } } }\left( { {a_{21} }{b_{11} } + {a_{22} }{b_{21} } + { \color{red} a_{23 } }{b_{31} } + {a_{24} }{b_{41} } } \right) = 0 + 0 + \frac{\partial }{ {\partial { \color{red} a_{23 } } } }\left( { { \color{red} a_{23 } }{b_{31} } } \right) + 0 = {b_{31} } \\ \frac{ {\partial {c_{22} } } }{ {\partial { \color{red} a_{23 } } } } &= \frac{\partial }{ {\partial { \color{red} a_{23 } } } }\left( { {a_{21} }{b_{12} } + {a_{22} }{b_{22} } + { \color{red} a_{23 } }{b_{32} } + {a_{24} }{b_{42} } } \right) = 0 + 0 + \frac{\partial }{ {\partial { \color{red} a_{23 } } } }\left( { { \color{red} a_{23 } }{b_{32} } } \right) + 0 = {b_{32} } \\ \frac{ {\partial {c_{23} } } }{ {\partial { \color{red} a_{23 } } } } &= \frac{\partial }{ {\partial { \color{red} a_{23 } } } }\left( { {a_{21} }{b_{13} } + {a_{22} }{b_{23} } + { \color{red} a_{23 } }{b_{33} } + {a_{24} }{b_{43} } } \right) = 0 + 0 + \frac{\partial }{ {\partial { \color{red} a_{23 } } } }\left( { { \color{red} a_{23 } }{b_{33} } } \right) + 0 = {b_{33} } \tag{9} \end{aligned} ∂a23∂c11∂a23∂c12∂a23∂c13∂a23∂c21∂a23∂c22∂a23∂c23=0=0=0=∂a23∂(a21b11+a22b21+a23b31+a24b41)=0+0+∂a23∂(a23b31)+0=b31=∂a23∂(a21b12+a22b22+a23b32+a24b42)=0+0+∂a23∂(a23b32)+0=b32=∂a23∂(a21b13+a22b23+a23b33+a24b43)=0+0+∂a23∂(a23b33)+0=b33(9)

Using the chain rule, we have the partial derivative of the loss L L L w.r.t. a 23 { \color{red} a_{23 }} a23

∂ L ∂ a 23 = ∂ L ∂ c 11 ∂ c 11 ∂ a 23 + ∂ L ∂ c 12 ∂ c 12 ∂ a 23 + ∂ L ∂ c 13 ∂ c 13 ∂ a 23 + ∂ L ∂ c 21 ∂ c 21 ∂ a 23 + ∂ L ∂ c 22 ∂ c 22 ∂ a 23 + ∂ L ∂ c 23 ∂ c 23 ∂ a 23 = 0 + 0 + 0 + ∂ L ∂ c 21 b 31 + ∂ L ∂ c 22 b 32 + ∂ L ∂ c 23 b 33 = ∂ L ∂ c 21 b 31 + ∂ L ∂ c 22 b 32 + ∂ L ∂ c 23 b 33 (10) \begin{aligned} \frac{ {\partial L} }{ {\partial { \color{red} a_{23 } } } } &= \frac{ {\partial L} }{ {\partial {c_{11} } } }\frac{ {\partial {c_{11} } } }{ {\partial { \color{red} a_{23 } } } } + \frac{ {\partial L} }{ {\partial {c_{12} } } }\frac{ {\partial {c_{12} } } }{ {\partial { \color{red} a_{23 } } } } + \frac{ {\partial L} }{ {\partial {c_{13} } } }\frac{ {\partial {c_{13} } } }{ {\partial { \color{red} a_{23 } } } } + \frac{ {\partial L} }{ {\partial {c_{21} } } }\frac{ {\partial {c_{21} } } }{ {\partial { \color{red} a_{23 } } } } + \frac{ {\partial L} }{ {\partial {c_{22} } } }\frac{ {\partial {c_{22} } } }{ {\partial { \color{red} a_{23 } } } } + \frac{ {\partial L} }{ {\partial {c_{23} } } }\frac{ {\partial {c_{23} } } }{ {\partial { \color{red} a_{23 } } } } \\ &= 0 + 0 + 0 + \frac{ {\partial L} }{ {\partial {c_{21} } } }{b_{31} } + \frac{ {\partial L} }{ {\partial {c_{22} } } }{b_{32} } + \frac{ {\partial L} }{ {\partial {c_{23} } } }{b_{33} } \\ &= \frac{ {\partial L} }{ {\partial {c_{21} } } }{b_{31} } + \frac{ {\partial L} }{ {\partial {c_{22} } } }{b_{32} } + \frac{ {\partial L} }{ {\partial {c_{23} } } }{b_{33} } \tag{10} \end{aligned} ∂a23∂L=∂c11∂L∂a23∂c11+∂c12∂L∂a23∂c12+∂c13∂L∂a23∂c13+∂c21∂L∂a23∂c21+∂c22∂L∂a23∂c22+∂c23∂L∂a23∂c23=0+0+0+∂c21∂Lb31+∂c22∂Lb32+∂c23∂Lb33=∂c21∂Lb31+∂c22∂Lb32+∂c23∂Lb33(10)

The second line in Equation (10) used the results from Equation (9).

Following a similar manner, we can derive the other elements of ∂ L ∂ A \frac{ {\partial L} }{ {\partial \boldsymbol {A} } } ∂A∂L as below

∂ L ∂ A = ∂ L ∂ a 11 ∂ L ∂ a 12 ∂ L ∂ a 13 ∂ L ∂ a 14 ∂ L ∂ a 21 ∂ L ∂ a 22 ∂ L ∂ a 23 ∂ L ∂ a 24 = ∂ L ∂ c 11 b 11 + ∂ L ∂ c 12 b 12 + ∂ L ∂ c 13 b 13 ∂ L ∂ c 11 b 21 + ∂ L ∂ c 12 b 22 + ∂ L ∂ c 13 b 23 ∂ L ∂ c 11 b 31 + ∂ L ∂ c 12 b 32 + ∂ L ∂ c 13 b 33 ∂ L ∂ c 11 b 41 + ∂ L ∂ c 12 b 42 + ∂ L ∂ c 13 b 43 ∂ L ∂ c 21 b 11 + ∂ L ∂ c 22 b 12 + ∂ L ∂ c 23 b 13 ∂ L ∂ c 21 b 21 + ∂ L ∂ c 22 b 22 + ∂ L ∂ c 23 b 23 ∂ L ∂ c 21 b 31 + ∂ L ∂ c 22 b 32 + ∂ L ∂ c 23 b 33 ∂ L ∂ c 21 b 41 + ∂ L ∂ c 22 b 42 + ∂ L ∂ c 23 b 43 (11) \begin{aligned} \frac{ {\partial L} }{ {\partial \boldsymbol {A} } } &= \left {\\begin{array}{}{\\frac{ {\\partial L} }{ {\\partial {a_{11} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {a_{12} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {a_{13} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {a_{14} } } } }\\\\{\\frac{ {\\partial L} }{ {\\partial {a_{21} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {a_{22} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {a_{23} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {a_{24} } } } }\\end{array} } \\right \\ &= \left {\\begin{array}{}{ {\\frac{ {\\partial L} }{ {\\partial {c_{11} } } }{b_{11} } + \\frac{ {\\partial L} }{ {\\partial {c_{12} } } }{b_{12} } + \\frac{ {\\partial L} }{ {\\partial {c_{13} } } }{b_{13} } } }\&{ { \\frac{ {\\partial L} }{ {\\partial {c_{11} } } }{b_{21} } + \\frac{ {\\partial L} }{ {\\partial {c_{12} } } }{b_{22} } + \\frac{ {\\partial L} }{ {\\partial {c_{13} } } }{b_{23} } } }\&{ { \\frac{ {\\partial L} }{ {\\partial {c_{11} } } }{b_{31} } + \\frac{ {\\partial L} }{ {\\partial {c_{12} } } }{b_{32} } + \\frac{ {\\partial L} }{ {\\partial {c_{13} } } }{b_{33} } } }\&{ { \\frac{ {\\partial L} }{ {\\partial {c_{11} } } }{b_{41} } + \\frac{ {\\partial L} }{ {\\partial {c_{12} } } }{b_{42} } + \\frac{ {\\partial L} }{ {\\partial {c_{13} } } }{b_{43} } } }\\\\{ { \\frac{ {\\partial L} }{ {\\partial {c_{21} } } }{b_{11} } + \\frac{ {\\partial L} }{ {\\partial {c_{22} } } }{b_{12} } + \\frac{ {\\partial L} }{ {\\partial {c_{23} } } }{b_{13} } } }\&{ { \\frac{ {\\partial L} }{ {\\partial {c_{21} } } }{b_{21} } + \\frac{ {\\partial L} }{ {\\partial {c_{22} } } }{b_{22} } + \\frac{ {\\partial L} }{ {\\partial {c_{23} } } }{b_{23} } } }\&{ { \\frac{ {\\partial L} }{ {\\partial {c_{21} } } }{b_{31} } + \\frac{ {\\partial L} }{ {\\partial {c_{22} } } }{b_{32} } + \\frac{ {\\partial L} }{ {\\partial {c_{23} } } }{b_{33} } } }\&{ { \\frac{ {\\partial L} }{ {\\partial {c_{21} } } }{b_{41} } + \\frac{ {\\partial L} }{ {\\partial {c_{22} } } }{b_{42} } + \\frac{ {\\partial L} }{ {\\partial {c_{23} } } }{b_{43} } } }\\end{array} } \\right \tag{11} \end{aligned} ∂A∂L=∂a11∂L∂a21∂L∂a12∂L∂a22∂L∂a13∂L∂a23∂L∂a14∂L∂a24∂L=∂c11∂Lb11+∂c12∂Lb12+∂c13∂Lb13∂c21∂Lb11+∂c22∂Lb12+∂c23∂Lb13∂c11∂Lb21+∂c12∂Lb22+∂c13∂Lb23∂c21∂Lb21+∂c22∂Lb22+∂c23∂Lb23∂c11∂Lb31+∂c12∂Lb32+∂c13∂Lb33∂c21∂Lb31+∂c22∂Lb32+∂c23∂Lb33∂c11∂Lb41+∂c12∂Lb42+∂c13∂Lb43∂c21∂Lb41+∂c22∂Lb42+∂c23∂Lb43(11)

Equation (11) can be equivalently rewritten as a matrix product.

∂ L ∂ A = ∂ L ∂ c 11 ∂ L ∂ c 12 ∂ L ∂ c 13 ∂ L ∂ c 21 ∂ L ∂ c 22 ∂ L ∂ c 23 b 11 b 21 b 31 b 41 b 12 b 22 b 32 b 42 b 13 b 23 b 33 b 43 (12) \begin{aligned} \frac{ {\partial L} }{ {\partial \boldsymbol {A} } } = \left {\\begin{array}{}{\\frac{ {\\partial L} }{ {\\partial {c_{11} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {c_{12} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {c_{13} } } } }\\\\{\\frac{ {\\partial L} }{ {\\partial {c_{21} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {c_{22} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {c_{23} } } } }\\end{array} } \\right\left {\\begin{array}{}{ {b_{11} } }\&{ {b_{21} } }\&{ {b_{31} } }\&{ {b_{41} } }\\\\{ {b_{12} } }\&{ {b_{22} } }\&{ {b_{32} } }\&{ {b_{42} } }\\\\{ {b_{13} } }\&{ {b_{23} } }\&{ {b_{33} } }\&{ {b_{43} } }\\end{array} } \\right \tag{12} \end{aligned} ∂A∂L=∂c11∂L∂c21∂L∂c12∂L∂c22∂L∂c13∂L∂c23∂L b11b12b13b21b22b23b31b32b33b41b42b43 (12)

In fact, the first matrix is the upstream derivative ∂ L ∂ C \frac{ {\partial L} }{ {\partial \boldsymbol {C} } } ∂C∂L and the second matrix is the transpose of B \boldsymbol {B} B. Then we have

∂ L ∂ A = ∂ L ∂ C B T (13) \frac{ {\partial L} }{ {\partial \boldsymbol {A} } } = \frac{ {\partial L} }{ {\partial \boldsymbol {C} } }{\boldsymbol {B} ^T} \tag{13} ∂A∂L=∂C∂LBT(13)

Equation (13) shows that, for a matrix multiplication C = A B \boldsymbol {C} = \boldsymbol{A}\boldsymbol{B} C=AB in a neural network, the derivative of the loss L L L w.r.t matrix A \boldsymbol {A} A equals the upstream derivative ∂ L ∂ C \frac{ {\partial L} }{ {\partial \boldsymbol {C} } } ∂C∂L times the transpose of matrix B \boldsymbol {B} B.

Let's check the dimensions. On the left hand side of Equation (13), ∂ L ∂ A \frac{ {\partial L} }{ {\partial \boldsymbol {A} } } ∂A∂L has a dimension of m × k m \times k m×k, the same as A \boldsymbol {A} A. On the right hand side, ∂ L ∂ C \frac{ {\partial L} }{ {\partial \boldsymbol {C} } } ∂C∂L has a dimension of m × n m \times n m×n and B T {\boldsymbol {B} ^T} BT has a dimension of n × k n \times k n×k; therefore, their matrix product has a dimension of m × k m \times k m×k and matches that of ∂ L ∂ A \frac{ {\partial L} }{ {\partial \boldsymbol {A} } } ∂A∂L.

2.4. Derivation of the gradient ∂ L ∂ B \frac{ {\partial L} }{ {\partial \boldsymbol {\boldsymbol {B} } } } ∂B∂L

Similarly, for ∂ L ∂ B \frac{ {\partial L} }{ {\partial \boldsymbol {B} } } ∂B∂L, let's consider an arbitrary element of B \boldsymbol {B} B, for example b 12 { \color{blue} b_{12} } b12, we have the local partial derivative of C \boldsymbol {C} C w.r.t. b 12 { \color{blue} b_{12} } b12 based on Equation (8) above.

∂ c 11 ∂ b 12 = 0 ∂ c 12 ∂ b 12 = ∂ ∂ b 12 ( a 11 b 12 + a 12 b 22 + a 13 b 32 + a 14 b 42 ) = a 11 ∂ c 13 ∂ b 12 = 0 ∂ c 21 ∂ b 12 = 0 ∂ c 22 ∂ b 12 = ∂ ∂ b 12 ( a 21 b 12 + a 22 b 22 + a 23 b 32 + a 24 b 42 ) = a 21 ∂ c 23 ∂ b 12 = 0 (14) \begin{aligned} \frac{ {\partial {c_{11} } } }{ {\partial { \color{blue} b_{12} } } } &= 0 \\ \frac{ {\partial {c_{12} } } }{ {\partial { \color{blue} b_{12} } } } &= \frac{\partial }{ {\partial { \color{blue} b_{12} } } }\left( { {a_{11} }{ \color{blue} b_{12} } + {a_{12} }{b_{22} } + {a_{13} }{b_{32} } + {a_{14} }{b_{42} } } \right) = {a_{11} } \\ \frac{ {\partial {c_{13} } } }{ {\partial { \color{blue} b_{12} } } } &= 0 \\ \frac{ {\partial {c_{21} } } }{ {\partial { \color{blue} b_{12} } } } &= 0 \\ \frac{ {\partial {c_{22} } } }{ {\partial { \color{blue} b_{12} } } } &= \frac{\partial }{ {\partial { \color{blue} b_{12} } } }\left( { {a_{21} }{ \color{blue} b_{12} } + {a_{22} }{b_{22} } + {a_{23} }{b_{32} } + {a_{24} }{b_{42} } } \right) = {a_{21} } \\ \frac{ {\partial {c_{23} } } }{ {\partial { \color{blue} b_{12} } } } &= 0 \tag{14} \end{aligned} ∂b12∂c11∂b12∂c12∂b12∂c13∂b12∂c21∂b12∂c22∂b12∂c23=0=∂b12∂(a11b12+a12b22+a13b32+a14b42)=a11=0=0=∂b12∂(a21b12+a22b22+a23b32+a24b42)=a21=0(14)

Using the chain rule, we have the partial derivative of the loss L L L w.r.t. b 12 { \color{blue} b_{12} } b12

∂ L ∂ b 12 = ∂ L ∂ c 11 ∂ c 11 ∂ b 12 + ∂ L ∂ c 12 ∂ c 12 ∂ b 12 + ∂ L ∂ c 13 ∂ c 13 ∂ b 12 + ∂ L ∂ c 21 ∂ c 21 ∂ b 12 + ∂ L ∂ c 22 ∂ c 22 ∂ b 12 + ∂ L ∂ c 23 ∂ c 23 ∂ b 12 = 0 + ∂ L ∂ c 12 a 11 + 0 + 0 + ∂ L ∂ c 22 a 21 + 0 = a 11 ∂ L ∂ c 12 + a 21 ∂ L ∂ c 22 (15) \begin{aligned} \frac{ {\partial L} }{ {\partial { \color{blue} b_{12} } } } &= \frac{ {\partial L} }{ {\partial {c_{11} } } }\frac{ {\partial {c_{11} } } }{ {\partial { \color{blue} b_{12} } } } + \frac{ {\partial L} }{ {\partial {c_{12} } } }\frac{ {\partial {c_{12} } } }{ {\partial { \color{blue} b_{12} } } } + \frac{ {\partial L} }{ {\partial {c_{13} } } }\frac{ {\partial {c_{13} } } }{ {\partial { \color{blue} b_{12} } } } + \frac{ {\partial L} }{ {\partial {c_{21} } } }\frac{ {\partial {c_{21} } } }{ {\partial { \color{blue} b_{12} } } } + \frac{ {\partial L} }{ {\partial {c_{22} } } }\frac{ {\partial {c_{22} } } }{ {\partial { \color{blue} b_{12} } } } + \frac{ {\partial L} }{ {\partial {c_{23} } } }\frac{ {\partial {c_{23} } } }{ {\partial { \color{blue} b_{12} } } } \\ &=0 + \frac{ {\partial L} }{ {\partial {c_{12} } } }{a_{11} } + 0 + 0 + \frac{ {\partial L} }{ {\partial {c_{22} } } }{a_{21} } + 0 \\ &= {a_{11} }\frac{ {\partial L} }{ {\partial {c_{12} } } } + {a_{21} }\frac{ {\partial L} }{ {\partial {c_{22} } } } \tag{15} \end{aligned} ∂b12∂L=∂c11∂L∂b12∂c11+∂c12∂L∂b12∂c12+∂c13∂L∂b12∂c13+∂c21∂L∂b12∂c21+∂c22∂L∂b12∂c22+∂c23∂L∂b12∂c23=0+∂c12∂La11+0+0+∂c22∂La21+0=a11∂c12∂L+a21∂c22∂L(15)

The second line in Equation (15) used the results from Equation (14). Following a similar manner again, we can derive the other elements of ∂ L ∂ B \frac{ {\partial L} }{ {\partial \boldsymbol {B} } } ∂B∂L as shown below

∂ L ∂ B = ∂ L ∂ b 11 ∂ L ∂ b 12 ∂ L ∂ b 13 ∂ L ∂ b 21 ∂ L ∂ b 22 ∂ L ∂ b 23 ∂ L ∂ b 31 ∂ L ∂ b 32 ∂ L ∂ b 33 ∂ L ∂ b 41 ∂ L ∂ b 42 ∂ L ∂ b 43 = a 11 ∂ L ∂ c 11 + a 21 ∂ L ∂ c 21 a 11 ∂ L ∂ c 12 + a 21 ∂ L ∂ c 22 a 11 ∂ L ∂ c 13 + a 21 ∂ L ∂ c 23 a 12 ∂ L ∂ c 11 + a 22 ∂ L ∂ c 21 a 12 ∂ L ∂ c 12 + a 22 ∂ L ∂ c 22 a 12 ∂ L ∂ c 13 + a 22 ∂ L ∂ c 23 a 13 ∂ L ∂ c 11 + a 23 ∂ L ∂ c 21 a 13 ∂ L ∂ c 12 + a 23 ∂ L ∂ c 22 a 13 ∂ L ∂ c 13 + a 23 ∂ L ∂ c 23 a 14 ∂ L ∂ c 11 + a 24 ∂ L ∂ c 21 a 14 ∂ L ∂ c 12 + a 24 ∂ L ∂ c 22 a 14 ∂ L ∂ c 13 + a 24 ∂ L ∂ c 23 (16) \frac{ {\partial L} }{ {\partial \boldsymbol {B} } } = \left {\\begin{array}{}{\\frac{ {\\partial L} }{ {\\partial {b_{11} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {b_{12} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {b_{13} } } } }\\\\{\\frac{ {\\partial L} }{ {\\partial {b_{21} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {b_{22} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {b_{23} } } } }\\\\{\\frac{ {\\partial L} }{ {\\partial {b_{31} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {b_{32} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {b_{33} } } } }\\\\{\\frac{ {\\partial L} }{ {\\partial {b_{41} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {b_{42} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {b_{43} } } } }\\end{array} } \\right \\ = \left {\\begin{array}{}{ { {a_{11} }\\frac{ {\\partial L} }{ {\\partial {c_{11} } } } + {a_{21} }\\frac{ {\\partial L} }{ {\\partial {c_{21} } } } } }\&{ { {a_{11} }\\frac{ {\\partial L} }{ {\\partial {c_{12} } } } + {a_{21} }\\frac{ {\\partial L} }{ {\\partial {c_{22} } } } } }\&{ { {a_{11} }\\frac{ {\\partial L} }{ {\\partial {c_{13} } } } + {a_{21} }\\frac{ {\\partial L} }{ {\\partial {c_{23} } } } } }\\\\{ { {a_{12} }\\frac{ {\\partial L} }{ {\\partial {c_{11} } } } + {a_{22} }\\frac{ {\\partial L} }{ {\\partial {c_{21} } } } } }\&{ { {a_{12} }\\frac{ {\\partial L} }{ {\\partial {c_{12} } } } + {a_{22} }\\frac{ {\\partial L} }{ {\\partial {c_{22} } } } } }\&{ { {a_{12} }\\frac{ {\\partial L} }{ {\\partial {c_{13} } } } + {a_{22} }\\frac{ {\\partial L} }{ {\\partial {c_{23} } } } } }\\\\{ { {a_{13} }\\frac{ {\\partial L} }{ {\\partial {c_{11} } } } + {a_{23} }\\frac{ {\\partial L} }{ {\\partial {c_{21} } } } } }\&{ { {a_{13} }\\frac{ {\\partial L} }{ {\\partial {c_{12} } } } + {a_{23} }\\frac{ {\\partial L} }{ {\\partial {c_{22} } } } } }\&{ { {a_{13} }\\frac{ {\\partial L} }{ {\\partial {c_{13} } } } + {a_{23} }\\frac{ {\\partial L} }{ {\\partial {c_{23} } } } } }\\\\{ { {a_{14} }\\frac{ {\\partial L} }{ {\\partial {c_{11} } } } + {a_{24} }\\frac{ {\\partial L} }{ {\\partial {c_{21} } } } } }\&{ { {a_{14} }\\frac{ {\\partial L} }{ {\\partial {c_{12} } } } + {a_{24} }\\frac{ {\\partial L} }{ {\\partial {c_{22} } } } } }\&{ { {a_{14} }\\frac{ {\\partial L} }{ {\\partial {c_{13} } } } + {a_{24} }\\frac{ {\\partial L} }{ {\\partial {c_{23} } } } } }\\end{array} } \\right \tag{16} ∂B∂L= ∂b11∂L∂b21∂L∂b31∂L∂b41∂L∂b12∂L∂b22∂L∂b32∂L∂b42∂L∂b13∂L∂b23∂L∂b33∂L∂b43∂L = a11∂c11∂L+a21∂c21∂La12∂c11∂L+a22∂c21∂La13∂c11∂L+a23∂c21∂La14∂c11∂L+a24∂c21∂La11∂c12∂L+a21∂c22∂La12∂c12∂L+a22∂c22∂La13∂c12∂L+a23∂c22∂La14∂c12∂L+a24∂c22∂La11∂c13∂L+a21∂c23∂La12∂c13∂L+a22∂c23∂La13∂c13∂L+a23∂c23∂La14∂c13∂L+a24∂c23∂L (16)

This can be rewritten as a matrix product.

∂ L ∂ B = a 11 a 21 a 12 a 22 a 13 a 23 a 14 a 24 ∂ L ∂ c 11 ∂ L ∂ c 12 ∂ L ∂ c 13 ∂ L ∂ c 21 ∂ L ∂ c 22 ∂ L ∂ c 23 (17) \frac{ {\partial L} }{ {\partial \boldsymbol {B} } } = \left {\\begin{array}{}{ {a_{11} } }\&{ {a_{21} } }\\\\{ {a_{12} } }\&{ {a_{22} } }\\\\{ {a_{13} } }\&{ { a_{23 } } }\\\\{ {a_{14} } }\&{ {a_{24} } }\\end{array} } \\right\left {\\begin{array}{}{\\frac{ {\\partial L} }{ {\\partial {c_{11} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {c_{12} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {c_{13} } } } }\\\\{\\frac{ {\\partial L} }{ {\\partial {c_{21} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {c_{22} } } } }\&{\\frac{ {\\partial L} }{ {\\partial {c_{23} } } } }\\end{array} } \\right \tag{17} ∂B∂L= a11a12a13a14a21a22a23a24 ∂c11∂L∂c21∂L∂c12∂L∂c22∂L∂c13∂L∂c23∂L(17)

In fact, the first matrix is the transpose of A \boldsymbol {A} A and the second matrix is the upstream derivative ∂ L ∂ C \frac{ {\partial L} }{ {\partial \boldsymbol {C} } } ∂C∂L. Then we have

∂ L ∂ B = A T ∂ L ∂ C (18) \frac{ {\partial L} }{ {\partial \boldsymbol {B} } } = {\boldsymbol {A} ^T}\frac{ {\partial L} }{ {\partial \boldsymbol {C} } } \tag{18} ∂B∂L=AT∂C∂L(18)

Equation (18) shows that, for a matrix multiplication C = A B \boldsymbol {C} = \boldsymbol{A}\boldsymbol{B} C=AB in a neural network, the derivative of the loss L L L w.r.t matrix B \boldsymbol {B} B equals the transpose of matrix A \boldsymbol {A} A times the upstream derivative ∂ L ∂ C \frac{ {\partial L} }{ {\partial \boldsymbol {C} } } ∂C∂L. Let's check the dimensions once more. On the left hand side of Equation (18), ∂ L ∂ B \frac{ {\partial L} }{ {\partial \boldsymbol {B} } } ∂B∂L has a dimension of k × n k \times n k×n, the same as B \boldsymbol {B} B. On the right hand side, A T {\boldsymbol {A} ^T} AT has a dimension of k × m k \times m k×m and ∂ L ∂ C \frac{ {\partial L} }{ {\partial \boldsymbol {C} } } ∂C∂L has a dimension of m × n m \times n m×n; therefore, their matrix product has a dimension of k × n k \times n k×n and matches that of ∂ L ∂ B \frac{ {\partial L} }{ {\partial \boldsymbol {B} } } ∂B∂L.

Again, the above derivations can be generalized to any matrix multiplication. If you have time, you can derive it by yourself, just make sure the subscript indices are correct.

3. Custom implementations and validation

With the derived Equations (13) and (18), it is in fact pretty easy to implement the backward pass of matrix multiplication. Please see the example implementation on GitHub for a network that simply takes the mean of the matrix product C = A B \boldsymbol {C} = \boldsymbol {A} \boldsymbol {B} C=AB as the loss. The core part is just a 3-line code as demonstrated below.

复制代码
grad_C_manual = (torch.ones(C.shape, dtype=torch.float64) / C.numel())

grad_A_manual = grad_C_manual.mm(B.t())
grad_B_manual = A.t().mm(grad_C_manual)

The first line calculate the derivative of the loss w.r.t C \boldsymbol {C} C for the mean operation in Equation (4), which serves as the upstream gradient for ∂ L ∂ A \frac{ {\partial L} }{ {\partial \boldsymbol {A} } } ∂A∂L and ∂ L ∂ B \frac{ {\partial L} }{ {\partial \boldsymbol {B} } } ∂B∂L.

The second and third lines compute ∂ L ∂ A \frac{ {\partial L} }{ {\partial \boldsymbol {A} } } ∂A∂L and ∂ L ∂ B \frac{ {\partial L} }{ {\partial \boldsymbol {B} } } ∂B∂L using the chain rule based on Equations (13) and (18), respectively. The function t ( ) t() t() is a matrix transpose operation.

To validate our derivations and implementation, we compared these results with those from Torch built-in implementation via l o s s . b a c k w a r d ( ) loss.backward() loss.backward() and they matched.

Demo_MatrixMultiplication_backward.py
https://github.com/coolgpu/Demo_Matrix_Multiplication_backward/blob/master/Demo_MatrixMultiplication_backward.py

复制代码
#!/usr/bin/env python
# coding=utf-8

import matplotlib

import torch

print(matplotlib.__version__)

# A is a (MxP) matrix and B is a (PxN) matrix, so C=AxB is a (MxN) matrix

M, P, N = 2, 3, 4

# torch.randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
# Returns a tensor filled with random integers generated uniformly between low (inclusive) and high (exclusive).
# requires_grad (bool, optional) - If autograd should record operations on the returned tensor. Default: False.
A = torch.randint(0, 100, (M, P), requires_grad=True, dtype=torch.float64)
B = torch.randint(0, 100, (P, N), requires_grad=True, dtype=torch.float64)

# https://pytorch.org/docs/stable/generated/torch.mm.html
# torch.mm(input, mat2, *, out=None) -> Tensor
# Performs a matrix multiplication of the matrices input and mat2.
# If `input` is a (n×m) tensor, `mat2` is a (m×p) tensor, `out` will be a (n×p) tensor.
# This function does not broadcast. For broadcasting matrix products, see torch.matmul().
C = A.mm(B)

# Tensor.retain_grad() -> None
# Enables this Tensor to have their grad populated during backward(). This is a no-op for leaf tensors.
C.retain_grad()

# calculate the loss simply as the mean of C
# torch.mean(input, *, dtype=None) -> Tensor
# Returns the mean value of all elements in the input tensor. Input must be floating point or complex.
loss = C.mean()
print(f"\nloss = {loss.item()}")

# perform build-in backpropagation
# Tensor.backward(gradient=None, retain_graph=None, create_graph=False, inputs=None)
# Computes the gradient of current tensor wrt graph leaves.
loss.backward()

print('\nA=\n', A)
print('B=\n', B)
print('C=\n', C)

print('\nbuilt-in dL/dC=\n', C.grad)
print('built-in dL/dA=\n', A.grad)
print('built-in dL/dB=\n', B.grad)
# Tensor.grad
# This attribute is None by default and becomes a Tensor the first time a call to backward() computes gradients for self.
# The attribute will then contain the gradients computed and future calls to backward() will accumulate (add) gradients into it.

# Now perform maunal calculation of the gradients dL/dC, dL/dA and dL/dB
grad_C_manual = (torch.ones(C.shape, dtype=torch.float64) / C.numel())
grad_A_manual = grad_C_manual.mm(B.t())
grad_B_manual = A.t().mm(grad_C_manual)

print('\nmanual dL/dC=\n', grad_C_manual)
print('manual dL/dA=\n', grad_A_manual)
print('manual dL/dA=\n', grad_B_manual)

diff_grad_C = grad_C_manual - C.grad
diff_grad_A = grad_A_manual - A.grad
diff_grad_B = grad_B_manual - B.grad

print('\nDifference between custom implementation and Torch built-in:')
print('diff_grad_C max difference:', diff_grad_C.abs().max().detach().numpy())
print('diff_grad_A max difference:', diff_grad_A.abs().max().detach().numpy())
print('diff_grad_B max difference:', diff_grad_B.abs().max().detach().numpy())

print('\nDone!')

/home/yongqiang/miniconda3/bin/python /home/yongqiang/stable_diffusion_work/stable_diffusion_diffusers/yongqiang.py 
3.7.1

loss = 6445.375

A=
 tensor([[ 2., 75., 68.],
        [ 1., 44.,  7.]], dtype=torch.float64, requires_grad=True)
B=
 tensor([[31., 37., 26., 41.],
        [72., 37., 76., 47.],
        [74., 76., 89., 75.]], dtype=torch.float64, requires_grad=True)
C=
 tensor([[10494.,  8017., 11804.,  8707.],
        [ 3717.,  2197.,  3993.,  2634.]], dtype=torch.float64,
       grad_fn=<MmBackward0>)

built-in dL/dC=
 tensor([[0.1250, 0.1250, 0.1250, 0.1250],
        [0.1250, 0.1250, 0.1250, 0.1250]], dtype=torch.float64)
built-in dL/dA=
 tensor([[16.8750, 29.0000, 39.2500],
        [16.8750, 29.0000, 39.2500]], dtype=torch.float64)
built-in dL/dB=
 tensor([[ 0.3750,  0.3750,  0.3750,  0.3750],
        [14.8750, 14.8750, 14.8750, 14.8750],
        [ 9.3750,  9.3750,  9.3750,  9.3750]], dtype=torch.float64)

manual dL/dC=
 tensor([[0.1250, 0.1250, 0.1250, 0.1250],
        [0.1250, 0.1250, 0.1250, 0.1250]], dtype=torch.float64)
manual dL/dA=
 tensor([[16.8750, 29.0000, 39.2500],
        [16.8750, 29.0000, 39.2500]], dtype=torch.float64,
       grad_fn=<MmBackward0>)
manual dL/dA=
 tensor([[ 0.3750,  0.3750,  0.3750,  0.3750],
        [14.8750, 14.8750, 14.8750, 14.8750],
        [ 9.3750,  9.3750,  9.3750,  9.3750]], dtype=torch.float64,
       grad_fn=<MmBackward0>)

Difference between custom implementation and Torch built-in:
diff_grad_C max difference: 0.0
diff_grad_A max difference: 0.0
diff_grad_B max difference: 0.0

Done!

Process finished with exit code 0

4. Summary

In this post, we demonstrated how to derive the gradients of matrix multiplication in neural networks. While the derivation steps seem complex, the final equations of the gradients are pretty simple and easy to implement:

∂ L ∂ A = ∂ L ∂ C B T \frac{ {\partial L} }{ {\partial \boldsymbol {A} } } = \frac{ {\partial L} }{ {\partial \boldsymbol {C} } }{\boldsymbol {B} ^T} ∂A∂L=∂C∂LBT

∂ L ∂ B = A T ∂ L ∂ C \frac{ {\partial L} }{ {\partial \boldsymbol {B} } } = {\boldsymbol {A} ^T}\frac{ {\partial L} }{ {\partial \boldsymbol {C} } } ∂B∂L=AT∂C∂L

In real neural networks applications, the matrix A \boldsymbol {A} A and B \boldsymbol {B} B typically come from the outputs of other layers. In those scenarios, the gradients ∂ L ∂ A \frac{ {\partial L} }{ {\partial \boldsymbol {A} } } ∂A∂L and ∂ L ∂ B \frac{ {\partial L} }{ {\partial \boldsymbol {B} } } ∂B∂L can serve as the upsteam gradients of those layers in backpropagation computing.

References

1 Yongqiang Cheng, https://yongqiang.blog.csdn.net/

2 Understanding Artificial Neural Networks with Hands-on Experience - Part 1. Matrix Multiplication, Its Gradients and Custom Implementations, https://coolgpu.github.io/coolgpu_blog/github/pages/2020/09/22/matrixmultiplication.html

3 U-Net: Convolutional Networks for Biomedical Image Segmentation, https://arxiv.org/abs/1505.04597

4 Deep Residual Learning for Image Recognition, https://arxiv.org/abs/1512.03385

相关推荐
祭曦念7 小时前
古诗小集开发实战:从零开发一款 HarmonyOS 古诗鉴赏应用
pytorch·深度学习·harmonyos
YOLO数据集集合8 小时前
无人机航拍街道巡检数据集 | 空中视角车辆检测、交通流量统计、违停识别、智能交通YOLO数据集10399期
深度学习·yolo·目标检测·无人机
放下华子我只抽RuiKe58 小时前
FastAPI 全栈后端(四):认证与授权
开发语言·前端·javascript·python·深度学习·react.js·fastapi
菜鸟‍9 小时前
【论文学习】Segment Anything 分割一切
深度学习·学习·计算机视觉
装不满的克莱因瓶10 小时前
自然语言处理发展历史——从规则系统到大语言模型的演进之路
网络·人工智能·python·深度学习·语言模型·自然语言处理
weixin_4082663411 小时前
H20训练CPGNET环境搭建
深度学习
装不满的克莱因瓶11 小时前
RLHF中的PPO算法——大语言模型对齐优化的核心引擎
人工智能·python·深度学习·算法·机器学习·语言模型·自然语言处理
AndrewHZ13 小时前
【LLM技术全景】开源大模型生态:如何选择适合你的基座模型?
人工智能·深度学习·语言模型·开源·llm·transformer·基座模型
硅谷秋水13 小时前
NVIDIA OmniDreams:用于闭环自动驾驶仿真、支持实时生成的世界模型
人工智能·深度学习·机器学习·计算机视觉·自动驾驶
txg66614 小时前
MirrorFuzz:利用共享漏洞与大模型的深度学习框架 API 模糊测试
人工智能·深度学习·安全·网络安全