PyTorch 模型能不能直接在昇腾上跑?直接回答是不行------PyTorch 内部调用的 CUDA 算子没法在 NPU 上执行。但间接方案是有的:TorchAir 是 CANN 社区维护的 PyTorch 适配层,让 PyTorch 的训练和推理代码在昇腾上透明运行。
TorchAir 的定位
TorchAir 不是 PyTorch 的分支------它不改 PyTorch 的源码。它是一组 PyTorch 的插件和扩展,通过 PyTorch 的 torch.nn.Module 钩子机制和自定义 backend 接口,把 PyTorch 算子调用路由到 CANN 的算子实现上。
PyTorch 代码(torch.nn.Linear, torch.matmul, torch.nn.Dropout)
↓
TorchAir 适配层:
┌────────────────────────────┐
│ 算子映射:CUDA → CANN │
│ 内存管理:cudaMalloc → aclrtMalloc │
│ 通信后端:NCCL → HCCL │
│ 图优化:TorchScript → GE │
└────────────────────────────┘
↓
CANN Runtime → 昇腾 NPU
用户代码不需要感知底层切换。
PyTorch 如何运行在昇腾上
从用户视角看,从 GPU 切到昇腾就是改两行代码:
python
# GPU 训练代码
import torch
model = model.cuda()
input = input.cuda()
# 昇腾训练代码
import torch
import torch_npu # TorchAir 的入口
model = model.npu()
input = input.npu()
背后的切换是 TorchAir 完成的。它在 PyTorch 的 backend 注册表中注册了 npu 设备类型,并把 torch.nn.functional 中的算子映射到 CANN 实现。
model.npu() 调用的过程中,TorchAir 做了:
- 每层参数从 CPU 内存搬到 NPU 显存(通过
aclrtMalloc) - 每层的
forward函数注册了 CANN 算子实现 - 层的运行设备标记为
npu------后续的input.to(device)都调到 CANN 的显存分配
CANN 如何承接底层执行
TorchAir 的算子映射表把 PyTorch 算子一对一或一对多映射到 CANN 算子:
PyTorch 算子 → CANN 实现
──────────────────────────────────────────────────
torch.nn.Linear.forward → ops-blas 的 GEMM
torch.nn.Conv2d.forward → ops-nn 的 Conv
torch.nn.Dropout.forward → ops-rand 的 Bernoulli
torch.matmul → ops-blas 的 GEMM
torch.nn.Softmax → ops-nn 的 Softmax
torch.nn.LayerNorm → ops-nn 的 LayerNorm
映射过程由 TorchAir 在 model 加载时完成------遍历模型的子模块,为每个 nn.Module 替换 forward 函数。
训练场景中反向传播的算子映射同理------TorchAir 会在每个前向算子的 Autograd 函数中注册对应的反向算子。
Graph Engine 如何优化 PyTorch 图
TorchAir 还支持 TorchScript 的图优化路径。当用户启用 torch.jit.trace 或 torch.jit.script 时,TorchAir 把 TorchScript 的计算图传给 CANN 的 GE 做图优化:
python
model_script = torch.jit.script(model)
# TorchAir 内部调用 GE 的优化接口
# GE 做算子融合、内存优化、Layout 转换
GE 对 TorchScript 图的优化跟 ONNX 路径类似------模式匹配 Attention 和 FFN 子图、融合连续算子、预分配内存。GE 优化后的执行计划直接传给 Runtime,不经过 ONNX 中间格式。
Transformer 推理中的适配
使用 TorchAir + CANN 部署 LLaMA-7B 推理时,关键适配点:
Attention 实现替换。 PyTorch 的 torch.nn.MultiheadAttention 是一个逐算子实现的 Attention------Q@K.T → Softmax → S@V。TorchAir 把这条路径替换为 CANN 的 FlashAttention 融合算子------GE 在编译时识别 Attention 子图后替换。不需要用户改代码。
KV Cache 管理。 TorchAir 把 PyTorch 推理循环中 KV Cache 的 Python 列表管理替换为 CANN Runtime 的 Paged Cache 管理。Paged Cache 的显存分配由 Runtime 控制,不经过 Python 内存管理器。
HCCL 后端。 分布式场景中 torch.distributed 的 backend 从 nccl 换成 hccl------TorchAir 在 dist.init_process_group(backend="hccl") 时自动初始化 HCCL 通信域。
TorchAir 的精度对齐
PyTorch 在 GPU 和昇腾上的推理精度可能不完全一致------主要原因是不同硬件上 FP16 的累加顺序不同、算子实现细节(如 Softmax 的指数近似)略有差异。
TorchAir 在算子映射时提供了精度对齐模式------对精度敏感的算子走 FP32 计算路径,不敏感的走 FP16。开发者可以通过环境变量控制:
python
torch.npu.set_precision_mode("high") # FP32 计算 + FP16 通信
torch.npu.set_precision_mode("mixed") # FP16 计算 + FP16 通信
high 模式在 Attention 和 LayerNorm 等精度敏感算子上用 FP32 累加器。
迁移注意事项
从 GPU 切到昇腾时需要注意的差异:
torch.cuda.*→torch.npu.*:API 名称不同但功能相同model.cuda()→model.npu():设备标记不同- DataLoader 的
pin_memory在昇腾上默认不开启------需要显式设置 torch.distributed的 backend 从nccl换hccl