在DCU上把PyTorch跑起来,我踩过的5个坑和一份踩坑指南
如果你也在国产加速卡上折腾过PyTorch,一定经历过
torch.cuda.is_available()返回False时的心跳骤停。这篇文章不讲虚的,从环境部署到分布式训练,把关键点串一遍。
1. PyTorch的底子:它凭什么成了主流
PyTorch最早是Facebook从Torch7(Lua语言那版)改过来的,核心思路很简单------用Python包一层,保持易用性,底层靠C++/CUDA扛性能。目前GitHub上有97k+ Star。
跟它同期竞争的框架不少,但活到今天且还活跃的没几个了:
| 框架 | Star数 | 维护方 | 现状 |
|---|---|---|---|
| PyTorch | 97.3k | Meta | 动态图起家,2.x后通过torch.compile补齐图优化短板 |
| TensorFlow | 194k | 2.x默认Eager模式,静态图靠tf.function,部署生态强 | |
| Caffe | 34.8k | BVLC | 推理性能不错,但社区基本凉了 |
| MxNet | 20.8k | DMLC/Amazon | 已基本停止发展 |
| PaddlePaddle | 23.6k | 百度 | 面向工业+国产生态,动静混合模式 |
PyTorch能活成今天这样,核心就三块东西,缺一不可:
#mermaid-svg-NfJWWRmVBpl6sI2Q{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-NfJWWRmVBpl6sI2Q .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-NfJWWRmVBpl6sI2Q .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-NfJWWRmVBpl6sI2Q .error-icon{fill:#552222;}#mermaid-svg-NfJWWRmVBpl6sI2Q .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-NfJWWRmVBpl6sI2Q .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-NfJWWRmVBpl6sI2Q .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-NfJWWRmVBpl6sI2Q .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-NfJWWRmVBpl6sI2Q .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-NfJWWRmVBpl6sI2Q .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-NfJWWRmVBpl6sI2Q .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-NfJWWRmVBpl6sI2Q .marker{fill:#333333;stroke:#333333;}#mermaid-svg-NfJWWRmVBpl6sI2Q .marker.cross{stroke:#333333;}#mermaid-svg-NfJWWRmVBpl6sI2Q svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-NfJWWRmVBpl6sI2Q p{margin:0;}#mermaid-svg-NfJWWRmVBpl6sI2Q .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-NfJWWRmVBpl6sI2Q .cluster-label text{fill:#333;}#mermaid-svg-NfJWWRmVBpl6sI2Q .cluster-label span{color:#333;}#mermaid-svg-NfJWWRmVBpl6sI2Q .cluster-label span p{background-color:transparent;}#mermaid-svg-NfJWWRmVBpl6sI2Q .label text,#mermaid-svg-NfJWWRmVBpl6sI2Q span{fill:#333;color:#333;}#mermaid-svg-NfJWWRmVBpl6sI2Q .node rect,#mermaid-svg-NfJWWRmVBpl6sI2Q .node circle,#mermaid-svg-NfJWWRmVBpl6sI2Q .node ellipse,#mermaid-svg-NfJWWRmVBpl6sI2Q .node polygon,#mermaid-svg-NfJWWRmVBpl6sI2Q .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-NfJWWRmVBpl6sI2Q .rough-node .label text,#mermaid-svg-NfJWWRmVBpl6sI2Q .node .label text,#mermaid-svg-NfJWWRmVBpl6sI2Q .image-shape .label,#mermaid-svg-NfJWWRmVBpl6sI2Q .icon-shape .label{text-anchor:middle;}#mermaid-svg-NfJWWRmVBpl6sI2Q .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-NfJWWRmVBpl6sI2Q .rough-node .label,#mermaid-svg-NfJWWRmVBpl6sI2Q .node .label,#mermaid-svg-NfJWWRmVBpl6sI2Q .image-shape .label,#mermaid-svg-NfJWWRmVBpl6sI2Q .icon-shape .label{text-align:center;}#mermaid-svg-NfJWWRmVBpl6sI2Q .node.clickable{cursor:pointer;}#mermaid-svg-NfJWWRmVBpl6sI2Q .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-NfJWWRmVBpl6sI2Q .arrowheadPath{fill:#333333;}#mermaid-svg-NfJWWRmVBpl6sI2Q .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-NfJWWRmVBpl6sI2Q .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-NfJWWRmVBpl6sI2Q .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-NfJWWRmVBpl6sI2Q .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-NfJWWRmVBpl6sI2Q .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-NfJWWRmVBpl6sI2Q .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-NfJWWRmVBpl6sI2Q .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-NfJWWRmVBpl6sI2Q .cluster text{fill:#333;}#mermaid-svg-NfJWWRmVBpl6sI2Q .cluster span{color:#333;}#mermaid-svg-NfJWWRmVBpl6sI2Q div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-NfJWWRmVBpl6sI2Q .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-NfJWWRmVBpl6sI2Q rect.text{fill:none;stroke-width:0;}#mermaid-svg-NfJWWRmVBpl6sI2Q .icon-shape,#mermaid-svg-NfJWWRmVBpl6sI2Q .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-NfJWWRmVBpl6sI2Q .icon-shape p,#mermaid-svg-NfJWWRmVBpl6sI2Q .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-NfJWWRmVBpl6sI2Q .icon-shape .label rect,#mermaid-svg-NfJWWRmVBpl6sI2Q .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-NfJWWRmVBpl6sI2Q .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-NfJWWRmVBpl6sI2Q .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-NfJWWRmVBpl6sI2Q :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 张量计算引擎(Tensor)
类似NumPy的Tensor操作
CPU + GPU高效实现
自动求导机制(Autograd)
autograd.Variable
自动微分
免除手动计算导数
神经网络高级接口(NN)
nn.Module
定义和运行神经网络
模块化接口
说白了:Tensor负责干活,Autograd负责自动算梯度,NN负责把网络搭起来。三层配合,写模型就像搭积木。
2. DCU上的PyTorch:安装不是pip install torch就完事的
在国产加速卡(DCU)上装PyTorch,跟NVIDIA GPU上最大的区别是:你不能直接用官方PyTorch。官方PyTorch绑的是CUDA后端,DCU走的是另一套路线------基于ROCm生态的HIP接口。
安装有四种路子:
| 方式 | 适用场景 | 要点 |
|---|---|---|
| pip安装 | 日常开发首选 | 从DAS源或下载whl包安装,配置镜像源加速 |
| Anaconda | 多环境隔离 | 创建独立环境,注意venv路径 |
| 源码编译 | 定制化需求 | 需设置环境变量、安装依赖、运行编译脚本 |
| Docker | 快速复现 / CI | 加载镜像后直接跑,环境零污染 |
2.1 pip安装:最常用的方式
公网源地址:http://pypi.sourcefind.cn:666/source/packages,在DAS PyPI上选择对应软件栈版本的仓库和torch包,直接复制安装命令即可。
bash
# 示例(版本号已脱敏处理)
pip install torch==x.x.x+das \
-i http://pypi.sourcefind.cn:666/release/dtk-xxx-rc1/+simple/ \
--trusted-host pypi.sourcefind.cn
⚠️ 坑1 :不同DTK版本的包不通用,选错版本装完会直接报符号未定义。务必确认das后缀和DTK版本匹配。
2.2 环境验证:三行命令确认是否就绪
bash
# 1. 加载DTK环境(物理机和部分容器需手动执行)
source /opt/dtk/env.sh
# 2. 查看加速卡信息
rocminfo | grep gfx
# 3. 查看卡状态(类似nvidia-smi)
hy-smi
然后用Python验证PyTorch是否认到卡:
python
import torch
print(torch.cuda.is_available()) # 应为 True
print(torch.__version__) # 如 x.x.x+das
print(torch.version.__hcu_version__) # 带das后缀的HCU版本号
⚠️ 坑2 :source /opt/dtk/env.sh这一步太容易忘。容器里有时自动加载了,物理机必手动执行。忘了的话hy-smi能用但PyTorch认不到卡------因为动态库路径没设对。
2.3 三方库支持:这些常用库都能跑
DCU生态对主流PyTorch三方库的移植情况相当不错:
| 项目 | 软件包 | 支持状态 |
|---|---|---|
| 计算机视觉 | TorchVision | ROCm官方已支持 |
| 3D视觉 | pytorch3d | 已完成移植 |
| 混精计算 | APEX | 部分移植(不含BNP模块),ROCm官方已支持 |
| 分布式训练优化 | DeepSpeed | 已完成移植 |
| 推理优化 | torch-ni | 已完成移植 |
| 音频处理 | torchaudio | 已完成移植 |
| 检测/分割/关键点 | MMCV全家桶 | 已完成移植,ROCm官方已支持 |
| 目标检测 | Detectron2 | 已完成移植 |
| 分布式MoE | fastmoe | 已完成移植 |
3. torch.compile:一行代码让模型快一截
PyTorch 2.x最大的亮点就是torch.compile。它的思路是:你照常写Eager模式的代码,它在底层帮你做JIT编译优化,不动你一行业务逻辑。
#mermaid-svg-HP5T60MyCnXwv959{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-HP5T60MyCnXwv959 .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-HP5T60MyCnXwv959 .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-HP5T60MyCnXwv959 .error-icon{fill:#552222;}#mermaid-svg-HP5T60MyCnXwv959 .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-HP5T60MyCnXwv959 .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-HP5T60MyCnXwv959 .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-HP5T60MyCnXwv959 .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-HP5T60MyCnXwv959 .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-HP5T60MyCnXwv959 .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-HP5T60MyCnXwv959 .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-HP5T60MyCnXwv959 .marker{fill:#333333;stroke:#333333;}#mermaid-svg-HP5T60MyCnXwv959 .marker.cross{stroke:#333333;}#mermaid-svg-HP5T60MyCnXwv959 svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-HP5T60MyCnXwv959 p{margin:0;}#mermaid-svg-HP5T60MyCnXwv959 .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-HP5T60MyCnXwv959 .cluster-label text{fill:#333;}#mermaid-svg-HP5T60MyCnXwv959 .cluster-label span{color:#333;}#mermaid-svg-HP5T60MyCnXwv959 .cluster-label span p{background-color:transparent;}#mermaid-svg-HP5T60MyCnXwv959 .label text,#mermaid-svg-HP5T60MyCnXwv959 span{fill:#333;color:#333;}#mermaid-svg-HP5T60MyCnXwv959 .node rect,#mermaid-svg-HP5T60MyCnXwv959 .node circle,#mermaid-svg-HP5T60MyCnXwv959 .node ellipse,#mermaid-svg-HP5T60MyCnXwv959 .node polygon,#mermaid-svg-HP5T60MyCnXwv959 .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-HP5T60MyCnXwv959 .rough-node .label text,#mermaid-svg-HP5T60MyCnXwv959 .node .label text,#mermaid-svg-HP5T60MyCnXwv959 .image-shape .label,#mermaid-svg-HP5T60MyCnXwv959 .icon-shape .label{text-anchor:middle;}#mermaid-svg-HP5T60MyCnXwv959 .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-HP5T60MyCnXwv959 .rough-node .label,#mermaid-svg-HP5T60MyCnXwv959 .node .label,#mermaid-svg-HP5T60MyCnXwv959 .image-shape .label,#mermaid-svg-HP5T60MyCnXwv959 .icon-shape .label{text-align:center;}#mermaid-svg-HP5T60MyCnXwv959 .node.clickable{cursor:pointer;}#mermaid-svg-HP5T60MyCnXwv959 .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-HP5T60MyCnXwv959 .arrowheadPath{fill:#333333;}#mermaid-svg-HP5T60MyCnXwv959 .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-HP5T60MyCnXwv959 .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-HP5T60MyCnXwv959 .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-HP5T60MyCnXwv959 .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-HP5T60MyCnXwv959 .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-HP5T60MyCnXwv959 .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-HP5T60MyCnXwv959 .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-HP5T60MyCnXwv959 .cluster text{fill:#333;}#mermaid-svg-HP5T60MyCnXwv959 .cluster span{color:#333;}#mermaid-svg-HP5T60MyCnXwv959 div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-HP5T60MyCnXwv959 .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-HP5T60MyCnXwv959 rect.text{fill:none;stroke-width:0;}#mermaid-svg-HP5T60MyCnXwv959 .icon-shape,#mermaid-svg-HP5T60MyCnXwv959 .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-HP5T60MyCnXwv959 .icon-shape p,#mermaid-svg-HP5T60MyCnXwv959 .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-HP5T60MyCnXwv959 .icon-shape .label rect,#mermaid-svg-HP5T60MyCnXwv959 .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-HP5T60MyCnXwv959 .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-HP5T60MyCnXwv959 .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-HP5T60MyCnXwv959 :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 原始模型
Eager模式
torch.compile()
编译优化
算子融合
Conv+ReLU合并为单算子
CUDA Graph提取
多次Kernel Launch
压缩为一次下发
减少显存读写瓶颈
提高计算密度
消除CPU-GPU
通信调度延迟
加速后的模型
使用方式极其简单:
python
# 定义模型
model = MyModel().cuda()
# 一行代码包裹
compiled_model = torch.compile(model)
# 后续训练/推理逻辑完全不变
output = compiled_model(input)
常用参数一览:
| 参数 | 作用 | 推荐场景 |
|---|---|---|
fullgraph=True |
强制生成完整计算图 | 模型结构固定,追求极致性能 |
dynamic=True |
支持动态shape输入 | 输入尺寸不固定的NLP任务 |
mode="max-autotune" |
自动搜索最优kernel配置 | 推理部署场景,编译慢但跑得快 |
mode="reduce-overhead" |
减少CPU调度开销 | 小模型高频推理 |
⚠️ 坑3 :mode="max-autotune"第一次编译极慢(有时候几分钟),但编译完的性能提升对得住等待。如果只是快速验证功能,先用默认模式。
4. 分布式通信:后端怎么选,mpirun怎么配
PyTorch分布式训练的核心通信模块是torch.distributed。三种后端各有利弊:
#mermaid-svg-qOj0R7pnybb8tkLh{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-qOj0R7pnybb8tkLh .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-qOj0R7pnybb8tkLh .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-qOj0R7pnybb8tkLh .error-icon{fill:#552222;}#mermaid-svg-qOj0R7pnybb8tkLh .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-qOj0R7pnybb8tkLh .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-qOj0R7pnybb8tkLh .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-qOj0R7pnybb8tkLh .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-qOj0R7pnybb8tkLh .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-qOj0R7pnybb8tkLh .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-qOj0R7pnybb8tkLh .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-qOj0R7pnybb8tkLh .marker{fill:#333333;stroke:#333333;}#mermaid-svg-qOj0R7pnybb8tkLh .marker.cross{stroke:#333333;}#mermaid-svg-qOj0R7pnybb8tkLh svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-qOj0R7pnybb8tkLh p{margin:0;}#mermaid-svg-qOj0R7pnybb8tkLh .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-qOj0R7pnybb8tkLh .cluster-label text{fill:#333;}#mermaid-svg-qOj0R7pnybb8tkLh .cluster-label span{color:#333;}#mermaid-svg-qOj0R7pnybb8tkLh .cluster-label span p{background-color:transparent;}#mermaid-svg-qOj0R7pnybb8tkLh .label text,#mermaid-svg-qOj0R7pnybb8tkLh span{fill:#333;color:#333;}#mermaid-svg-qOj0R7pnybb8tkLh .node rect,#mermaid-svg-qOj0R7pnybb8tkLh .node circle,#mermaid-svg-qOj0R7pnybb8tkLh .node ellipse,#mermaid-svg-qOj0R7pnybb8tkLh .node polygon,#mermaid-svg-qOj0R7pnybb8tkLh .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-qOj0R7pnybb8tkLh .rough-node .label text,#mermaid-svg-qOj0R7pnybb8tkLh .node .label text,#mermaid-svg-qOj0R7pnybb8tkLh .image-shape .label,#mermaid-svg-qOj0R7pnybb8tkLh .icon-shape .label{text-anchor:middle;}#mermaid-svg-qOj0R7pnybb8tkLh .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-qOj0R7pnybb8tkLh .rough-node .label,#mermaid-svg-qOj0R7pnybb8tkLh .node .label,#mermaid-svg-qOj0R7pnybb8tkLh .image-shape .label,#mermaid-svg-qOj0R7pnybb8tkLh .icon-shape .label{text-align:center;}#mermaid-svg-qOj0R7pnybb8tkLh .node.clickable{cursor:pointer;}#mermaid-svg-qOj0R7pnybb8tkLh .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-qOj0R7pnybb8tkLh .arrowheadPath{fill:#333333;}#mermaid-svg-qOj0R7pnybb8tkLh .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-qOj0R7pnybb8tkLh .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-qOj0R7pnybb8tkLh .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-qOj0R7pnybb8tkLh .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-qOj0R7pnybb8tkLh .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-qOj0R7pnybb8tkLh .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-qOj0R7pnybb8tkLh .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-qOj0R7pnybb8tkLh .cluster text{fill:#333;}#mermaid-svg-qOj0R7pnybb8tkLh .cluster span{color:#333;}#mermaid-svg-qOj0R7pnybb8tkLh div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-qOj0R7pnybb8tkLh .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-qOj0R7pnybb8tkLh rect.text{fill:none;stroke-width:0;}#mermaid-svg-qOj0R7pnybb8tkLh .icon-shape,#mermaid-svg-qOj0R7pnybb8tkLh .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-qOj0R7pnybb8tkLh .icon-shape p,#mermaid-svg-qOj0R7pnybb8tkLh .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-qOj0R7pnybb8tkLh .icon-shape .label rect,#mermaid-svg-qOj0R7pnybb8tkLh .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-qOj0R7pnybb8tkLh .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-qOj0R7pnybb8tkLh .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-qOj0R7pnybb8tkLh :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} torch.distributed
分布式通信
Gloo
RCCL
(DCU上的NCCL等价实现)
MPI
CPU训练
多机通信
GPU训练
需UCX支持
GPU分布式训练
首选
超算集群
Slurm集成
一句话总结选型:GPU分布式训练用RCCL,CPU用Gloo,超算环境用MPI。
初始化代码长这样:
python
import torch.distributed as dist
dist.init_process_group(
backend='nccl', # 后端:nccl / gloo / mpi
init_method='tcp://...', # 初始化方式:tcp / file / env
world_size=4, # 总进程数
rank=0, # 当前进程序号
timeout=timedelta(minutes=30) # 超时,nccl默认10分钟,gloo默认30分钟
)
4.1 mpirun多节点启动
先用Slurm申请资源(关键参数已脱敏处理):
bash
#!/bin/bash
#SBATCH -p <队列名>
#SBATCH -N 2 # 节点数
#SBATCH --ntasks-per-node=8 # 每节点任务数
#SBATCH --cpus-per-task=16 # 每任务CPU核数
#SBATCH --gres=dcu:8 # 每节点加速卡数
#SBATCH --exclusive # 独占节点保证性能
申请到节点后启动分布式任务:
bash
mpirun --allow-run-as-root \
-np 16 \
-H <node1>:8,<node2>:8 \
-x ROCM_PATH \
-x NCCL_SOCKET_IFNAME=ib0 \
-x NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1 \
-x NCCL_NET_GDR_LEVEL=4 \
-x NCCL_MAX_NCHANNELS=16 \
-x LD_LIBRARY_PATH \
/path/to/your/program -b 4096 -e 1G -f 2 -g 1 -d half
关键参数说明:
| 参数 | 含义 |
|---|---|
-np 16 |
启动16个进程 |
-H node1:8,node2:8 |
每节点8个进程 |
-x |
将环境变量传递给远程进程 |
NCCL_SOCKET_IFNAME=ib0 |
指定InfiniBand网卡 |
NCCL_NET_GDR_LEVEL=4 |
启用GPUDirect RDMA |
⚠️ 坑4 :多节点通信时网卡指定错了一个字符就跑不起来。ib0还是ibp0还是别的,先用ibstat确认实际网卡名。另外NCCL_IB_HCA的端口号格式(mlx5_0:1)在不同机器上可能不同,别直接抄。
5. 调优三板斧:numa绑定、DataLoader、Profiler
5.1 基础优化项
**第一步:CPU拓扑绑定。**用lscpu看NUMA节点分布,然后:
bash
numactl --cpunodebind=<节点ID> --membind=<节点ID> <任务启动命令>
跨NUMA节点访问内存延迟高很多,不绑的话大模型训练会有明显的吞吐波动。
第二步:DataLoader参数调优。
python
DataLoader(
dataset,
batch_size=64,
num_workers=4, # 数据加载线程数
pin_memory=True, # 预分配锁页内存
prefetch_factor=4, # 每个worker预加载batch数
persistent_workers=True # worker常驻,避免反复fork
)
第三步:开启cuDNN benchmark。
python
torch.backends.cudnn.benchmark = True
⚠️ 坑5 :cudnn.benchmark=True在输入shape固定时效果明显,但如果每个batch的输入尺寸都在变(比如NLP里padding长度飘忽),开了反而会拖慢------因为每次都要重新搜索最优算法。
5.2 pdb调试:分布式场景下的救命工具
PyTorch分布式出问题时,print大法经常不够用。pdb(Python Debugger)是最轻量的调试方案:
python
# 方式1:代码内设置断点
import pdb; pdb.set_trace()
# 或直接用Python 3.7+的breakpoint()
breakpoint()
# 方式2:命令行启动
python -m pdb code.py
常用命令速查:
| 分类 | 命令 | 作用 |
|---|---|---|
| 执行控制 | n (next) |
单步执行 |
s (step) |
进入函数内部 | |
c (continue) |
执行到下一个断点 | |
r (return) |
执行到当前函数返回 | |
q (quit) |
退出调试 | |
| 信息查看 | p / pp |
打印变量值 |
whatis var |
查看变量类型 | |
w (where) |
当前调用栈 | |
a (argument) |
打印函数参数 | |
| 断点管理 | b <行号> |
设置断点 |
cl <文件:行号> |
清除断点 | |
l (list) |
显示上下文代码 |
5.3 torch.profiler:找到性能瓶颈的最终手段
#mermaid-svg-Yq1vTtvBVPDZI8tA{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-Yq1vTtvBVPDZI8tA .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-Yq1vTtvBVPDZI8tA .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-Yq1vTtvBVPDZI8tA .error-icon{fill:#552222;}#mermaid-svg-Yq1vTtvBVPDZI8tA .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-Yq1vTtvBVPDZI8tA .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-Yq1vTtvBVPDZI8tA .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-Yq1vTtvBVPDZI8tA .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-Yq1vTtvBVPDZI8tA .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-Yq1vTtvBVPDZI8tA .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-Yq1vTtvBVPDZI8tA .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-Yq1vTtvBVPDZI8tA .marker{fill:#333333;stroke:#333333;}#mermaid-svg-Yq1vTtvBVPDZI8tA .marker.cross{stroke:#333333;}#mermaid-svg-Yq1vTtvBVPDZI8tA svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-Yq1vTtvBVPDZI8tA p{margin:0;}#mermaid-svg-Yq1vTtvBVPDZI8tA .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-Yq1vTtvBVPDZI8tA .cluster-label text{fill:#333;}#mermaid-svg-Yq1vTtvBVPDZI8tA .cluster-label span{color:#333;}#mermaid-svg-Yq1vTtvBVPDZI8tA .cluster-label span p{background-color:transparent;}#mermaid-svg-Yq1vTtvBVPDZI8tA .label text,#mermaid-svg-Yq1vTtvBVPDZI8tA span{fill:#333;color:#333;}#mermaid-svg-Yq1vTtvBVPDZI8tA .node rect,#mermaid-svg-Yq1vTtvBVPDZI8tA .node circle,#mermaid-svg-Yq1vTtvBVPDZI8tA .node ellipse,#mermaid-svg-Yq1vTtvBVPDZI8tA .node polygon,#mermaid-svg-Yq1vTtvBVPDZI8tA .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-Yq1vTtvBVPDZI8tA .rough-node .label text,#mermaid-svg-Yq1vTtvBVPDZI8tA .node .label text,#mermaid-svg-Yq1vTtvBVPDZI8tA .image-shape .label,#mermaid-svg-Yq1vTtvBVPDZI8tA .icon-shape .label{text-anchor:middle;}#mermaid-svg-Yq1vTtvBVPDZI8tA .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-Yq1vTtvBVPDZI8tA .rough-node .label,#mermaid-svg-Yq1vTtvBVPDZI8tA .node .label,#mermaid-svg-Yq1vTtvBVPDZI8tA .image-shape .label,#mermaid-svg-Yq1vTtvBVPDZI8tA .icon-shape .label{text-align:center;}#mermaid-svg-Yq1vTtvBVPDZI8tA .node.clickable{cursor:pointer;}#mermaid-svg-Yq1vTtvBVPDZI8tA .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-Yq1vTtvBVPDZI8tA .arrowheadPath{fill:#333333;}#mermaid-svg-Yq1vTtvBVPDZI8tA .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-Yq1vTtvBVPDZI8tA .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-Yq1vTtvBVPDZI8tA .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-Yq1vTtvBVPDZI8tA .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-Yq1vTtvBVPDZI8tA .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-Yq1vTtvBVPDZI8tA .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-Yq1vTtvBVPDZI8tA .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-Yq1vTtvBVPDZI8tA .cluster text{fill:#333;}#mermaid-svg-Yq1vTtvBVPDZI8tA .cluster span{color:#333;}#mermaid-svg-Yq1vTtvBVPDZI8tA div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-Yq1vTtvBVPDZI8tA .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-Yq1vTtvBVPDZI8tA rect.text{fill:none;stroke-width:0;}#mermaid-svg-Yq1vTtvBVPDZI8tA .icon-shape,#mermaid-svg-Yq1vTtvBVPDZI8tA .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-Yq1vTtvBVPDZI8tA .icon-shape p,#mermaid-svg-Yq1vTtvBVPDZI8tA .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-Yq1vTtvBVPDZI8tA .icon-shape .label rect,#mermaid-svg-Yq1vTtvBVPDZI8tA .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-Yq1vTtvBVPDZI8tA .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-Yq1vTtvBVPDZI8tA .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-Yq1vTtvBVPDZI8tA :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 开启Profiler
记录CPU/CUDA活动
记录Tensor形状
record_shapes=True
追踪显存分配
profile_memory=True
记录调用栈
with_stack=True
按schedule分步采集
prof.step()
输出统计表
key_averages().table()
导出Chrome Trace
export_chrome_trace()
perfetto.dev可视化
典型用法:
python
import torch.profiler as profiler
from torch.profiler import ProfilerActivity, schedule
prof_schedule = schedule(
wait=5, # 等5步预热
warmup=2, # 预热2步
active=3, # 采集3步
repeat=1 # 重复1轮
)
with profiler.profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True, # 记录每个Tensor形状
schedule=prof_schedule,
profile_memory=True, # 追踪显存分配/释放
with_stack=True # 记录Python调用栈
) as prof:
# 你的训练/推理代码
for step, data in enumerate(dataloader):
output = model(data)
loss.backward()
optimizer.step()
prof.step()
# 输出统计表
print(prof.key_averages().table(
sort_by="self_cuda_time_total",
row_limit=10
))
# 导出Chrome Trace文件
prof.export_chrome_trace("trace.json")
# 把trace.json拖到 https://ui.perfetto.dev/ 可视化分析
输出示例(数据已脱敏):
Name Self CUDA% CUDA total # of Calls
vectorized_elementwise 12.55% xxx us 8
MemcpyDeviceToHost 9.51% xxx us 2
reduce_kernel 8.37% xxx us 2
unrolled_elementwise 6.46% xxx us 4
从输出能快速定位哪些算子最耗时------MemcpyDeviceToHost占比高说明CPU-GPU数据传输是瓶颈,reduce_kernel高说明规约操作需要优化。
6. 总结
本文把在DCU上用PyTorch从头装到调优的关键节点串了一遍:
- 安装:走DAS源,注意DTK版本匹配,别用官方PyTorch
- 环境验证 :
source /opt/dtk/env.sh别忘,torch.cuda.is_available()确认 - torch.compile :一行代码换性能,但
max-autotune第一次编译要有耐心 - 分布式:GPU训练无脑选RCCL,多节点注意网卡名和IB配置
- 调优:NUMA绑定 + DataLoader参数 + profiler定位瓶颈,三步走
完整源码和更多细节可以参考公开源:http://pypi.sourcefind.cn:666/source/packages
声明 :本文基于公开技术资料与个人实践整理,所有具体版本号、内部节点名、性能数据均已做脱敏处理。文中mermaid图均为原创绘制。
(内容由AI生成,仅供参考)