可插拔训练加速trick-Scaling PyTorch Model Training With Minimal Code Changes

依赖:

shell 复制代码
pip install lightning

插拔改动:

python 复制代码
from lightning.fabric import Fabric

#...

# 实例化
fabric = Fabric(accelerator='cuda')  
# 混精度用这个,加速明显
#fabric = Fabric(accelerator="cuda", precision="bf16-mixed")
fabric.launch()

#...

# 插拔接入
model, optimizer = fabric.setup(model, optimizer) 
train_dataloader = fabric.setup_dataloaders(train_dataloader)

#...

def train(num_epochs, model, optimizer, train_loader, val_loader, fabric):
    for epoch in range(num_epochs):
        train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device)

        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            model.train()  
            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            fabric.backward(loss)  # 插拔接入,原反向传播:loss.backward()
            optimizer.step()
            #...

参考文献

CVPR 2023 Talk:Scaling PyTorch Model Training With Minimal Code Changes

相关推荐
shayudiandian7 小时前
YOLOv8目标检测项目实战(从训练到部署)
人工智能·yolo·目标检测
陈天伟教授7 小时前
基于学习的人工智能(4)机器学习基本框架
人工智能·学习·机器学习
studytosky7 小时前
深度学习理论与实战:MNIST 手写数字分类实战
人工智能·pytorch·python·深度学习·机器学习·分类·matplotlib
做萤石二次开发的哈哈7 小时前
11月27日直播预告 | 萤石智慧台球厅创新场景化方案分享
大数据·人工智能
AGI前沿7 小时前
AdamW的继任者?AdamHD让LLM训练提速15%,性能提升4.7%,显存再省30%
人工智能·算法·语言模型·aigc
后端小肥肠8 小时前
小佛陀漫画怎么做?深扒中老年高互动赛道,用n8n流水线批量打造
人工智能·aigc·agent
是店小二呀8 小时前
本地绘图工具也能远程协作?Excalidraw+cpolar解决团队跨网画图难题
人工智能
i爱校对8 小时前
爱校对团队服务全新升级
人工智能
KL132881526938 小时前
AI 介绍的东西大概率是不会错的,包括这款酷铂达 VGS耳机
人工智能
vigel19908 小时前
人工智能的7大应用领域
人工智能