一、模块概述
torch.library 是 PyTorch 提供的用于扩展核心算子库的 API 集合,主要功能包括:
| 功能 | 说明 |
|---|---|
| 测试自定义算子 | 验证算子注册是否正确 |
| 创建自定义算子 | 在 Python 中定义新算子 |
| 扩展现有算子 | 为 C++ 注册的算子添加 Python 实现 |
二、测试自定义算子
2.1 torch.library.opcheck() - 算子正确性检测
核心作用:验证自定义算子的元数据和属性是否符合 PyTorch 规范。
检测内容:
python
test_schema # 检查 schema 与实现是否匹配
test_autograd_registration # 检查自动微分是否正确注册
test_faketensor # 检查 FakeTensor 内核是否正确
test_aot_dispatch_dynamic # 检查与 torch.compile 的兼容性
使用示例:
python
import torch
from torch.library import custom_op
@custom_op("mylib::numpy_mul", mutates_args=())
def numpy_mul(x: torch.Tensor, y: float) -> torch.Tensor:
x_np = x.numpy(force=True)
z_np = x_np * y
return torch.from_numpy(z_np).to(x.device)
@numpy_mul.register_fake
def _(x, y):
return torch.empty_like(x)
# 测试多种输入场景
sample_inputs = [
(torch.randn(3), 3.14),
(torch.randn(2, 3, device='cuda'), 2.718),
(torch.randn(1, 10, requires_grad=True), 1.234),
]
for args in sample_inputs:
torch.library.opcheck(numpy_mul, args) # ✅ 通过则无输出
⚠️ 注意 :
opcheck与torch.autograd.gradcheck()互补,前者测试 API 使用正确性,后者测试梯度数学正确性。
三、创建自定义算子
3.1 torch.library.custom_op() - 通用自定义算子
核心参数:
| 参数 | 说明 |
|---|---|
name |
命名空间格式:"namespace::name",如 "mylib::my_linear" |
mutates_args |
被修改的参数名列表,必须准确指定 |
device_types |
支持的设备类型:"cpu"、"cuda" 等 |
schema |
算子 schema 字符串(推荐自动推断) |
基础示例:
python
from torch.library import custom_op
import numpy as np
# 示例1:简单算子
@custom_op("mylib::numpy_sin", mutates_args=())
def numpy_sin(x: torch.Tensor) -> torch.Tensor:
x_np = x.cpu().numpy()
y_np = np.sin(x_np)
return torch.from_numpy(y_np).to(device=x.device)
# 示例2:原地修改算子
@custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
def numpy_sin_inplace(x: torch.Tensor) -> None:
x_np = x.numpy()
np.sin(x_np, out=x_np) # 原地修改
# 示例3:工厂函数(无输入 Tensor)
@custom_op("mylib::bar", mutates_args={}, device_types="cpu")
def bar(device: torch.device) -> torch.Tensor:
return torch.ones(3)
3.2 torch.library.triton_op() - Triton 内核封装
适用场景 :当实现包含 Triton 内核 时使用,允许 torch.compile 优化 Triton 代码。
python
from torch.library import triton_op, wrap_triton
import triton
from triton import language as tl
@triton.jit
def add_kernel(in_ptr0, in_ptr1, out_ptr, n_elements, BLOCK_SIZE: "tl.constexpr"):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
tl.store(out_ptr + offsets, x + y, mask=mask)
@triton_op("mylib::add", mutates_args={})
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.empty_like(x)
n_elements = output.numel()
def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
# 必须用 wrap_triton 包装
wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
return output
# 可与 torch.compile 配合使用
@torch.compile
def f(x, y):
return add(x, y)
3.3 torch.library.wrap_triton() - Triton 内核追踪包装
作用 :使 Triton 内核可被 make_fx 或 torch.export 捕获到计算图中。
四、扩展现有算子
4.1 register_kernel() - 注册设备特定实现
python
@custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
def numpy_sin(x: torch.Tensor) -> torch.Tensor:
x_np = x.numpy()
y_np = np.sin(x_np)
return torch.from_numpy(y_np)
# 为 CUDA 添加实现
@torch.library.register_kernel("mylib::numpy_sin", "cuda")
def _(x):
x_np = x.cpu().numpy()
y_np = np.sin(x_np)
return torch.from_numpy(y_np).to(device=x.device)
4.2 register_fake() / impl_abstract() - 注册 FakeTensor 实现
别名 :impl_abstract 在 PyTorch 2.4 后重命名为 register_fake。
作用 :定义算子在无数据 Tensor(FakeTensor/Meta Tensor)上的行为,支持编译和导出。
python
# 示例1:常规算子
@torch.library.register_fake("mylib::custom_linear")
def _(x, weight, bias):
assert x.dim() == 2 and weight.dim() == 2
assert x.shape[1] == weight.shape[1]
return (x @ weight.t()) + bias # 返回元数据正确的空 Tensor
# 示例2:数据依赖形状(如 nonzero)
@torch.library.register_fake("mylib::custom_nonzero")
def _(x):
ctx = torch.library.get_ctx() # 获取上下文
nnz = ctx.new_dynamic_size() # 创建动态符号整数
return x.new_empty([nnz, x.dim()], dtype=torch.int64)
4.3 register_autograd() - 注册反向传播
python
@custom_op("mylib::numpy_sin", mutates_args=())
def numpy_sin(x: torch.Tensor) -> torch.Tensor:
x_np = x.cpu().numpy()
y_np = np.sin(x_np)
return torch.from_numpy(y_np).to(device=x.device)
def setup_context(ctx, inputs, output):
x, = inputs
ctx.save_for_backward(x) # 保存前向所需数据
def backward(ctx, grad):
x, = ctx.saved_tensors
return grad * x.cos() # 返回梯度
torch.library.register_autograd("mylib::numpy_sin", backward, setup_context=setup_context)
# 测试
x = torch.randn(3, requires_grad=True)
y = numpy_sin(x)
grad_x = torch.autograd.grad(y, x, torch.ones_like(y))[0]
4.4 register_autocast() - 自动类型转换
python
@custom_op("mylib::my_sin", mutates_args=())
def my_sin(x: torch.Tensor) -> torch.Tensor:
return torch.sin(x)
# 注册 CUDA 下的 FP16 自动转换
torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
# 使用
with torch.autocast("cuda", dtype=torch.float16):
y = torch.ops.mylib.my_sin(x) # x 自动转为 fp16,输出也是 fp16
4.5 register_vmap() - 批量映射支持
python
@torch.library.register_vmap("mylib::numpy_mul")
def numpy_mul_vmap(info, in_dims, x, y):
x_bdim, y_bdim = in_dims
# 调整维度进行广播计算
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = x * y
return result.movedim(-1, 0), 0 # 返回 (输出, 输出维度)
# 使用
x = torch.randn(3)
y = torch.randn(3)
torch.vmap(numpy_mul)(x, y) # 批量映射
4.6 register_torch_dispatch() - TorchDispatch 规则
python
class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)
@torch.library.register_torch_dispatch("mylib::foo", MyMode)
def _(mode, func, types, args, kwargs):
x, = args
return x + 1 # 在 MyMode 下行为改变
# 测试
with MyMode():
y = foo(x) # 输出 x + 1
五、辅助工具
5.1 infer_schema() - 从类型注解推断 Schema
python
def foo_impl(x: torch.Tensor) -> torch.Tensor:
return x.sin()
schema = torch.library.infer_schema(foo_impl, op_name="foo", mutates_args={})
print(schema) # 输出: foo(Tensor x) -> Tensor
5.2 get_ctx() - 获取 Fake 实现上下文
仅在 register_fake 内部有效,用于创建动态符号尺寸。
5.3 get_kernel() - 获取已注册的内核
python
# 获取 aten::add 的 CPU 内核
kernel = torch.library.get_kernel("aten::add.Tensor", "CPU")
# 用于实现条件分发
original_sin = torch.library.get_kernel("aten::sin", "CPU")
def conditional_sin(dispatch_keys, x):
if (x < 0).any():
return original_sin.call_boxed(dispatch_keys, x)
return torch.zeros_like(x)
六、底层 API(Library 类)
警告:建议优先使用上述高级 API,底层 API 需要理解 PyTorch Dispatcher 机制。
python
# 创建库
my_lib = torch.library.Library("mylib", "DEF") # 定义新算子
my_lib = torch.library.Library("aten", "IMPL") # 扩展现有算子
# 定义算子
my_lib.define("sum(Tensor self) -> Tensor")
# 注册实现
my_lib.impl("div.Tensor", div_cpu, "CPU")
# 注册 fallback
my_lib.fallback(fallback_kernel, "Autocast")
七、最佳实践总结
| 场景 | 推荐 API |
|---|---|
| 快速创建自定义算子 | custom_op() |
| 包装 Triton 内核 | triton_op() + wrap_triton() |
| 支持 torch.compile | 必须实现 register_fake() |
| 支持自动微分 | register_autograd() |
| 支持多设备 | register_kernel() 按设备注册 |
| 支持 torch.vmap | register_vmap() |
| 支持自动混合精度 | register_autocast() |
| 验证算子正确性 | opcheck() + gradcheck() |
八、参考链接
💡 提示:本文基于 PyTorch 2.11 官方文档整理,建议结合官方最新文档使用。