深度学习 反向传播与计算图实验

深度学习 反向传播与计算图实验笔记

这份笔记将按照实验模块,帮你梳理 计算图原理、链式法则、反向传播流程和实例解析,和之前的实验笔记保持统一风格,方便你对照复习。


一、实验工具与依赖导入

1. 依赖说明

表格

库 / 模块 作用
sympy 符号计算库,用于符号化求导、验证链式法则
numpy as np 数值计算库,实现数值求导、验证反向传播结果
matplotlib.pyplot 数据可视化库,绘制计算图、展示链式变化
matplotlib.widgets 交互式组件,用于实验中的输入与交互
lab_utils_backprop 课程自定义工具,提供计算图绘制函数

2. 完整导入代码

python

运行

复制代码
from sympy import *
import numpy as np
import re
%matplotlib widget
import matplotlib.pyplot as plt
from matplotlib.widgets import TextBox, Button
import ipywidgets as widgets
from lab_utils_backprop import *

二、计算图与链式法则基础

1. 核心概念

  • 计算图:将复杂的数学表达式拆分为多个简单节点(如加法、乘法、平方),通过节点间的数据流表示计算过程,简化反向传播的导数计算。
  • 链式法则:若 \(J = f(a)\) 且 \(a = g(w)\),则 \(\frac{\partial J}{\partial w} = \frac{\partial J}{\partial a} \times \frac{\partial a}{\partial w}\),即导数的 "连乘" 传递。

2. 示例 1:\(J=(2+3w)^2\) 的反向传播

(1)前向传播

将表达式拆分为两个节点:

  • 节点 1:\(a = 2 + 3w\)
  • 节点 2:\(J = a^2\)

当 \(w=3\) 时:

python

运行

复制代码
w = 3
a = 2 + 3 * w
J = a ** 2
print(f"a = {a}, J = {J}")

输出:a = 11, J = 121

(2)反向传播:分步求导
  • 步骤 1:求 \(\frac{\partial J}{\partial a}\)

    • 符号计算:\(J=a^2\),导数为 \(\frac{\partial J}{\partial a} = 2a\)

    • 数值验证: python

      运行

      复制代码
      a_epsilon = a + 0.001
      J_epsilon = a_epsilon ** 2
      k = (J_epsilon - J) / 0.001
      print(f"dJ/da ≈ {k}")  # 输出 ≈ 22,即 2×11
  • 步骤 2:求 \(\frac{\partial a}{\partial w}\)

    • 符号计算:\(a=2+3w\),导数为 \(\frac{\partial a}{\partial w} = 3\)

    • 数值验证: python

      运行

      复制代码
      w_epsilon = w + 0.001
      a_epsilon = 2 + 3 * w_epsilon
      k = (a_epsilon - a) / 0.001
      print(f"da/dw ≈ {k}")  # 输出 ≈ 3
  • 步骤 3:链式法则合并\(\frac{\partial J}{\partial w} = \frac{\partial J}{\partial a} \times \frac{\partial a}{\partial w} = 2a \times 3 = 6a\)当 \(a=11\) 时,\(\frac{\partial J}{\partial w} = 66\)。

(3)计算图填充结果
  • 蓝色方框(前向值):第一个方框填 11,第二个方框填 121
  • 绿色方框(导数):\(\frac{\partial a}{\partial w}=3\),\(\frac{\partial J}{\partial a}=22\),\(\frac{\partial J}{\partial w}=66\)

三、示例 2:简单神经网络的计算图与反向传播

1. 表达式与节点拆分

给定参数:\(w=-2, b=8, x=2, y=1\),目标函数:\(J = \frac{1}{2}(a - y)^2, \quad a = wx + b\)

将表达式拆分为 4 个节点:

  1. 节点 1:\(c = wx\)
  2. 节点 2:\(a = c + b\)
  3. 节点 3:\(d = a - y\)
  4. 节点 4:\(J = \frac{1}{2}d^2\)

2. 前向传播计算

python

运行

复制代码
w, b, x, y = -2, 8, 2, 1
c = w * x
a = c + b
d = a - y
J = 0.5 * (d ** 2)
print(f"c={c}, a={a}, d={d}, J={J}")

输出:c=-4, a=4, d=3, J=4.5

3. 反向传播:分步求导与链式法则

(1)从右往左求导
  • **节点 4:\(J = \frac{1}{2}d^2\)**导数:\(\frac{\partial J}{\partial d} = d\),代入 \(d=3\),得 \(\frac{\partial J}{\partial d}=3\)

  • **节点 3:\(d = a - y\)**导数:\(\frac{\partial d}{\partial a} = 1\),由链式法则得:\(\frac{\partial J}{\partial a} = \frac{\partial J}{\partial d} \times \frac{\partial d}{\partial a} = 3 \times 1 = 3\)

  • **节点 2:\(a = c + b\)**导数:\(\frac{\partial a}{\partial c} = 1\),\(\frac{\partial a}{\partial b} = 1\),由链式法则得:\(\frac{\partial J}{\partial c} = \frac{\partial J}{\partial a} \times \frac{\partial a}{\partial c} = 3 \times 1 = 3\)\(\frac{\partial J}{\partial b} = \frac{\partial J}{\partial a} \times \frac{\partial a}{\partial b} = 3 \times 1 = 3\)

  • **节点 1:\(c = wx\)**导数:\(\frac{\partial c}{\partial w} = x\),代入 \(x=2\),得 \(\frac{\partial c}{\partial w}=2\),由链式法则得:\(\frac{\partial J}{\partial w} = \frac{\partial J}{\partial c} \times \frac{\partial c}{\partial w} = 3 \times 2 = 6\)

