神经网络后面的层被freeze住,会影响前面的层的梯度吗?

神经网络后面的层被freeze住,会影响前面的层的梯度吗?

答案是不会。

假设一个最简单的神经网络,它只有一个输入 x x x,一个隐藏层神经元 h h h,和一个输出层神经元 y y y,均方差损失 L L L,真实标签 t t t:

h = w 1 ⋅ x y = w 2 ⋅ h L = 1 2 ( y − t ) 2 \begin{gathered} h = w_1 \cdot x \\ y = w_2 \cdot h \\ L=\frac{1}{2}(y-t)^2 \end{gathered} h=w1⋅xy=w2⋅hL=21(y−t)2

以下分 w 2 w_2 w2是否被freeze住,即 w 2 w_2 w2.requires_grad是否为True来讨论。

情况1: w 2 w_2 w2.requires_grad为True

这种情况下, L L L对 w 1 w_1 w1的梯度为:
∂ L ∂ w 1 = ∂ L ∂ y ⋅ ∂ y ∂ h ⋅ ∂ h ∂ w 1 \frac{\partial L}{\partial w 1}=\frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial h} \cdot \frac{\partial h}{\partial w 1} ∂w1∂L=∂y∂L⋅∂h∂y⋅∂w1∂h

∂ L ∂ y = ∂ ∂ y ( 1 2 ( y − t ) 2 ) = y − t \frac{\partial L}{\partial y}=\frac{\partial}{\partial y}\left(\frac{1}{2}(y-t)^2\right)=y-t ∂y∂L=∂y∂(21(y−t)2)=y−t

∂ y ∂ h = ∂ ∂ h ( w 2 ⋅ h ) = w 2 \frac{\partial y}{\partial h}=\frac{\partial}{\partial h}\left(w_2 \cdot h\right)=w_2 ∂h∂y=∂h∂(w2⋅h)=w2

∂ h ∂ w 1 = ∂ ∂ w 1 ( w 1 ⋅ x ) = x \frac{\partial h}{\partial w_1}=\frac{\partial}{\partial w_1}\left(w_1 \cdot x\right)=x ∂w1∂h=∂w1∂(w1⋅x)=x

因此:
∂ L ∂ w 1 = ∂ L ∂ y ⋅ ∂ y ∂ h ⋅ ∂ h ∂ w 1 = ( y − t ) ⋅ w 2 ⋅ x \frac{\partial L}{\partial w 1}=\frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial h} \cdot \frac{\partial h}{\partial w 1} = (y-t) \cdot w_2 \cdot x ∂w1∂L=∂y∂L⋅∂h∂y⋅∂w1∂h=(y−t)⋅w2⋅x

情况2: w 2 w_2 w2.requires_grad为False

这种情况下, w 2 w_2 w2被视为一个常数,此时 L L L对 w 1 w_1 w1的梯度仍然为:
∂ L ∂ w 1 = ∂ L ∂ y ⋅ ∂ y ∂ h ⋅ ∂ h ∂ w 1 = ( y − t ) ⋅ w 2 ⋅ x \frac{\partial L}{\partial w 1}=\frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial h} \cdot \frac{\partial h}{\partial w 1} = (y-t) \cdot w_2 \cdot x ∂w1∂L=∂y∂L⋅∂h∂y⋅∂w1∂h=(y−t)⋅w2⋅x

因为无论 w 2 w_2 w2是否被freeze住, ∂ y ∂ h = ∂ ∂ h ( w 2 ⋅ h ) = w 2 \frac{\partial y}{\partial h}=\frac{\partial}{\partial h}\left(w_2 \cdot h\right)=w_2 ∂h∂y=∂h∂(w2⋅h)=w2这一点是不会变的。

在计算 w 1 w_1 w1的梯度时,我们并不需要 w 2 w_2 w2的梯度,而是只需要 w 2 w_2 w2这个参数值。

相关推荐
Light602 小时前
破局而立:制造业软件企业的模式重构与AI赋能新路径
人工智能·云原生·工业软件·商业模式创新·ai赋能·人机协同·制造业软件
Quintus五等升3 小时前
深度学习①|线性回归的实现
人工智能·python·深度学习·学习·机器学习·回归·线性回归
natide3 小时前
text-generateion-webui模型加载器(Model Loaders)选项
人工智能·llama
野生的码农3 小时前
码农的妇产科实习记录
android·java·人工智能
TechubNews3 小时前
2026 年观察名单:基于 a16z「重大构想」,详解稳定币、RWA 及 AI Agent 等 8 大流行趋势
大数据·人工智能·区块链
脑极体3 小时前
机器人的罪与罚
人工智能·机器人
三不原则3 小时前
故障案例:容器启动失败排查(AI运维场景)——从日志分析到根因定位
运维·人工智能·kubernetes
点云SLAM4 小时前
凸优化(Convex Optimization)理论(1)
人工智能·算法·slam·数学原理·凸优化·数值优化理论·机器人应用
会周易的程序员4 小时前
多模态AI 基于工业级编译技术的PLC数据结构解析与映射工具
数据结构·c++·人工智能·单例模式·信息可视化·架构
BlockWay4 小时前
WEEX 成为 LALIGA 西甲联赛香港及台湾地区官方区域合作伙伴
大数据·人工智能·安全