LeakyRelu链式法则

python 复制代码
# 分解版
class LeakyRelu:
    # 初始化
    def __init__(self, slope=0.1):
        # α是一个在训练时从一个均匀分布中随机选择的参数,用于控制负数区域的斜率
        self.slope = slope
        self.mask = None
    def forward(self, x):
        self.mask = (x <= 0)
        y = x.copy() # 复制输入数据,避免修改原数据
        y[self.mask] = self.slope * x[self.mask]
        return y
    def backward(self, dy):
        dx = dy.copy()
        # 将x<=0的值都赋为slope * x[self.mask]
        dx[self.mask] = self.slope * dx[self.mask]
        return dx
# np版
class LeakyRelu:
    # 初始化
    def __init__(self, slope=0.1):
        # α是一个在训练时从一个均匀分布中随机选择的参数,用于控制负数区域的斜率
        self.slope = slope
        self.mask = None
    
    def forward(self, x):
        self.mask = (x <= 0)
        # 使用np.where更高效,避免复制整个数组后再修改
        return np.where(self.mask, self.slope * x, x)
    
    def backward(self, dy):
        # 使用np.where更高效,LeakyReLU的导数是:正数区域为1,负数区域为slope
        return np.where(self.mask, self.slope * dy, dy)

链式法则

假设网络结构是:x → LeakyReLU → y → Loss

前向传播

复制代码
x = -2
y = LeakyReLU(-2) = slope * (-2) = -0.2
Loss = (y - 1)² = 1.44

反向传播要求什么?

我们要计算 ∂Loss/∂x,也就是"当 x 变化一点点时,Loss 会变化多少"。

链式法则的应用

复制代码
∂Loss/∂x = ∂Loss/∂y × ∂y/∂x
           ↑           ↑
        上游梯度    当前层的局部梯度

具体计算

  1. 上游传来的梯度(从 Loss 传来):

    复制代码
    ∂Loss/∂y = 2(y - 1) = 2(-0.2 - 1) = -2.4
  2. LeakyReLU 的局部梯度(导数):

    复制代码
    ∂y/∂x = slope = 0.1  (因为 x <= 0)
  3. 最终得到

    复制代码
    ∂Loss/∂x = (-2.4) × 0.1 = -0.24

代码对应关系

python 复制代码
def backward(self, dy):        # dy 就是上游传来的 ∂Loss/∂y = -2.4
    dx = dy.copy()             # 先复制上游梯度
    dx[self.mask] = self.slope * dx[self.mask]  # 乘以局部梯度
    return dx                  # 返回 ∂Loss/∂x

所以dx(返回值)才是真正的"导数"(∂Loss/∂x),dy(输入)只是上游传过来的梯度,两者不是同一个东西。

相关推荐
阿正的梦工坊18 分钟前
深入理解 PyTorch 中的 unsqueeze 操作
人工智能·pytorch·python
FreakStudio1 小时前
硬件版【Cursor】?aily blockly IDE尝鲜封神,实战硬伤尽显
python·单片机·嵌入式·大学生·面向对象·并行计算·电子diy·电子计算机
郝学胜-神的一滴3 小时前
Qt 入门 01-01:从零基础到商业级客户端实战
开发语言·c++·qt·程序人生·软件构建
测试员周周3 小时前
【Appium 系列】第06节-页面对象实现 — LoginPage 实战
开发语言·前端·人工智能·python·功能测试·appium·测试用例
摇滚侠3 小时前
@Autowired 和 @Resource 的区别
java·开发语言
2301_783848653 小时前
优化文本分类中堆叠模型的网格搜索性能:避免训练卡顿的实战指南
jvm·数据库·python
Wy_编程3 小时前
go语言中的结构体
开发语言·后端·golang
SeaTunnel3 小时前
(八)收官篇 | 数据平台最后一公里:数据集成开发设计与上线治理实战
java·大数据·开发语言·白鲸开源
tzc_fly4 小时前
AnisoAlign:各向异性模态对齐
人工智能·深度学习·机器学习