(2)计算图填充结果
  • 蓝色方框(前向值):

    1. \(c=wx=-4\)
    2. \(a=c+b=4\)
    3. \(d=a-y=3\)
    4. \(J=\frac{1}{2}d^2=4.5\)
  • 绿色方框(导数):

    1. \(\frac{\partial c}{\partial w}=2\),\(\frac{\partial J}{\partial w}=6\)
    2. \(\frac{\partial a}{\partial c}=1\),\(\frac{\partial J}{\partial c}=3\);\(\frac{\partial a}{\partial b}=1\),\(\frac{\partial J}{\partial b}=3\)
    3. \(\frac{\partial d}{\partial a}=1\),\(\frac{\partial J}{\partial a}=3\)
    4. \(\frac{\partial J}{\partial d}=3\);d 增加 \(0.001\),J 增加约 \(3×0.001=0.003\)

四、核心知识点总结

1. 计算图的作用

  • 将复杂导数计算拆解为简单节点,每个节点的导数计算独立且易验证;
  • 清晰展示链式法则的传递过程,直观理解反向传播中梯度如何从损失函数传递到各参数。

2. 反向传播的关键步骤

  1. 前向传播:按计算图从左到右计算所有节点的输出值;
  2. 反向传播:从损失函数节点开始,按链式法则从右往左计算每个节点的梯度;
  3. 参数更新:利用参数的梯度更新参数值(梯度下降)。

3. 链式法则的本质

梯度在计算图中按 "路径" 传递,每经过一个节点,就乘以该节点的局部导数,最终得到参数的梯度。


需要我帮你整理一份反向传播的分步练习清单,包含更多不同结构的计算图题目,方便你巩固链式法则的应用吗?

深度学习 反向传播与链式法则实验笔记(续)

这份笔记将承接上一部分的计算图实验,完整梳理前向传播、反向传播分步推导、链式法则验证和代码实现,和之前的笔记保持统一风格,方便你对照复习。


一、前向传播:完整计算流程

1. 输入与参数定义

python

运行

复制代码
# 输入和参数
x = 2
w = -2
b = 8
y = 1

2. 分步前向计算

python

运行

复制代码
# 分步计算每个节点的值
c = w * x       # 节点1:乘法
a = c + b       # 节点2:加法
d = a - y       # 节点3:减法
J = d**2 / 2    # 节点4:平方+除以2

print(f"J={J}, d={d}, a={a}, c={c}")

输出结果:

  • c = -4
  • a = 4
  • d = 3
  • J = 4.5

二、反向传播:从损失函数开始,从右往左求导

反向传播的核心逻辑:从损失函数节点开始,按链式法则,从右往左依次计算每个节点的局部导数,再乘上右侧节点传来的梯度

步骤 1:节点 4:\(J = \frac{1}{2}d^2\),求 \(\frac{\partial J}{\partial d}\)

(1)符号求导

\(\frac{\partial J}{\partial d} = d\)

(2)数值验证(微小变化法)

python

运行

复制代码
d_epsilon = d + 0.001
J_epsilon = d_epsilon**2 / 2
k = (J_epsilon - J) / 0.001  # 差商近似导数
print(f"dJ/dd ≈ {k}")

输出:dJ/dd ≈ 3.0004999999997395,和符号求导结果 d=3 一致。


步骤 2:节点 3:\(d = a - y\),求 \(\frac{\partial J}{\partial a}\)

(1)先求局部导数 \(\frac{\partial d}{\partial a}\)

符号求导:\(\frac{\partial d}{\partial a} = 1\)

数值验证:

python

运行

复制代码
a_epsilon = a + 0.001
d_epsilon = a_epsilon - y
k = (d_epsilon - d) / 0.001
print(f"dd/da ≈ {k}")

输出:dd/da ≈ 1.000000000000334,和符号结果一致。

(2)链式法则合并

\(\frac{\partial J}{\partial a} = \frac{\partial J}{\partial d} \times \frac{\partial d}{\partial a} = d \times 1 = d\)

当 \(d=3\) 时,\(\frac{\partial J}{\partial a} = 3\)。

数值验证:

python

运行

复制代码
a_epsilon = a + 0.001
d_epsilon = a_epsilon - y
J_epsilon = d_epsilon**2 / 2
k = (J_epsilon - J) / 0.001
print(f"dJ/da ≈ {k}")

输出:dJ/da ≈ 3.00050000000006277,和链式法则结果一致。


步骤 3:节点 2:\(a = c + b\),求 \(\frac{\partial J}{\partial c}\) 和 \(\frac{\partial J}{\partial b}\)

