TorchAir:PyTorch 跑在昇腾NPU上的桥梁

PyTorch 模型能不能直接在昇腾上跑?直接回答是不行------PyTorch 内部调用的 CUDA 算子没法在 NPU 上执行。但间接方案是有的:TorchAir 是 CANN 社区维护的 PyTorch 适配层,让 PyTorch 的训练和推理代码在昇腾上透明运行。


TorchAir 的定位

TorchAir 不是 PyTorch 的分支------它不改 PyTorch 的源码。它是一组 PyTorch 的插件和扩展,通过 PyTorch 的 torch.nn.Module 钩子机制和自定义 backend 接口,把 PyTorch 算子调用路由到 CANN 的算子实现上。

复制代码
PyTorch 代码(torch.nn.Linear, torch.matmul, torch.nn.Dropout)
  ↓
TorchAir 适配层:
  ┌────────────────────────────┐
  │ 算子映射:CUDA → CANN      │
  │ 内存管理:cudaMalloc → aclrtMalloc │
  │ 通信后端:NCCL → HCCL      │
  │ 图优化:TorchScript → GE   │
  └────────────────────────────┘
  ↓
CANN Runtime → 昇腾 NPU

用户代码不需要感知底层切换。


PyTorch 如何运行在昇腾上

从用户视角看,从 GPU 切到昇腾就是改两行代码:

python 复制代码
# GPU 训练代码
import torch
model = model.cuda()
input = input.cuda()

# 昇腾训练代码
import torch
import torch_npu    # TorchAir 的入口
model = model.npu()
input = input.npu()

背后的切换是 TorchAir 完成的。它在 PyTorch 的 backend 注册表中注册了 npu 设备类型,并把 torch.nn.functional 中的算子映射到 CANN 实现。

model.npu() 调用的过程中,TorchAir 做了:

  1. 每层参数从 CPU 内存搬到 NPU 显存(通过 aclrtMalloc
  2. 每层的 forward 函数注册了 CANN 算子实现
  3. 层的运行设备标记为 npu------后续的 input.to(device) 都调到 CANN 的显存分配

CANN 如何承接底层执行

TorchAir 的算子映射表把 PyTorch 算子一对一或一对多映射到 CANN 算子:

复制代码
PyTorch 算子                → CANN 实现
──────────────────────────────────────────────────
torch.nn.Linear.forward    → ops-blas 的 GEMM
torch.nn.Conv2d.forward    → ops-nn 的 Conv
torch.nn.Dropout.forward   → ops-rand 的 Bernoulli
torch.matmul               → ops-blas 的 GEMM
torch.nn.Softmax           → ops-nn 的 Softmax
torch.nn.LayerNorm         → ops-nn 的 LayerNorm

映射过程由 TorchAir 在 model 加载时完成------遍历模型的子模块,为每个 nn.Module 替换 forward 函数。

训练场景中反向传播的算子映射同理------TorchAir 会在每个前向算子的 Autograd 函数中注册对应的反向算子。


Graph Engine 如何优化 PyTorch 图

TorchAir 还支持 TorchScript 的图优化路径。当用户启用 torch.jit.tracetorch.jit.script 时,TorchAir 把 TorchScript 的计算图传给 CANN 的 GE 做图优化:

python 复制代码
model_script = torch.jit.script(model)
# TorchAir 内部调用 GE 的优化接口
# GE 做算子融合、内存优化、Layout 转换

GE 对 TorchScript 图的优化跟 ONNX 路径类似------模式匹配 Attention 和 FFN 子图、融合连续算子、预分配内存。GE 优化后的执行计划直接传给 Runtime,不经过 ONNX 中间格式。


Transformer 推理中的适配

使用 TorchAir + CANN 部署 LLaMA-7B 推理时,关键适配点:

Attention 实现替换。 PyTorch 的 torch.nn.MultiheadAttention 是一个逐算子实现的 Attention------Q@K.T → Softmax → S@V。TorchAir 把这条路径替换为 CANN 的 FlashAttention 融合算子------GE 在编译时识别 Attention 子图后替换。不需要用户改代码。

KV Cache 管理。 TorchAir 把 PyTorch 推理循环中 KV Cache 的 Python 列表管理替换为 CANN Runtime 的 Paged Cache 管理。Paged Cache 的显存分配由 Runtime 控制,不经过 Python 内存管理器。

HCCL 后端。 分布式场景中 torch.distributed 的 backend 从 nccl 换成 hccl------TorchAir 在 dist.init_process_group(backend="hccl") 时自动初始化 HCCL 通信域。

TorchAir 的精度对齐

PyTorch 在 GPU 和昇腾上的推理精度可能不完全一致------主要原因是不同硬件上 FP16 的累加顺序不同、算子实现细节(如 Softmax 的指数近似)略有差异。

TorchAir 在算子映射时提供了精度对齐模式------对精度敏感的算子走 FP32 计算路径,不敏感的走 FP16。开发者可以通过环境变量控制:

python 复制代码
torch.npu.set_precision_mode("high")  # FP32 计算 + FP16 通信
torch.npu.set_precision_mode("mixed") # FP16 计算 + FP16 通信

high 模式在 Attention 和 LayerNorm 等精度敏感算子上用 FP32 累加器。

迁移注意事项

从 GPU 切到昇腾时需要注意的差异:

  • torch.cuda.*torch.npu.*:API 名称不同但功能相同
  • model.cuda()model.npu():设备标记不同
  • DataLoader 的 pin_memory 在昇腾上默认不开启------需要显式设置
  • torch.distributed 的 backend 从 ncclhccl

参考仓库

TorchAir 适配层

CANN GE 图执行引擎

相关推荐
Lucky_ldy2 小时前
C语言学习:编译和链接
学习
chimchim662 小时前
Azure ADF(Azure Data Factory 数据工厂)学习
学习·microsoft·azure
小新同学^O^2 小时前
简单学习 --> Transformer架构
学习·架构·transformer
他们叫我阿冠2 小时前
Docker的基础学习
学习·docker·容器
辰海Coding11 小时前
MiniSpring框架学习笔记-解决循环依赖的简化IoC容器
笔记·学习
晓梦林11 小时前
cp520靶场学习笔记
android·笔记·学习
心中有国也有家13 小时前
cann-recipes-infer:昇腾 NPU 推理的“菜谱集合”
经验分享·笔记·学习·算法
Upsy-Daisy13 小时前
AI Agent 项目学习笔记(八):Tool Calling 工具调用机制总览
人工智能·笔记·学习
LuminousCPP14 小时前
数据结构 - 线性表第四篇:C 语言通讯录优化升级全记录(踩坑 + 思考)
c语言·开发语言·数据结构·经验分享·笔记·学习