PyTorch API 2

文章目录


自动微分包 - torch.autograd

torch.autograd 提供了实现任意标量值函数自动微分的类和函数。

只需对现有代码进行最小改动------您只需要通过requires_grad=True关键字声明需要计算梯度的Tensor即可。目前,我们仅支持浮点型Tensor(包括half、float、double和bfloat16)和复数型Tensor(cfloat、cdouble)的自动微分功能。

backward 计算给定张量相对于计算图叶节点的梯度之和
grad 计算并返回输出相对于输入的梯度之和

前向模式自动微分


警告:此API目前处于测试阶段。尽管函数签名不太可能更改,但在我们将其视为稳定版本之前,计划增加更多的运算符支持。

详细使用步骤请参阅前向模式AD教程

forward_ad.dual_level 前向AD的上下文管理器,所有前向AD计算都必须在dual_level上下文中进行。
forward_ad.make_dual 将张量值与其切线关联,创建用于前向AD梯度计算的"对偶张量"。
forward_ad.unpack_dual 解包"对偶张量",获取其张量值和前向AD梯度。
forward_ad.enter_dual_level 进入一个新的前向梯度级别。
forward_ad.exit_dual_level 退出前向梯度级别。
forward_ad.UnpackedDualTensor unpack_dual()返回的命名元组,包含对偶张量的原始分量和切线分量。

高阶函数式API


警告:此API目前处于测试阶段。尽管函数签名不太可能变更,但在我们将其视为稳定版本前,计划进行重大的性能改进。

本节包含基于上述基础API构建的高阶自动微分API,可用于计算雅可比矩阵、海森矩阵等。

该API仅适用于用户提供的函数,这些函数仅接受张量作为输入并仅返回张量。

若函数包含非张量参数或未设置requires_grad的张量参数,可通过lambda表达式捕获这些参数。

例如,对于接受三个输入的函数f:一个需要计算雅可比矩阵的张量、一个应视为常量的张量、以及布尔标志f(input, constant, flag=flag),可通过functional.jacobian(lambda x: f(x, constant, flag=flag), input)方式调用。

functional.jacobian 计算给定函数的雅可比矩阵
functional.hessian 计算标量函数的黑塞矩阵
functional.vjp 计算向量v与给定函数在输入点处雅可比矩阵的点积
functional.jvp 计算给定函数在输入点处雅可比矩阵与向量v的点积
functional.vhp 计算向量v与标量函数在指定点处黑塞矩阵的点积
functional.hvp 计算标量函数在指定点处黑塞矩阵与向量v的点积

局部禁用梯度计算

有关无梯度模式与推理模式之间的区别,以及其他可能与之混淆的相关机制,请参阅局部禁用梯度计算获取更多信息。另请参考局部禁用梯度计算查看可用于局部禁用梯度的函数列表。


默认梯度布局

当非稀疏参数 paramtorch.autograd.backward()torch.Tensor.backward() 过程中接收到非稀疏梯度时,param.grad 会按以下方式累积:

param.grad 初始为 None

1、如果 param 的内存是非重叠且密集的,.grad 会创建与 param 步幅相匹配的布局(即与 param 的布局一致)。

2、否则,.grad 会创建行主序连续(rowmajor-contiguous)的步幅。


param 已存在非稀疏的 .grad 属性:

3、当 create_graph=False 时,backward() 会就地累加到 .grad 中,保持其原有步幅不变。

4、当 create_graph=True 时,backward() 会将 .grad 替换为新张量 .grad + new grad,该操作会尝试(但不保证)匹配原有 .grad 的步幅。

推荐采用默认行为(在首次 backward() 前让 .grad 保持为 None,从而根据情况1或2创建布局,并通过情况3或4保持布局)以获得最佳性能。调用 model.zero_grad()optimizer.zero_grad() 不会影响 .grad 的布局。

实际上,在每个累积阶段前将所有 .grad 重置为 None,例如:

python 复制代码
for iterations...
    ...
    for param in model.parameters():
        param.grad = None
    loss.backward()

这样每次根据1或2重新创建它们,是替代model.zero_grad()optimizer.zero_grad()的有效方法,可能提升某些网络的性能。


手动梯度布局

如果需要手动控制 .grad 的步长(strides),在首次调用 backward() 前,将一个步长符合预期的归零张量赋值给 param.grad =,并且不要将其重置为 None

3 保证只要 create_graph=False,你的布局就会被保留。

4 表示即使 create_graph=True,你的布局也可能被保留。


张量的原地操作

在自动微分系统中支持原地操作是一个复杂的问题,我们建议在大多数情况下避免使用。自动微分系统通过积极的缓冲区释放和重用机制实现了高效运行,实际上只有极少数情况下原地操作能显著降低内存使用量。除非面临严重的内存压力,否则您可能永远不需要使用它们。


原地操作的正确性检查

所有Tensor都会跟踪应用于它们的原地操作。如果实现检测到某个张量在某个函数中被保存用于反向传播,但之后被原地修改,一旦开始反向传播就会引发错误。这确保当你使用原地操作函数且没有看到任何错误时,可以确信计算出的梯度是正确的。


变量(已弃用)


警告:Variable API 已被弃用:现在使用张量进行自动求导不再需要 Variables。Autograd 已自动支持将 requires_grad 设为 True 的张量。以下是主要变更的快速指南:

  • Variable(tensor)Variable(tensor, requires_grad) 仍可正常使用,但会返回张量而非 Variables。
  • var.data 现在等同于 tensor.data
  • 诸如 var.backward()var.detach()var.register_hook() 等方法现在可直接在张量上以同名方法调用。

此外,现在可以通过工厂方法如 torch.randn()torch.zeros()torch.ones() 等直接创建带 requires_grad=True 的张量,例如:

autograd_tensor = torch.randn((2, 3, 4), requires_grad=True)


张量自动求导函数

torch.Tensor.grad 该属性默认为None,当首次调用backward()self计算梯度时,会变成一个张量。
torch.Tensor.requires_grad 如果需要为该张量计算梯度则为True,否则为False
torch.Tensor.is_leaf 按照惯例,所有requires_gradFalse的张量都是叶子张量。
torch.Tensor.backward([gradient, ...]) 计算当前张量相对于图中叶子节点的梯度。
torch.Tensor.detach 返回一个与当前计算图分离的新张量。
torch.Tensor.detach_ 将张量从创建它的计算图中分离,使其成为叶子节点。
torch.Tensor.register_hook(hook) 注册一个反向传播钩子。
torch.Tensor.register_post_accumulate_grad_hook(hook) 注册一个在梯度累积后运行的反向传播钩子。
torch.Tensor.retain_grad() 使该张量在backward()过程中能够保留其grad值。

函数


python 复制代码
class torch.autograd.Function(*args, **kwargs)

用于创建自定义 autograd.Function 的基类。

要创建自定义 autograd.Function,请继承该类并实现 forward()backward() 静态方法。在前向传播中使用自定义操作时,需调用类方法 apply,切勿直接调用 forward()

为确保正确性和最佳性能,请确保正确调用 ctx 上的方法,并使用 torch.autograd.gradcheck() 验证反向传播函数。

关于如何使用该类的更多细节,请参阅 扩展 torch.autograd


示例:

python 复制代码
>>> class Exp(Function):
>>>     @staticmethod
>>>     def forward(ctx, i):
>>>         result = i.exp()
>>>         ctx.save_for_backward(result)
>>>         return result
>>> >
>>>     @staticmethod
>>>     def backward(ctx, grad_output):
>>>         result, = ctx.saved_tensors
>>>         return grad_output * result
>>> >
>>> # Use it by calling the apply method:
>>> output = Exp.apply(input)

Function.forward 定义自定义autograd Function的前向传播逻辑。
Function.backward 定义反向模式自动微分操作的求导公式。
Function.jvp 定义前向模式自动微分操作的求导公式。
Function.vmap 定义该autograd.Function在torch.vmap()下的行为。

上下文方法混入

在创建新的 Function 时,以下方法可通过 ctx 使用。

function.FunctionCtx.mark_dirty 将给定张量标记为原地操作中被修改的
function.FunctionCtx.mark_non_differentiable 将输出标记为不可微分的
function.FunctionCtx.save_for_backward 保存给定张量以供后续调用 backward() 时使用
function.FunctionCtx.set_materialize_grads 设置是否物化梯度张量

自定义函数工具集

用于反向方法的装饰器。

function.once_differentiable

基础自定义Function类,用于构建PyTorch工具

function.BackwardCFunction 该类用于autograd内部工作。
function.InplaceFunction 此类的存在仅出于向后兼容性考虑。
function.NestedIOFunction 此类的存在仅出于向后兼容性考虑。

数值梯度检验

gradcheck 通过微小有限差分计算的梯度与解析梯度进行对比验证,适用于inputs中浮点或复数类型且requires_grad=True的张量。
gradgradcheck 通过微小有限差分计算的二阶梯度与解析梯度进行对比验证,适用于inputsgrad_outputs中浮点或复数类型且requires_grad=True的张量。
GradcheckError gradcheck()gradgradcheck()抛出的错误类型。

性能分析器

Autograd 内置了一个性能分析器,可帮助开发者检查模型中不同运算符在 CPU 和 GPU 上的计算开销。目前实现了三种分析模式:

1、纯 CPU 模式 :使用 profile 进行分析

2、基于 nvprof :通过 emit_nvtx 同时记录 CPU 和 GPU 活动

3、基于 VTune 分析器 :使用 emit_itt 实现


python 复制代码
torch.autograd.profiler.profile(enabled=True, *, use_cuda=False, use_device=None, record_shapes=False, with_flops=False, profile_memory=False, with_stack=False, with_modules=False, use_kineto=False, use_cpu=True, experimental_config=None, acc_events=False, custom_trace_id_callback=None)[source]

管理自动梯度分析器状态并保存结果摘要的上下文管理器。

底层实现是通过记录C++中执行的函数事件,并将这些事件暴露给Python。你可以将任何代码包裹其中,它只会报告PyTorch函数的运行时间。

注意:分析器是线程局部的,会自动传播到异步任务中。

参数

  • enabled ([bool], 可选) -- 设为False时,该上下文管理器将不执行任何操作。
  • use_cuda ([bool], 可选) -- 启用CUDA事件计时功能,使用cudaEvent API(即将弃用)。
  • use_device (str, 可选) -- 启用设备事件计时功能。当使用CUDA时,每个张量操作会增加约4微秒的开销。有效设备选项包括'cuda'、'xpu'、'mtia'和'privateuseone'。
  • record_shapes ([bool], 可选) -- 如果启用形状记录,将收集输入维度信息。这允许查看底层使用的维度,并通过prof.key_averages(group_by_input_shape=True)进行分组。请注意,形状记录可能会影响分析数据准确性,建议分别运行带和不带形状记录的测试来验证时间。最底层事件的偏差可能可以忽略(在嵌套函数调用情况下),但对高层函数而言,由于形状收集,总的自CPU时间可能会人为增加。
  • with_flops ([bool], 可选) -- 启用时,分析器会根据算子输入形状估算FLOPs(浮点运算)值,用于评估硬件性能。目前仅支持矩阵乘法和2D卷积算子。
  • profile_memory ([bool], 可选) -- 跟踪张量内存分配/释放。
  • with_stack ([bool], 可选) -- 记录操作对应的源代码信息(文件和行号)。
  • with_modules ([bool]) -- 记录与操作调用栈对应的模块层级(含函数名)。例如:若模块A的forward调用模块B的forward(其中包含aten::add操作),则aten::add的模块层级为A.B。注意:当前仅支持TorchScript模型,不支持eager模式模型。
  • use_kineto ([bool], 可选) -- 实验性功能,启用Kineto分析器。
  • use_cpu ([bool], 可选) -- 分析CPU事件;设为False时需同时设置use_kineto=True,可降低纯GPU分析的开销。
  • experimental_config (_ExperimentalConfig) -- 供Kineto等分析器库使用的实验性选项集,不保证向后兼容性。
  • acc_events ([bool]) -- 启用跨多个分析周期的FunctionEvents累积功能。

示例


python 复制代码
x = torch.randn((1, 1), requires_grad=True)
with torch.autograd.profiler.profile() as prof:
    for _ in range(100):  # any normal python code, really!
        y = x ** 2
        y.backward()
# NOTE: some columns were removed for brevity
print(prof.key_averages().table(sort_by="self_cpu_time_total"))

shell 复制代码
-----------------------------------  ---------------  ---------------  ---------------
Name                                 Self CPU total   CPU time avg     Number of Calls
-----------------------------------  ---------------  ---------------  ---------------
mul                                  32.048ms         32.048ms         200
pow                                  27.041ms         27.041ms         200
PowBackward0                         9.727ms          55.483ms         100
torch::autograd::AccumulateGrad      9.148ms          9.148ms          100
torch::autograd::GraphRoot           691.816us        691.816us        100
-----------------------------------  ---------------  ---------------  ---------------

profiler.profile.export_chrome_trace 将事件列表导出为Chrome追踪工具文件格式
profiler.profile.key_averages 对所有函数事件按关键字段进行平均统计
profiler.profile.self_cpu_time_total 返回CPU总耗时
profiler.profile.total_average 对所有事件进行平均统计
profiler.parse_nvprof_trace
profiler.EnforceUnique 当检测到重复键时抛出错误
profiler.KinetoStepTracker 提供全局步进计数的抽象接口
profiler.record_function 上下文管理器/函数装饰器,在运行autograd分析器时为代码块/函数添加标签
profiler_util.Interval
profiler_util.Kernel
profiler_util.MemRecordsAcc 用于快速访问区间内内存记录的加速结构
profiler_util.StringTable

python 复制代码
class torch.autograd.profiler.emit_nvtx(enabled=True, record_shapes=False)

上下文管理器,使每个自动梯度操作都生成一个NVTX范围。

在程序运行于nvprof下时非常有用。


python 复制代码
nvprof --profile-from-start off -o trace_name.prof -- <regular command here>

遗憾的是,我们无法强制 nvprof 将收集的数据刷新到磁盘,因此在 CUDA 性能分析中,必须使用此上下文管理器来标注 nvprof 跟踪信息,并等待进程退出后才能检查这些数据。

随后,可以使用 NVIDIA Visual Profiler (nvvp) 来可视化时间线,或者通过 torch.autograd.profiler.load_nvprof() 加载结果进行检查,例如在 Python REPL 环境中。

参数

  • enabled ([bool], 可选) -- 设置为 enabled=False 时,此上下文管理器将不执行任何操作。默认值:True
  • record_shapes ([bool], 可选) -- 如果 record_shapes=True,包装每个自动求导操作的 nvtx 范围将追加该操作接收的张量参数的大小信息,格式如下:[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]。非张量参数将表示为 []。参数将按照后端操作接收的顺序列出。请注意,此顺序可能与 Python 端传递参数的顺序不一致。另外需注意,记录形状可能会增加 nvtx 范围创建的开销。默认值:False

示例

python 复制代码
with torch.cuda.profiler.profile():
    model(x)  # Warmup CUDA memory allocator and profiler
    with torch.autograd.profiler.emit_nvtx():
        model(x)

前向-反向关联性

当在Nvidia Visual Profiler中查看使用emit_nvtx生成的性能分析文件时,将每个反向传播操作与对应的前向传播操作关联起来可能较为困难。

为简化这一过程,emit_nvtx会为其生成的区间附加序列号信息。

在前向传播过程中,每个函数区间会被标记seq=<N>。其中seq是一个运行计数器,每当创建新的反向Function对象并暂存用于反向传播时,该计数器就会递增。

因此,与前向函数区间关联的seq=<N>标注表明:如果该前向函数创建了反向Function对象,那么这个反向对象将获得序列号N。

在反向传播过程中,封装每个C++反向Function的apply()调用的顶层区间会被标记stashed seq=<M>。这里的M是该反向对象创建时的序列号。通过比较反向传播中的stashed seq与前向传播中的seq,可以追踪出是哪个前向操作创建了特定的反向Function。

反向传播期间执行的函数也会被标记seq=<N>。在默认反向传播(create_graph=False)时,这些信息无关紧要,实际上此类函数的N值可能均为0。只有与反向Function对象的apply()方法关联的顶层区间才有实际意义,它们可用于将这些Function对象与之前的前向传播建立关联。

双重反向传播

另一方面,如果正在进行create_graph=True的反向传播(即准备进行双重反向传播),那么反向传播期间每个函数的执行都会被赋予非零且有效的seq=<N>。这些函数本身可能会像前向传播中的原始函数那样,创建将在后续双重反向传播中执行的Function对象。

反向传播与双重反向传播之间的关系,在概念上与前向-反向关系相同:函数仍会发出带有当前序列号标记的区间,它们创建的Function对象仍会暂存这些序列号。在最终的双重反向传播期间,Function对象的apply()区间仍会带有stashed seq标记,这些标记可与反向传播阶段的序列号进行对比。


python 复制代码
class torch.autograd.profiler.emit_itt(enabled=True, record_shapes=False)

上下文管理器,使每个自动梯度(autograd)操作都生成一个ITT范围标记。

在通过Intel® VTune Profiler运行程序时,该功能非常实用。


python 复制代码
vtune <--vtune-flags<regular command here>

仪器化和追踪技术(ITT)API 能让你的应用程序在执行过程中生成并控制跨不同英特尔工具的追踪数据收集。

这个上下文管理器用于标注 Intel® VTune Profiling 追踪。借助该上下文管理器,你将能在 Intel® VTune Profiler 图形界面中看到标记的范围。

参数

  • enabled ([bool], 可选) -- 设置为 enabled=False 会使该上下文管理器不执行任何操作。默认值:True
  • record_shapes ([bool], 可选) -- 如果 record_shapes=True,包裹每个自动梯度操作的 itt 范围会附加该操作接收的张量参数的大小信息,格式如下:[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]。非张量参数将表示为 []。参数将按照后端操作接收的顺序列出。请注意,此顺序可能与 Python 端传递这些参数的顺序不一致。另外请注意,记录形状可能会增加 itt 范围创建的开销。默认值:False

示例

python 复制代码
>>> with torch.autograd.profiler.emit_itt():
...     model(x)

profiler.load_nvprof 打开一个nvprof跟踪文件并解析自动梯度注释。

调试与异常检测


python 复制代码
class torch.autograd.detect_anomaly(check_nan=True)

上下文管理器,用于为自动求导引擎启用异常检测功能。

该功能主要实现以下两个作用:

  • 在启用检测的情况下运行前向传播时,反向传播过程会打印出导致失败的反向函数所对应的前向操作调用栈。
  • check_nan设为True,任何产生"nan"值的反向计算都会触发报错(默认为True)。

