PyTorch 学习笔记(10) : PyTorch torch.library

一、模块概述

torch.library 是 PyTorch 提供的用于扩展核心算子库的 API 集合,主要功能包括:

功能 说明
测试自定义算子 验证算子注册是否正确
创建自定义算子 在 Python 中定义新算子
扩展现有算子 为 C++ 注册的算子添加 Python 实现

二、测试自定义算子

2.1 torch.library.opcheck() - 算子正确性检测

核心作用:验证自定义算子的元数据和属性是否符合 PyTorch 规范。

检测内容

python 复制代码
test_schema              # 检查 schema 与实现是否匹配
test_autograd_registration  # 检查自动微分是否正确注册
test_faketensor          # 检查 FakeTensor 内核是否正确
test_aot_dispatch_dynamic # 检查与 torch.compile 的兼容性

使用示例

python 复制代码
import torch
from torch.library import custom_op

@custom_op("mylib::numpy_mul", mutates_args=())
def numpy_mul(x: torch.Tensor, y: float) -> torch.Tensor:
    x_np = x.numpy(force=True)
    z_np = x_np * y
    return torch.from_numpy(z_np).to(x.device)

@numpy_mul.register_fake
def _(x, y):
    return torch.empty_like(x)

# 测试多种输入场景
sample_inputs = [
    (torch.randn(3), 3.14),
    (torch.randn(2, 3, device='cuda'), 2.718),
    (torch.randn(1, 10, requires_grad=True), 1.234),
]

for args in sample_inputs:
    torch.library.opcheck(numpy_mul, args)  # ✅ 通过则无输出

⚠️ 注意opchecktorch.autograd.gradcheck() 互补,前者测试 API 使用正确性,后者测试梯度数学正确性。


三、创建自定义算子

3.1 torch.library.custom_op() - 通用自定义算子

核心参数

参数 说明
name 命名空间格式:"namespace::name",如 "mylib::my_linear"
mutates_args 被修改的参数名列表,必须准确指定
device_types 支持的设备类型:"cpu""cuda"
schema 算子 schema 字符串(推荐自动推断)

基础示例

python 复制代码
from torch.library import custom_op
import numpy as np

# 示例1:简单算子
@custom_op("mylib::numpy_sin", mutates_args=())
def numpy_sin(x: torch.Tensor) -> torch.Tensor:
    x_np = x.cpu().numpy()
    y_np = np.sin(x_np)
    return torch.from_numpy(y_np).to(device=x.device)

# 示例2:原地修改算子
@custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
def numpy_sin_inplace(x: torch.Tensor) -> None:
    x_np = x.numpy()
    np.sin(x_np, out=x_np)  # 原地修改

# 示例3:工厂函数(无输入 Tensor)
@custom_op("mylib::bar", mutates_args={}, device_types="cpu")
def bar(device: torch.device) -> torch.Tensor:
    return torch.ones(3)

3.2 torch.library.triton_op() - Triton 内核封装

适用场景 :当实现包含 Triton 内核 时使用,允许 torch.compile 优化 Triton 代码。

python 复制代码
from torch.library import triton_op, wrap_triton
import triton
from triton import language as tl

@triton.jit
def add_kernel(in_ptr0, in_ptr1, out_ptr, n_elements, BLOCK_SIZE: "tl.constexpr"):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(in_ptr0 + offsets, mask=mask)
    y = tl.load(in_ptr1 + offsets, mask=mask)
    tl.store(out_ptr + offsets, x + y, mask=mask)

@triton_op("mylib::add", mutates_args={})
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    output = torch.empty_like(x)
    n_elements = output.numel()
    
    def grid(meta):
        return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    
    # 必须用 wrap_triton 包装
    wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
    return output

# 可与 torch.compile 配合使用
@torch.compile
def f(x, y):
    return add(x, y)

3.3 torch.library.wrap_triton() - Triton 内核追踪包装

作用 :使 Triton 内核可被 make_fxtorch.export 捕获到计算图中。


四、扩展现有算子

4.1 register_kernel() - 注册设备特定实现

python 复制代码
@custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
def numpy_sin(x: torch.Tensor) -> torch.Tensor:
    x_np = x.numpy()
    y_np = np.sin(x_np)
    return torch.from_numpy(y_np)

# 为 CUDA 添加实现
@torch.library.register_kernel("mylib::numpy_sin", "cuda")
def _(x):
    x_np = x.cpu().numpy()
    y_np = np.sin(x_np)
    return torch.from_numpy(y_np).to(device=x.device)

4.2 register_fake() / impl_abstract() - 注册 FakeTensor 实现

别名impl_abstract 在 PyTorch 2.4 后重命名为 register_fake

作用 :定义算子在无数据 Tensor(FakeTensor/Meta Tensor)上的行为,支持编译和导出

python 复制代码
# 示例1:常规算子
@torch.library.register_fake("mylib::custom_linear")
def _(x, weight, bias):
    assert x.dim() == 2 and weight.dim() == 2
    assert x.shape[1] == weight.shape[1]
    return (x @ weight.t()) + bias  # 返回元数据正确的空 Tensor

# 示例2:数据依赖形状(如 nonzero)
@torch.library.register_fake("mylib::custom_nonzero")
def _(x):
    ctx = torch.library.get_ctx()  # 获取上下文
    nnz = ctx.new_dynamic_size()    # 创建动态符号整数
    return x.new_empty([nnz, x.dim()], dtype=torch.int64)

4.3 register_autograd() - 注册反向传播

