LeakyRelu链式法则

python 复制代码
# 分解版
class LeakyRelu:
    # 初始化
    def __init__(self, slope=0.1):
        # α是一个在训练时从一个均匀分布中随机选择的参数,用于控制负数区域的斜率
        self.slope = slope
        self.mask = None
    def forward(self, x):
        self.mask = (x <= 0)
        y = x.copy() # 复制输入数据,避免修改原数据
        y[self.mask] = self.slope * x[self.mask]
        return y
    def backward(self, dy):
        dx = dy.copy()
        # 将x<=0的值都赋为slope * x[self.mask]
        dx[self.mask] = self.slope * dx[self.mask]
        return dx
# np版
class LeakyRelu:
    # 初始化
    def __init__(self, slope=0.1):
        # α是一个在训练时从一个均匀分布中随机选择的参数,用于控制负数区域的斜率
        self.slope = slope
        self.mask = None
    
    def forward(self, x):
        self.mask = (x <= 0)
        # 使用np.where更高效,避免复制整个数组后再修改
        return np.where(self.mask, self.slope * x, x)
    
    def backward(self, dy):
        # 使用np.where更高效,LeakyReLU的导数是:正数区域为1,负数区域为slope
        return np.where(self.mask, self.slope * dy, dy)

链式法则

假设网络结构是:x → LeakyReLU → y → Loss

前向传播

复制代码
x = -2
y = LeakyReLU(-2) = slope * (-2) = -0.2
Loss = (y - 1)² = 1.44

反向传播要求什么?

我们要计算 ∂Loss/∂x,也就是"当 x 变化一点点时,Loss 会变化多少"。

链式法则的应用

复制代码
∂Loss/∂x = ∂Loss/∂y × ∂y/∂x
           ↑           ↑
        上游梯度    当前层的局部梯度

具体计算

  1. 上游传来的梯度(从 Loss 传来):

    复制代码
    ∂Loss/∂y = 2(y - 1) = 2(-0.2 - 1) = -2.4
  2. LeakyReLU 的局部梯度(导数):

    复制代码
    ∂y/∂x = slope = 0.1  (因为 x <= 0)
  3. 最终得到

    复制代码
    ∂Loss/∂x = (-2.4) × 0.1 = -0.24

代码对应关系

python 复制代码
def backward(self, dy):        # dy 就是上游传来的 ∂Loss/∂y = -2.4
    dx = dy.copy()             # 先复制上游梯度
    dx[self.mask] = self.slope * dx[self.mask]  # 乘以局部梯度
    return dx                  # 返回 ∂Loss/∂x

所以dx(返回值)才是真正的"导数"(∂Loss/∂x),dy(输入)只是上游传过来的梯度,两者不是同一个东西。

相关推荐
.道阻且长.5 小时前
C++ string 操作指南:接口解析
java·c语言·开发语言·c++
蚰蜒螟5 小时前
Java 对象的内存密语:从字段偏移量计算到 Unsafe 访问的完整链路
java·开发语言
星辰_mya6 小时前
CountDownLatch深度解析
java·开发语言·后端·架构
laplaya6 小时前
使用 vcpkg 管理 C++ 项目中的依赖
开发语言·c++
feixing_fx6 小时前
选择器的威力——深入理解优先级计算与层叠规则
开发语言·前端·css·前端框架·html
极光代码工作室6 小时前
基于深度学习的手写数字识别系统
人工智能·python·深度学习·神经网络·机器学习
6v6-博客6 小时前
C语言字符串中空格的表示方法
c语言·开发语言
geovindu6 小时前
python: speech to text offline
开发语言·python·语音识别
AI创界者6 小时前
告别云端限制!Sulphur 2 本地文生视频/图生视频整合包,本地部署,解压即用,保姆级部署与工作流实战
人工智能·python·aigc·音视频
于指尖飞舞6 小时前
java后端面试题(多线程极简)
java·开发语言