警告:该模式仅应用于调试场景,因其各项检测会显著降低程序执行速度。

使用示例


python 复制代码
import torch
from torch import autograd
class MyFunc(autograd.Function):
    @staticmethod
    def forward(ctx, inp):
        return inp.clone()
    @staticmethod
    def backward(ctx, gO):
        # Error during the backward pass
        raise RuntimeError("Some error in backward")
        return gO.clone()
def run_fn(a):
    out = MyFunc.apply(a)
    return out.sum()
inp = torch.rand(10, 10, requires_grad=True)
out = run_fn(inp)
out.backward()
'''

    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
        torch.autograd.backward(self, gradient, retain_graph, create_graph)
      File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
        allow_unreachable=True)  # allow_unreachable flag
      File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
        return self._forward_cls.backward(self, *args)
      File "<stdin>", line 8, in backward
    RuntimeError: Some error in backward
'''

with autograd.detect_anomaly():
    inp = torch.rand(10, 10, requires_grad=True)
    out = run_fn(inp)
    out.backward()

  '''
  Traceback of forward call that caused the error:
    File "tmp.py", line 53, in <module>
      out = run_fn(inp)
    File "tmp.py", line 44, in run_fn
      out = MyFunc.apply(a)
  Traceback (most recent call last):
    File "<stdin>", line 4, in <module>
    File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
      torch.autograd.backward(self, gradient, retain_graph, create_graph)
    File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
      allow_unreachable=True)  # allow_unreachable flag
    File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
      return self._forward_cls.backward(self, *args)
    File "<stdin>", line 8, in backward
  RuntimeError: Some error in backward
  '''

python 复制代码
class torch.autograd.set_detect_anomaly(mode, check_nan=True)

用于开启或关闭自动梯度引擎异常检测的上下文管理器。

set_detect_anomaly 将根据参数 mode 启用或禁用自动梯度异常检测功能。

该功能既可作为上下文管理器使用,也可作为普通函数调用。关于异常检测行为的具体说明,请参阅上文 detect_anomaly 的文档。

参数说明

  • mode ([bool]) - 控制是否启用异常检测的标志位:True 表示启用,False 表示禁用
  • check_nan ([bool]) - 控制当反向传播产生"nan"值时是否触发错误的标志位
grad_mode.set_multithreading_enabled 用于开启或关闭多线程反向传播的上下文管理器

自动求导图

Autograd 提供了一系列方法,允许开发者检查计算图并在反向传播过程中插入自定义行为。

torch.Tensorgrad_fn 属性会保存一个 torch.autograd.graph.Node 对象(当该张量是由被 autograd 记录的操作产生时,即 grad_mode 已启用且至少有一个输入需要梯度),否则该属性值为 None

graph.Node.name 返回节点名称
graph.Node.metadata 返回元数据
graph.Node.next_functions
graph.Node.register_hook 注册反向传播钩子
graph.Node.register_prehook 注册反向传播前置钩子
graph.increment_version 更新 autograd 元数据以跟踪指定张量是否被原地修改

某些操作需要在正向传播过程中保存中间结果,以便执行反向传播。这些中间结果会被保存在 grad_fn 的属性中,并可供访问。


例如:

python 复制代码
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
>>> b = a.exp()
>>> print(isinstance(b.grad_fn, torch.autograd.graph.Node))
True
>>> print(dir(b.grad_fn))
['__call__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_raw_saved_result', '_register_hook_dict', '_saved_result', 'metadata', 'name', 'next_functions', 'register_hook', 'register_prehook', 'requires_grad']
>>> print(torch.allclose(b.grad_fn._saved_result, b))
True

您还可以通过钩子(hooks)定义这些保存的张量应如何打包/解包。

一个典型应用是通过将中间结果保存到磁盘或CPU(而非保留在GPU上),以计算资源换取内存空间。如果您发现模型在评估阶段能放入GPU内存但训练时不行,这一方法尤其有用。

另请参阅保存张量的钩子


python 复制代码
class torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook)

用于设置张量保存/解包钩子对的上下文管理器。

该上下文管理器用于定义操作产生的中间结果在保存前应如何打包,以及在检索时如何解包。

在此上下文中,每当操作保存一个张量用于反向传播时(包括通过save_for_backward()保存的中间结果,以及PyTorch内置操作记录的张量),都会调用pack_hook函数。随后,计算图中存储的是pack_hook的输出结果而非原始张量。

当需要访问已保存的张量时(即执行torch.Tensor.backward()torch.autograd.grad()时),会调用unpack_hook函数。该函数以pack_hook返回的打包 对象作为参数,并应返回与原始张量(即对应pack_hook的输入张量)内容完全一致的张量。

钩子函数应遵循以下签名格式:

pack_hook(tensor: Tensor) -> Any

unpack_hook(Any) -> Tensor

其中pack_hook的返回值必须能作为unpack_hook的有效输入。

通常要求unpack_hook(pack_hook(t))在数值、大小、数据类型和设备类型方面与原始张量t完全一致。


示例:

python 复制代码
>>> def pack_hook(x):
...     print("Packing", x)
...     return x
>>> >
>>> def unpack_hook(x):
...     print("Unpacking", x)
...     return x
>>> >
>>> a = torch.ones(5, requires_grad=True)
>>> b = torch.ones(5, requires_grad=True) * 2
>>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
...     y = a * b
Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
>>> y.sum().backward()
Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)

警告:对钩子函数的输入执行原地操作可能导致未定义行为。

警告:同一时间只允许存在一对钩子函数。当递归嵌套此上下文管理器时,仅最内层的一对钩子函数会生效。


python 复制代码
class torch.autograd.graph.save_on_cpu(pin_memory=False, device_type='cuda')

在该上下文管理器下,前向传播保存的张量将存储在CPU上,然后在反向传播时取回。

在此上下文管理器内执行操作时,前向传播期间保存在计算图中的中间结果将被移至CPU,当反向传播需要时再复制回原始设备。如果计算图已在CPU上,则不会执行张量复制。

使用此上下文管理器可以在计算性能和GPU内存使用之间进行权衡(例如当训练时模型无法完全放入GPU内存时)。

  • 参数

pin_memory (bool) - 如果设为True,张量在打包时会保存到CPU的固定内存中,并在解包时异步复制到GPU。默认为False。另请参阅使用固定内存缓冲区


示例:

python 复制代码
>>> a = torch.randn(5, requires_grad=True, device="cuda")
>>> b = torch.randn(5, requires_grad=True, device="cuda")
>>> c = torch.randn(5, requires_grad=True, device="cuda")
>>> >
>>> def f(a, b, c):
...     prod_1 = a * b           # a and b are saved on GPU
...     with torch.autograd.graph.save_on_cpu():
...         prod_2 = prod_1 * c  # prod_1 and c are saved on CPU
...     y = prod_2 * a           # prod_2 and a are saved on GPU
...     return y
>>> >
>>> y = f(a, b, c)
>>> del a, b, c  # for illustration only
>>> # the content of a, b, and prod_2 are still alive on GPU
>>> # the content of prod_1 and c only live on CPU
>>> y.sum().backward()  # all CPU tensors are moved back to GPU, for backward
>>> # all intermediary tensors are released (deleted) after the call to backward

python 复制代码
class torch.autograd.graph.disable_saved_tensors_hooks(error_message)

上下文管理器,用于禁用保存张量的默认钩子功能。

当您开发的功能与保存张量的默认钩子不兼容时,此功能非常有用。


参数

error_message (str) -- 当保存张量的默认钩子被禁用后仍被使用时,将抛出带有此错误信息的 RuntimeError。

返回类型 : 生成器


示例

python 复制代码
>>> message = "saved tensors default hooks are disabled"
>>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
...     # Raises RuntimeError: saved tensors default hooks are disabled
...     with torch.autograd.graph.save_on_cpu():
...         pass

python 复制代码
class torch.autograd.graph.register_multi_grad_hook(tensors, fn, *, mode='all')

注册一个多梯度反向传播钩子。

支持两种模式:"all""any"

"all" 模式下,该钩子将在计算完 tensors 中每个张量的梯度后被调用。如果一个张量在 tensors 中但不是计算图的一部分,或者如果当前 .backward().grad() 调用中指定的任何 inputs 不需要该张量来计算梯度,则该张量将被忽略,钩子不会等待其梯度被计算。

在所有未被忽略的张量的梯度计算完成后,fn 将被调用并传入这些梯度。对于未计算梯度的张量,将传入 None

"any" 模式下,该钩子将在计算完 tensors 中任意一个张量的第一个梯度后被调用。钩子将被调用,并将该梯度作为参数传入。

钩子不应修改其参数。

此函数返回一个带有 handle.remove() 方法的句柄,该方法用于移除钩子。

注意:有关此钩子的执行时机以及其执行顺序相对于其他钩子的更多信息,请参阅反向传播钩子执行


示例:

python 复制代码
>>> import torch
>>> >
>>> a = torch.rand(2, 3, requires_grad=True)
>>> b = torch.rand(2, 3, requires_grad=True)
>>> c = a * b
>>> d = a * b
>>> >
>>> def fn(grads):
...     print([g is not None for g in grads])
...
>>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn)
>>> >
>>> c.sum().backward(retain_graph=True)
[True, True, True, False]
>>> c.sum().backward(inputs=(a,), retain_graph=True)
[True, False, True, False]
>>> >

返回类型:RemovableHandle


python 复制代码
class torch.autograd.graph.allow_mutation_on_saved_tensors

允许在上下文管理器中对保存用于反向传播的张量进行修改。

在此上下文管理器下,保存用于反向传播的张量在修改时会被克隆,因此原始版本仍可在反向传播期间使用。通常情况下,修改保存用于反向传播的张量会导致在反向传播使用时引发错误。

为确保正确行为,前向传播和反向传播都应在同一上下文管理器下运行。

返回值:一个存储该上下文管理器所管理状态的 _AllowMutationOnSavedContext 对象。该对象可用于调试目的。上下文管理器管理的状态会在退出时自动清除。

返回类型:Generator[_AllowMutationOnSavedContext, None, None]


示例:

python 复制代码
>>> import torch
>>> with torch.autograd.graph.allow_mutation_on_saved_tensors():
...     # forward
...     a = torch.ones(2, 3, requires_grad=True)
...     b = a.clone()
...     out = (b**2).sum()
...     b.sin_()
...     # backward
...     out.sum().backward()
...
tensor([[0.8415, 0.8415, 0.8415], [0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>)

python 复制代码
class torch.autograd.graph.GradientEdge(node, output_nr)

表示自动微分图中特定梯度边的对象。

要获取计算给定张量梯度所在的梯度边,可以执行 edge = autograd.graph.get_gradient_edge(tensor)


python 复制代码
torch.autograd.graph.get_gradient_edge(tensor)[source]

获取用于计算给定张量梯度的梯度边。

具体而言,其等效于调用:

g = autograd.grad(loss, input)g = autograd.grad(loss, get_gradient_edge(input))

返回类型:GradientEdge



torch.library

torch.library 是一个用于扩展 PyTorch 核心算子库的 API 集合。它包含以下功能工具:

  • 测试自定义算子
  • 创建新的自定义算子
  • 扩展通过 PyTorch C++ 算子注册 API(例如 aten 算子)定义的算子

如需详细了解如何高效使用这些 API,请参阅 PyTorch 自定义算子指南页获取更多信息。


(注:根据翻译原则,我进行了以下处理:

1、保留所有代码术语如"torch.library"、"API"、"PyTorch"、"aten"等

2、将长句拆分为更易读的列表形式

3、被动语态转为主动语态

4、完整保留原始链接格式

5、删除多余空行,保持文档紧凑性)


测试自定义算子

使用 torch.library.opcheck() 来检测自定义算子是否正确使用了 Python torch.library 和/或 C++ TORCH_LIBRARY API。此外,如果算子支持训练过程,应使用 torch.autograd.gradcheck() 来验证梯度计算的数学正确性。


python 复制代码
torch.library.opcheck(op, args, kwargs=None, *, test_utils=('test_schema', 'test_autograd_registration', 'test_faketensor', 'test_aot_dispatch_dynamic'), raise_exception=True, atol=None, rtol=None)

给定一个运算符和一些示例参数,测试该运算符是否正确注册。

也就是说,当你使用 torch.library/TORCH_LIBRARY API 创建自定义操作时,你指定了关于该自定义操作的元数据(例如可变性信息),这些 API 要求你传递的函数满足某些属性(例如在 fake/meta/abstract 内核中不访问数据指针)。

opcheck 测试这些元数据和属性。

具体来说,我们测试以下内容:

  • test_schema:检查模式是否与运算符的实现匹配。例如:如果模式指定了一个张量是可变的,那么我们检查实现是否修改了该张量。如果模式指定我们返回一个新的张量,那么我们检查实现是否返回了一个新的张量(而不是现有的张量或现有张量的视图)。
  • test_autograd_registration:如果运算符支持训练(自动梯度):我们检查其自动梯度公式是否通过 torch.library.register_autograd 或手动注册到一个或多个 DispatchKey::Autograd 键。任何其他基于 DispatchKey 的注册可能导致未定义行为。
  • test_faketensor:检查运算符是否有 FakeTensor 内核(以及是否正确)。FakeTensor 内核是运算符与 PyTorch 编译 API(torch.compile/export/FX)协同工作的必要条件(但不是充分条件)。我们检查是否为运算符注册了 FakeTensor 内核(有时也称为 meta 内核)以及它是否正确。该测试比较在真实张量上运行运算符的结果和在 FakeTensors 上运行运算符的结果,检查它们是否具有相同的张量元数据(大小/步长/数据类型/设备等)。
  • test_aot_dispatch_dynamic:检查运算符在 PyTorch 编译 API(torch.compile/export/FX)中是否有正确的行为。

这检查了在 eager-mode PyTorch 和 torch.compile 下输出(以及梯度,如果适用)是否相同。

该测试是 test_faketensor 的超集,是一个端到端测试;

它还测试运算符是否支持功能化,以及反向传播(如果存在)是否也支持 FakeTensor 和功能化。

为了获得最佳结果,请使用一组具有代表性的输入多次调用 opcheck。如果你的运算符支持自动梯度,请使用 requires_grad = True 的输入调用 opcheck;如果你的运算符支持多个设备(例如 CPU 和 CUDA),请在所有支持的设备上使用 opcheck 进行测试。

参数

  • op (Union[OpOverload, OpOverloadPacket, CustomOpDef]) -- 运算符。必须是用 torch.library.custom_op() 装饰的函数或在 torch.ops.* 中找到的 OpOverload/OpOverloadPacket(例如 torch.ops.aten.sin, torch.ops.mylib.foo)。
  • args (tuple[Any, ...]) -- 运算符的参数。
  • kwargs (Optional[dict[str, Any]]) -- 运算符的关键字参数。
  • test_utils (Union[str, Sequence[str]]) -- 应运行的测试。默认:全部。例如:("test_schema", "test_faketensor")。
  • raise_exception (bool) -- 是否在第一个错误时引发异常。如果为 False,将返回一个字典,其中包含每个测试是否通过的信息。
  • rtol (Optional[float]) -- 浮点比较的相对容差。如果指定,还必须指定 atol。如果省略,则根据数据类型选择默认值(参见 torch.testing.assert_close() 中的表格)。
  • atol (Optional[float]) -- 浮点比较的绝对容差。如果指定,还必须指定 rtol。如果省略,则根据数据类型选择默认值(参见 torch.testing.assert_close() 中的表格)。

返回类型:dict[str, str]

警告:opcheck 和 torch.autograd.gradcheck() 测试不同的内容;

opcheck 测试你对 torch.library API 的使用是否正确,而 torch.autograd.gradcheck() 测试你的自动梯度公式在数学上是否正确。对于支持梯度计算的自定义操作,请同时使用两者进行测试。


示例

python 复制代码
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, y: float) -Tensor:
>>>     x_np = x.numpy(force=True)
>>>     z_np = x_np * y
>>>     return torch.from_numpy(z_np).to(x.device)
>>> >
>>> @numpy_mul.register_fake
>>> def _(x, y):
>>>     return torch.empty_like(x)
>>> >
>>> def setup_context(ctx, inputs, output):
>>>     y, = inputs
>>>     ctx.y = y
>>> >
>>> def backward(ctx, grad):
>>>     return grad * ctx.y, None
>>> >
>>> numpy_mul.register_autograd(backward, setup_context=setup_context)
>>> >
>>> sample_inputs = [
>>>     (torch.randn(3), 3.14), >>    (torch.randn(2, 3, device='cuda'), 2.718), >>    (torch.randn(1, 10, requires_grad=True), 1.234), >>    (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18), >>]
>>> >
>>> for args in sample_inputs:
>>>     torch.library.opcheck(numpy_mul, args)

在 Python 中创建新的自定义算子

使用 torch.library.custom_op() 来创建新的自定义算子。


python 复制代码
torch.library.custom_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None)

将函数封装为自定义运算符。

创建自定义运算符的常见原因包括:

  • 封装第三方库或自定义内核以兼容 PyTorch 子系统(如 Autograd)
  • 防止 torch.compile/export/FX 追踪机制探查函数内部实现

本 API 通过装饰器方式使用(参见示例)。被封装函数必须包含类型提示,这些类型提示用于与 PyTorch 各子系统交互。

参数说明

  • name (str) - 运算符命名格式为 "{命名空间}::{名称}"(例如 "mylib::my_linear")。该名称将作为 PyTorch 子系统(如 torch.export、FX 图)中的稳定标识符。为避免命名冲突,建议使用项目名称作为命名空间(例如 pytorch/fbgemm 中的所有自定义运算符均采用 "fbgemm" 作为命名空间)。
  • mutates_args (Iterable[str] 或 "unknown") - 函数会修改的参数名称列表。此信息必须准确,否则将导致未定义行为。若设为 "unknown",系统会保守假设运算符的所有输入参数都可能被修改。
  • device_types (None | str | Sequence[str]) - 函数适用的设备类型。未指定时,该函数将作为所有设备类型的默认实现。示例值:"cpu", "cuda"。注意:当为不接受张量的运算符注册设备特定实现时,要求该运算符必须包含 "device: torch.device" 参数。
  • schema (None | str) - 运算符模式字符串。推荐设为 None(默认值),系统会根据类型注解自动推导模式。除非有特殊需求,否则建议使用自动推导模式。示例格式:"(Tensor x, int y) -> (Tensor, Tensor)"。

返回类型 : Union[Callable[[Callable[[...], object]], CustomOpDef], CustomOpDef]

重要提示:建议不要手动传入 schema 参数,而是通过类型注解自动推导。手动编写模式字符串容易出错,仅当系统对类型注解的解释不符合预期时才需要自定义。关于模式字符串的编写规范,详见:
官方文档

