python
@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, ...)
...
_torch_custom_op_wrapper是pytorch中用于封装自定义CUDA算子的装饰器,功能是将底层注册到torch.ops的算子(flash_attn::_flash_attn_varlen_forward)绑定到这个python函数。其中指定算子的元信息有:mutates_args=()表示算子不修改输入参数;devices_types="cuda"限定仅支持CUDA设备
把算子注册成pytorch custom op有什么好处
- 支持
torch.compile- 自定义op会被当作一个单独的节点出现在计算图里,而不是被内联展开成大量python / C++ 调用
- 编译器能正确识别、保留这个op,并在图优化时正确处理,而不是当成未知黑盒导致报错或回退到eager
- 对FlashAttention这类复杂CUDA内核,注册成custom op是让
torch.compile稳定工作的常见做法
- 支持
torch.export/ TorchDynamo- 导出模型时,需要知道每个op的输入输出、设备类型、是否修改输入等
custom_op的schema、mutages_args等会提供这些元信息,便于torch.export生成正确的导出图
- FakerTensor / 元设备执行
- 用
register_fake提供在meta device上的实现,可以不真正分配GPU显存 - 常用于:
- 形状推导(shape inference)
torch.compile的tracing阶段(不需要真实数据)- 某些分布式/大模型规划逻辑
- 用
- 统一到Pytorch的op系统
- 通过
torch.ops.<namespace>.<op_name>调用,和其他内置op用法一致 - 便于做op级别的分析、profiling、调度等
- 通过
- 对autograd的可控性
- 可以显式声明
mutates_args(例如哪些输入会被in-place修改) - 配合自定义backward,让autograd正确传播梯度、正确处理需要梯度的输入
- 可以显式声明
把FlashAttention注册成PyTorch custom op后,FlashAttention就是一整块融合的CUDA内核,作为整体暴露给PyTorch。对于torch.compile来说,这就是一个原子操作,要么作为单个节点被正确处理,要么只能当做黑盒触发graph break或报错
FlashAttention中的两种调用区别
python
if torch.__version__ >= "2.4.0":
_wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward
else:
_wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward
这两段代码在算子本体(CUDA调用)上是一样的,差别在于在PyTorch中怎么暴露、怎么和torch.compile等配合。
torch.__version__>=2.4.0时:用@_torch_custom_op_wrapper把函数注册为Pytorch custom op,_wrapped_flash_attn_varlen_forward则指向torch.ops里注册好的那个op,和custom_op/register_fake走的是同一套机制,和torch.compile/torch.export/ FakeTensor等兼容torch.__version__<2.4.0时:没有通过torch.library.custom_op挂到torch.ops上,对2.4之后那套编译/图捕获的支持不如新路径