神经网络后面的层被freeze住,会影响前面的层的梯度吗?

神经网络后面的层被freeze住,会影响前面的层的梯度吗?

答案是不会。

假设一个最简单的神经网络,它只有一个输入 x x x,一个隐藏层神经元 h h h,和一个输出层神经元 y y y,均方差损失 L L L,真实标签 t t t:

h = w 1 ⋅ x y = w 2 ⋅ h L = 1 2 ( y − t ) 2 \begin{gathered} h = w_1 \cdot x \\ y = w_2 \cdot h \\ L=\frac{1}{2}(y-t)^2 \end{gathered} h=w1⋅xy=w2⋅hL=21(y−t)2

以下分 w 2 w_2 w2是否被freeze住,即 w 2 w_2 w2.requires_grad是否为True来讨论。

情况1: w 2 w_2 w2.requires_grad为True

这种情况下, L L L对 w 1 w_1 w1的梯度为:
∂ L ∂ w 1 = ∂ L ∂ y ⋅ ∂ y ∂ h ⋅ ∂ h ∂ w 1 \frac{\partial L}{\partial w 1}=\frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial h} \cdot \frac{\partial h}{\partial w 1} ∂w1∂L=∂y∂L⋅∂h∂y⋅∂w1∂h

∂ L ∂ y = ∂ ∂ y ( 1 2 ( y − t ) 2 ) = y − t \frac{\partial L}{\partial y}=\frac{\partial}{\partial y}\left(\frac{1}{2}(y-t)^2\right)=y-t ∂y∂L=∂y∂(21(y−t)2)=y−t

∂ y ∂ h = ∂ ∂ h ( w 2 ⋅ h ) = w 2 \frac{\partial y}{\partial h}=\frac{\partial}{\partial h}\left(w_2 \cdot h\right)=w_2 ∂h∂y=∂h∂(w2⋅h)=w2

∂ h ∂ w 1 = ∂ ∂ w 1 ( w 1 ⋅ x ) = x \frac{\partial h}{\partial w_1}=\frac{\partial}{\partial w_1}\left(w_1 \cdot x\right)=x ∂w1∂h=∂w1∂(w1⋅x)=x

因此:
∂ L ∂ w 1 = ∂ L ∂ y ⋅ ∂ y ∂ h ⋅ ∂ h ∂ w 1 = ( y − t ) ⋅ w 2 ⋅ x \frac{\partial L}{\partial w 1}=\frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial h} \cdot \frac{\partial h}{\partial w 1} = (y-t) \cdot w_2 \cdot x ∂w1∂L=∂y∂L⋅∂h∂y⋅∂w1∂h=(y−t)⋅w2⋅x

情况2: w 2 w_2 w2.requires_grad为False

这种情况下, w 2 w_2 w2被视为一个常数,此时 L L L对 w 1 w_1 w1的梯度仍然为:
∂ L ∂ w 1 = ∂ L ∂ y ⋅ ∂ y ∂ h ⋅ ∂ h ∂ w 1 = ( y − t ) ⋅ w 2 ⋅ x \frac{\partial L}{\partial w 1}=\frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial h} \cdot \frac{\partial h}{\partial w 1} = (y-t) \cdot w_2 \cdot x ∂w1∂L=∂y∂L⋅∂h∂y⋅∂w1∂h=(y−t)⋅w2⋅x

因为无论 w 2 w_2 w2是否被freeze住, ∂ y ∂ h = ∂ ∂ h ( w 2 ⋅ h ) = w 2 \frac{\partial y}{\partial h}=\frac{\partial}{\partial h}\left(w_2 \cdot h\right)=w_2 ∂h∂y=∂h∂(w2⋅h)=w2这一点是不会变的。

在计算 w 1 w_1 w1的梯度时,我们并不需要 w 2 w_2 w2的梯度,而是只需要 w 2 w_2 w2这个参数值。

相关推荐
Generalzy4 分钟前
langchain deepagent框架
人工智能·python·langchain
人工智能培训11 分钟前
10分钟了解向量数据库(4)
人工智能·机器学习·数据挖掘·深度学习入门·深度学习证书·ai培训证书·ai工程师证书
无忧智库17 分钟前
从“数据孤岛”到“城市大脑”:深度拆解某智慧城市“十五五”数字底座建设蓝图
人工智能·智慧城市
Rui_Freely19 分钟前
Vins-Fusion之 SFM准备篇(十二)
人工智能·算法·计算机视觉
hugerat21 分钟前
在AI的帮助下,用C++构造微型http server
linux·c++·人工智能·http·嵌入式·嵌入式linux
绿洲-_-27 分钟前
MBHM_DATASET_GUIDE
深度学习·机器学习
AI街潜水的八角28 分钟前
深度学习洪水分割系统2:含训练测试代码和数据集
人工智能·深度学习
万行31 分钟前
机器学习&第二章线性回归
人工智能·python·机器学习·线性回归
小宇的天下1 小时前
HBM(高带宽内存)深度解析:先进封装视角的技术指南
网络·人工智能
rongcj1 小时前
2026,“硅基经济”的时代正在悄然来临
人工智能