使用示例::


python 复制代码
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>> >
>>> @custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>> >
>>> x = torch.randn(3)
>>> y = numpy_sin(x)
>>> assert torch.allclose(y, x.sin())
>>> >
>>> # Example of a custom op that only works for one device type.
>>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
>>> def numpy_sin_cpu(x: Tensor) -Tensor:
>>>     x_np = x.numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np)
>>> >
>>> x = torch.randn(3)
>>> y = numpy_sin_cpu(x)
>>> assert torch.allclose(y, x.sin())
>>> >
>>> # Example of a custom op that mutates an input
>>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
>>> def numpy_sin_inplace(x: Tensor) -None:
>>>     x_np = x.numpy()
>>>     np.sin(x_np, out=x_np)
>>> >
>>> x = torch.randn(3)
>>> expected = x.sin()
>>> numpy_sin_inplace(x)
>>> assert torch.allclose(x, expected)
>>> >
>>> # Example of a factory function
>>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu")
>>> def bar(device: torch.device) -Tensor:
>>>     return torch.ones(3)
>>> >
>>> bar("cpu")

python 复制代码
torch.library.triton_op(name, fn=None, /, *, mutates_args, schema=None)

创建一个由1个或多个triton内核支持的定制算子实现。

这是将triton内核与PyTorch结合使用的更结构化方式。

优先使用不带torch.library自定义算子包装器(如torch.library.custom_op()torch.library.triton_op())的triton内核,因为这样更简单;只有在需要创建行为类似PyTorch内置算子的情况下,才使用torch.library.custom_op()/torch.library.triton_op()

例如,当传递张量子类或在TorchDispatchMode下时,可以使用torch.library包装API来定义triton内核的行为。

当实现包含1个或多个triton内核时,使用torch.library.triton_op()而非torch.library.custom_op()torch.library.custom_op()将自定义算子视为不透明(torch.compile()torch.export.export()永远不会追踪它们),但triton_op使这些子系统能够看到实现,从而允许它们优化triton内核。

注意,fn必须仅包含对PyTorch理解的算子和triton内核的调用。在fn中调用的任何triton内核必须通过torch.library.wrap_triton()调用进行包装。

参数

  • name (str) - 自定义算子的名称,格式为"{命名空间}::{名称}",例如"mylib::my_linear"。该名称用作PyTorch子系统(如torch.export、FX图)中算子的稳定标识符。为避免名称冲突,请使用项目名称作为命名空间;例如,pytorch/fbgemm中的所有自定义算子都使用"fbgemm"作为命名空间。
  • mutates_args (Iterable[str] 或 "unknown") - 函数修改的参数名称。这必须准确,否则行为未定义。如果为"unknown",则悲观地假设算子的所有输入都被修改。
  • schema (None | str) - 算子的模式字符串。如果为None(推荐),我们将根据其类型注释推断算子的模式。除非有特定原因,否则建议让我们推断模式。示例:"(Tensor x, int y) -> (Tensor, Tensor)"。

返回类型:Callable


示例:

python 复制代码
>>> import torch
>>> from torch.library import triton_op, wrap_triton
>>> >
>>> import triton
>>> from triton import language as tl
>>> >
>>> @triton.jit
>>> def add_kernel(
>>>     in_ptr0, >>    in_ptr1, >>    out_ptr, >>    n_elements, >>    BLOCK_SIZE: "tl.constexpr", >>):
>>>     pid = tl.program_id(axis=0)
>>>     block_start = pid * BLOCK_SIZE
>>>     offsets = block_start + tl.arange(0, BLOCK_SIZE)
>>>     mask = offsets < n_elements
>>>     x = tl.load(in_ptr0 + offsets, mask=mask)
>>>     y = tl.load(in_ptr1 + offsets, mask=mask)
>>>     output = x + y
>>>     tl.store(out_ptr + offsets, output, mask=mask)
>>> >
>>> @triton_op("mylib::add", mutates_args={})
>>> def add(x: torch.Tensor, y: torch.Tensor) -torch.Tensor:
>>>     output = torch.empty_like(x)
>>>     n_elements = output.numel()
>>> >
>>>     def grid(meta):
>>>         return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>> >
>>>     # NB: we need to wrap the triton kernel in a call to wrap_triton
>>>     wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
>>>     return output
>>> >
>>> @torch.compile
>>> def f(x, y):
>>>     return add(x, y)
>>> >
>>> x = torch.randn(3, device="cuda")
>>> y = torch.randn(3, device="cuda")
>>> >
>>> z = f(x, y)
>>> assert torch.allclose(z, x + y)

python 复制代码
torch.library.wrap_triton(triton_kernel, /)

允许通过 make_fx 或非严格模式的 torch.export 将 triton 内核捕获到计算图中。

这些技术基于调度器进行追踪(通过 __torch_dispatch__),无法直接捕获原始 triton 内核的调用。

wrap_triton API 能够将 triton 内核封装为可调用对象,从而真正实现计算图的可追踪性。

建议将此 API 与 torch.library.triton_op() 配合使用。


示例

python 复制代码
>>> import torch
>>> import triton
>>> from triton import language as tl
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>> from torch.library import wrap_triton
>>> >
>>> @triton.jit
>>> def add_kernel(
>>>     in_ptr0, >>    in_ptr1, >>    out_ptr, >>    n_elements, >>    BLOCK_SIZE: "tl.constexpr", >>):
>>>     pid = tl.program_id(axis=0)
>>>     block_start = pid * BLOCK_SIZE
>>>     offsets = block_start + tl.arange(0, BLOCK_SIZE)
>>>     mask = offsets < n_elements
>>>     x = tl.load(in_ptr0 + offsets, mask=mask)
>>>     y = tl.load(in_ptr1 + offsets, mask=mask)
>>>     output = x + y
>>>     tl.store(out_ptr + offsets, output, mask=mask)
>>> >
>>> def add(x, y):
>>>     output = torch.empty_like(x)
>>>     n_elements = output.numel()
>>> >
>>>     def grid_fn(meta):
>>>         return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>> >
>>>     wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
>>>     return output
>>> >
>>> x = torch.randn(3, device="cuda")
>>> y = torch.randn(3, device="cuda")
>>> gm = make_fx(add)(x, y)
>>> print(gm.code)
>>> # def forward(self, x_1, y_1):
>>> #     empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
>>> #     triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
>>> #         kernel_idx = 0, constant_args_idx = 0, >>#         grid = [(1, 1, 1)], kwargs = {
>>> #             'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like, >>#             'n_elements': 3, 'BLOCK_SIZE': 16
>>> #         })
>>> #     return empty_like

返回类型:任意


扩展自定义算子(由Python或C++创建)

使用 torch.library.register_kernel()torch.library.register_fake() 等 register.* 方法,可以为任何算子添加实现(这些算子可能是通过 torch.library.custom_op() 创建的,或是通过 PyTorch 的 C++ 算子注册 API 创建的)。


python 复制代码
torch.library.register_kernel(op, device_types, func=None, /, *, lib=None)

为该操作符注册一个设备类型的实现。

一些有效的设备类型包括:"cpu"、"cuda"、"xla"、"mps"、"ipu"、"xpu"。

此API可用作装饰器。

参数

  • op (str | OpOverload) - 要注册实现的操作符。
  • device_types (None | str | Sequence[str]) - 要注册实现的设备类型。如果为None,将注册到所有设备类型------请仅在实现确实与设备类型无关时使用此选项。
  • func (Callable) - 注册为给定设备类型实现的函数。
  • lib (Optional[Library]) - 如果提供,此注册的生命周期

示例::


python 复制代码
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>> >
>>> # Create a custom op that works on cpu
>>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
>>> def numpy_sin(x: Tensor) -Tensor:
>>>     x_np = x.numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np)
>>> >
>>> # Add implementations for the cuda device
>>> @torch.library.register_kernel("mylib::numpy_sin", "cuda")
>>> def _(x):
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>> >
>>> x_cpu = torch.randn(3)
>>> x_cuda = x_cpu.cuda()
>>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
>>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())

python 复制代码
torch.library.register_autocast(op, device_type, cast_inputs, /, *, lib=None)

为自定义操作注册自动类型转换调度规则。

有效设备类型包括:"cpu"和"cuda"。

参数

  • op (str | OpOverload) -- 要注册自动类型转换调度规则的操作符。
  • device_type (str) -- 使用的设备类型。'cuda'或'cpu'。该类型与torch.device的类型属性相同。因此,您可以通过Tensor.device.type获取张量的设备类型。
  • cast_inputs (torch.dtype) -- 当自定义操作在启用自动类型转换的区域运行时,将传入的浮点张量转换为目标数据类型(非浮点张量不受影响),然后在禁用自动类型转换的情况下执行自定义操作。
  • lib (Optional[Library]) -- 如果提供,此注册的生命周期

示例::


python 复制代码
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> >
>>> # Create a custom op that works on cuda
>>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
>>> def my_sin(x: Tensor) -Tensor:
>>>     return torch.sin(x)
>>> >
>>> # Register autocast dispatch rule for the cuda device
>>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
>>> >
>>> x = torch.randn(3, dtype=torch.float32, device="cuda")
>>> with torch.autocast("cuda", dtype=torch.float16):
>>>     y = torch.ops.mylib.my_sin(x)
>>> assert y.dtype == torch.float16

python 复制代码
torch.library.register_autograd(op, backward, /, *, setup_context=None, lib=None)

为自定义算子注册反向传播公式。

要使算子能够与自动微分(autograd)协同工作,您需要注册一个反向传播公式:

1、必须通过提供"backward"函数来告知我们如何在反向传播过程中计算梯度。

2、如果需要使用前向传播中的值来计算梯度,可以通过setup_context保存这些值供反向传播使用。

backward函数在反向传播阶段执行,它接受(ctx, grads)参数:

  • grads是一个或多个梯度值,其数量与算子的输出数量相匹配。

ctx对象与torch.autograd.Function使用的上下文对象相同。backward_fn的语义与torch.autograd.Function.backward()完全一致。

setup_context(ctx, inputs, output)在前向传播阶段执行。请通过以下方式将反向传播所需的数据保存到ctx对象中:

如果自定义算子包含仅关键字参数,我们期望setup_context的签名为:setup_context(ctx, inputs, keyword_only_inputs, output)

setup_context_fnbackward_fn都必须是可追踪的。这意味着:

如果需要不可追踪的反向传播,可以将其实现为单独的自定义算子,在backward_fn内部调用。

如果需要在不同设备上实现不同的自动微分行为,建议为每种需要不同行为的设备创建单独的自定义算子,并在运行时进行切换。


示例

python 复制代码
>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>> >
>>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>> >
>>> def setup_context(ctx, inputs, output) -Tensor:
>>>     x, = inputs
>>>     ctx.save_for_backward(x)
>>> >
>>> def backward(ctx, grad):
>>>     x, = ctx.saved_tensors
>>>     return grad * x.cos()
>>> >
>>> torch.library.register_autograd(
...     "mylib::numpy_sin", backward, setup_context=setup_context
... )
>>> >
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_sin(x)
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, x.cos())
>>> >
>>> # Example with a keyword-only arg
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, *, val: float) -Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = x_np * val
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>> >
>>> def setup_context(ctx, inputs, keyword_only_inputs, output) -Tensor:
>>>     ctx.val = keyword_only_inputs["val"]
>>> >
>>> def backward(ctx, grad):
>>>     return grad * ctx.val
>>> >
>>> torch.library.register_autograd(
...     "mylib::numpy_mul", backward, setup_context=setup_context
... )
>>> >
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_mul(x, val=3.14)
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))

python 复制代码
torch.library.register_fake(op, func=None, /, *, lib=None, _stacklevel=1)

为该运算符注册一个 FakeTensor 实现(称为"fake impl")。

有时也被称为"元内核"或"抽象实现"。

"FakeTensor 实现"定义了该运算符在不携带数据的张量("FakeTensor")上的行为。给定具有某些属性(尺寸/步长/存储偏移量/设备)的输入张量,它规定了输出张量的属性。

FakeTensor 实现与运算符具有相同的签名。

该实现会同时作用于 FakeTensor 和元张量。编写 FakeTensor 实现时,需假设运算符的所有张量输入都是常规的 CPU/CUDA/元张量,但它们没有存储空间,而你需要返回常规的 CPU/CUDA/元张量作为输出。

FakeTensor 实现必须仅包含 PyTorch 操作(且不得直接访问任何输入或中间张量的存储或数据)。

此 API 可用作装饰器(参见示例)。

有关自定义运算符的详细指南,请参阅 https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html


示例

python 复制代码
>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>> >
>>> # Example 1: an operator without data-dependent output shape
>>> @torch.library.custom_op("mylib::custom_linear", mutates_args=())
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -Tensor:
>>>     raise NotImplementedError("Implementation goes here")
>>> >
>>> @torch.library.register_fake("mylib::custom_linear")
>>> def _(x, weight, bias):
>>>     assert x.dim() == 2
>>>     assert weight.dim() == 2
>>>     assert bias.dim() == 1
>>>     assert x.shape[1] == weight.shape[1]
>>>     assert weight.shape[0] == bias.shape[0]
>>>     assert x.device == weight.device
>>> >
>>>     return (x @ weight.t()) + bias
>>> >
>>> with torch._subclasses.fake_tensor.FakeTensorMode():
>>>     x = torch.randn(2, 3)
>>>     w = torch.randn(3, 3)
>>>     b = torch.randn(3)
>>>     y = torch.ops.mylib.custom_linear(x, w, b)
>>> >
>>> assert y.shape == (2, 3)
>>> >
>>> # Example 2: an operator with data-dependent output shape
>>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=())
>>> def custom_nonzero(x: Tensor) -Tensor:
>>>     x_np = x.numpy(force=True)
>>>     res = np.stack(np.nonzero(x_np), axis=1)
>>>     return torch.tensor(res, device=x.device)
>>> >
>>> @torch.library.register_fake("mylib::custom_nonzero")
>>> def _(x):
>>> # Number of nonzero-elements is data-dependent.
>>> # Since we cannot peek at the data in an fake impl, >># we use the ctx object to construct a new symint that >># represents the data-dependent size.
>>>     ctx = torch.library.get_ctx()
>>>     nnz = ctx.new_dynamic_size()
>>>     shape = [nnz, x.dim()]
>>>     result = x.new_empty(shape, dtype=torch.int64)
>>>     return result
>>> >
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>> >
>>> x = torch.tensor([0, 1, 2, 3, 4, 0])
>>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x)
>>> trace.print_readable()
>>> >
>>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))

python 复制代码
torch.library.register_vmap(op, func=None, /, *, lib=None)

注册一个 vmap 实现以支持该自定义操作使用 torch.vmap()

此 API 可用作装饰器(参见示例)。

为了让运算符能与 torch.vmap() 协同工作,您可能需要按以下签名注册 vmap 实现:

vmap_func(info, in_dims: Tuple[Optional[int]], args, *kwargs)

其中 *args**kwargsop 的参数和关键字参数。

我们不支持仅含关键字参数的 Tensor 参数。

该实现需指定如何计算带额外维度的输入(由 in_dims 指定)下 op 的批处理版本。

对于 args 中的每个参数,in_dims 都有一个对应的 Optional[int]。如果参数不是 Tensor 或未被 vmap 处理,则为 None;否则是一个整数,表示 Tensor 中被 vmap 处理的维度索引。

info 是可能有用的额外元数据集合:

  • info.batch_size 指定被 vmap 处理的维度大小
  • info.randomness 是传递给 torch.vmap()randomness 选项

函数 func 应返回 (output, out_dims) 元组。与 in_dims 类似,out_dims 需与 output 结构相同,并为每个输出包含一个 out_dim,指明输出是否具有 vmap 处理的维度及其索引位置。


示例

python 复制代码
>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>> from typing import Tuple
>>> >
>>> def to_numpy(tensor):
>>>     return tensor.cpu().numpy()
>>> >
>>> lib = torch.library.Library("mylib", "FRAGMENT")
>>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
>>> def numpy_cube(x: Tensor) -Tuple[Tensor, Tensor]:
>>>     x_np = to_numpy(x)
>>>     dx = torch.tensor(3 * x_np ** 2, device=x.device)
>>>     return torch.tensor(x_np ** 3, device=x.device), dx
>>> >
>>> def numpy_cube_vmap(info, in_dims, x):
>>>     result = numpy_cube(x)
>>>     return result, (in_dims[0], in_dims[0])
>>> >
>>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap)
>>> >
>>> x = torch.randn(3)
>>> torch.vmap(numpy_cube)(x)
>>> >
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, y: Tensor) -Tensor:
>>>     return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
>>> >
>>> @torch.library.register_vmap("mylib::numpy_mul")
>>> def numpy_mul_vmap(info, in_dims, x, y):
>>>     x_bdim, y_bdim = in_dims
>>>     x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
>>>     y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
>>>     result = x * y
>>>     result = result.movedim(-1, 0)
>>>     return result, 0
>>> >
>>> >
>>> x = torch.randn(3)
>>> y = torch.randn(3)
>>> torch.vmap(numpy_mul)(x, y)

注意:vmap函数的设计应确保保留整个自定义运算符的语义。

也就是说,grad(vmap(op))应当能够被grad(map(op))替代。

如果您的自定义运算符在反向传播过程中有任何特殊行为,请牢记这一点。


python 复制代码
torch.library.impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1)

该API在PyTorch 2.4中已更名为torch.library.register_fake(),请使用新名称。


python 复制代码
torch.library.get_ctx()

get_ctx() 返回当前的 AbstractImplCtx 对象。

调用 get_ctx() 仅在 fake 实现内部有效(更多使用细节请参阅 torch.library.register_fake())。

返回类型:FakeImplCtx


python 复制代码
torch.library.register_torch_dispatch(op, torch_dispatch_class, func=None, /, *, lib=None)

为给定运算符和 torch_dispatch_class 注册一个 torch_dispatch 规则。

这种方式支持通过开放注册来指定运算符与 torch_dispatch_class 之间的交互行为,而无需直接修改 torch_dispatch_class 或运算符本身。

torch_dispatch_class 可以是以下两种类型之一:

  • 带有 __torch_dispatch__ 方法的 Tensor 子类
  • TorchDispatchMode 对象

