PyTorch 2.x 用"编译器化体系(torch.compile)",Triton 是其中最重要的 kernel 生成方式之一,而不是唯一或默认替代 CUDA kernel。
结构图(重点)
text
PyTorch Model
↓
torch.compile()
↓
TorchInductor (Compiler)
↓ ↓
CUDA kernel Triton kernel
↓ ↓
└── PTX ─┘
↓
GPU
1、 PyTorch 2.x 到底发生了什么变化?
PyTorch 在 2.x 引入了:
torch.compile()(核心变革)
结构变成:
text
Eager Mode (PyTorch 1.x)
↓
Graph Capture (TorchDynamo)
↓
Graph Optimization (AOTAutograd)
↓
Backend (TorchInductor)
↓
GPU Code
PyTorch 从"解释执行" → "编译执行"
Triton 在 PyTorch 2.x 里的位置
Triton 主要在:
text
TorchInductor
↓
Kernel generation
↓
Triton / CUDA / CPU backend
👉 Triton 只是 GPU kernel生成器之一
2、PyTorch 2.x 目的包含:自动帮你生成 fused kernel
因为 PyTorch 2.x 做了一件大事:
自动帮你生成 fused kernel
以前:
text
PyTorch op1 → CUDA kernel
PyTorch op2 → CUDA kernel
PyTorch op3 → CUDA kernel
现在:
text
op1 + op2 + op3
↓
fused kernel
↓
Triton or CUDA kernel
👉 变化本质:
| 旧时代 | 新时代 |
|---|---|
| 手写 CUDA kernel | 自动生成 kernel |
| kernel 很多 | kernel fusion |
| 手动优化 | 编译器优化 |
3、 Triton 更简单
① 写 CUDA kernel 太难
CUDA kernel:
cpp
__global__ void kernel(...)
问题:
- warp / block / memory 太复杂
- 写错很难调
- 维护成本高
② Triton 更适合"AI kernel模式"
大模型 kernel特点:
- matrix-heavy
- memory bound
- pattern重复(attention / MLP)
Triton:
python
@triton.jit
def kernel(...):
更像 Python → 更适合研究员写
③ PyTorch 2.x 的目标不是"写 kernel",而是"自动生成 kernel"
核心理念:
用户写模型,系统决定 kernel
(正文完毕)