神经网络之链式法则

一、什么是链式法则?

链式法则是微积分中用于处理复合函数求导的基本规则。

🧠 基本思想:

当一个变量依赖另一个变量,而另一个变量又依赖另一个变量时,总的变化率是每一步变化率的乘积


🎯 经典形式(单变量):

设:

z=f(y),y=g(x)⇒z=f(g(x)) z = f(y), \quad y = g(x) \Rightarrow z = f(g(x)) z=f(y),y=g(x)⇒z=f(g(x))

那么链式法则是:

dzdx=dzdy⋅dydx \frac{dz}{dx} = \frac{dz}{dy} \cdot \frac{dy}{dx} dxdz=dydz⋅dxdy

即:

"总导数 = 每层局部导数的乘积"


二、链式法则中:何时是乘法?何时是加法?

链式法则的核心操作是乘法 ,但在计算图(如神经网络)中,有些情况却需要加法 。这两种形式其实不是矛盾,而是出现在不同的结构中。

我们来分别说明:


✅ 1. 使用乘法的场景:函数嵌套

📌 结构:

x→y=g(x)→z=f(y)⇒z=f(g(x)) x \rightarrow y = g(x) \rightarrow z = f(y) \Rightarrow z = f(g(x)) x→y=g(x)→z=f(y)⇒z=f(g(x))

🧮 梯度计算:

dzdx=dzdy⋅dydx \frac{dz}{dx} = \frac{dz}{dy} \cdot \frac{dy}{dx} dxdz=dydz⋅dxdy

✅ 特点:
  • 每个函数的输出作为下一个函数的输入
  • 形成一条链
  • 每一步都"放大"或"缩小"输入的变化
  • 所以梯度乘起来
📎 举例:

z=sin⁡(x2)⇒dzdx=cos⁡(x2)⋅2x z = \sin(x^2) \Rightarrow \frac{dz}{dx} = \cos(x^2) \cdot 2x z=sin(x2)⇒dxdz=cos(x2)⋅2x


✅ 2. 使用加法的场景:分支结构(多路径传播)

📌 结构:

x→f(x)=a,x→g(x)=b,L=a+b⇒L=f(x)+g(x) x \rightarrow f(x) = a,\quad x \rightarrow g(x) = b,\quad L = a + b \Rightarrow L = f(x) + g(x) x→f(x)=a,x→g(x)=b,L=a+b⇒L=f(x)+g(x)

🧮 梯度计算:

dLdx=dadx+dbdx \frac{dL}{dx} = \frac{da}{dx} + \frac{db}{dx} dxdL=dxda+dxdb

✅ 特点:
  • 一个变量 xxx 同时用于多个地方(有多个"输出"分支)
  • 每条路径都有独立的影响
  • 梯度在反向传播时合流 ,进行加法
📎 举例:

L=x2+sin⁡(x)⇒dLdx=2x+cos⁡(x) L = x^2 + \sin(x) \Rightarrow \frac{dL}{dx} = 2x + \cos(x) L=x2+sin(x)⇒dxdL=2x+cos(x)


三、为什么是乘法?为什么是加法?

🔷 为什么是乘法(链式结构)?

源于导数的定义:

dfdx=lim⁡Δx→0ΔfΔx \frac{df}{dx} = \lim_{\Delta x \to 0} \frac{\Delta f}{\Delta x} dxdf=Δx→0limΔxΔf

如果:

x→y→z x \rightarrow y \rightarrow z x→y→z

则:

Δz=dzdy⋅Δy=dzdy⋅dydx⋅Δx \Delta z = \frac{dz}{dy} \cdot \Delta y = \frac{dz}{dy} \cdot \frac{dy}{dx} \cdot \Delta x Δz=dydz⋅Δy=dydz⋅dxdy⋅Δx

所以:

dzdx=dzdy⋅dydx \frac{dz}{dx} = \frac{dz}{dy} \cdot \frac{dy}{dx} dxdz=dydz⋅dxdy

变化被层层放大/缩小 → 所以是乘法


🔶 为什么是加法(分支结构)?

当一个变量影响多个路径时,每条路径都会对最终输出产生一部分影响。

例如:

L=f(x)+g(x) L = f(x) + g(x) L=f(x)+g(x)

那么:

ΔL=Δf+Δg=f′(x)Δx+g′(x)Δx=(f′(x)+g′(x))Δx \Delta L = \Delta f + \Delta g = f'(x)\Delta x + g'(x)\Delta x = (f'(x) + g'(x)) \Delta x ΔL=Δf+Δg=f′(x)Δx+g′(x)Δx=(f′(x)+g′(x))Δx

所以:

dLdx=f′(x)+g′(x) \frac{dL}{dx} = f'(x) + g'(x) dxdL=f′(x)+g′(x)

多个路径独立贡献 → 所以是加法


✅ 总结表格:链式法则中的乘法 vs 加法

项目 使用场景 结构形式 导数公式 原因
🔗 乘法 函数嵌套 z=f(g(x))z = f(g(x))z=f(g(x)) dzdx=f′(g(x))⋅g′(x)\frac{dz}{dx} = f'(g(x)) \cdot g'(x)dxdz=f′(g(x))⋅g′(x) 变化层层放大
➕ 加法 分支结构 L=f(x)+g(x)L = f(x) + g(x)L=f(x)+g(x) dLdx=f′(x)+g′(x)\frac{dL}{dx} = f'(x) + g'(x)dxdL=f′(x)+g′(x) 多路径贡献叠加

✅ 一句话总结

链式法则的本质是"乘法"传播,但当一个变量影响多个路径时,每条路径的梯度会在反向传播时加和** ------ 所以在分支结构中使用"加法"。**

相关推荐
曼城的天空是蓝色的9 小时前
GroupNet:基于多尺度神经网络的交互推理轨迹预测
深度学习·计算机视觉
zl_vslam9 小时前
SLAM中的非线性优-3D图优化之轴角在Opencv-PNP中的应用(一)
前端·人工智能·算法·计算机视觉·slam se2 非线性优化
koo3649 小时前
李宏毅机器学习笔记43
人工智能·笔记·机器学习
lzjava20249 小时前
Spring AI使用知识库增强对话功能
人工智能·python·spring
B站_计算机毕业设计之家9 小时前
深度血虚:Django水果检测识别系统 CNN卷积神经网络算法 python语言 计算机 大数据✅
python·深度学习·计算机视觉·信息可视化·分类·cnn·django
Francek Chen10 小时前
【自然语言处理】预训练05:全局向量的词嵌入(GloVe)
人工智能·pytorch·深度学习·自然语言处理·glove
这张生成的图像能检测吗10 小时前
(论文速读)LyT-Net:基于YUV变压器的轻量级微光图像增强网络
图像处理·人工智能·计算机视觉·低照度
许泽宇的技术分享10 小时前
AI黑客来袭:Strix如何用大模型重新定义渗透测试游戏规则
人工智能
Oxo Security10 小时前
【AI安全】检索增强生成(RAG)
人工智能·安全·网络安全·ai
少林码僧10 小时前
2.3 Transformer 变体与扩展:BERT、GPT 与多模态模型
人工智能·gpt·ai·大模型·bert·transformer·1024程序员节