如果属于 Tensor 子类,要求 func 具有以下签名格式:
(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any

如果属于 TorchDispatchMode,则要求 func 具有以下签名格式:
(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any

其中 argskwargs 的标准化处理方式与 __torch_dispatch__ 中的规范一致(参见 torch_dispatch 调用约定)。


示例

python 复制代码
>>> import torch
>>> >
>>> @torch.library.custom_op("mylib::foo", mutates_args={})
>>> def foo(x: torch.Tensor) -torch.Tensor:
>>>     return x.clone()
>>> >
>>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
>>>     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
>>>         return func(args, *kwargs)
>>> >
>>> @torch.library.register_torch_dispatch("mylib::foo", MyMode)
>>> def _(mode, func, types, args, kwargs):
>>>     x, = args
>>>     return x + 1
>>> >
>>> x = torch.randn(3)
>>> y = foo(x)
>>> assert torch.allclose(y, x)
>>> >
>>> with MyMode():
>>>     y = foo(x)
>>> assert torch.allclose(y, x + 1)

python 复制代码
torch.library.infer_schema(prototype_function, /, *, mutates_args, op_name=None)

解析带有类型提示的给定函数的模式。该模式从函数的类型提示中推断得出,可用于定义新的运算符。

我们做出以下假设:

  • 所有输出都不会与任何输入或其他输出产生别名冲突。
  • 未指定库的字符串类型注解"device, dtype, Tensor, types"默认视为torch.。同样,未指定库的字符串类型注解"Optional, List, Sequence, Union"默认视为typing.
  • 只有mutates_args中列出的参数会被修改。如果mutates_args为"unknown",则假定运算符的所有输入都会被修改。

调用方(例如自定义操作API)需负责验证这些假设。

参数

  • prototype_function ( Callable ) - 用于从其类型注解推断模式的函数。
  • op_name (Optional[str]) - 模式中运算符的名称。如果name为None,则推断的模式中不包含名称。注意torch.library.Library.define的输入模式需要运算符名称。
  • mutates_args ("unknown" | Iterable[str]) - 函数中被修改的参数。

返回值:推断出的模式。

返回类型 : str


示例

python 复制代码
>>> def foo_impl(x: torch.Tensor) -torch.Tensor:
>>>     return x.sin()
>>> >
>>> infer_schema(foo_impl, op_name="foo", mutates_args={})
foo(Tensor x) -Tensor
>>> >
>>> infer_schema(foo_impl, mutates_args={})
(Tensor x) -Tensor

python 复制代码
class torch._library.custom_ops.CustomOpDef(namespace, name, schema, fn)

CustomOpDef 是一个函数包装器,用于将其转换为自定义操作符。

它提供了多种方法,用于为此自定义操作符注册附加行为。

您不应直接实例化 CustomOpDef,而应使用 torch.library.custom_op() API。


python 复制代码
set_kernel_enabled(device_type, enabled=True)

禁用或重新启用已为此自定义算子注册的内核。

如果内核已处于禁用/启用状态,则此操作无效。

注意:如果内核先被禁用然后重新注册,它将保持禁用状态直到再次启用。

参数

  • device_type (str) -- 要禁用/启用内核的设备类型。
  • disable ([bool]) -- 是否禁用内核(True表示禁用,False表示启用)。

示例

python 复制代码
>>> inp = torch.randn(1)
>>> >
>>> # define custom op `f`.
>>> @custom_op("mylib::f", mutates_args=())
>>> def f(x: Tensor) -Tensor:
>>>     return torch.zeros(1)
>>> >
>>> print(f(inp))  # tensor([0.]), default kernel
>>> >
>>> @f.register_kernel("cpu")
>>> def _(x):
>>>     return torch.ones(1)
>>> >
>>> print(f(inp))  # tensor([1.]), CPU kernel
>>> >
>>> # temporarily disable the CPU kernel
>>> with f.set_kernel_enabled("cpu", enabled = False):
>>>     print(f(inp))  # tensor([0.]) with CPU kernel disabled

底层 API

以下 API 是直接绑定到 PyTorch C++ 底层算子注册接口的封装。


警告:底层算子注册 API 和 PyTorch 调度器是 PyTorch 中较为复杂的概念。我们建议您尽可能使用上文提到的高级 API(无需 torch.library.Library 对象)。这篇博文 http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/ 是了解 PyTorch 调度器的入门好材料。

关于如何使用该 API 的具体示例教程,可访问 Google Colab 查看。


python 复制代码
class torch.library.Library(ns, kind, dispatch_key='')

一个用于创建库的类,这些库可用于从Python注册新运算符或覆盖现有库中的运算符。

用户可以选择性地传入一个调度键名,如果只想注册对应特定调度键的内核。

要创建覆盖现有库(名称为ns)中运算符的库,请将kind设置为"IMPL"。

要创建新库(名称为ns)以注册新运算符,请将kind设置为"DEF"。

要创建可能现有库的片段以注册运算符(并绕过给定命名空间只能有一个库的限制),请将kind设置为"FRAGMENT"。

参数

  • ns -- 库名称
  • kind -- "DEF"、"IMPL"(默认:"IMPL")、"FRAGMENT"
  • dispatch_key -- PyTorch调度键(默认:"")

python 复制代码
define(schema, alias_analysis='', *, tags=())

在ns命名空间中定义一个新运算符及其语义。

参数

  • schema -- 用于定义新运算符的函数模式。
  • alias_analysis (可选) -- 指示运算符参数的别名属性是否可以从模式中推断(默认行为)或不能("CONSERVATIVE")。
  • tags ( Tag | *Sequence [Tag ]) -- 一个或多个要应用于此运算符的torch.Tag。标记运算符会改变运算符在各种PyTorch子系统中的行为;应用前请仔细阅读torch.Tag的文档。

返回值:从模式中推断出的运算符名称。


示例:

python 复制代码
>>> my_lib = Library("mylib", "DEF")
>>> my_lib.define("sum(Tensor self) -Tensor")

python 复制代码
fallback(fn, dispatch_key='', *, with_keyset=False)

将函数实现注册为给定键的回退处理程序。

此函数仅适用于具有全局命名空间("_")的库。

参数

  • fn -- 用作给定调度键回退的函数,或使用 fallthrough_kernel() 注册穿透回退。
  • dispatch_key -- 输入函数应注册到的调度键。默认使用创建库时指定的调度键。
  • with_keyset -- 控制调用时是否将当前调度器调用的键集作为第一个参数传递给 fn 的标志。这应用于为重新调度调用创建适当的键集。

示例

python 复制代码
>>> my_lib = Library("_", "IMPL")
>>> def fallback_kernel(op, args, *kwargs):
>>>     # Handle all autocast ops generically
>>>     # ...
>>> my_lib.fallback(fallback_kernel, "Autocast")

python 复制代码
impl(op_name, fn, dispatch_key='', *, with_keyset=False)

注册库中定义的操作符的函数实现。

参数

  • op_name -- 操作符名称(包括重载)或 OpOverload 对象。
  • fn -- 作为输入调度键的操作符实现函数,或使用 fallthrough_kernel() 注册回退函数。
  • dispatch_key -- 输入函数应注册到的调度键。默认使用创建库时所用的调度键。
  • with_keyset -- 控制标志,决定调用时是否将当前调度器调用的键集作为第一个参数传递给 fn。这应用于为重新调度调用创建适当的键集。

示例:

python 复制代码
>>> my_lib = Library("aten", "IMPL")
>>> def div_cpu(self, other):
>>>     return self * (1 / other)
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")

python 复制代码
torch.library.fallthrough_kernel()

一个传递给 Library.impl 的虚拟函数,用于注册回退机制。


python 复制代码
torch.library.define(qualname, schema, *, lib=None, tags=())

python 复制代码
torch.library.define(lib, schema, alias_analysis='')

定义一个新运算符。

在PyTorch中,定义一个运算符(operator的简称)需要两个步骤:

  • 首先需要定义运算符(通过提供运算符名称和模式)
  • 然后需要实现该运算符与PyTorch各子系统的交互行为,例如CPU/CUDA张量、自动微分等

此入口点用于定义自定义运算符(即第一步)

接着必须通过调用各种impl_* API来执行第二步,例如torch.library.impl()torch.library.register_fake()

参数说明

  • qualname (str) - 运算符的限定名称。应为类似"namespace::name"的字符串,例如"aten::sin"。

PyTorch中的运算符需要命名空间来避免名称冲突;每个运算符只能创建一次。

如果您正在编写Python库,我们建议使用顶层模块名称作为命名空间。

  • schema (str) - 运算符的模式。例如"(Tensor x) -> Tensor"表示接受一个张量并返回一个张量的运算符。该模式不包含运算符名称(名称在qualname中传递)。
  • lib (Optional[[Library](https://pytorch.org/docs/stable/data.html#torch.library.Library "torch.library.Library")]) - 如果提供,此运算符的生命周期将与Library对象的生命周期绑定。
  • tags (Tag | Sequence[Tag]) - 一个或多个torch.Tag标签,用于标记此运算符。标记运算符会改变其在各种PyTorch子系统中的行为;应用前请仔细阅读torch.Tag的文档。

示例:

python 复制代码
>>> import torch
>>> import numpy as np
>>> >
>>> # Define the operator
>>> torch.library.define("mylib::sin", "(Tensor x) -Tensor")
>>> >
>>> # Add implementations for the operator
>>> @torch.library.impl("mylib::sin", "cpu")
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>> >
>>> # Call the new operator from torch.ops.
>>> x = torch.randn(3)
>>> y = torch.ops.mylib.sin(x)
>>> assert torch.allclose(y, x.sin())

python 复制代码
torch.library.impl(lib, name, dispatch_key='')

python 复制代码
torch.library.impl(qualname: str, types: Union[str, Sequence[str]], func: Literal[None] = None, *, lib: Optional[Library] = None) → Callable[[Callable[..., object]], None]

python 复制代码
torch.library.impl(qualname: str, types: Union[str, Sequence[str]], func: Callable[..., object], *, lib: Optional[Library] = None) → None

python 复制代码
torch.library.impl(lib: Library, name: str, dispatch_key: str = '') → Callable[[Callable[_P, _T]], Callable[_P, _T]]

为该操作符注册一个设备类型的实现。

您可以为types参数传入"default",将此实现注册为所有设备类型的默认实现。但请仅在实现确实支持所有设备类型时使用此选项;例如,如果它是PyTorch内置操作符的组合,则适用此情况。

此API可用作装饰器。您可以将嵌套装饰器与此API一起使用,前提是它们返回一个函数并且放置在此API内部(参见示例2)。

一些有效的设备类型包括:"cpu"、"cuda"、"xla"、"mps"、"ipu"、"xpu"。

参数

  • qualname (str) - 应为类似"namespace::operator_name"格式的字符串。
  • types (str | Sequence[str]) - 要注册实现的设备类型。
  • lib (Optional[Library]) - 如果提供,此注册的生命周期将与Library对象的生命周期绑定。

示例

python 复制代码
>>> import torch
>>> import numpy as np
>>> # Example 1: Register function.
>>> # Define the operator
>>> torch.library.define("mylib::mysin", "(Tensor x) -Tensor")
>>> >
>>> # Add implementations for the cpu device
>>> @torch.library.impl("mylib::mysin", "cpu")
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>> >
>>> x = torch.randn(3)
>>> y = torch.ops.mylib.mysin(x)
>>> assert torch.allclose(y, x.sin())
>>> >
>>> # Example 2: Register function with decorator.
>>> def custom_decorator(func):
>>>     def wrapper(args, *kwargs):
>>>         return func(args, *kwargs) + 1
>>>     return wrapper
>>> >
>>> # Define the operator
>>> torch.library.define("mylib::sin_plus_one", "(Tensor x) -Tensor")
>>> >
>>> # Add implementations for the operator
>>> @torch.library.impl("mylib::sin_plus_one", "cpu")
>>> @custom_decorator
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>> >
>>> # Call the new operator from torch.ops.
>>> x = torch.randn(3)
>>> >
>>> y1 = torch.ops.mylib.sin_plus_one(x)
>>> y2 = torch.sin(x) + 1
>>> assert torch.allclose(y1, y2)


torch.accelerator

该包提供了对Python中当前加速器的支持。

device_count 返回当前可用加速器的数量
is_available 检查当前加速器在运行时是否可用:已构建完成、所有必需驱动可用且至少有一个设备可见
current_accelerator 返回编译时可用的加速器设备
set_device_index 将当前设备索引设置为指定设备
set_device_idx 将当前设备索引设置为指定设备
current_device_index 返回当前加速器所选设备的索引
current_device_idx 返回当前加速器所选设备的索引
set_stream 将当前流设置为指定流
current_stream 返回指定设备的当前选定流
synchronize 等待给定设备上所有流中的所有内核完成执行


torch.cpu

该包实现了torch.cuda中的抽象功能,便于编写设备无关的代码。

current_device 返回当前CPU设备
current_stream 返回指定设备当前选中的Stream
is_available 返回布尔值表示CPU当前是否可用
synchronize 等待CPU设备上所有流中的所有内核完成
stream StreamContext上下文管理器的包装器,用于选择指定流
set_device 设置当前设备(对CPU无实际操作)
device_count 返回CPU设备数量(非核心数)
StreamContext 用于选择指定流的上下文管理器

流与事件

Stream 注意


torch.cuda

该包为CUDA张量类型提供了支持。

它实现了与CPU张量相同的功能,但使用GPU进行计算。

该模块采用延迟初始化机制,因此你可以随时导入它,并通过is_available()检查当前系统是否支持CUDA。

CUDA语义文档提供了关于CUDA操作的更多细节。

StreamContext 用于选择指定流的上下文管理器
can_device_access_peer 检查两个设备间是否支持点对点访问
current_blas_handle 返回当前cuBLAS句柄的cublasHandle_t指针
current_device 返回当前选定设备的索引
current_stream 返回指定设备当前选定的Stream
cudart 获取CUDA运行时API模块
default_stream 返回指定设备的默认Stream
device 用于切换当前设备的上下文管理器
device_count 返回可用GPU的数量
device_memory_used 返回nvidia-smi或amd-smi报告的已用全局(设备)内存(字节)
device_of 将当前设备切换为给定对象所在设备的上下文管理器
get_arch_list 返回本库编译时支持的CUDA架构列表
get_device_capability 获取设备的CUDA计算能力
get_device_name 获取设备名称
get_device_properties 获取设备属性
get_gencode_flags 返回本库编译时使用的NVCC gencode标志
get_stream_from_external 从外部分配的CUDA流返回Stream
get_sync_debug_mode 返回当前CUDA同步操作的调试模式值
init 初始化PyTorch的CUDA状态
ipc_collect 强制回收CUDA IPC释放后的GPU内存
is_available 返回布尔值表示CUDA当前是否可用
is_initialized 返回PyTorch的CUDA状态是否已初始化
is_tf32_supported 返回布尔值表示当前CUDA/ROCm设备是否支持tf32数据类型
memory_usage 返回nvidia-smi报告的过去采样周期内全局(设备)内存读写时间的百分比
set_device 设置当前设备
set_stream 设置当前流。这是设置流的封装API
set_sync_debug_mode 设置CUDA同步操作的调试模式
stream 封装选择指定流的StreamContext上下文管理器
synchronize 等待CUDA设备上所有流中的所有内核完成
utilization 返回nvidia-smi报告的过去采样周期内GPU执行一个或多个内核的时间百分比
temperature 返回GPU传感器的平均温度(摄氏度)
power_draw 返回GPU传感器的平均功耗(毫瓦)
clock_rate 返回nvidia-smi报告的过去采样周期内GPU SM时钟速度(兆赫)
OutOfMemoryError 设备内存不足时抛出的异常

随机数生成器

get_rng_state 以ByteTensor形式返回指定GPU的随机数生成器状态。
get_rng_state_all 返回表示所有设备随机数状态的ByteTensor列表。
set_rng_state 设置指定GPU的随机数生成器状态。
set_rng_state_all 设置所有设备的随机数生成器状态。
manual_seed 为当前GPU设置生成随机数的种子。
manual_seed_all 为所有GPU设置生成随机数的种子。
seed 将当前GPU的随机数生成种子设置为一个随机数。
seed_all 将所有GPU的随机数生成种子设置为随机数。
initial_seed 返回当前GPU的随机种子值。

通信集合操作

comm.broadcast 将张量广播到指定的GPU设备
comm.broadcast_coalesced 将张量序列广播到指定的GPU设备
comm.reduce_add 对多个GPU上的张量进行求和操作
comm.scatter 将张量分散到多个GPU设备
comm.gather 从多个GPU设备收集张量

流与事件

Stream CUDA流的封装类
ExternalStream 外部分配的CUDA流的封装类
Event CUDA事件的封装类

图表功能(测试版)

is_current_stream_capturing 如果当前CUDA流正在进行CUDA图捕获则返回True,否则返回False。
graph_pool_handle 返回表示图内存池ID的不透明令牌。
CUDAGraph CUDA图的封装类。
graph 上下文管理器,用于将CUDA工作捕获到torch.cuda.CUDAGraph对象中供后续重放。
make_graphed_callables 接收可调用对象(函数或nn.Module)并返回其图表化版本。

内存管理

empty_cache 释放缓存分配器当前持有的所有未占用缓存内存,以便这些内存可用于其他GPU应用程序并在nvidia-smi中可见。
get_per_process_memory_fraction 获取进程的内存分配比例。
list_gpu_processes 返回指定设备上运行进程及其GPU内存使用情况的可读打印输出。
mem_get_info 使用cudaMemGetInfo返回指定设备的全局空闲和总GPU内存。
memory_stats 返回指定设备的CUDA内存分配器统计信息字典。
host_memory_stats 返回指定设备的CUDA内存分配器统计信息字典。
memory_summary 返回指定设备当前内存分配器统计信息的可读打印输出。
memory_snapshot 返回所有设备上CUDA内存分配器状态的快照。
memory_allocated 返回指定设备上张量当前占用的GPU内存(字节)。
max_memory_allocated 返回指定设备上张量占用的最大GPU内存(字节)。
reset_max_memory_allocated 重置指定设备上张量占用最大GPU内存的跟踪起始点。
memory_reserved 返回指定设备上缓存分配器当前管理的GPU内存(字节)。
max_memory_reserved 返回指定设备上缓存分配器管理的最大GPU内存(字节)。
set_per_process_memory_fraction 设置进程的内存分配比例。
memory_cached 已弃用;参见memory_reserved()
max_memory_cached 已弃用;参见max_memory_reserved()
reset_max_memory_cached 重置指定设备上缓存分配器管理的最大GPU内存的跟踪起始点。
reset_peak_memory_stats 重置CUDA内存分配器跟踪的"峰值"统计信息。
reset_peak_host_memory_stats 重置主机内存分配器跟踪的"峰值"统计信息。
caching_allocator_alloc 使用CUDA内存分配器执行内存分配。
caching_allocator_delete 删除使用CUDA内存分配器分配的内存。
get_allocator_backend 返回描述当前活动分配器后端的字符串(由PYTORCH_CUDA_ALLOC_CONF设置)。
CUDAPluggableAllocator 从so文件加载的CUDA内存分配器。
change_current_allocator 将当前使用的内存分配器更改为提供的分配器。
MemPool MemPool表示缓存分配器中的内存池。
MemPoolContext MemPoolContext保存当前活动的内存池并暂存前一个内存池。
caching_allocator_enable 启用或禁用CUDA内存分配器。
--- ---

python 复制代码
class torch.cuda.use_mem_pool(pool, device=None)

一个将内存分配路由到指定内存池的上下文管理器。

参数

  • pool (torch.cuda.MemPool) - 将被激活的MemPool对象,使内存分配路由到该池。
  • device (torch.device 或 int, 可选) - 选择的设备。如果deviceNone(默认值),则使用由current_device()给出的当前设备上的MemPool。

NVIDIA 工具扩展 (NVTX)

nvtx.mark 描述某个时间点发生的瞬时事件。
nvtx.range_push 将一个范围压入嵌套范围跨度堆栈。
nvtx.range_pop 从嵌套范围跨度堆栈中弹出一个范围。
nvtx.range 上下文管理器/装饰器,在其作用域开始时推送一个NVTX范围,并在结束时弹出。

Jiterator (测试版)

jiterator._create_jit_fn 为逐元素运算创建由jiterator生成的CUDA内核。
jiterator._create_multi_output_jit_fn 为支持返回一个或多个输出的逐元素运算创建由jiterator生成的CUDA内核。

TunableOp(可调优操作符)

某些运算可以通过多种库或多种技术实现。例如,GEMM(通用矩阵乘法)运算既可以使用CUDA平台的cublas/cublasLt库实现,也可以使用ROCm平台的hipblas/hipblasLt库实现。那么如何确定哪种实现方式速度最快且应被选用呢?这正是TunableOp所提供的功能。部分运算符已通过多种策略实现为可调优操作符,在运行时会对所有策略进行性能分析,并选择最快的策略用于后续所有运算。

具体使用方法请参阅文档


(翻译说明:

1、保留所有代码术语如GEMM/CUDA/ROCm/cublas等

2、"TunableOp"译为"可调优操作符"并保留括号标注原词

3、将被动语态"could be implemented"转为主动语态"可以通过...实现"

4、拆分长句"Certain operators...operations"为两个中文短句

5、严格保留文档链接格式)


流消毒器(原型)

CUDA消毒器是一个用于检测PyTorch中流间同步错误的原型工具。有关如何使用它的信息,请参阅文档


GPUDirect Storage (原型)

torch.cuda.gds中的API提供了对部分cuFile API的轻量级封装,支持在GPU内存与存储设备之间直接进行内存访问传输,避免了通过CPU的跳转缓冲区。更多详情请参阅cufile API文档

这些API可在CUDA 12.6及以上版本中使用。使用时需确保系统已按照GPUDirect Storage文档正确配置。

具体使用示例可参考GdsFile文档

gds_register_buffer 将CUDA设备上的存储注册为cufile缓冲区
gds_deregister_buffer 注销先前在CUDA设备上注册的cufile缓冲区存储
GdsFile cuFile的封装类


理解 CUDA 内存使用

为了调试 CUDA 内存使用情况,PyTorch 提供了一种生成内存快照的方法,这些快照会记录特定时间点已分配的 CUDA 内存状态,并可选地记录导致该快照的内存分配事件历史。

生成的快照可以拖放到 pytorch.org/memory_viz 托管的交互式查看器中,用于探索快照内容。


生成快照

记录快照的常见模式是:先启用内存历史记录,运行需要观察的代码,然后将快照保存为pickle格式的文件:

python 复制代码
# enable memory history, which will
# add tracebacks and event history to snapshots
torch.cuda.memory._record_memory_history()

run_your_code()
torch.cuda.memory._dump_snapshot("my_snapshot.pickle")

使用可视化工具

打开 pytorch.org/memory_viz 并将序列化的快照文件拖放至可视化工具中。该可视化工具是一个在您本地计算机上运行的 JavaScript 应用程序,不会上传任何快照数据。


活动内存时间线

活动内存时间线展示了特定 GPU 上快照中随时间变化的所有活跃张量。通过平移/缩放图表可以查看较小的内存分配。将鼠标悬停在已分配的内存块上,可以查看该内存块分配时的堆栈跟踪信息,以及地址等详细信息。当数据量较大时,可以调整细节滑块以减少渲染的分配数量,从而提高性能。



分配器状态历史

分配器状态历史在左侧时间轴中显示各个分配器事件。选择时间轴中的事件可查看该事件发生时分配器状态的可视化摘要。该摘要显示从cudaMalloc返回的每个独立内存段,以及这些段如何被分割成单个分配块或空闲空间。将鼠标悬停在内存段和块上可查看内存分配时的堆栈跟踪。悬停在事件上可查看事件发生时的堆栈跟踪,例如张量释放时的信息。内存不足错误会显示为OOM事件。通过观察OOM时的内存状态,可能有助于理解为何在仍有保留内存的情况下分配会失败。



堆栈跟踪信息还会报告内存分配发生的地址。

地址b7f064c000000_0表示位于7f064c000000的(b)lock块,这是该地址第"_0"次被分配。

这个唯一字符串可以在活动内存时间轴中查找,并在活动状态历史中搜索,用于检查张量分配或释放时的内存状态。


快照 API 参考文档


python 复制代码
torch.cuda.memory._record_memory_history(enabled='all', context='all', stacks='all', max_entries=9223372036854775807, device=None)

启用记录与内存分配相关的堆栈跟踪功能,这样您就能通过 torch.cuda.memory._snapshot() 查看每块内存的分配来源。

除了记录当前每次内存分配和释放的堆栈跟踪外,该功能还会记录完整的分配/释放事件历史。

使用 torch.cuda.memory._snapshot() 获取这些信息,并通过 _memory_viz.py 中的工具可视化快照数据。

Python 的跟踪收集速度很快(每次跟踪仅需 2 微秒),因此如果您预计可能需要调试内存问题,可以在生产任务中启用此功能。

C++ 跟踪收集同样高效(约 50 纳秒/帧),对于大多数典型程序而言,每次跟踪耗时约 2 微秒,实际时间会随堆栈深度变化。

参数说明

  • enabled (Literal[None, "state", "all"], 可选) - None:禁用内存历史记录;"state":仅记录当前已分配内存的信息;"all":额外记录所有分配/释放调用历史。默认为 "all"。
  • context (Literal[None, "state", "alloc", "all"], 可选) - None:不记录任何堆栈跟踪;"state":记录当前已分配内存的堆栈跟踪;"alloc":额外记录分配调用的堆栈跟踪;"all":额外记录释放调用的堆栈跟踪。默认为 "all"。
  • stacks (Literal["python", "all"], 可选) - "python":在堆栈跟踪中包含 Python、TorchScript 和 inductor 帧;"all":额外包含 C++ 帧。默认为 "all"。
  • max_entries (int, 可选) - 限制记录历史中保存的最大分配/释放事件数量。

python 复制代码
torch.cuda.memory._snapshot(device=None)

保存调用时的 CUDA 内存状态快照。

该状态以字典形式表示,结构如下。


python 复制代码
class Snapshot(TypedDict):
    segments : List[Segment]
    device_traces: List[List[TraceEntry]]

class Segment(TypedDict):
    # Segments are memory returned from a cudaMalloc call.
    # The size of reserved memory is the sum of all Segments.
    # Segments are cached and reused for future allocations.
    # If the reuse is smaller than the segment, the segment
    # is split into more then one Block.
    # empty_cache() frees Segments that are entirely inactive.
    address: int
    total_size: int #  cudaMalloc'd size of segment
    stream: int
    segment_type: Literal['small', 'large'] # 'large' (>1MB)
    allocated_size: int # size of memory in use     active_size: int # size of memory in use or in active_awaiting_free state
    blocks : List[Block]

class Block(TypedDict):
    # A piece of memory returned from the allocator, or      # current cached but inactive.
    size: int
    requested_size: int # size requested during malloc, may be smaller than
                        # size due to rounding
    address: int
    state: Literal['active_allocated', # used by a tensor
                'active_awaiting_free', # waiting for another stream to finish using
                                        # this, then it will become free
                'inactive',] # free for reuse
    frames: List[Frame] # stack trace from where the allocation occurred

class Frame(TypedDict):
        filename: str
        line: int
        name: str

class TraceEntry(TypedDict):
    # When `torch.cuda.memory._record_memory_history()` is enabled, # the snapshot will contain TraceEntry objects that record each
    # action the allocator took.
    action: Literal[
    'alloc'  # memory allocated
    'free_requested', # the allocated received a call to free memory
    'free_completed', # the memory that was requested to be freed is now
                    # able to be used in future allocation calls
    'segment_alloc', # the caching allocator ask cudaMalloc for more memory
                    # and added it as a segment in its cache     'segment_free', # the caching allocator called cudaFree to return memory
                    # to cuda possibly trying free up memory to                     # allocate more segments or because empty_caches was called
    'oom',   # the allocator threw an OOM exception. 'size' is                     # the requested number of bytes that did not succeed
    'snapshot'      # the allocator generated a memory snapshot
                    # useful to coorelate a previously taken
                    # snapshot with this trace
    ]
    addr: int # not present for OOM
    frames: List[Frame]
    size: int
    stream: int
    device_free: int # only present for OOM, the amount of                     # memory cuda still reports to be free

返回:Snapshot 字典对象


python 复制代码
torch.cuda.memory._dump_snapshot(filename='dump_snapshot.pickle')

将 torch.memory._snapshot() 字典的 pickle 版本保存到文件中。

该文件可通过 pytorch.org/memory_viz 的交互式快照查看器打开。

参数

  • filename (str, optional) - 要创建的文件名。默认为 "dump_snapshot.pickle"。


torch.mps

该包提供了在Python中访问MPS(Metal Performance Shaders)后端的接口。Metal是苹果公司用于编程金属GPU(图形处理器)的API。使用MPS意味着可以通过在金属GPU上运行工作来实现更高的性能。详情请参阅https://developer.apple.com/documentation/metalperformanceshaders

device_count 返回可用的MPS设备数量
synchronize 等待MPS设备上所有流中的所有内核完成
get_rng_state 以ByteTensor形式返回随机数生成器状态
set_rng_state 设置随机数生成器状态
manual_seed 设置生成随机数的种子
seed 将生成随机数的种子设置为随机数
empty_cache 释放缓存分配器当前持有的所有未占用缓存内存,以便其他GPU应用程序使用
set_per_process_memory_fraction 设置MPS设备上限制进程内存分配的内存比例
current_allocated_memory 返回当前张量占用的GPU内存(字节)
driver_allocated_memory 返回Metal驱动为该进程分配的总GPU内存(字节)
recommended_max_memory 返回GPU内存工作集大小的推荐最大值(字节)
compile_shader 从源代码编译计算着色器,并允许从Python运行时轻松调用其中定义的内核示例

MPS 性能分析器

profiler.start 启动 MPS 后端的 OS Signpost 追踪功能
profiler.stop 停止 MPS 后端的 OS Signpost 追踪功能
profiler.profile 上下文管理器,用于启用 MPS 后端的 OS Signpost 追踪
profiler.is_capturing_metal 检查 Metal 捕获是否正在进行
profiler.is_metal_capture_enabled 检查 metal_capture 上下文管理器是否可用。要启用 Metal 捕获,需设置 MTL_CAPTURE_ENABLED 环境变量
profiler.metal_capture 上下文管理器,用于将 Metal 调用捕获到 gputrace 中

MPS 事件

event.Event MPS 事件的封装类


torch.xpu

该包提供了对XPU后端的支持,专门针对英特尔GPU进行了优化。

该包采用延迟初始化机制,因此您可以随时导入它,并通过is_available()方法检测当前系统是否支持XPU。

StreamContext 用于选择指定流的上下文管理器
current_device 返回当前选定设备的索引
current_stream 返回指定设备当前选定的Stream
device 用于变更选定设备的上下文管理器
device_count 返回可用的XPU设备数量
device_of 将当前设备切换为给定对象所在设备的上下文管理器
get_arch_list 返回本库编译时支持的XPU架构列表
get_device_capability 获取设备的XPU计算能力
get_device_name 获取设备名称
get_device_properties 获取设备属性
get_gencode_flags 返回本库编译时使用的XPU预编译构建标志
get_stream_from_external 从外部SYCL队列返回Stream
init 初始化PyTorch的XPU状态
is_available 返回布尔值表示XPU当前是否可用
is_initialized 返回PyTorch的XPU状态是否已初始化
set_device 设置当前设备
set_stream 设置当前流(该API是设置流的封装接口)
stream 封装了用于选择指定流的StreamContext上下文管理器
synchronize 等待XPU设备上所有流中的所有内核执行完成

随机数生成器

get_rng_state 以ByteTensor形式返回指定GPU的随机数生成器状态。
get_rng_state_all 返回一个ByteTensor列表,表示所有设备的随机数状态。
initial_seed 返回当前GPU的随机种子值。
manual_seed 为当前GPU设置随机数生成种子。
manual_seed_all 为所有GPU设置随机数生成种子。
seed 将当前GPU的随机数生成种子设置为一个随机数。
seed_all 将所有GPU的随机数生成种子设置为随机数。
set_rng_state 设置指定GPU的随机数生成器状态。
set_rng_state_all 设置所有设备的随机数生成器状态。

流与事件

Event XPU事件的封装类
Stream XPU流的封装类

内存管理

empty_cache 释放缓存分配器当前持有的所有未占用缓存内存,以便其他XPU应用程序可以使用这些内存。
max_memory_allocated 返回给定设备上张量占用的最大GPU内存(以字节为单位)。
max_memory_reserved 返回给定设备上缓存分配器管理的最大GPU内存(以字节为单位)。
mem_get_info 返回给定设备上全局可用的和总的GPU内存。
memory_allocated 返回给定设备上张量当前占用的GPU内存(以字节为单位)。
memory_reserved 返回给定设备上缓存分配器当前管理的GPU内存(以字节为单位)。
memory_stats 返回给定设备上XPU内存分配器统计信息的字典。
memory_stats_as_nested_dict 以嵌套字典的形式返回memory_stats()的结果。
reset_accumulated_memory_stats 重置XPU内存分配器跟踪的"累计"(历史)统计信息。
reset_peak_memory_stats 重置XPU内存分配器跟踪的"峰值"统计信息。


torch.mtia

MTIA后端实现位于外部代码库,此处仅定义接口。

该包提供了在Python中访问MTIA后端的接口。

StreamContext 用于选择指定流的上下文管理器
current_device 返回当前选定设备的索引
current_stream 返回指定设备当前选定的Stream
default_stream 返回指定设备的默认Stream
device_count 返回可用的MTIA设备数量
init
is_available 如果MTIA设备可用则返回true
is_initialized 返回PyTorch的MTIA状态是否已初始化
memory_stats 返回指定设备的MTIA内存分配器统计字典
get_device_capability 以(主版本号,次版本号)元组形式返回指定设备的计算能力
empty_cache 清空MTIA设备缓存
record_memory_history 启用/禁用MTIA分配器的内存分析器
snapshot 返回MTIA内存分配器历史记录字典
set_device 设置当前设备
set_stream 设置当前流。这是设置流的封装API
stream 封装了选择指定流的StreamContext上下文管理器
synchronize 等待MTIA设备上所有流中的所有任务完成
device 用于更改选定设备的上下文管理器
set_rng_state 设置随机数生成器状态
get_rng_state 以ByteTensor形式返回随机数生成器状态
DeferredMtiaCallError

流与事件

Event 查询和记录流状态,用于识别或控制跨流依赖关系以及测量时间。
Stream 一个按先进先出(FIFO)顺序异步执行相应任务的顺序队列。


torch.mtia.memory

MTIA后端实现位于外部代码库中,此处仅定义接口。

该包提供了MTIA实现的设备内存管理支持。

memory_stats 返回指定设备的MTIA内存分配器统计信息字典。


元设备

"元"设备是一种抽象设备,它表示仅记录元数据而不存储实际数据的张量。元张量主要有两个使用场景:

  • 模型可以加载到元设备上,这样您可以在不将实际参数加载到内存的情况下获取模型的表示形式。如果您需要在加载真实数据之前对模型进行转换,这会非常有用。
  • 大多数操作都可以在元张量上执行,生成新的元张量来描述如果在真实张量上执行该操作会得到什么结果。您可以用这种方式进行抽象分析,而无需花费计算时间或存储空间来表示实际张量。由于元张量没有真实数据,因此无法执行数据依赖的操作,如torch.nonzero()item()。在某些情况下,并非所有设备类型(例如CPU和CUDA)对同一操作都能产生完全相同的输出元数据;我们通常倾向于在这种情况下准确表示CUDA的行为。

警告:虽然原则上元张量计算应该总是比等效的CPU/CUDA计算更快,但许多元张量实现是用Python编写的,尚未移植到C++以提升速度,因此您可能会发现使用小型CPU张量时框架的绝对延迟更低。


元张量操作惯用法

可以通过指定 map_location='meta' 将对象加载到元设备上,使用 torch.load() 方法实现。


python 复制代码
>>> torch.save(torch.randn(2), 'foo.pt')
>>> torch.load('foo.pt', map_location='meta')
tensor(..., device='meta', size=(2,))

如果你有一段任意代码,它在没有明确指定设备的情况下执行张量构建操作,你可以通过使用 torch.device() 上下文管理器来覆盖该行为,改为在元设备(meta device)上进行构建:

python 复制代码
>>> with torch.device('meta'):
...     print(torch.randn(30, 30))
...
tensor(..., device='meta', size=(30, 30))

这在神经网络模块构建中特别有用,因为通常无法显式传入设备进行初始化。


python 复制代码
>>> from torch.nn.modules import Linear
>>> with torch.device('meta'):
...     print(Linear(20, 30))
...
Linear(in_features=20, out_features=30, bias=True)

无法直接将元张量转换为CPU/CUDA张量,因为元张量不存储数据,我们无法确定新张量的正确数据值。


python 复制代码
>>> torch.ones(5, device='meta').to("cpu")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: Cannot copy out of meta tensor; no data!

使用类似 torch.empty_like() 的工厂函数来明确指定缺失数据的填充方式。

神经网络模块提供了一个便捷方法 torch.nn.Module.to_empty(),允许将模块转移到其他设备并保持所有参数未初始化状态。开发者需要手动显式地重新初始化这些参数。


python 复制代码
>>> from torch.nn.modules import Linear
>>> with torch.device('meta'):
...     m = Linear(20, 30)
>>> m.to_empty(device="cpu")
Linear(in_features=20, out_features=30, bias=True)

torch._subclasses.meta_utils 包含一系列未公开的工具函数,能够以高保真度将任意 Tensor 转换为等价的元数据 Tensor。这些 API 目前处于实验阶段,可能会随时做出不兼容的破坏性变更。



torch.backends

torch.backends 用于控制 PyTorch 所支持的各种后端的行为。

这些后端包括:

  • torch.backends.cpu
  • torch.backends.cuda
  • torch.backends.cudnn
  • torch.backends.cusparselt
  • torch.backends.mha
  • torch.backends.mps
  • torch.backends.mkl
  • torch.backends.mkldnn
  • torch.backends.nnpack
  • torch.backends.openmp
  • torch.backends.opt_einsum
  • torch.backends.xeon

torch.backends.cpu


python 复制代码
torch.backends.cpu.get_cpu_capability()

返回CPU能力作为字符串值。

可能的值包括:

  • "DEFAULT"
  • "VSX"
  • "Z VECTOR"
  • "NO AVX"
  • "AVX2"
  • "AVX512"
  • "SVE256"

返回类型:str


torch.backends.cuda


python 复制代码
torch.backends.cuda.is_built()

返回PyTorch是否构建了CUDA支持。

请注意,这并不一定意味着CUDA可用;仅表示如果这个PyTorch二进制文件运行在具有正常工作的CUDA驱动程序和设备的机器上,我们将能够使用它。


python 复制代码
torch.backends.cuda.matmul.allow_tf32 

一个控制是否允许在Ampere或更新款GPU上使用TensorFloat-32张量核心进行矩阵乘法的bool值。详情参阅Ampere(及后续)设备上的TensorFloat-32 (TF32)


python 复制代码
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction 

一个控制是否允许在fp16 GEMM运算中使用降低精度归约(例如采用fp16累加类型)的bool值。


python 复制代码
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction 

A boolthat controls whether reduced precision reductions are allowed with bf16 GEMMs.


python 复制代码
torch.backends.cuda.cufft_plan_cache 

cufft_plan_cache contains the cuFFT plan caches for each CUDA device.

Query a specific device i's cache via torch.backends.cuda.cufft_plan_cache[i].


python 复制代码
torch.backends.cuda.cufft_plan_cache.size 

A readonly int that shows the number of plans currently in a cuFFT plan cache.


python 复制代码
torch.backends.cuda.cufft_plan_cache.max_size

一个控制 cuFFT 计划缓存容量的 int 类型参数。


python 复制代码
torch.backends.cuda.cufft_plan_cache.clear() 

清除 cuFFT 计划缓存。


python 复制代码
torch.backends.cuda.preferred_blas_library(backend=None)

覆盖 PyTorch 用于 BLAS 运算的库。可选 cuBLAS、cuBLASLt 和 CK [仅限 ROCm]。

警告:此标志为实验性功能,后续可能变更。

当 PyTorch 执行 CUDA BLAS 运算时,即使 cuBLAS 和 cuBLASLt 都可用,默认仍会使用 cuBLAS。

针对 ROCm 构建的 PyTorch 中,hipBLAS、hipBLASLt 和 CK 可能提供不同的性能表现。

此标志(str 类型)允许覆盖要使用的 BLAS 库:

  • 设为 "cublas" 时,将尽可能使用 cuBLAS
  • 设为 "cublaslt" 时,将尽可能使用 cuBLASLt
  • 设为 "ck" 时,将尽可能使用 CK
  • 设为 "default"(默认值)时,将通过启发式方法在其他选项间选择
  • 无输入时,此函数返回当前首选库

用户可通过环境变量 TORCH_BLAS_PREFER_CUBLASLT=1 全局设置首选库为 cuBLASLt。

注意:

1、此标志仅设置首选库的初始值,后续仍可能被脚本中的函数调用覆盖

2、当某个库被设为首选时,若该库未实现所调用的运算,仍可能使用其他库

3、若 PyTorch 的库选择不适合您的应用输入,此标志可能获得更好的性能

返回类型:_BlasBackend


python 复制代码
torch.backends.cuda.preferred_rocm_fa_library(backend=None)

仅限ROCm环境

覆盖PyTorch在ROCm环境下用于Flash Attention的后端实现。可选AOTriton或CK作为后端。


警告:此标志为实验性功能,后续可能变更。

当启用Flash Attention时,PyTorch默认使用AOTriton作为后端。

该标志(类型为str)允许用户将后端覆盖为composable_kernel:

  • 设为"default"时,将尽可能使用默认后端(当前为AOTriton)
  • 设为"aotriton"时,将尽可能使用AOTriton
  • 设为"ck"时,将尽可能使用CK
  • 无输入参数时,函数返回当前首选库
  • 用户可通过环境变量TORCH_ROCM_FA_PREFER_CK=1全局设置首选库为CK

注意:当指定首选库时,若该库未实现相关操作,仍可能使用其他库。

若PyTorch的库选择机制对您的应用输入不适用,此标志可能获得更好的性能。

返回类型:_ROCmFABackend


python 复制代码
torch.backends.cuda.preferred_linalg_library(backend=None)

覆盖 PyTorch 在 CUDA 线性代数运算中选择 cuSOLVER 或 MAGMA 的启发式策略。

警告:此标志为实验性功能,后续可能变更。

当 PyTorch 执行 CUDA 线性代数运算时,通常会使用 cuSOLVER 或 MAGMA 库。若两者均可用,系统会通过启发式规则自动选择。

该标志(类型为 str)允许覆盖默认的启发式选择逻辑:

  • 设为 "cusolver" 时,将尽可能使用 cuSOLVER
  • 设为 "magma" 时,将尽可能使用 MAGMA
  • 设为 "default"(默认值)时,若两个库均可用则通过启发式规则选择
  • 无输入参数时,函数返回当前优先使用的库
  • 用户可通过环境变量 TORCH_LINALG_PREFER_CUSOLVER=1 全局设置优先使用 cuSOLVER

注意:

1、此标志仅设置初始优先库,后续仍可通过脚本中的函数调用覆盖

2、即使设置了优先库,若该库未实现特定运算,仍可能使用其他库

3、当 PyTorch 的自动选择不适合您的应用场景时,手动指定可能获得更好性能

当前支持的线性代数运算符:

返回类型:_LinalgBackend


python 复制代码
class torch.backends.cuda.SDPAParams 

python 复制代码
torch.backends.cuda.flash_sdp_enabled()

警告:此标志处于测试阶段,可能会发生变化。

返回是否启用了Flash缩放点积注意力机制。


python 复制代码
torch.backends.cuda.enable_mem_efficient_sdp(enabled)

警告:此标志处于测试阶段,可能会发生变化。

启用或禁用内存高效的缩放点积注意力机制。


python 复制代码
torch.backends.cuda.mem_efficient_sdp_enabled()

警告:此标志处于测试阶段,可能会发生变化。

返回是否启用了内存高效的缩放点积注意力机制。


python 复制代码
torch.backends.cuda.enable_flash_sdp(enabled)

警告:此标志为测试版,可能会发生变化。

启用或禁用缩放点积注意力机制。


python 复制代码
torch.backends.cuda.math_sdp_enabled()

警告:此标志为测试版,可能会发生变化。

返回是否启用了数学缩放点积注意力机制。


python 复制代码
torch.backends.cuda.enable_math_sdp(enabled)

警告:此标志处于测试阶段,可能会发生变化。

启用或禁用数学缩放点积注意力机制。


python 复制代码
torch.backends.cuda.fp16_bf16_reduction_math_sdp_allowed()

警告:此标志为测试版,可能会发生变化。

返回是否启用了数学缩放点积注意力中的 fp16/bf16 缩减功能。


python 复制代码
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(enabled)

警告:此标志处于测试阶段,可能会发生变化。

启用或禁用数学缩放点积注意力中的fp16/bf16缩减计算。


python 复制代码
torch.backends.cuda.cudnn_sdp_enabled()

警告:此标志为测试版,可能会发生变化。

返回是否启用了 cuDNN 缩放点积注意力功能。


python 复制代码
torch.backends.cuda.enable_cudnn_sdp(enabled)

警告:此标志为测试版功能,可能会发生变化。

启用或禁用 cuDNN 的缩放点积注意力机制。


python 复制代码
torch.backends.cuda.is_flash_attention_available()

检查 PyTorch 是否构建了 FlashAttention 以支持 scaled_dot_product_attention

返回值:如果 FlashAttention 已构建且可用,则返回 True;否则返回 False

返回类型:bool

注意:此功能依赖于支持 CUDA 的 PyTorch 版本。在非 CUDA 环境下将始终返回 False


python 复制代码
torch.backends.cuda.can_use_flash_attention(params, debug=False)

检查是否可以在 scaled_dot_product_attention 中使用 FlashAttention。


参数

  • params (_SDPAParams) -- 包含查询(query)、键(key)、值(value)张量的 SDPAParams 实例,可选注意力掩码(attention mask)、dropout率,以及指示注意力是否为因果(causal)的标志。
  • debug (bool) -- 是否通过logging.warn输出无法运行FlashAttention的调试信息。默认为False。

返回值:如果给定的参数可以使用FlashAttention则返回True;否则返回False。

返回类型 : bool


注意:此函数依赖于启用CUDA的PyTorch版本。在非CUDA环境下将返回False。


python 复制代码
torch.backends.cuda.can_use_efficient_attention(params, debug=False)

检查是否可以在scaled_dot_product_attention中使用efficient_attention

参数

  • params (_SDPAParams) - 包含query、key、value张量、可选attention掩码、dropout率和表示注意力是否因果的标志的SDPAParams实例。
  • debug (bool) - 是否通过logging.warn记录无法运行efficient_attention的原因。默认为False。

返回

如果给定参数可以使用efficient_attention则返回True;否则返回False。

返回类型:bool

注意:此功能依赖于支持CUDA的PyTorch版本。在非CUDA环境下将返回False。


python 复制代码
torch.backends.cuda.can_use_cudnn_attention(params, debug=False)

检查是否可以在 scaled_dot_product_attention 中使用 cudnn_attention

参数

  • params (_SDPAParams) - 一个包含查询(query)、键(key)、值(value)张量的 SDPAParams 实例,可选注意力掩码(attention mask)、dropout率,以及一个指示注意力是否为因果(causal)的标志。
  • debug (bool) - 是否通过 logging.warn 记录无法运行 cuDNN 注意力的原因信息。默认为 False。

返回

如果可以使用 cuDNN 处理给定参数则返回 True;否则返回 False。

返回类型:bool

注意:此函数依赖于支持 CUDA 的 PyTorch 版本。在非 CUDA 环境下将返回 False。


python 复制代码
torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True, enable_cudnn=True)

警告:此标志为测试版,可能会发生变化。

该上下文管理器可用于临时启用或禁用三种缩放点积注意力后端中的任意一种。

退出上下文管理器时,将恢复标志的先前状态。


torch.backends.cudnn


python 复制代码
torch.backends.cudnn.version()

返回 cuDNN 的版本号。


python 复制代码
torch.backends.cudnn.is_available()

返回一个布尔值,表示当前是否可用 CUDNN。


python 复制代码
torch.backends.cudnn.enabled 

一个控制是否启用 cuDNN 的 bool 值。


python 复制代码
torch.backends.cudnn.allow_tf32 

一个控制是否在Ampere或更新款GPU的cuDNN卷积中使用TensorFloat-32张量核心的bool值。详见Ampere(及后续)设备上的TensorFloat-32 (TF32)

(说明:严格遵循核心翻译原则,保留代码块bool和链接格式,技术术语"TensorFloat-32/TF32/Ampere/cuDNN"不翻译,被动语态转为主动语态"是否使用",并合并了原文因换行中断的连贯内容)


python 复制代码
torch.backends.cudnn.deterministic 

一个bool值,若设为True,将强制cuDNN仅使用确定性卷积算法。

另请参阅 torch.are_deterministic_algorithms_enabled()torch.use_deterministic_algorithms()


python 复制代码
torch.backends.cudnn.benchmark 

一个bool值,如果为True,将让cuDNN对多种卷积算法进行基准测试并选择最快的算法。


python 复制代码
torch.backends.cudnn.benchmark_limit 

一个 int 类型参数,用于指定当 torch.backends.cudnn.benchmark 为 True 时,尝试的 cuDNN 卷积算法的最大数量。将 benchmark_limit 设为零会尝试所有可用算法。请注意此设置仅影响通过 cuDNN v8 API 调度的卷积操作。


torch.backends.cusparselt


python 复制代码
torch.backends.cusparselt.version()

返回 cuSPARSELt 的版本号

返回类型:Optional[int]


python 复制代码
torch.backends.cusparselt.is_available()

返回一个布尔值,表示当前是否可用 cuSPARSELt。

返回类型:bool


torch.backends.mha


python 复制代码
torch.backends.mha.get_fastpath_enabled()

返回是否启用了TransformerEncoder和MultiHeadAttention的快速路径,如果处于jit脚本模式则返回True

注意:即使get_fastpath_enabled返回True,也可能不会执行快速路径,除非满足所有输入条件。

返回类型:bool


python 复制代码
torch.backends.mha.set_fastpath_enabled(value)

设置是否启用快速路径


torch.backends.mps


python 复制代码
torch.backends.mps.is_available()

返回一个布尔值,表示当前是否可用MPS。

返回类型:bool


python 复制代码
torch.backends.mps.is_built()

返回当前 PyTorch 是否支持 MPS 构建。

需要注意的是,这并不代表 MPS 一定可用;仅表示如果该 PyTorch 二进制文件运行在具有正常 MPS 驱动和设备的机器上时,我们将能够使用该功能。

返回类型:bool


torch.backends.mkl


python 复制代码
torch.backends.mkl.is_available()

返回 PyTorch 是否启用了 MKL 支持。


python 复制代码
class torch.backends.mkl.verbose(enable)

按需启用的oneMKL详细输出功能。

为便于调试性能问题,oneMKL可以输出包含内核执行信息(如执行时长)的详细消息。该功能可通过名为MKL_VERBOSE的环境变量触发。但此方法会输出所有步骤的消息,导致产生大量冗余信息。

实际上在调查性能问题时,通常只需获取单次迭代的详细消息即可。这种按需启用的功能可以精确控制详细消息的输出范围。在以下示例中,详细消息将仅针对第二次推理过程进行输出。


python 复制代码
import torch
model(data) with torch.backends.mkl.verbose(torch.backends.mkl.VERBOSE_ON):
    model(data)

参数

  • level -- 详细级别
  • VERBOSE_OFF: 禁用详细输出
  • VERBOSE_ON: 启用详细输出

torch.backends.mkldnn


python 复制代码
torch.backends.mkldnn.is_available()

返回 PyTorch 是否构建了 MKL-DNN 支持。


python 复制代码
class torch.backends.mkldnn.verbose(level)

按需启用 oneDNN(原 MKL-DNN)的详细日志功能

为便于调试性能问题,oneDNN 可在执行内核时输出包含内核大小、输入数据大小和执行时长等信息的详细日志。该功能可通过环境变量 DNNL_VERBOSE 触发,但此方法会在所有步骤中输出日志,产生大量冗余信息。实际上在调查性能问题时,通常只需获取单次迭代的日志即可。

这项按需日志功能实现了对日志输出范围的控制。在以下示例中,系统将仅针对第二次推理过程输出详细日志。


python 复制代码
import torch
model(data) with torch.backends.mkldnn.verbose(torch.backends.mkldnn.VERBOSE_ON):
    model(data)

参数

  • level -- 详细级别
  • VERBOSE_OFF: 禁用详细输出
  • VERBOSE_ON: 启用详细输出
  • VERBOSE_ON_CREATION: 启用详细输出,包括 oneDNN 内核创建

torch.backends.nnpack


python 复制代码
torch.backends.nnpack.is_available()

返回 PyTorch 是否启用了 NNPACK 支持。


python 复制代码
torch.backends.nnpack.flags(enabled=False)

用于全局设置是否启用nnpack的上下文管理器


python 复制代码
torch.backends.nnpack.set_flags(_enabled)

全局设置是否启用nnpack


torch.backends.openmp


python 复制代码
torch.backends.openmp.is_available()

返回 PyTorch 是否启用了 OpenMP 支持。


torch.backends.opt_einsum


python 复制代码
torch.backends.opt_einsum.is_available()

返回一个布尔值,表示当前是否可用opt_einsum。

必须安装opt-einsum才能让torch自动优化einsum运算。要使opt-einsum可用,你可以通过以下方式安装:

  • 与torch一起安装:pip install torch[opt-einsum]
  • 单独安装:pip install opt-einsum

如果该包已安装,torch会自动导入并相应使用它。此函数用于检查opt-einsum是否已安装且被torch正确导入。

返回类型:bool


python 复制代码
torch.backends.opt_einsum.get_opt_einsum()

如果当前可用 opt_einsum 包则返回该包,否则返回 None。

返回类型:Any


python 复制代码
torch.backends.opt_einsum.enabled 

一个控制是否启用 opt_einsum 的 bool 值(默认为 True)。如果启用,torch.einsum 将使用 opt_einsum(https://optimized-einsum.readthedocs.io/en/stable/path_finding.html)来计算最优的收缩路径以获得更快的性能(前提是 opt_einsum 可用)。

如果 opt_einsum 不可用,torch.einsum 将回退到默认的从左到右收缩路径。


python 复制代码
torch.backends.opt_einsum.strategy 

一个str字符串,用于指定当torch.backends.opt_einsum.enabledTrue时要尝试的优化策略。默认情况下,torch.einsum会尝试"auto"策略,但也支持"greedy"和"optimal"策略。需要注意的是,"optimal"策略会随着输入数量的增加呈阶乘级复杂度,因为它会尝试所有可能的计算路径。更多细节请参阅opt_einsum的文档(https://optimized-einsum.readthedocs.io/en/stable/path_finding.html)。


torch.backends.xeon



torch.export


警告:此功能为正在积极开发中的原型,未来将会有破坏性变更。


概述

torch.export.export() 接收一个 torch.nn.Module 并生成一个跟踪图,该图以提前编译(AOT)的方式仅表示函数的张量计算过程。生成的跟踪图随后可以用不同的输出执行或进行序列化。


python 复制代码
import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x: torch.Tensor, y: torch.Tensor) -torch.Tensor:
        a = torch.sin(x)
        b = torch.cos(y)
        return a + b

example_args = (torch.randn(10, 10), torch.randn(10, 10))

exported_program: torch.export.ExportedProgram = export(
    Mod(), args=example_args
)
print(exported_program)

python 复制代码
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
            # code: a = torch.sin(x)
            sin: "f32[10, 10]" = torch.ops.aten.sin.default(x)

            # code: b = torch.cos(y)
            cos: "f32[10, 10]" = torch.ops.aten.cos.default(y)

            # code: return a + b
            add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos)
            return (add,)

    Graph signature:
        ExportGraphSignature(
            input_specs=[
                InputSpec(
                    kind=<InputKind.USER_INPUT: 1>,   arg=TensorArgument(name='x'),   target=None,   persistent=None
                ), InputSpec(
                    kind=<InputKind.USER_INPUT: 1>,   arg=TensorArgument(name='y'),   target=None,   persistent=None
                )
            ], output_specs=[
                OutputSpec(
                    kind=<OutputKind.USER_OUTPUT: 1>,   arg=TensorArgument(name='add'),   target=None
                )
            ]
        )
    Range constraints: {}

torch.export 生成一个符合以下不变量的简洁中间表示(IR)。关于该 IR 的更多规范可查阅此处

  • 正确性:保证是原始程序的准确表示,并保持原始程序的调用约定。
  • 规范化:图中不包含 Python 语义。原始程序中的子模块会被内联,形成一个完全扁平化的计算图。
  • 图属性:该图是纯函数式的,意味着不包含具有副作用(如突变或别名)的操作。它不会改变任何中间值、参数或缓冲区。
  • 元数据:图中包含在跟踪期间捕获的元数据,例如来自用户代码的堆栈跟踪。

在底层,torch.export 利用了以下最新技术:

  • TorchDynamo (torch._dynamo) 是一个内部 API,它使用 CPython 的 Frame Evaluation API 功能来安全地跟踪 PyTorch 图。这极大地改善了图捕获体验,减少了完全跟踪 PyTorch 代码所需的改写工作。
  • AOT Autograd 提供了一个功能化的 PyTorch 图,并确保图被分解/降级为 ATen 操作符集。
  • Torch FX (torch.fx) 是图的底层表示,支持基于 Python 的灵活转换。

现有框架

torch.compile() 同样使用了与 torch.export 相同的 PT2 技术栈,但存在以下差异:

  • JIT 与 AOTtorch.compile() 是一个即时(JIT)编译器,其设计目的并非用于生成可部署的编译产物。
  • 部分图与完整图捕获 :当 torch.compile() 遇到模型不可追踪部分时,会触发"图中断"并回退到 Python 即时运行模式。相比之下,torch.export 旨在获取 PyTorch 模型的完整计算图表示,因此遇到不可追踪内容时会直接报错。由于 torch.export 生成的完整图与 Python 特性及运行时完全解耦,该计算图可被保存、加载并在不同环境和语言中运行。
  • 可用性权衡torch.compile() 在遇到不可追踪内容时可回退至 Python 运行时,因此灵活性更高。而 torch.export 需要用户提供更多信息或重写代码以确保可追踪性。

torch.fx.symbolic_trace() 相比,torch.export 通过 TorchDynamo 在 Python 字节码层面进行追踪,使其能够处理不受 Python 运算符重载限制的任意 Python 结构。此外,torch.export 会精细追踪张量元数据,因此基于张量形状等条件的操作不会导致追踪失败。总体而言,torch.export 适用于更多用户程序,并能生成更底层的计算图(基于 torch.ops.aten 算子级别)。用户仍可将 torch.fx.symbolic_trace() 作为 torch.export 的预处理步骤。

torch.jit.script() 相比,torch.export 不捕获 Python 控制流或数据结构,但支持比 TorchScript 更多的 Python 语言特性(因其对 Python 字节码的覆盖更全面)。生成的计算图更简洁,仅包含直线控制流(显式控制流算子除外)。

torch.jit.trace() 相比,torch.export 具备可靠性:它能追踪对张量尺寸进行整数运算的代码,并记录所有必要的边界条件,以证明特定追踪结果对其他输入的有效性。


导出 PyTorch 模型


示例

主入口是通过 torch.export.export(),它接收一个可调用对象(torch.nn.Module、函数或方法)和示例输入,并将计算图捕获到 torch.export.ExportedProgram 中。示例如下:

python 复制代码
import torch
from torch.export import export

# Simple module for demonstration
class M(torch.nn.Module):
    def __init__(self) -None:
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=3, out_channels=16, kernel_size=3, padding=1
        )
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

    def forward(self, x: torch.Tensor, *, constant=None) -torch.Tensor:
        a = self.conv(x)
        a.add_(constant)
        return self.maxpool(self.relu(a))

example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}

exported_program: torch.export.ExportedProgram = export(
    M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)

python 复制代码
ExportedProgram:
    class GraphModule(torch.nn.Module):
    def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
            # code: a = self.conv(x)
            conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1])

            # code: a.add_(constant)
            add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant)

            # code: return self.maxpool(self.relu(a))
            relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_)
            max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3])
            return (max_pool2d,)

