pytorch custom op的简单介绍

python 复制代码
@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, ...)
	...

_torch_custom_op_wrapper是pytorch中用于封装自定义CUDA算子的装饰器,功能是将底层注册到torch.ops的算子(flash_attn::_flash_attn_varlen_forward)绑定到这个python函数。其中指定算子的元信息有:mutates_args=()表示算子不修改输入参数;devices_types="cuda"限定仅支持CUDA设备

把算子注册成pytorch custom op有什么好处
  1. 支持torch.compile
    • 自定义op会被当作一个单独的节点出现在计算图里,而不是被内联展开成大量python / C++ 调用
    • 编译器能正确识别、保留这个op,并在图优化时正确处理,而不是当成未知黑盒导致报错或回退到eager
    • 对FlashAttention这类复杂CUDA内核,注册成custom op是让torch.compile稳定工作的常见做法
  2. 支持torch.export / TorchDynamo
    • 导出模型时,需要知道每个op的输入输出、设备类型、是否修改输入等
    • custom_op的schema、mutages_args等会提供这些元信息,便于torch.export生成正确的导出图
  3. FakerTensor / 元设备执行
    • register_fake提供在meta device上的实现,可以不真正分配GPU显存
    • 常用于:
      • 形状推导(shape inference)
      • torch.compile的tracing阶段(不需要真实数据)
      • 某些分布式/大模型规划逻辑
  4. 统一到Pytorch的op系统
    • 通过torch.ops.<namespace>.<op_name>调用,和其他内置op用法一致
    • 便于做op级别的分析、profiling、调度等
  5. 对autograd的可控性
    • 可以显式声明mutates_args(例如哪些输入会被in-place修改)
    • 配合自定义backward,让autograd正确传播梯度、正确处理需要梯度的输入

把FlashAttention注册成PyTorch custom op后,FlashAttention就是一整块融合的CUDA内核,作为整体暴露给PyTorch。对于torch.compile来说,这就是一个原子操作,要么作为单个节点被正确处理,要么只能当做黑盒触发graph break或报错

FlashAttention中的两种调用区别
python 复制代码
if torch.__version__ >= "2.4.0":
    _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward
else:
    _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward

这两段代码在算子本体(CUDA调用)上是一样的,差别在于在PyTorch中怎么暴露、怎么和torch.compile等配合。

  1. torch.__version__>=2.4.0时:用@_torch_custom_op_wrapper把函数注册为Pytorch custom op,_wrapped_flash_attn_varlen_forward则指向torch.ops里注册好的那个op,和custom_op / register_fake 走的是同一套机制,和torch.compile / torch.export / FakeTensor等兼容
  2. torch.__version__<2.4.0时:没有通过torch.library.custom_op挂到torch.ops上,对2.4之后那套编译/图捕获的支持不如新路径
相关推荐
我登哥MVP2 分钟前
VS Code 安装 Claude Code 并接入 DeepSeek V4 Model
人工智能·python·node.js·agent·codex·deepseek·claude code
unique3 分钟前
AI Native 调研报告
人工智能
云烟成雨TD3 分钟前
Spring AI Alibaba 1.x 系列【73】两步 RAG
java·人工智能·spring
ai产品老杨5 分钟前
解耦视频高并发与边缘计算AI布控:基于Docker的高性能安防平台,破局GB28181/RTSP协议兼容与源码交付痛点
人工智能·音视频·边缘计算
CHrisFC6 分钟前
LIMS 系统 AI 建设路径:从自动化到智能化的演进之路
运维·人工智能·自动化
饼干哥哥7 分钟前
一口气搭了300个AI Agents并发处理跨境运营的dirty work
人工智能
AI行业学习7 分钟前
CC‑Switch v3.16.1-下载、配置、安装(2026‑06‑01 最新官方版)
开发语言·人工智能·windows·python
小糖学代码8 分钟前
机器学习:5.深度学习
人工智能·深度学习·机器学习
unity工具人8 分钟前
python+yolov8 图像识别-测试案例
python·opencv·yolo
lipku9 分钟前
LiveTalking 更新:集成 vLLM-Omni TTS服务
python·开源·数字人·vllm·实时数字人