pytorch custom op的简单介绍

python 复制代码
@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, ...)
	...

_torch_custom_op_wrapper是pytorch中用于封装自定义CUDA算子的装饰器,功能是将底层注册到torch.ops的算子(flash_attn::_flash_attn_varlen_forward)绑定到这个python函数。其中指定算子的元信息有:mutates_args=()表示算子不修改输入参数;devices_types="cuda"限定仅支持CUDA设备

把算子注册成pytorch custom op有什么好处
  1. 支持torch.compile
    • 自定义op会被当作一个单独的节点出现在计算图里,而不是被内联展开成大量python / C++ 调用
    • 编译器能正确识别、保留这个op,并在图优化时正确处理,而不是当成未知黑盒导致报错或回退到eager
    • 对FlashAttention这类复杂CUDA内核,注册成custom op是让torch.compile稳定工作的常见做法
  2. 支持torch.export / TorchDynamo
    • 导出模型时,需要知道每个op的输入输出、设备类型、是否修改输入等
    • custom_op的schema、mutages_args等会提供这些元信息,便于torch.export生成正确的导出图
  3. FakerTensor / 元设备执行
    • register_fake提供在meta device上的实现,可以不真正分配GPU显存
    • 常用于:
      • 形状推导(shape inference)
      • torch.compile的tracing阶段(不需要真实数据)
      • 某些分布式/大模型规划逻辑
  4. 统一到Pytorch的op系统
    • 通过torch.ops.<namespace>.<op_name>调用,和其他内置op用法一致
    • 便于做op级别的分析、profiling、调度等
  5. 对autograd的可控性
    • 可以显式声明mutates_args(例如哪些输入会被in-place修改)
    • 配合自定义backward,让autograd正确传播梯度、正确处理需要梯度的输入

把FlashAttention注册成PyTorch custom op后,FlashAttention就是一整块融合的CUDA内核,作为整体暴露给PyTorch。对于torch.compile来说,这就是一个原子操作,要么作为单个节点被正确处理,要么只能当做黑盒触发graph break或报错

FlashAttention中的两种调用区别
python 复制代码
if torch.__version__ >= "2.4.0":
    _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward
else:
    _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward

这两段代码在算子本体(CUDA调用)上是一样的,差别在于在PyTorch中怎么暴露、怎么和torch.compile等配合。

  1. torch.__version__>=2.4.0时:用@_torch_custom_op_wrapper把函数注册为Pytorch custom op,_wrapped_flash_attn_varlen_forward则指向torch.ops里注册好的那个op,和custom_op / register_fake 走的是同一套机制,和torch.compile / torch.export / FakeTensor等兼容
  2. torch.__version__<2.4.0时:没有通过torch.library.custom_op挂到torch.ops上,对2.4之后那套编译/图捕获的支持不如新路径
相关推荐
不懒不懒2 小时前
【实战案例:基于特征匹配的指纹识别系统开发】
人工智能·opencv·计算机视觉
chushiyunen2 小时前
uv使用笔记(python包的管理工具)
笔记·python·uv
曲幽2 小时前
FastAPI状态共享秘籍:别再让中间件、依赖和路由“各自为政”了!
python·fastapi·web·request·state·depends·middleware
风清扬【coder】2 小时前
Anaconda 被误删后抢救手册:数据恢复 + 环境重建应急流程
python·数据恢复·anaconda·环境重建
2401_884563242 小时前
进阶技巧与底层原理
jvm·数据库·python
2401_873204652 小时前
使用Pandas进行数据分析:从数据清洗到可视化
jvm·数据库·python
ZGi.ai2 小时前
生产级 Agent 编排 从单一 LLM 调用到多智能体工作流的工程设计
大数据·数据库·人工智能
木斯佳2 小时前
前端八股文面经大全:阿里云AI应用开发一面(2026-03-20)·面经深度解析
前端·人工智能·阿里云·ai·智能体·流式打印
l1t2 小时前
DeepSeek 辅助编写python程序求解欧拉计划932题:2025数
开发语言·python·欧拉计划