多台服务器同时训练llamfactory 大语言模型 国家超算中心 Slurm 是目前全球最主流的开源、高性能计算(HPC)集群资源管理与作业调度系统
完全可以。10 张 64GB Blackwell 跨 5 节点分布式训练 14B+LoRA 不仅可行,而且单卡显存非常宽裕。
1. 显存核算(打消顾虑)
| 项目 | 占用估算 | 说明 |
|---|---|---|
| 14B 基座 (BF16) | ~28 GB | 参数本身 |
| LoRA 可训练参数 (r=128) | ~0.2 GB | 低秩矩阵 |
| 优化器状态 (AdamW) | ~0.6 GB | 只存 LoRA 的 momentum |
| 激活值 (seq=768, bs=4) | ~4-8 GB | 与长度正相关 |
| KV Cache / 碎片 | ~2 GB | 预留 |
| 单卡总计 | < 40 GB | 64GB 卡剩余 24GB+ |
结论 :单卡就能独立放下完整模型+LoRA,不需要 ZeRO-3 切分参数。多节点训练本质是纯数据并行(DDP),通信量极小(只同步 LoRA 梯度的几百 MB)。
2. 网络瓶颈评估(关键前提)
5 台机器训练,节点间网络是唯一可能卡你的地方:
| 网络类型 | 是否可用 | 建议 |
|---|---|---|
| InfiniBand (200Gbps+) | ✅ 最优 | 直接跑,无瓶颈 |
| RDMA RoCE (25Gbps+) | ✅ 可用 | 设置 NCCL 环境变量即可 |
| 普通以太网 (1Gbps) | ⚠️ 能用但慢 | 增大 gradient accumulation,减少通信频次 |
| 跨公网/异地 | ❌ 不建议 | 延迟太高,NCCL 会超时 |
检测命令(每台执行):
bash
# 看网卡速度
ethtool eth0 | grep Speed
# 如果是 25000Mb/s 以上,放心跑;如果是 1000Mb/s,看下方降级方案
3. 方案 A:多节点分布式训练一个任务(推荐)
3.1 训练代码最小改动
你只需要在原有 train.py 上加几行分布式初始化:
python
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig
from datasets import load_dataset
def main():
# 分布式环境变量由 torchrun 自动注入
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# 初始化进程组(NCCL 是 NVIDIA 多卡通信后端)
dist.init_process_group(backend="nccl")
global_rank = dist.get_rank()
world_size = dist.get_world_size() # 应该是 10
# 只在主节点打印,避免 10 份日志刷屏
if global_rank == 0:
print(f"Total GPUs: {world_size}, Local rank: {local_rank}")
# 加载模型(跟之前一样)
model = AutoModelForCausalLM.from_pretrained(
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
torch_dtype=torch.bfloat16,
device_map={"": local_rank}, # 关键:指定当前卡
trust_remote_code=True
)
# LoRA
lora_config = LoraConfig(
r=128, lora_alpha=256,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05, bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# 关键:包装为 DDP。注意 device_ids 必须设,否则默认用所有可见卡
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
# 加载数据
dataset = load_dataset("json", data_files="week1_mixed.jsonl", split="train")
# 关键:DistributedSampler 保证每台机器只处理自己那批数据
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=global_rank, shuffle=True)
# 注意:这里的 batch_size 是单卡 batch_size!
# 总 global batch = 4 * 10 * grad_accum
dataloader = DataLoader(
dataset,
sampler=sampler,
batch_size=4, # 单卡 batch
num_workers=4,
pin_memory=True
)
# 优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6, weight_decay=0.1)
# 训练循环
model.train()
for epoch in range(3):
sampler.set_epoch(epoch) # 关键:保证多节点 shuffle 随机性
for step, batch in enumerate(dataloader):
# 数据迁移到当前 GPU
input_ids = batch["input_ids"].to(local_rank)
attention_mask = batch["attention_mask"].to(local_rank)
labels = batch["labels"].to(local_rank)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
# 梯度累积(可选,减少通信频次)
loss = loss / 4 # 假设 grad_accum=4
loss.backward()
if (step + 1) % 4 == 0:
optimizer.step()
optimizer.zero_grad()
if global_rank == 0 and step % 50 == 0:
print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")
# 只在主节点保存
if global_rank == 0:
# 保存时要用 model.module(因为 DDP 包装了一层)
model.module.save_pretrained(f"checkpoint_epoch_{epoch}")
print(f"Saved checkpoint at epoch {epoch}")
dist.destroy_process_group()
if __name__ == "__main__":
main()
3.2 一键启动脚本
节点 1(主节点,假设 IP 为 192.168.1.10):
bash
# node1.sh
export NCCL_DEBUG=INFO # 调试用,跑通后可删
export NCCL_IB_DISABLE=0 # 如果有 IB,保持 0;如果纯以太网,设 1
export NCCL_SOCKET_IFNAME=eth0 # 改成你的实际网卡名(ifconfig 看)
torchrun \
--nnodes=5 \
--nproc_per_node=2 \
--rdzv_id=druggpt_job \
--rdzv_backend=c10d \
--rdzv_endpoint=192.168.1.10:29500 \
train.py
节点 2~5(工作节点):
bash
# node2.sh(其他节点唯一区别是 rdzv_endpoint 指向主节点)
export NCCL_DEBUG=INFO
export NCCL_IB_DISABLE=0
export NCCL_SOCKET_IFNAME=eth0
torchrun \
--nnodes=5 \
--nproc_per_node=2 \
--rdzv_id=druggpt_job \
--rdzv_backend=c10d \
--rdzv_endpoint=192.168.1.10:29500 \
train.py
执行方式:
bash
# 在 5 台机器上同时执行各自脚本(用 tmux/screen/ssh 并行)
# 或者如果你用 Kubernetes/Slurm,由调度系统统一启动
4. 方案 B:用 Accelerate 库(更省事)
HuggingFace Accelerate 能自动生成多节点命令,不用手写 DDP:
bash
# 1. 在每台机器上生成配置
accelerate config --config_file ds_config.yaml
# 交互式选择:multi-node, 2 GPUs per node, DeepSpeed/DDP
# 2. 主节点启动(Accelerate 会自动通过 rdzv 发现其他节点)
accelerate launch \
--config_file ds_config.yaml \
--main_process_ip 192.168.1.10 \
--main_process_port 29500 \
--machine_rank 0 \
--num_machines 5 \
--num_processes 10 \
train.py
train.py 几乎不用改 ,把 model = DDP(...) 删掉,Accelerate 会自动处理。
5. 方案 C:如果网络很差(降级)
如果节点间只有千兆网(1Gbps),All-Reduce 会成为瓶颈。此时不要分布式训一个任务,而是:
| 机器 | 任务 | 数据 |
|---|---|---|
| 节点 1-2 | 训 Week 1(全长70%+Domain30%) | 数据分片 A |
| 节点 3-4 | 训 Week 2(全长30%+Domain70%) | 数据分片 B |
| 节点 5 | 训 Week 3(Domain50%+Pocket50%) | 数据分片 C |
这不是分布式,是并行跑实验 。每台机器内部 2 张卡用 torchrun --nproc_per_node=2 做本地 DDP。
结果:5 个实验同时跑,3 周后选最优模型。适合网络差、想快速试错的情况。
6. 关键检查清单(跑前必做)
bash
# 1. 确认 5 台机器 NCCL 能互相发现(在主节点执行)
python -c "import torch; print(torch.cuda.device_count())" # 应输出 2
# 2. 确认端口互通
# 在节点 2 上:telnet 192.168.1.10 29500
# 必须能通,否则 torchrun 会卡死
# 3. 确认所有节点 PyTorch/CUDA/NCCL 版本一致
python -c "import torch; print(torch.__version__, torch.version.cuda)"
# 4. 数据可见性
# 所有节点必须能读取 week1_mixed.jsonl(NFS 挂载或每台本地复制一份)
# 5. 防火墙
sudo ufw allow 29500/tcp # Ubuntu 示例
7. 显存与速度预期
| 配置 | 单卡显存占用 | 全局 Batch Size | 预估速度 |
|---|---|---|---|
| 单卡独立 (1x2) | ~35 GB | 4 | 1x |
| 单机双卡 (1x2 DDP) | ~38 GB | 8 | ~1.9x |
| 5机10卡 (5x2 DDP) | ~38 GB | 40 | ~9x |
10 张卡理论上能把你的训练时间压缩到原来的 1/9 。30 万步如果原来要 7 天,现在约 18 小时一个 epoch。
8. 总结
| 你的情况 | 推荐方案 |
|---|---|
| 节点间有 25Gbps+ / IB | 方案 A:5机10卡分布式训一个任务 |
| 节点间只有千兆网 | 方案 C:每台独立跑不同 Week 实验 |
| 不想改代码 | 方案 B:Accelerate 自动分布式 |
现在最该做的:
- 选一台机器,先单机
torchrun --nproc_per_node=2 train.py跑通 - 确认能跑后,5 台同时启动,加
--nnodes=5 - 观察 NCCL 日志,确保
global rank 0-9全部上线
需要我提供 DeepSpeed ZeRO-1 的多节点配置文件 (进一步压榨通信效率),或者 Slurm 集群的 sbatch 提交脚本 吗?
Slurm (全称 Simple Linux Utility for Resource Management )是目前全球最主流的开源、高性能计算(HPC)集群资源管理与作业调度系统,被超算 TOP500 中超过半数的系统采用。
一、核心定位与作用
Slurm 作为集群"操作系统",主要解决三大问题:
- 资源分配:为用户分配计算节点、CPU、GPU、内存、网络等资源(独占/共享)。
- 作业执行:提供框架启动、运行、监控并行/串行作业(如 MPI、AI 训练)。
- 调度仲裁:管理作业队列,按策略调度、解决资源竞争。
二、基本架构(核心组件)
- slurmctld (中央控制器)
- 集群大脑:监控资源、管理队列、分配资源、调度作业。
- 支持主备冗余(高可用)。
- slurmd (节点守护)
- 每个计算节点运行一个:接收指令、启动任务、监控节点状态。
- slurmdbd (记账数据库)
- 集中存储作业历史、资源使用、用户配额、账单数据。
- slurmstepd
- 作业步守护:管理单个作业内任务的 I/O、信号、资源限制。
- 用户命令
sbatch:批处理作业提交srun:交互式/实时运行salloc:申请资源并进入会话sinfo:查看集群/节点状态squeue:查看排队/运行作业scancel:取消作业sacct:查询历史记账
三、关键特性
- 极致可扩展 :单集群可管理数百万 CPU 核心、数万节点。
- 高可用:控制器主备、节点故障隔离、无单点失效。
- 丰富调度策略
- FIFO、公平共享(Fair-share)、回填调度(Backfill)
- 优先级、抢占、预留、GANG 调度(分时)
- 资源精细化管理
- 支持 CPU/GPU/内存/网卡/IB 等绑定与隔离
- 分区(Partition)、QoS、用户/账户配额、时限
- 作业高级特性
- 作业依赖、作业数组、数组任务、信号控制
- 标准输出/错误重定向、实时 attach 作业(
sattach)
- 插件化:调度、记账、认证、网络等均为插件,易扩展。
- 开源免费:GPLv2 协议,社区活跃。
四、典型使用流程(用户视角)
-
用
sinfo看可用资源。 -
写脚本(
job.sh),指定资源:bash#!/bin/bash #SBATCH --job-name=my_train #SBATCH --nodes=2 #SBATCH --ntasks-per-node=8 #SBATCH --gres=gpu:4 #SBATCH --time=12:00:00 python train.py -
提交:
sbatch job.sh -
查看:
squeue、sacct -
取消:
scancel <jobid>
五、应用场景
- 超算中心、高校/科研集群
- AI 训练集群(大模型、深度学习)
- 云计算、私有云资源调度
- 大规模科学计算、数值模拟、生物信息、气象、CAE 等
六、与同类对比
- 比 PBS Pro、LSF 更轻量、开源、生态强。
- 比 Kubernetes 更适合长时、批处理、MPI 并行、GPU 密集的 HPC 场景。
要不要我给你一份可直接复制的 Slurm 常用命令速查表(含提交、查询、监控、取消、资源查看)?
这份是超精简、可直接复制粘贴的 Slurm 常用命令速查表,日常HPC/AI训练完全够用。
Slurm 常用命令速查
一、查看集群与资源
bash
# 查看所有节点/分区状态
sinfo
# 只看空闲节点
sinfo -N -l | grep idle
# 查看 GPU 资源(需要集群配置 gres)
sinfo -o "%N %P %T %G"
# 查看节点详细信息
scontrol show node <节点名>
二、作业提交(最常用)
1. 提交批处理脚本
bash
sbatch job.sh
2. 常用 #SBATCH 参数(写在脚本开头)
bash
#!/bin/bash
#SBATCH --job-name=myjob # 作业名
#SBATCH --partition=your_part # 分区
#SBATCH --nodes=1 # 节点数
#SBATCH --ntasks-per-node=8 # 每节点任务数
#SBATCH --gres=gpu:1 # GPU 数量
#SBATCH --mem=32G # 总内存
#SBATCH --time=12:00:00 # 最大运行时间
#SBATCH --output=%j.out # 输出日志
#SBATCH --error=%j.err # 错误日志
#SBATCH --mail-type=END # 结束时发邮件
3. 直接交互式运行
bash
srun -p 分区 --gres=gpu:1 nvidia-smi
4. 申请资源进入节点
bash
salloc -p 分区 --gres=gpu:1 --time=2:00:00
# 进入后:ssh 节点名
三、查看作业
bash
# 查看自己的作业
squeue
# 查看所有人的作业
squeue -a
# 只看自己
squeue -u $USER
# 查看作业详细信息
scontrol show job <jobid>
# 查看历史作业(结束的作业)
sacct
sacct -j <jobid>
sacct -u $USER --format=jobid,jobname,partition,state,elapsed
四、控制作业
bash
# 取消一个作业
scancel <jobid>
# 取消自己所有作业
scancel -u $USER
# 暂停/恢复作业
scontrol hold <jobid>
scontrol release <jobid>
# 修改作业最大时限
scontrol update jobid=<jobid> timelimit=24:00:00
五、查看账户与配额
bash
# 查看账户信息
sacctmgr show assoc user=$USER
# 查看 QoS 限制
sacctmgr show qos
六、常用状态说明
PD:Pending 排队R:Running 运行中CG:Completing 正在结束F:Failed 失败TO:TimeOut 超时CA:Canceled 被取消
如果你告诉我你用的是CPU 还是 GPU 训练 ,我可以直接给你一个可直接用的通用 sbatch 脚本模板。