【深度学习】—— 自动微分、非标量变量的反向传播、 分离计算、 Python控制流的梯度计算

【深度学习】------ 自动微分

自动微分

求导是⼏乎所有深度学习优化算法的关键步骤。虽然求导的计算很简单,只需要⼀些基本的微积分。但对于复杂的模型,⼿⼯进⾏更新是⼀件很痛苦的事情(⽽且经常容易出错)。深度学习框架通过⾃动计算导数,即⾃动微分(automatic differentiation)来加快求导。实际中,根据我们设计的模型,系统会构建⼀个计算图(computational graph),来跟踪计算是哪些数据通过哪些操作组合起来产⽣输出。⾃动微分使系统能够随后反向传播梯度。这⾥,反向传播(backpropagate)意味着跟踪整个计算图,填充关于每个参数的偏导数。

一个简单的例子

作为一个演示例子,假设我们想对函数 y = 2 x ⊤ x y = 2\mathbf{x}^\top \mathbf{x} y=2x⊤x 关于列向量 x \mathbf{x} x 求导。首先,我们创建列向量 x \mathbf{x} x 并为其分配一个初始值。

python 复制代码
import torch
x = torch.arange(4.0).view(-1, 1)  # 将 x 转换为列向量
x

输出:

复制代码
tensor([0., 1., 2., 3.])

在我们计算 y y y 关于 x \mathbf{x} x 的梯度之前,我们需要一个地方来存储梯度。重要的是,我们不会在每次对一个参数求导时都分配新的内存。因为我们经常会成千上万次地更新相同的参数,每次都分配新的内存可能很快就会将内存耗尽。注意,一个标量函数关于向量 x \mathbf{x} x 的梯度是向量,并且与 x \mathbf{x} x 具有相同的形状。

python 复制代码
x.requires_grad_(True)  # 等价于 x = torch.arange(4.0).view(-1, 1).requires_grad_(True)
x.grad  # 默认值是 None

现在让我们计算 y y y。

python 复制代码
y = 2 * torch.dot(x, x)  
y

输出:

复制代码
tensor(28., grad_fn=<MulBackward0>)

x \mathbf{x} x 是一个 4×1 的列向量,计算 x ⊤ x \mathbf{x}^\top \mathbf{x} x⊤x 得到一个标量输出。接下来,我们通过调用反向传播函数来自动计算 y y y 关于 x \mathbf{x} x 每个分量的梯度,并打印这些梯度。

python 复制代码
y.backward()
x.grad

输出:

复制代码
tensor([ 0.,  4.,  8., 12.])

函数 y = 2 x ⊤ x y = 2\mathbf{x}^\top \mathbf{x} y=2x⊤x 关于 x \mathbf{x} x 的梯度应为 4 x 4\mathbf{x} 4x。让我们快速验证这个梯度是否计算正确。

python 复制代码
x.grad == 4 * x

输出:

复制代码
tensor([[True],
        [True],
        [True],
        [True]])

现在让我们计算 x \mathbf{x} x 的另一个函数。

python 复制代码
# 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值
x.grad.zero_()
y = x.sum()
y.backward()
x.grad

输出:

复制代码
tensor([[1.],
        [1.],
        [1.],
        [1.]])

非标量变量的反向传播

当 y y y 不是标量时,向量 y \mathbf{y} y 关于向量 x \mathbf{x} x 的导数的最自然解释是一个矩阵。对于高阶和高维的 y \mathbf{y} y 和 x \mathbf{x} x,求导的结果可以是一个高阶张量。

然而,虽然这些更奇特的对象确实出现在高级机器学习中(包括深度学习中),但当我们调用向量的反向计算时,我们通常会试图计算一批训练样本中每个组成部分的损失函数的导数。在这里,我们的目的不是计算微分矩阵,而是单独计算批量中每个样本的偏导数之和。

