pytorch custom op的简单介绍

python 复制代码
@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, ...)
	...

_torch_custom_op_wrapper是pytorch中用于封装自定义CUDA算子的装饰器,功能是将底层注册到torch.ops的算子(flash_attn::_flash_attn_varlen_forward)绑定到这个python函数。其中指定算子的元信息有:mutates_args=()表示算子不修改输入参数;devices_types="cuda"限定仅支持CUDA设备

把算子注册成pytorch custom op有什么好处
  1. 支持torch.compile
    • 自定义op会被当作一个单独的节点出现在计算图里,而不是被内联展开成大量python / C++ 调用
    • 编译器能正确识别、保留这个op,并在图优化时正确处理,而不是当成未知黑盒导致报错或回退到eager
    • 对FlashAttention这类复杂CUDA内核,注册成custom op是让torch.compile稳定工作的常见做法
  2. 支持torch.export / TorchDynamo
    • 导出模型时,需要知道每个op的输入输出、设备类型、是否修改输入等
    • custom_op的schema、mutages_args等会提供这些元信息,便于torch.export生成正确的导出图
  3. FakerTensor / 元设备执行
    • register_fake提供在meta device上的实现,可以不真正分配GPU显存
    • 常用于:
      • 形状推导(shape inference)
      • torch.compile的tracing阶段(不需要真实数据)
      • 某些分布式/大模型规划逻辑
  4. 统一到Pytorch的op系统
    • 通过torch.ops.<namespace>.<op_name>调用,和其他内置op用法一致
    • 便于做op级别的分析、profiling、调度等
  5. 对autograd的可控性
    • 可以显式声明mutates_args(例如哪些输入会被in-place修改)
    • 配合自定义backward,让autograd正确传播梯度、正确处理需要梯度的输入

把FlashAttention注册成PyTorch custom op后,FlashAttention就是一整块融合的CUDA内核,作为整体暴露给PyTorch。对于torch.compile来说,这就是一个原子操作,要么作为单个节点被正确处理,要么只能当做黑盒触发graph break或报错

FlashAttention中的两种调用区别
python 复制代码
if torch.__version__ >= "2.4.0":
    _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward
else:
    _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward

这两段代码在算子本体(CUDA调用)上是一样的,差别在于在PyTorch中怎么暴露、怎么和torch.compile等配合。

  1. torch.__version__>=2.4.0时:用@_torch_custom_op_wrapper把函数注册为Pytorch custom op,_wrapped_flash_attn_varlen_forward则指向torch.ops里注册好的那个op,和custom_op / register_fake 走的是同一套机制,和torch.compile / torch.export / FakeTensor等兼容
  2. torch.__version__<2.4.0时:没有通过torch.library.custom_op挂到torch.ops上,对2.4之后那套编译/图捕获的支持不如新路径
相关推荐
Ulyanov1 分钟前
打造现代化雷达电子对抗仿真界面 第二篇:雷达电子对抗仿真系统核心功能实现
前端·python·信息可视化·数据可视化·系统仿真·雷达电子战
财经资讯数据_灵砚智能4 分钟前
基于全球经济类多源新闻的NLP情感分析与数据可视化(夜间-次晨)2026年4月12日
人工智能·python·信息可视化·自然语言处理·ai编程
β添砖java4 分钟前
从函数到神经网络【AI入门01】(b站飞天闪客~~
人工智能
永霖光电_UVLED5 分钟前
宽带圆偏振光(CPL)探测器的技术归纳、以及对未来应用
人工智能·生成对抗网络·汽车·娱乐·激光
二等饼干~za89866812 分钟前
云罗 GEO 优化系统源码厂家测评报告
大数据·网络·数据库·人工智能·django
天地沧海12 分钟前
AI测试用例检查
人工智能
GISer_Jing15 分钟前
前端视频多模态:编解码、传输、渲染全链路详解
前端·人工智能·音视频
乔公子搬砖17 分钟前
告别识别率焦虑:视频 AI 工程化实战 —— 检测→判定→聚合→治理全链路拆解
人工智能·yolo·决策树·计算机视觉·视觉检测
视觉&物联智能18 分钟前
【杂谈】-人工智能疲劳是真实存在的,但它并非你想象的那样
人工智能·ai·chatgpt·agi·deepseek
GlobalInfo20 分钟前
工业控制类芯片市场份额、市场占有率、行业调研报告2026
大数据·人工智能·物联网