python 复制代码
@custom_op("mylib::numpy_sin", mutates_args=())
def numpy_sin(x: torch.Tensor) -> torch.Tensor:
    x_np = x.cpu().numpy()
    y_np = np.sin(x_np)
    return torch.from_numpy(y_np).to(device=x.device)

def setup_context(ctx, inputs, output):
    x, = inputs
    ctx.save_for_backward(x)  # 保存前向所需数据

def backward(ctx, grad):
    x, = ctx.saved_tensors
    return grad * x.cos()     # 返回梯度

torch.library.register_autograd("mylib::numpy_sin", backward, setup_context=setup_context)

# 测试
x = torch.randn(3, requires_grad=True)
y = numpy_sin(x)
grad_x = torch.autograd.grad(y, x, torch.ones_like(y))[0]

4.4 register_autocast() - 自动类型转换

python 复制代码
@custom_op("mylib::my_sin", mutates_args=())
def my_sin(x: torch.Tensor) -> torch.Tensor:
    return torch.sin(x)

# 注册 CUDA 下的 FP16 自动转换
torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)

# 使用
with torch.autocast("cuda", dtype=torch.float16):
    y = torch.ops.mylib.my_sin(x)  # x 自动转为 fp16,输出也是 fp16

4.5 register_vmap() - 批量映射支持

python 复制代码
@torch.library.register_vmap("mylib::numpy_mul")
def numpy_mul_vmap(info, in_dims, x, y):
    x_bdim, y_bdim = in_dims
    # 调整维度进行广播计算
    x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
    y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
    result = x * y
    return result.movedim(-1, 0), 0  # 返回 (输出, 输出维度)

# 使用
x = torch.randn(3)
y = torch.randn(3)
torch.vmap(numpy_mul)(x, y)  # 批量映射

4.6 register_torch_dispatch() - TorchDispatch 规则

python 复制代码
class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        return func(*args, **kwargs)

@torch.library.register_torch_dispatch("mylib::foo", MyMode)
def _(mode, func, types, args, kwargs):
    x, = args
    return x + 1  # 在 MyMode 下行为改变

# 测试
with MyMode():
    y = foo(x)  # 输出 x + 1

五、辅助工具

5.1 infer_schema() - 从类型注解推断 Schema

python 复制代码
def foo_impl(x: torch.Tensor) -> torch.Tensor:
    return x.sin()

schema = torch.library.infer_schema(foo_impl, op_name="foo", mutates_args={})
print(schema)  # 输出: foo(Tensor x) -> Tensor

5.2 get_ctx() - 获取 Fake 实现上下文

仅在 register_fake 内部有效,用于创建动态符号尺寸。

5.3 get_kernel() - 获取已注册的内核

python 复制代码
# 获取 aten::add 的 CPU 内核
kernel = torch.library.get_kernel("aten::add.Tensor", "CPU")

# 用于实现条件分发
original_sin = torch.library.get_kernel("aten::sin", "CPU")

def conditional_sin(dispatch_keys, x):
    if (x < 0).any():
        return original_sin.call_boxed(dispatch_keys, x)
    return torch.zeros_like(x)

六、底层 API(Library 类)

警告:建议优先使用上述高级 API,底层 API 需要理解 PyTorch Dispatcher 机制。

python 复制代码
# 创建库
my_lib = torch.library.Library("mylib", "DEF")  # 定义新算子
my_lib = torch.library.Library("aten", "IMPL")  # 扩展现有算子

# 定义算子
my_lib.define("sum(Tensor self) -> Tensor")

# 注册实现
my_lib.impl("div.Tensor", div_cpu, "CPU")

# 注册 fallback
my_lib.fallback(fallback_kernel, "Autocast")

七、最佳实践总结

场景 推荐 API
快速创建自定义算子 custom_op()
包装 Triton 内核 triton_op() + wrap_triton()
支持 torch.compile 必须实现 register_fake()
支持自动微分 register_autograd()
支持多设备 register_kernel() 按设备注册
支持 torch.vmap register_vmap()
支持自动混合精度 register_autocast()
验证算子正确性 opcheck() + gradcheck()

八、参考链接


💡 提示:本文基于 PyTorch 2.11 官方文档整理,建议结合官方最新文档使用。

相关推荐
小陈phd2 小时前
多模态大模型学习笔记(三十一)—— 基于CCT(Compact Convolutional Transformers)实现中文车牌数据集微调
笔记·学习
zzh0812 小时前
MySQL故障排查与优化笔记
数据库·笔记·mysql
婷婷_1722 小时前
【PCIe 验证每日学习・Day26】PCIe 错误处理与异常恢复机制
网络·学习·程序人生·芯片·原子操作·pcie 验证
AI成长日志2 小时前
【笔面试算法学习专栏】堆与优先队列实战:力扣hot100之215.数组中的第K个最大元素、347.前K个高频元素
学习·算法·leetcode
&&Citrus2 小时前
【CPN 学习笔记(三)】—— Chap3 CPN ML 编程语言 上半部分 3.1 ~ 3.3
笔记·python·学习·cpn·petri网
航Hang*3 小时前
第3章:Linux系统安全管理——第1节:Linux 防火墙部署(firewalld)
linux·服务器·网络·学习·系统安全·vmware
宋小米的csdn3 小时前
网络知识学习路线(实用向)
网络·学习
南境十里·墨染春水3 小时前
linux学习进展 基础命令 vi基础命令
linux·运维·服务器·笔记·学习
Xudde.3 小时前
班级作业笔记报告0x08
笔记·学习·安全·web安全