Graph signature:
    ExportGraphSignature(
        input_specs=[
            InputSpec(
                kind=<InputKind.PARAMETER: 2>,
                arg=TensorArgument(name='p_conv_weight'),
                target='conv.weight',
                persistent=None
            ),
            InputSpec(
                kind=<InputKind.PARAMETER: 2>,
                arg=TensorArgument(name='p_conv_bias'),
                target='conv.bias',
                persistent=None
            ),
            InputSpec(
                kind=<InputKind.USER_INPUT: 1>,
                arg=TensorArgument(name='x'),
                target=None,
                persistent=None
            ),
            InputSpec(
                kind=<InputKind.USER_INPUT: 1>,
                arg=TensorArgument(name='constant'),
                target=None,
                persistent=None
            )
        ],
        output_specs=[
            OutputSpec(
                kind=<OutputKind.USER_OUTPUT: 1>,
                arg=TensorArgument(name='max_pool2d'),
                target=None
            )
        ]
    )
Range constraints: {}

在检查ExportedProgram时,我们可以注意到以下几点:

  • torch.fx.Graph包含了原始程序的计算图,同时保留了原始代码记录以便于调试。
  • 图中仅包含此处列出的torch.ops.aten运算符和自定义运算符,且完全可运行,不包含任何原地操作符(如torch.add_)。
  • 参数(如卷积层的权重和偏置)被提升为图的输入节点,因此图中不再存在torch.fx.symbolic_trace()结果中曾出现的get_attr节点。
  • torch.export.ExportGraphSignature对输入输出签名进行建模,并明确指定哪些输入是参数。
  • 图中每个节点输出的张量形状和数据类型都有标注。例如,convolution节点将产生一个数据类型为torch.float32、形状为(1, 16, 256, 256)的张量。

