【pytorch】自动求导机制

基础概念汇总

Tensor是 torch.autograd中的数据类型,主要用于封装 Tensor,进行自动求导。

  • grad : data的梯度
  • grad_fn : 创建 Tensor的 Function,是自动求导的关键
  • requires_grad:指示是否需要梯度
  • is_leaf : 指示是否是叶子结点

PyTorch张量可以记住它们来自什么运算以及其起源的父张量,并且提供相对于输入的导数链。你无需手动对模型求导:不管如何嵌套,只要你给出前向传播表达式,PyTorch都会自动提供该表达式相对于其输入参数的梯度。

当设置.requires_grad = True之后,在其上进行的各种操作就会被记录下来,它将开始追踪在其上的所有操作,从而利用链式法则进行梯度传播。任何以tensor为祖先的张量都可以访问从tensor到该张量所调用的函数链。如果这些函数是可微的(大多数PyTorch张量运算都是可微的),则导数的值将自动存储在参数张量的grad属性中。

完成计算后,可以调用.backward()来完成所有梯度计算。沿着整个函数链(即计算图)计算损失的导数。此Tensor的梯度将累积到.grad属性中。调用backward会导致导数值在叶节点处累积。所以将其用于参数更新后,需要将梯度显式清零

c 复制代码
if params.grad is not None:
    params.grad.zero_()

但是,如果中间加载了不支持梯度的操作,就会发生梯度断流 。这在自己写模型时候时常发生,会导致模型无法求导。例如,在求loss时使用pil、cv2的库,导致无法反向传播。后面即使手动打开也没有用,梯度流不能被中断。或者自己写了transform函数,调用官方不支持grad_fn的函数,也会导致这样的问题。

如果不想要被继续追踪,可以调用.detach()将其从追踪记录中分离出来,可以防止将来的计算被追踪,这样梯度就传不过去了。此外,还可以用with torch.no_grad()将不想被追踪的操作代码块包裹起来,这种方法在评估模型的时候很常用,因为在评估模型时,我们并不需要计算可训练参数(requires_grad=True)的梯度。

深入

autograd 机制

Autograd 是一种反向自动微分系统。从概念上讲, autograd 记录了一个图表,记录了创建的所有操作 执行操作时的数据。提供有向无环图 其叶子是输入张量,根是输出张量。

在内部,autograd 将该图表示为 Function 对象(真正的表达式),可以是 apply() 编辑计算结果 评估图表。计算前向传播时,autograd 同时执行请求的计算并构建图表 表示计算梯度的函数(.grad_fn 每个 torch.Tensor 的属性都是该图的入口点)。 当前向传递完成后,我们在 向后传递以计算梯度。

需要注意的重要一点是,该图每次都会从头开始重新创建 迭代,这正是允许使用任意 Python 控件的原因 流语句,可以改变图形的整体形状和大小 每次迭代。您不必先对所有可能的路径进行编码 启动培训------你跑什么,你就与众不同。

拓展torch

https://pytorch.org/docs/stable/notes/extending.html

想在模型中执行计算,请实现自定义函数 不可微分或依赖于非 PyTorch 库(例如 NumPy)。如果想让操作能够与其他操作链接并使用 autograd 引擎,就得使用自定义函数。

自定义函数也可用于提高性能和 内存使用情况:如果您使用 C++ 扩展, 您可以将它们包装在 Function 中以与 autograd 交互 引擎。如果您想减少为向后传递保存的缓冲区数量, 自定义函数可用于将操作组合在一起。

第 1 步:子类化Function后,您需要定义 3 个方法

forward() 是执行该操作的代码。它可以需要 你想要多少个参数,其中一些是可选的,如果你 指定默认值。这里接受所有类型的 Python 对象。 Tensor 跟踪历史记录的参数(即, requires_grad=True)将被转换为不跟踪历史记录的内容 在调用之前,它们的使用将被记录在图表中。请注意,这 逻辑不会遍历列表/字典/任何其他数据结构,只会 考虑作为调用的直接参数的张量。你可以 返回单个 Tensor 输出,或 tuple 张量(如果有多个输出)。另外,请参阅 Function 的文档来查找有用方法的描述,这些方法可以 仅从 forward() 调用。