(1)局部导数计算
  • \(\frac{\partial a}{\partial c} = 1\)
  • \(\frac{\partial a}{\partial b} = 1\)
(2)链式法则合并

\(\frac{\partial J}{\partial c} = \frac{\partial J}{\partial a} \times \frac{\partial a}{\partial c} = 3 \times 1 = 3\)\(\frac{\partial J}{\partial b} = \frac{\partial J}{\partial a} \times \frac{\partial a}{\partial b} = 3 \times 1 = 3\)


步骤 4:节点 1:\(c = w \times x\),求 \(\frac{\partial J}{\partial w}\)

(1)局部导数计算

符号求导:\(\frac{\partial c}{\partial w} = x\)在本例子中,\(x=2\),所以 \(\frac{\partial c}{\partial w} = 2\)。

(2)链式法则合并

\(\frac{\partial J}{\partial w} = \frac{\partial J}{\partial c} \times \frac{\partial c}{\partial w} = 3 \times 2 = 6\)

数值验证:

python

运行

复制代码
w_epsilon = w + 0.001
J_epsilon = ((w_epsilon * x + b) - y)**2 / 2
k = (J_epsilon - J) / 0.001
print(f"dJ/dw ≈ {k}")

输出:dJ/dw ≈ 6.001999999999619,和链式法则结果一致。


三、反向传播的通用步骤总结

  1. 前向传播:按计算图从左到右计算所有节点的输出值,保存每个节点的中间结果;
  2. 反向传播初始化:从损失函数节点开始,初始化梯度为 1(或根据损失函数直接求导);
  3. 从右往左遍历节点
    • 计算当前节点对输入的局部导数
    • 用局部导数乘上右侧节点传来的梯度,得到当前节点输入的梯度;
  4. 参数更新:最终得到的参数梯度(如 \(\frac{\partial J}{\partial w}\)、\(\frac{\partial J}{\partial b}\))用于梯度下降更新参数。

四、关键知识点补充

1. 链式法则的本质

梯度在计算图中按 "路径" 传递,每经过一个节点,就乘以该节点的局部导数,最终得到参数的梯度。例如:\(\frac{\partial J}{\partial w} = \frac{\partial J}{\partial d} \times \frac{\partial d}{\partial a} \times \frac{\partial a}{\partial c} \times \frac{\partial c}{\partial w}\)

2. 数值验证的意义

用 "微小变化法" 验证导数,是为了直观理解导数的定义:导数是 "输入微小变化时,输出的变化率",也可以验证链式法则的正确性。

3. 为什么不反向传播到输入 x?

在这个实验中,x 是固定的输入数据,不是需要优化的参数,因此我们只需要计算损失函数对参数 w 和 b 的梯度,不需要计算对 x 的梯度。


五、完整代码汇总

python

运行

复制代码
from sympy import *
import numpy as np

# 1. 输入与参数
x = 2
w = -2
b = 8
y = 1

# 2. 前向传播
c = w * x
a = c + b
d = a - y
J = d**2 / 2
print(f"前向结果:J={J}, d={d}, a={a}, c={c}")

# 3. 反向传播分步求导
# 节点4:J = 0.5*d^2
def dJ_dd(d):
    return d
print(f"dJ/dd = {dJ_dd(d)}")

# 节点3:d = a - y
def dd_da():
    return 1
dJ_da = dJ_dd(d) * dd_da()
print(f"dJ/da = {dJ_da}")

# 节点2:a = c + b
def da_dc():
    return 1
def da_db():
    return 1
dJ_dc = dJ_da * da_dc()
dJ_db = dJ_da * da_db()
print(f"dJ/dc = {dJ_dc}, dJ/db = {dJ_db}")

# 节点1:c = w*x
def dc_dw(x):
    return x
dJ_dw = dJ_dc * dc_dw(x)
print(f"dJ/dw = {dJ_dw}")

要不

相关推荐
Tisfy1 小时前
LeetCode 1914.循环轮转矩阵:大模拟(数组原地轮转) —— 附O(1)空间版本
算法·leetcode·矩阵·大模拟
Hello.Reader1 小时前
算法基础(三)—— 插入排序从整理扑克牌到有序数组
java·算法·排序算法
50万马克的面包1 小时前
C语言:三大基础排序算法模板 冒泡 / 选择 / 插入)
c语言·笔记·算法·排序算法
罗超驿1 小时前
3.快乐数专题学习笔记——双指针法在LeetCode 202题中的应用
java·算法·leetcode·职场和发展
无限进步_1 小时前
【C++】深入底层:自己动手实现一个哈希表
开发语言·数据结构·c++·算法·链表·散列表·visual studio
_深海凉_1 小时前
LeetCode热题100-小于 n 的最大数(字节高频题)
算法·leetcode·职场和发展
小雅痞1 小时前
[Java][Leetcode middle] 36. 有效的数独
java·算法·leetcode
paeamecium1 小时前
【PAT甲级真题】- General Palindromic Number(20)
数据结构·c++·算法·pat考试·pat
北顾笙9802 小时前
day43-数据结构力扣
数据结构·算法·leetcode