可插拔训练加速trick-Scaling PyTorch Model Training With Minimal Code Changes

依赖:

shell 复制代码
pip install lightning

插拔改动:

python 复制代码
from lightning.fabric import Fabric

#...

# 实例化
fabric = Fabric(accelerator='cuda')  
# 混精度用这个,加速明显
#fabric = Fabric(accelerator="cuda", precision="bf16-mixed")
fabric.launch()

#...

# 插拔接入
model, optimizer = fabric.setup(model, optimizer) 
train_dataloader = fabric.setup_dataloaders(train_dataloader)

#...

def train(num_epochs, model, optimizer, train_loader, val_loader, fabric):
    for epoch in range(num_epochs):
        train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device)

        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            model.train()  
            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            fabric.backward(loss)  # 插拔接入,原反向传播:loss.backward()
            optimizer.step()
            #...

参考文献

CVPR 2023 Talk:Scaling PyTorch Model Training With Minimal Code Changes

相关推荐
冬奇Lab12 小时前
每日一个开源项目(第140篇):AgentScope 2.0 - 阿里开源的生产级 Agent 框架
人工智能·开源·agent
冬奇Lab12 小时前
Skill 系列(04):Skill 指标体系——L1/L2/L3 三层监控,让质量下降有据可查
人工智能·开源·llm
IT_陈寒13 小时前
Vite的静态资源打包让我熬夜到三点,这坑千万别跳
前端·人工智能·后端
玩转AI不是事14 小时前
用IndexedDB做AI对话离线缓存实战
人工智能
Asize15 小时前
多模态生图:从 Vite 工程化到前端调用 Qwen Image
javascript·人工智能·后端
MobotStone15 小时前
AI项目越多,为什么越容易失控
人工智能·aigc
十有八七15 小时前
AI时代的置身X内
前端·人工智能
Lkstar15 小时前
A2A协议深度解析|Agent2Agent通信标准,智能体互联网的"HTTP"
人工智能·llm
百度Geek说15 小时前
当代码越来越便宜,什么在变贵?
人工智能
橘子星15 小时前
LLM 无状态架构实践:从原理到代码落地
前端·javascript·人工智能