对非标量调用 backward 需要传入一个 gradient 参数,该参数指定微分函数关于 self 的梯度

在我们的例子中,我们只想求偏导数的和,所以传递一个 1 的梯度是合适的。

python 复制代码
x.grad.zero_()
y = x * x
# 等价于 y.backward(torch.ones(len(x)))
y.sum().backward()
x.grad

分离计算

有时,我们希望将某些计算移动到记录的计算图之外。例如,假设 y y y 是作为 x x x 的函数计算的,而 z z z 则是作为 y y y 和 x x x 的函数计算的。想象一下,我们想计算 z z z 关于 x x x 的梯度,但由于某种原因,我们希望将 y y y 视为一个常数,并且只考虑到 x x x 在 y y y 被计算后发挥的作用。

在这里,我们可以分离 y y y 来返回一个新变量 u u u,该变量与 y y y 具有相同的值,但丢弃计算图中如何计算 y y y 的任何信息。换句话说,梯度不会向后流经 u u u 到 x x x。因此,下面的反向传播函数计算 z = u ⋅ x z = u \cdot x z=u⋅x 关于 x x x 的偏导数,同时将 u u u 作为常数处理,而不是 z = x ⋅ x ⋅ x z = x \cdot x \cdot x z=x⋅x⋅x 关于 x x x 的偏导数。

python 复制代码
x.grad.zero_()
y = x * x
u = y.detach()
z = u * x
z.sum().backward()
x.grad == u

输出:

复制代码
tensor([True, True, True, True])

由于记录了 y y y 的计算结果,我们可以随后在 y y y 上调用反向传播,得到 y = x ⋅ x y = x \cdot x y=x⋅x 关于 x x x 的导数,即 2 ⋅ x 2 \cdot x 2⋅x。

python 复制代码
x.grad.zero_()
y.sum().backward()
x.grad == 2 * x

输出:

复制代码
tensor([True, True, True, True])

以下是根据您提供的内容整理成的 Markdown 源码,使用 LaTeX $ 包裹公式和符号,并确保代码格式正确:

Python控制流的梯度计算

使用自动微分的一个好处是:即使构建函数的计算图需要通过 Python 控制流(例如,条件、循环或任意函数调用),我们仍然可以计算得到的变量的梯度。在下面的代码中,while 循环的迭代次数和 if 语句的结果都取决于输入 a a a 的值。

python 复制代码
def f(a):
    b = a * 2
    while b.norm() < 1000:
        b = b * 2
    if b.sum() > 0:
        c = b
    else:
        c = 100 * b
    return c

让我们计算梯度。

python 复制代码
a = torch.randn(size=(), requires_grad=True)
d = f(a)
d.backward()

我们现在可以分析上面定义的 f f f 函数。请注意,它在其输入 a a a 中是分段线性的。换言之,对于任何 a a a,存在某个常量标量 k k k,使得 f ( a ) = k ⋅ a f(a) = k \cdot a f(a)=k⋅a,其中 k k k 的值取决于输入 a a a。因此,我们可以用 d / a d / a d/a 验证梯度是否正确。

python 复制代码
a.grad == d / a

输出:

复制代码
tensor(True)

• 深度学习框架可以⾃动计算导数:我们⾸先将梯度附加到想要对其计算偏导数的变量上。然后我们记录⽬标值的计算,执⾏它的反向传播函数,并访问得到的梯度。

相关推荐
风象南9 分钟前
Claude Code这个隐藏技能,让我告别PPT焦虑
人工智能·后端
曲幽31 分钟前
FastAPI压力测试实战:Locust模拟真实用户并发及优化建议
python·fastapi·web·locust·asyncio·test·uvicorn·workers
Mintopia1 小时前
OpenClaw 对软件行业产生的影响
人工智能
陈广亮1 小时前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬2 小时前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia2 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区2 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两5 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
敏编程5 小时前
一天一个Python库:jsonschema - JSON 数据验证利器
python