非严格导出模式

在 PyTorch 2.3 版本中,我们引入了一种新的追踪模式------非严格模式。该功能目前仍处于完善阶段,如果您遇到任何问题,请在 GitHub 上提交 issue 并标记 "oncall: export" 标签。

非严格模式下,我们通过 Python 解释器执行程序追踪。您的代码会像在即时执行模式(eager mode)中一样运行,唯一的区别是所有 Tensor 对象都会被替换为 ProxyTensor,这些代理张量会将其所有操作记录到计算图中。

当前默认使用的是严格模式,该模式下我们首先使用 TorchDynamo(一个字节码分析引擎)进行程序追踪。TorchDynamo 实际上不会执行您的 Python 代码,而是对其进行符号化分析,并根据分析结果构建计算图。这种分析方式使得 torch.export 能够提供更强的安全性保证,但并非所有 Python 代码都受支持。

当您遇到 TorchDynamo 不支持的特性且难以解决时,如果确定相关 Python 代码并非计算所必需,就可以考虑使用非严格模式。例如:

python 复制代码
import contextlib
import torch

class ContextManager():
    def __init__(self):
        self.count = 0
    def __enter__(self):
        self.count += 1
    def __exit__(self, exc_type, exc_value, traceback):
        self.count -= 1

class M(torch.nn.Module):
    def forward(self, x):
        with ContextManager():
            return x.sin() + x.cos()

