pytorch custom op的简单介绍

python 复制代码
@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, ...)
	...

_torch_custom_op_wrapper是pytorch中用于封装自定义CUDA算子的装饰器,功能是将底层注册到torch.ops的算子(flash_attn::_flash_attn_varlen_forward)绑定到这个python函数。其中指定算子的元信息有:mutates_args=()表示算子不修改输入参数;devices_types="cuda"限定仅支持CUDA设备

把算子注册成pytorch custom op有什么好处
  1. 支持torch.compile
    • 自定义op会被当作一个单独的节点出现在计算图里,而不是被内联展开成大量python / C++ 调用
    • 编译器能正确识别、保留这个op,并在图优化时正确处理,而不是当成未知黑盒导致报错或回退到eager
    • 对FlashAttention这类复杂CUDA内核,注册成custom op是让torch.compile稳定工作的常见做法
  2. 支持torch.export / TorchDynamo
    • 导出模型时,需要知道每个op的输入输出、设备类型、是否修改输入等
    • custom_op的schema、mutages_args等会提供这些元信息,便于torch.export生成正确的导出图
  3. FakerTensor / 元设备执行
    • register_fake提供在meta device上的实现,可以不真正分配GPU显存
    • 常用于:
      • 形状推导(shape inference)
      • torch.compile的tracing阶段(不需要真实数据)
      • 某些分布式/大模型规划逻辑
  4. 统一到Pytorch的op系统
    • 通过torch.ops.<namespace>.<op_name>调用,和其他内置op用法一致
    • 便于做op级别的分析、profiling、调度等
  5. 对autograd的可控性
    • 可以显式声明mutates_args(例如哪些输入会被in-place修改)
    • 配合自定义backward,让autograd正确传播梯度、正确处理需要梯度的输入

把FlashAttention注册成PyTorch custom op后,FlashAttention就是一整块融合的CUDA内核,作为整体暴露给PyTorch。对于torch.compile来说,这就是一个原子操作,要么作为单个节点被正确处理,要么只能当做黑盒触发graph break或报错

FlashAttention中的两种调用区别
python 复制代码
if torch.__version__ >= "2.4.0":
    _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward
else:
    _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward

这两段代码在算子本体(CUDA调用)上是一样的,差别在于在PyTorch中怎么暴露、怎么和torch.compile等配合。

  1. torch.__version__>=2.4.0时:用@_torch_custom_op_wrapper把函数注册为Pytorch custom op,_wrapped_flash_attn_varlen_forward则指向torch.ops里注册好的那个op,和custom_op / register_fake 走的是同一套机制,和torch.compile / torch.export / FakeTensor等兼容
  2. torch.__version__<2.4.0时:没有通过torch.library.custom_op挂到torch.ops上,对2.4之后那套编译/图捕获的支持不如新路径
相关推荐
科技小花20 小时前
全球化深水区,数据治理成为企业出海 “核心竞争力”
大数据·数据库·人工智能·数据治理·数据中台·全球化
X56611 天前
如何在 Laravel 中正确保存嵌套动态表单数据(主服务与子服务)
jvm·数据库·python
zhuiyisuifeng1 天前
2026前瞻:GPTimage2镜像官网或将颠覆视觉创作
人工智能·gpt
徐健峰1 天前
GPT-image-2 热门玩法实战(一):AI 看手相 — 一张手掌照片生成专业手相分析图
人工智能·gpt
weixin_370976351 天前
AI的终极赛跑:进入AGI,还是泡沫破灭?
大数据·人工智能·agi
Slow菜鸟1 天前
AI学习篇(五) | awesome-design-md 使用说明
人工智能·学习
ZhengEnCi1 天前
03ab-PyTorch安装教程 📚
python
冬奇Lab1 天前
RAG 系列(五):Embedding 模型——语义理解的核心
人工智能·llm·aigc
深小乐1 天前
AI 周刊【2026.04.27-05.03】:Anthropic 9000亿美元估值、英伟达死磕智能体、中央重磅定调AI
人工智能
码点滴1 天前
什么时候用 DeepSeek V4,而不是 GPT-5/Claude/Gemini?
人工智能·gpt·架构·大模型·deepseek