神经网络后面的层被freeze住,会影响前面的层的梯度吗?

神经网络后面的层被freeze住,会影响前面的层的梯度吗?

答案是不会。

假设一个最简单的神经网络,它只有一个输入 x x x,一个隐藏层神经元 h h h,和一个输出层神经元 y y y,均方差损失 L L L,真实标签 t t t:

h = w 1 ⋅ x y = w 2 ⋅ h L = 1 2 ( y − t ) 2 \begin{gathered} h = w_1 \cdot x \\ y = w_2 \cdot h \\ L=\frac{1}{2}(y-t)^2 \end{gathered} h=w1⋅xy=w2⋅hL=21(y−t)2

以下分 w 2 w_2 w2是否被freeze住,即 w 2 w_2 w2.requires_grad是否为True来讨论。

情况1: w 2 w_2 w2.requires_grad为True

这种情况下, L L L对 w 1 w_1 w1的梯度为:
∂ L ∂ w 1 = ∂ L ∂ y ⋅ ∂ y ∂ h ⋅ ∂ h ∂ w 1 \frac{\partial L}{\partial w 1}=\frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial h} \cdot \frac{\partial h}{\partial w 1} ∂w1∂L=∂y∂L⋅∂h∂y⋅∂w1∂h

∂ L ∂ y = ∂ ∂ y ( 1 2 ( y − t ) 2 ) = y − t \frac{\partial L}{\partial y}=\frac{\partial}{\partial y}\left(\frac{1}{2}(y-t)^2\right)=y-t ∂y∂L=∂y∂(21(y−t)2)=y−t

∂ y ∂ h = ∂ ∂ h ( w 2 ⋅ h ) = w 2 \frac{\partial y}{\partial h}=\frac{\partial}{\partial h}\left(w_2 \cdot h\right)=w_2 ∂h∂y=∂h∂(w2⋅h)=w2

∂ h ∂ w 1 = ∂ ∂ w 1 ( w 1 ⋅ x ) = x \frac{\partial h}{\partial w_1}=\frac{\partial}{\partial w_1}\left(w_1 \cdot x\right)=x ∂w1∂h=∂w1∂(w1⋅x)=x

因此:
∂ L ∂ w 1 = ∂ L ∂ y ⋅ ∂ y ∂ h ⋅ ∂ h ∂ w 1 = ( y − t ) ⋅ w 2 ⋅ x \frac{\partial L}{\partial w 1}=\frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial h} \cdot \frac{\partial h}{\partial w 1} = (y-t) \cdot w_2 \cdot x ∂w1∂L=∂y∂L⋅∂h∂y⋅∂w1∂h=(y−t)⋅w2⋅x

情况2: w 2 w_2 w2.requires_grad为False

这种情况下, w 2 w_2 w2被视为一个常数,此时 L L L对 w 1 w_1 w1的梯度仍然为:
∂ L ∂ w 1 = ∂ L ∂ y ⋅ ∂ y ∂ h ⋅ ∂ h ∂ w 1 = ( y − t ) ⋅ w 2 ⋅ x \frac{\partial L}{\partial w 1}=\frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial h} \cdot \frac{\partial h}{\partial w 1} = (y-t) \cdot w_2 \cdot x ∂w1∂L=∂y∂L⋅∂h∂y⋅∂w1∂h=(y−t)⋅w2⋅x

因为无论 w 2 w_2 w2是否被freeze住, ∂ y ∂ h = ∂ ∂ h ( w 2 ⋅ h ) = w 2 \frac{\partial y}{\partial h}=\frac{\partial}{\partial h}\left(w_2 \cdot h\right)=w_2 ∂h∂y=∂h∂(w2⋅h)=w2这一点是不会变的。

在计算 w 1 w_1 w1的梯度时,我们并不需要 w 2 w_2 w2的梯度,而是只需要 w 2 w_2 w2这个参数值。

相关推荐
serve the people1 分钟前
神经网络中梯度计算求和公式求导问题
神经网络·算法·机器学习
云卓SKYDROID3 分钟前
无人机投屏技术解码过程详解!
人工智能·5g·音视频·无人机·科普·高科技·云卓科技
zy_destiny9 分钟前
【YOLOv12改进trick】三重注意力TripletAttention引入YOLOv12中,实现遮挡目标检测涨点,含创新点Python代码,方便发论文
网络·人工智能·python·深度学习·yolo·计算机视觉·三重注意力
自由的晚风11 分钟前
深度学习在SSVEP信号分类中的应用分析
人工智能·深度学习·分类
大数据追光猿11 分钟前
【大模型技术】LlamaFactory 的原理解析与应用
人工智能·python·机器学习·docker·语言模型·github·transformer
Start_Present24 分钟前
Pytorch 第七回:卷积神经网络——VGG模型
pytorch·python·神经网络·cnn·分类算法
玩电脑的辣条哥27 分钟前
大模型LoRA微调训练原理是什么?
人工智能·lora·微调
极客BIM工作室33 分钟前
DeepSeek V3 源码:从入门到放弃!
人工智能
神秘的土鸡1 小时前
如何在WPS中接入DeepSeek并使用OfficeAI助手(超细!成功版本)
人工智能·机器学习·自然语言处理·数据分析·llama·wps
fydw_7151 小时前
PreTrainedModel 类代码分析:_load_pretrained_model
人工智能·pytorch