export(M(), (torch.ones(3, 3),), strict=False)  # Non-strict traces successfully
export(M(), (torch.ones(3, 3),))  # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager

在这个示例中,首次调用使用非严格模式(通过strict=False标志)能成功追踪,而第二次采用默认严格模式的调用则失败了,因为TorchDynamo无法支持上下文管理器。一种解决方案是重写代码(参见torch.export的限制),但考虑到上下文管理器不会影响模型中的张量计算,我们可以直接采用非严格模式的结果。


训练与推理导出功能

在 PyTorch 2.5 中,我们推出了名为 export_for_training() 的新 API。该功能目前仍在强化阶段,如果您遇到任何问题,请在 Github 上提交问题并标记 "oncall: export" 标签。

此 API 会生成包含所有 ATen 算子(包括功能性和非功能性)的最通用中间表示(IR),可用于 PyTorch Autograd 的即时训练模式。该 API 主要面向即时训练场景,例如 PT2 量化,并将很快成为 torch.export.export 的默认 IR。要深入了解这一变更背后的动机,请参阅 https://dev-discuss.pytorch.org/t/why-pytorch-does-not-need-a-new-standardized-operator-set/2206

当此 API 与 run_decompositions() 结合使用时,您应该能够获得具有任何所需分解行为的推理 IR。

以下是一些示例:

python 复制代码
class ConvBatchnorm(torch.nn.Module):
    def __init__(self) -None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

mod = ConvBatchnorm()
inp = torch.randn(1, 1, 3, 3)

ep_for_training = torch.export.export_for_training(mod, (inp,))
print(ep_for_training)

python 复制代码
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
            conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
            add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1)
            batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True)
            return (batch_norm,)

从上述输出可以看出,export_for_training()生成的ExportedProgram与export()几乎相同,除了图中的运算符不同。可以看到我们以最通用的形式捕获了batch_norm操作。该操作是非功能性的,在运行推理时会被降级为不同的操作。

你还可以通过run_decompositions()从这个中间表示(IR)转换到推理IR,并进行任意自定义。


python 复制代码
# Lower to core aten inference IR, but keep conv2d
decomp_table = torch.export.default_decompositions()
del decomp_table[torch.ops.aten.conv2d.default]
ep_for_inference = ep_for_training.run_decompositions(decomp_table)

print(ep_for_inference)

python 复制代码
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
            conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
            add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
            _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
            getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
            getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
            getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]
            return (getitem_3, getitem_4, add, getitem)

可以看到,我们在保持 IR 中 conv2d 算子不变的同时,分解了其余部分。现在该 IR 已成为一个功能型中间表示,仅保留 conv2d 以外的核心 aten 算子。

通过直接注册您选择的分解行为,您可以实现更深入的定制化。您还可以通过直接注册自定义分解行为来获得更灵活的定制能力。


python 复制代码
# Lower to core aten inference IR, but customize conv2d
decomp_table = torch.export.default_decompositions()

def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
    return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)

decomp_table[torch.ops.aten.conv2d.default] = my_awesome_conv2d_function
ep_for_inference = ep_for_training.run_decompositions(decomp_table)

print(ep_for_inference)

python 复制代码
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
            convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1)
            mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2)
            add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
            _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
            getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
            getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
            getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];
            return (getitem_3, getitem_4, add, getitem)

表达动态性

默认情况下,torch.export 会假设所有输入形状都是静态的 来追踪程序,并将导出的程序特化到这些维度。然而,某些维度(例如批次维度)可以是动态的,每次运行都可能变化。必须通过使用 torch.export.Dim() API 创建这些维度,并通过 dynamic_shapes 参数将它们传递给 torch.export.export() 来指定这些维度。示例如下:

python 复制代码
import torch
from torch.export import Dim, export

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.branch1 = torch.nn.Sequential(
            torch.nn.Linear(64, 32), torch.nn.ReLU()
        )
        self.branch2 = torch.nn.Sequential(
            torch.nn.Linear(128, 64), torch.nn.ReLU()
        )
        self.buffer = torch.ones(32)

    def forward(self, x1, x2):
        out1 = self.branch1(x1)
        out2 = self.branch2(x2)
        return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

exported_program: torch.export.ExportedProgram = export(
    M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(exported_program)

python 复制代码
ExportedProgram:
class GraphModule(torch.nn.Module):
    def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"):

         # code: out1 = self.branch1(x1)
        linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias)
        relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear)

         # code: out2 = self.branch2(x2)
        linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias)
        relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1)

         # code: return (out1 + self.buffer, out2)
        add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer)
        return (add, relu_1)

Range constraints: {s0: VR[0, int_oo]}

需要注意的几点补充事项:

  • 通过 torch.export.Dim() API 和 dynamic_shapes 参数,我们指定了每个输入的第一个维度为动态维度。观察输入 x1x2,它们的符号形状分别为 (s0, 64) 和 (s0, 128),而非我们传入的示例输入中 (32, 64) 和 (32, 128) 形状的张量。

s0 是一个符号,表示该维度可以接受一定范围内的数值。

  • exported_program.range_constraints 描述了图中每个符号的取值范围。在本例中,我们看到 s0 的范围是 [0, int_oo]。由于某些技术原因(此处难以详述),系统假定这些取值不为 0 或 1。这并非程序错误,也不意味着导出的程序一定无法处理维度为 0 或 1 的情况。关于此话题的深入讨论,请参阅文档 0/1 特化问题

我们还可以指定输入形状之间更具表现力的关系,例如:两个形状可能相差 1,某个形状可能是另一个的两倍,或者某个形状是偶数。示例如下:

python 复制代码
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y[1:]

x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1

exported_program = torch.export.export(
    M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}), )
print(exported_program)

python 复制代码
ExportedProgram:
class GraphModule(torch.nn.Module):
    def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"):
        # code: return x + y[1:]
        slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807)
        add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1)
        return (add,)

Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]}

需要注意以下几点:

  • 当为第一个输入指定 {0: dimx} 时,可以看到第一个输入的结果形状变为动态的 [s0]。接着为第二个输入指定 {0: dimy} 时,第二个输入的结果形状也变为动态。但由于我们定义了 dimy = dimx + 1y 的形状并未引入新符号,而是沿用 x 中的符号 s0 来表示。可以看到 dimy = dimx + 1 的关系通过 s0 + 1 体现。
  • 观察范围约束条件,s0 的初始范围是 [3, 6],而 s0 + 1 的解算范围是 [4, 7]。

序列化

要保存 ExportedProgram,用户可以使用 torch.export.save()torch.export.load() API。通常建议使用 .pt2 文件扩展名来保存 ExportedProgram


示例:

python 复制代码
import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

exported_program = torch.export.export(MyModule(), torch.randn(5))

torch.export.save(exported_program, 'exported_program.pt2')

saved_exported_program = torch.export.load('exported_program.pt2')

特化机制

理解torch.export行为的一个核心概念在于区分静态值动态值

动态值 指每次运行可能发生变化的值。这类值的行为类似于Python函数的常规参数------你可以为同一参数传递不同值,并期望函数能正确执行。张量的数据就被视为动态值。

静态值则是在导出时固定,且在导出程序多次执行间保持不变的值。当跟踪过程中遇到静态值时,导出器会将其视为常量并硬编码到计算图中。

当执行某个操作(例如x + y)且所有输入均为静态值时,该操作的输出会直接硬编码到计算图中,该操作将不会显式出现(即被常量折叠优化)。

当一个值被硬编码到计算图中时,我们称该计算图已针对该值进行了特化

以下类型的值属于静态值:

输入张量形状

默认情况下,torch.export 会根据输入张量的具体形状进行程序追踪,除非通过 dynamic_shapes 参数将某个维度指定为动态。这意味着如果存在依赖形状的控制流,torch.export 将根据给定示例输入所触发的分支进行特化处理。例如:

python 复制代码
import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x):
        if x.shape[0] > 5:
            return x + 1
        else:
            return x - 1

example_inputs = (torch.rand(10, 2),)
exported_program = export(Mod(), example_inputs)
print(exported_program)

python 复制代码
ExportedProgram:
class GraphModule(torch.nn.Module):
    def forward(self, x: "f32[10, 2]"):
        # code: return x + 1
        add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1)
        return (add,)

条件判断 (x.shape[0] > 5) 不会出现在 ExportedProgram 中,因为示例输入的静态形状是 (10, 2)。由于 torch.export 会针对输入的静态形状进行特化处理,else分支 (x - 1) 将永远不会被执行。若要在追踪图中保留基于张量形状的动态分支行为,需要使用 torch.export.Dim() 来指定输入张量的维度 (x.shape[0]) 为动态维度,同时需要重写源代码

请注意:作为模块状态一部分的张量(如参数和缓冲区)始终具有静态形状。


Python 基本类型

torch.export 同样支持对 Python 基本类型的特化处理,例如 intfloatboolstr。不过它们也有对应的动态变体,如 SymIntSymFloatSymBool


例如:

python 复制代码
import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x: torch.Tensor, const: int, times: int):
        for i in range(times):
            x = x + const
        return x

example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(Mod(), example_inputs)
print(exported_program)

python 复制代码
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[2, 2]", const, times):
            # code: x = x + const
            add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1)
            add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1)
            add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1)
            return (add_2,)

由于整数是特化的,torch.ops.aten.add.Tensor 操作都会使用硬编码的常量 1 进行计算,而非变量 const。如果用户在运行时传入与导出时不同的 const 值(例如 2 而非 1),就会导致错误。

此外,for 循环中使用的 times 迭代器也通过 3 次重复的 torch.ops.aten.add.Tensor 调用被"内联"到计算图中,而输入参数 times 实际上从未被使用。


Python 容器

Python 容器(ListDictNamedTuple 等)被认为具有静态结构。


torch.export 的局限性


图中断问题

由于torch.export是一个从PyTorch程序中捕获计算图的一次性过程,它最终可能会遇到程序无法追踪的部分,因为几乎不可能支持追踪所有PyTorch和Python特性。在torch.compile的情况下,遇到不支持的操作会导致"图中断",该操作将通过默认的Python解释器执行。相比之下,torch.export会要求用户提供额外信息或重写部分代码使其可追踪。由于追踪基于TorchDynamo(在Python字节码级别进行评估),与之前的追踪框架相比,所需的代码重写将显著减少。

当遇到图中断时,ExportDB是了解支持/不支持程序类型以及如何重写程序使其可追踪的绝佳资源。

解决图中断问题的一个选项是使用非严格导出模式


数据/形状相关的控制流

当形状未被特化时,在数据依赖的控制流(如if x.shape[0] > 2)中也可能遇到图中断问题。这是因为追踪编译器无法在不生成组合爆炸路径数量的代码的情况下处理这种情况。此时,用户需要使用特定的控制流运算符重写代码。目前,我们支持使用torch.cond来表达类似if-else的控制流(更多功能即将推出!)。