setup_context()(可选)。人们可以写一个"组合"forward() 接受一个 ctx 对象或(从 PyTorch 2.0 开始)一个单独的 forward() 不接受 ctx 和发生 修改的 setup_context() 方法。 应该具有计算能力, 应该具有 只负责修改(并且不进行任何计算)。 一般来说,单独的 和 更接近于如何 PyTorch 本机操作可以工作,因此更适合与各种 PyTorch 子系统组合。 请参阅组合或单独的forward() 和setup_context()了解更多详情。ctxforward()setup_context()ctxforward()setup_context()

backward()(或vjp())定义渐变公式。 它将给出与输出一样多的 Tensor 参数,每个参数 其中代表梯度 w.r.t.那个输出。重要的是永远不要修改 这些就地。它应该返回尽可能多的张量 是输入,每个输入都包含梯度 w.r.t.它是 相应的输入。如果您的输入不需要梯度 (needs_input_grad 是一个布尔值元组,表示 每个输入是否需要梯度计算),或者是非Tensor 对象,您可以返回python:None。另外,如果您有可选的 forward() 的参数你可以返回比那里更多的梯度 都是输入,只要它们都是 None。

第 2 步:使用 ctx 中的功能 正确地确保新的 Function 能够正常工作 autograd 引擎。

save_for_backward() 必须是 用于保存向后传递中使用的任何张量。非张量应该 直接存储在ctx上。如果张量既不是输入也不是输出 保存为向后,您的 Function 可能不支持双向后 (参见步骤 3)。

mark_dirty()必须习惯于 标记由转发函数就地修改的任何输入。

mark_non_differentiable()必须 用于告诉引擎输出是否不可微。经过 默认所有可微分类型的输出张量都会被设置 要求梯度。不可微类型的张量(即整数类型) 从未被标记为需要渐变。

set_materialize_grads()可 用于告诉 autograd 引擎在以下情况下优化梯度计算 通过不具体化给予向后的梯度张量,输出不依赖于输入 功能。也就是说,如果设置为 False,则 Python 中的 None 对象或"未定义张量"(张量 x 为 C++ 中的 x.define() 为 False) 不会转换为先用零填充的张量 向后调用,因此您的代码将需要处理此类对象,就好像它们是 张量用零填充。此设置的默认值为 True。

Step 3:

If your Function does not support double backward you should explicitly declare this by decorating backward with the once_differentiable(). With this decorator, attempts to perform double backward through your function will produce an error. See our double backward tutorial for more information on double backward.

验证

使用torch.autograd.gradcheck() 检查你的后向函数是否正确计算了 通过使用后向函数计算雅可比矩阵来向前推进 将值按元素与使用数值计算的雅可比行列式进行比较 有限差分。

reference

xml 复制代码
https://github.com/ShusenTang/Deep-Learning-with-PyTorch-Chinese/blob/master/docs/chapter4/4.2.md

https://tianchi.aliyun.com/forum/post/336073
https://pytorch.org/docs/stable/notes/extending.html
https://pytorch.org/tutorials/advanced/cpp_extension.html
https://pytorch.org/docs/stable/notes/autograd.html
https://pytorch.org/docs/stable/generated/torch.autograd.Function.backward.html#torch.autograd.Function.backward
相关推荐
小馒头学python10 分钟前
深度学习中的卷积神经网络:原理、结构与应用
人工智能·深度学习·cnn
2zcode10 分钟前
基于YOLOv8深度学习的脑肿瘤智能检测系统设计与实现(PyQt5界面+数据集+训练代码)
人工智能·深度学习·yolo
龙虎榜小红牛系统19 分钟前
WordCloud参数的用法:
python·wordcloud
fhf23 分钟前
感觉根本等不到35岁AI就把我裁了
前端·人工智能·程序员
hummhumm23 分钟前
第 36 章 - Go语言 服务网格
java·运维·前端·后端·python·golang·java-ee
m0_7428488827 分钟前
PyTorch3
人工智能·深度学习
achaoyang36 分钟前
【Python中while循环】
开发语言·python
lindsayshuo38 分钟前
香橙派--安装RKMPP、x264、libdrm、FFmpeg(支持rkmpp)以及opencv(支持带rkmpp的ffmpeg)(适用于RK3588平台)
人工智能·opencv·ffmpeg
soso196842 分钟前
构建与优化数据仓库-实践指南
大数据·数据仓库·人工智能
linmoo19861 小时前
java脚手架系列16-AI大模型集成
java·人工智能·ai·大模型·通义千问·qwen·脚手架