算子缺少 Fake/Meta/Abstract 内核实现

在进行追踪时,所有算子都必须具备 FakeTensor 内核(也称为元内核或抽象实现)。该内核用于推导该算子的输入/输出形状。

更多详情请参阅 torch.library.register_fake()

若您的模型使用了尚未实现 FakeTensor 内核的 ATen 算子,请提交问题报告。


扩展阅读

面向导出用户的附加链接

PyTorch开发者深度指南


API 参考


python 复制代码
torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=True, preserve_module_call_signature=())

export() 接收任意 nn.Module 及示例输入,并以预先编译(AOT)的方式生成一个仅表示函数张量计算过程的追踪图。该追踪图具有以下特性:(1) 生成符合功能化 ATen 算子集的标准化算子(以及用户指定的任何自定义算子);(2) 消除了所有 Python 控制流和数据结构(特定例外情况除外);(3) 记录了所需的形状约束集合,以证明这种标准化和控制流消除对未来输入是可靠的。

可靠性保证

在追踪过程中,export() 会记录用户程序及底层 PyTorch 算子内核对形状相关的假设。只有当这些假设成立时,输出的 ExportedProgram 才被视为有效。

追踪过程会对输入张量的形状(而非数值)作出假设。这些假设必须在图捕获阶段完成验证,export() 才能成功执行。具体而言:

  • 对输入张量静态形状的假设会自动验证,无需额外操作。
  • 对输入张量动态形状的假设需要通过 Dim() API 显式声明动态维度,并通过 dynamic_shapes 参数将其与示例输入关联。

若任何假设无法验证,将触发致命错误。此时,错误信息会包含验证假设所需的规范修改建议。例如,export() 可能针对输入 x 关联形状中出现的动态维度 dim0_x(假设先前定义为 Dim("dim0_x"))给出如下修正建议:

python 复制代码
dim = Dim("dim0_x", max=5)

这个示例意味着生成的代码要求输入x的第0维必须小于或等于5才有效。您可以检查针对动态维度定义的建议修正方案,然后原封不动地复制到代码中,而无需修改export()调用中的dynamic_shapes参数。

参数说明

  • mod (Module) -- 我们将追踪该模块的forward方法。

  • args (tuple[Any, ...]) -- 示例位置输入。

  • kwargs (Optional[dict[str, Any]]) -- 可选的示例关键字输入。

  • dynamic_shapes (Optional[Union[dict[str, Any], tuple[Any], list[Any]]]) -- 可选参数,其类型应为以下之一:

    1. f的参数名到其动态形状规范的字典;
    2. 按原始顺序为每个输入指定动态形状规范的元组。

    若要对关键字参数指定动态性,需按照原始函数签名中定义的顺序传递它们。

    张量参数的动态形状可通过以下方式指定:

    (1) 从动态维度索引到Dim()类型的字典(静态维度索引无需包含在此字典中,若包含则应映射为None);

    (2) Dim()类型或None组成的元组/列表,其中Dim()类型对应动态维度,静态维度用None表示。

    对于字典或张量元组/列表类型的参数,可通过递归使用包含规范的映射或序列来指定。

  • strict (bool) -- 启用时(默认),导出函数将通过TorchDynamo追踪程序以确保结果图的正确性。否则,导出程序不会验证图中隐含的假设,可能导致原始模型与导出模型的行为差异。这在用户需要绕过追踪器错误或逐步启用模型安全性时很有用。注意这不会导致最终IR规范不同,无论此处传递何值,模型都将以相同方式序列化。

警告:此选项为实验性功能,使用时需自行承担风险。

返回值:返回包含被追踪可调用对象的ExportedProgram

返回类型:ExportedProgram

可接受的输入/输出类型

argskwargs的输入及输出可接受类型包括:

  • 基本类型:torch.Tensorintfloatboolstr
  • 数据类(需先通过调用register_dataclass()注册)
  • 包含上述所有类型的(嵌套)数据结构:dictlisttuplenamedtupleOrderedDict

python 复制代码
torch.export.save(ep, f, *, extra_files=None, opset_version=None, pickle_protocol=2)

警告: 当前功能处于积极开发阶段,保存的文件可能无法在 PyTorch 新版本中使用。

ExportedProgram 保存到类文件对象中,后续可通过 Python API torch.export.load 加载。

参数说明

  • ep (ExportedProgram) -- 待保存的导出程序
  • f (str | os.PathLike[str] | *IO[bytes ]) -- 需实现 write 和 flush 方法的类文件对象,或包含文件名的字符串
  • extra_files (Optional[Dict[str, Any]]) -- 文件名到内容的映射,这些内容将作为文件的一部分存储
  • opset_version (Optional[Dict[str, int ]]) -- 操作集名称到其版本的映射
  • pickle_protocol ( int ) -- 可指定以覆盖默认协议

示例:

python 复制代码
import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

ep = torch.export.export(MyModule(), (torch.randn(5),))

# Save to file
torch.export.save(ep, 'exported_program.pt2')

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.export.save(ep, buffer)

# Save with extra files
extra_files = {'foo.txt': b'bar'.decode('utf-8')}
torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)

python 复制代码
torch.export.load(f, *, extra_files=None, expected_opset_version=None)

警告:当前功能处于积极开发阶段,保存的文件可能无法在较新版本的PyTorch中使用。

加载先前通过 torch.export.save 保存的 ExportedProgram

参数

  • f (str | os.PathLike[str] | *IO[bytes ]) -- 文件类对象(需实现write和flush方法)或包含文件名的字符串。
  • extra_files (Optional[Dict[str, Any]]) -- 此映射中提供的额外文件名将被加载,其内容会存储到给定的映射中。
  • expected_opset_version (Optional[Dict[str, int ]]) -- 操作集名称到预期版本号的映射

返回值 :一个 ExportedProgram 对象

返回类型:ExportedProgram

示例

python 复制代码
import torch
import io

# Load ExportedProgram from file
ep = torch.export.load('exported_program.pt2')

# Load ExportedProgram from io.BytesIO object
with open('exported_program.pt2', 'rb') as f:
    buffer = io.BytesIO(f.read())
buffer.seek(0)
ep = torch.export.load(buffer)

# Load with extra files.
extra_files = {'foo.txt': ''}  # values will be replaced with data
ep = torch.export.load('exported_program.pt2', extra_files=extra_files)
print(extra_files['foo.txt'])
print(ep(torch.randn(5)))

python 复制代码
torch.export.register_dataclass(cls, *, serialized_type_name=None)

将数据类注册为 torch.export.export() 的有效输入/输出类型。

参数

  • cls (type[Any]) - 要注册的数据类类型
  • serialized_type_name (Optional[str]) - 数据类的序列化名称。这是* this (当需要序列化包含数据类的 pytree TreeSpec 时为必填项)

示例:

python 复制代码
import torch

from dataclasses import dataclass

@dataclass
class InputDataClass:
    feature: torch.Tensor
    bias: int

@dataclass 
class OutputDataClass:
    res: torch.Tensor

python 复制代码
torch.export.register_dataclass(InputDataClass)
torch.export.register_dataclass(OutputDataClass)

class Mod(torch.nn.Module):
    def forward(self, x: InputDataClass) -OutputDataClass:
        res = x.feature + x.bias
        return OutputDataClass(res=res)

ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), ))
print(ep)

python 复制代码
torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)

Dim() 构造了一个类似于带范围命名符号整数的类型。

该类型可用于描述动态张量维度的多种可能取值。

注意:同一张量的不同动态维度,或不同张量的动态维度,都可以用同一类型来描述。

参数说明

  • name (str) - 用于调试的可读名称
  • min (Optional[int]) - 符号的最小可能值(包含)
  • max (Optional[int]) - 符号的最大可能值(包含)

返回值

返回一个可用于张量动态形状规范的类型。


python 复制代码
torch.export.exported_program.default_decompositions()

这是默认的分解表,包含了所有 ATEN 算子向核心 aten 算子集的分解映射。请将此 API 与 run_decompositions() 配合使用。

返回类型 : CustomDecompTable


python 复制代码
torch.export.dims(*names, min=None, max=None)

用于创建多个 Dim() 类型的工具函数。

返回值:返回由 Dim() 类型组成的元组。

返回类型:tuple[torch.export.dynamic_shapes._Dim, ...]


python 复制代码
class torch.export.dynamic_shapes.ShapesCollection

动态形状构建器。

用于为输入张量分配动态形状规格。

args()是嵌套输入结构时特别有用,相比在dynamic_shapes()规范中复制args()的结构,直接索引输入张量会更方便。


示例:

python 复制代码
args = ({"x": tensor_x, "others": [tensor_y, tensor_z]})

dim = torch.export.Dim(...)
dynamic_shapes = torch.export.ShapesCollection()
dynamic_shapes[tensor_x] = (dim, dim + 1, 8)
dynamic_shapes[tensor_y] = {0: dim * 2}
# This is equivalent to the following (now auto-generated):
# dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]}

torch.export(..., args, dynamic_shapes=dynamic_shapes)

python 复制代码
dynamic_shapes(m, args, kwargs=None)

生成与 args()kwargs() 对应的 dynamic_shapes() pytree 结构。


python 复制代码
torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)

使用 dynamic_shapes() 导出时,如果规格与模型追踪推断出的约束条件不匹配,可能会因 ConstraintViolation 错误导致导出失败。错误信息通常会提供修正建议------即对 dynamic_shapes() 进行哪些修改才能成功导出。

ConstraintViolation 错误示例信息:

python 复制代码
Suggested fixes:

    dim = Dim('dim', min=3, max=6)  # this just refines the dim's range
    dim = 4  # this specializes to a constant
    dy = dx + 1  # dy was specified as an independent dim, but is actually tied to dx with this relation

这是一个辅助函数,它接收 ConstraintViolation 错误信息和原始的 dynamic_shapes() 规格参数,返回一个整合了建议修复方案的新 dynamic_shapes() 规格。

使用示例:

python 复制代码
try:
    ep = export(mod, args, dynamic_shapes=dynamic_shapes)
except torch._dynamo.exc.UserError as exc:
    new_shapes = refine_dynamic_shapes_from_suggested_fixes(
        exc.msg, dynamic_shapes
    )
    ep = export(mod, args, dynamic_shapes=new_shapes)

返回类型:Union[dict[str, Any], tuple[Any], list[Any]]


python 复制代码
torch.export.Constraint 

Union[_Constraint, _DerivedConstraint, _RelaxedConstraint]的别名


python 复制代码
class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, constants=None, *, verifiers=None)

export() 导出的程序包。它包含:

  • 一个表示张量计算的 torch.fx.Graph
  • 包含所有提升参数和缓冲区张量值的 state_dict
  • 各种元数据

可以像调用原始可调用对象那样调用 ExportedProgram,其调用约定与 export() 追踪的版本相同。

要对计算图进行转换时,可通过 .module 属性访问 torch.fx.GraphModule,然后使用 FX 变换重写计算图。完成后,只需再次调用 export() 即可构建正确的 ExportedProgram。


python 复制代码
module()

返回一个自包含的 GraphModule,其中内联了所有参数/缓冲区。

返回类型:Module


python 复制代码
buffers()

返回一个遍历原始模块缓冲区的迭代器。

警告:此 API 是实验性的,且向后兼容。

返回类型:Iterator[Tensor]


python 复制代码
named_buffers()

返回一个遍历原始模块缓冲区的迭代器,同时生成缓冲区的名称和缓冲区本身。

警告:此 API 为实验性质,且向后兼容。

返回类型:Iterator[tuple[str, torch.Tensor]]


python 复制代码
parameters()

返回一个遍历原始模块参数的迭代器。

警告:此 API 为实验性质,且向后兼容。

返回类型:迭代器 [Parameter]


python 复制代码
named_parameters()

返回一个遍历原始模块参数的迭代器,同时生成参数名称和参数本身。

警告:此 API 为实验性质,且向后兼容。

返回类型:Iterator[tuple[str, torch.nn.parameter.Parameter]]


python 复制代码
run_decompositions(decomp_table=None, decompose_custom_triton_ops=False)

对导出的程序运行一系列分解操作,并返回一个新的导出程序。默认情况下,我们会运行 Core ATen 分解来获取 Core ATen Operator Set 中的运算符。

目前暂不支持分解联合图。

参数

  • decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]) - 可选参数,用于指定 Aten 运算符的分解行为:(1) 如果为 None,则分解至核心 aten 分解;(2) 如果为空字典,则不分解任何运算符

返回类型ExportedProgram

使用示例

如果不想分解任何操作


python 复制代码
ep = torch.export.export(model, ...)
ep = ep.run_decompositions(decomp_table={}) 

如果你想获取核心 aten 操作符集合(排除特定操作符),可以执行以下操作:


python 复制代码
ep = torch.export.export(model, ...)
decomp_table = torch.export.default_decompositions()
decomp_table[your_op] = your_custom_decomp
ep = ep.run_decompositions(decomp_table=decomp_table) 

python 复制代码
class torch.export.ExportBackwardSignature(gradients_to_parameters: dict[str, str], gradients_to_user_inputs: dict[str, str], loss_output: str)

python 复制代码
class torch.export.ExportGraphSignature(input_specs, output_specs)

ExportGraphSignature 定义了导出图的输入/输出签名,这是一个具有更强不变性保证的 fx.Graph。

导出图是纯函数式的,不会通过 getattr 节点访问图内的"状态"(如参数或缓冲区)。相反,export() 确保将参数、缓冲区和常量张量都作为输入从图中提取出来。同样,对缓冲区的任何修改也不会包含在图中,而是将修改后的缓冲区值建模为导出图的额外输出。

所有输入和输出的顺序如下:

python 复制代码
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]

例如,如果导出以下模块:

python 复制代码
class CustomModule(nn.Module):
    def __init__(self) -> None:
        super(CustomModule, self).__init__()

        # Define a parameter
        self.my_parameter = nn.Parameter(torch.tensor(2.0))

        # Define two buffers
        self.register_buffer('my_buffer1', torch.tensor(3.0))
        self.register_buffer('my_buffer2', torch.tensor(4.0))

    def forward(self, x1, x2):
        # Use the parameter, buffers, and both inputs in the forward method
        output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2

        # Mutate one of the buffers (e.g., increment it by 1)
        self.my_buffer2.add_(1.0) # In-place addition

        return output

生成的图表如下:

python 复制代码
graph():
    %arg0_1 := placeholder[target=arg0_1]
    %arg1_1 := placeholder[target=arg1_1]
    %arg2_1 := placeholder[target=arg2_1]
    %arg3_1 := placeholder[target=arg3_1]
    %arg4_1 := placeholder[target=arg4_1]
    %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {})
    %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {})
    %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {})
    %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {})
    %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {})
    return (add_tensor_2, add_tensor_1)

生成的 ExportGraphSignature 将是:

json 复制代码
ExportGraphSignature(
    input_specs=[
        InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None)
    ],
    output_specs=[
        OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'),
        OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)
    ]
)

python 复制代码
replace_all_uses(old, new)

将签名中所有旧名称的使用替换为新名称。


python 复制代码
get_replace_hook(replace_inputs=False

python 复制代码
class torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str, fake_val: Optional[torch._library.fake_class_registry.FakeScriptObject] = None)

python 复制代码
class torch.export.unflatten.FlatArgsAdapter

根据输入参数 input_spec 进行调整,使其与 target_spec 对齐。


python 复制代码
abstract adapt(target_spec, input_spec, input_args, metadata=None)

注意 :此适配器可能会修改给定的 input_args_with_path

返回类型:list[Any]


python 复制代码
class torch.export.unflatten.InterpreterModule(graph, ty=None)

一个使用 torch.fx.Interpreter 而非 GraphModule 常规代码生成来执行的模块。这种方式能提供更清晰的堆栈跟踪信息,使执行调试更加便捷。


python 复制代码
class torch.export.unflatten.InterpreterModuleDispatcher(attrs, call_modules)

一个包含一系列对应模块调用序列的InterpreterModule的模块。每次对该模块的调用都会分派给下一个InterpreterModule,并在最后一个之后循环回第一个。


python 复制代码
torch.export.unflatten.unflatten(module, flat_args_adapter=None)

展开一个ExportedProgram,生成与原始eager模块具有相同模块层级的模块。当你尝试将torch.export与其他期望模块层级(而非torch.export通常生成的扁平图)的系统结合使用时,这会很有用。

注意:展开后模块的args/kwargs不一定与eager模块匹配,因此直接进行模块替换(例如self.submod = new_mod)可能无效。如需替换模块,必须设置torch.export.export()preserve_module_call_signature参数。

参数

  • module (ExportedProgram) -- 待展开的ExportedProgram
  • flat_args_adapter (Optional[FlatArgsAdapter]) -- 当输入TreeSpec与导出模块不匹配时,用于适配扁平参数

返回值

返回UnflattenedModule实例,其模块层级与导出前的原始eager模块相同。

返回类型:UnflattenedModule


python 复制代码
torch.export.passes.move_to_device_pass(ep, location)

将导出的程序移动到指定设备。

参数

  • ep ([ExportedProgram](https://pytorch.org/docs/stable/data.html#torch.export.ExportedProgram "torch.export.ExportedProgram")) -- 需要移动的导出程序。
  • location (Union[torch.device ,* str, Dict[str,* str]]) -- 目标设备。

如果是字符串,会被解析为设备名称。

如果是字典,会被解析为从现有设备到目标设备的映射关系。

返回

移动后的导出程序。

返回类型:ExportedProgram


2025-08-20(三)

相关推荐
小和尚敲代码9 小时前
八字变十字国学api根据日期得到十字加入刻柱干支的api调用
api·十字·八字·国学·刻柱
小烤箱13 小时前
Autoware Universe 感知模块详解 | 第十二节 CUDA 编程基础——CUDA执行模型
自动驾驶·cuda·感知
koo36413 小时前
pytorch深度学习笔记12
pytorch·笔记·深度学习
新诺韦尔API15 小时前
手机三要素验证不通过的原因?
大数据·智能手机·api
koo36417 小时前
pytorch深度学习笔记9
pytorch·笔记·深度学习
创作者mateo19 小时前
PyTorch 入门笔记配套【完整练习代码】
人工智能·pytorch·笔记
创作者mateo20 小时前
PyTorch 入门学习笔记(基础篇)一
pytorch·笔记·学习
victory043120 小时前
pytorch 矩阵乘法和实际存储形状的差异
人工智能·pytorch·矩阵
EchoL、21 小时前
指定GPU设备
pytorch·笔记
m0_6136070121 小时前
小土堆-P3-笔记
pytorch·python·深度学习