可插拔训练加速trick-Scaling PyTorch Model Training With Minimal Code Changes

依赖:

shell 复制代码
pip install lightning

插拔改动:

python 复制代码
from lightning.fabric import Fabric

#...

# 实例化
fabric = Fabric(accelerator='cuda')  
# 混精度用这个,加速明显
#fabric = Fabric(accelerator="cuda", precision="bf16-mixed")
fabric.launch()

#...

# 插拔接入
model, optimizer = fabric.setup(model, optimizer) 
train_dataloader = fabric.setup_dataloaders(train_dataloader)

#...

def train(num_epochs, model, optimizer, train_loader, val_loader, fabric):
    for epoch in range(num_epochs):
        train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device)

        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            model.train()  
            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            fabric.backward(loss)  # 插拔接入,原反向传播:loss.backward()
            optimizer.step()
            #...

参考文献

CVPR 2023 Talk:Scaling PyTorch Model Training With Minimal Code Changes

相关推荐
MidJourney中文版7 分钟前
深度报告:中老年AI陪伴机器人需求分析
人工智能·机器人
William.csj19 分钟前
Pytorch/CUDA——flash-attn 库编译的 gcc 版本问题
pytorch·cuda
王上上34 分钟前
【论文阅读41】-LSTM-PINN预测人口
论文阅读·人工智能·lstm
智慧化智能化数字化方案1 小时前
69页全面预算管理体系的框架与落地【附全文阅读】
大数据·人工智能·全面预算管理·智慧财务·智慧预算
PyAIExplorer1 小时前
图像旋转:从原理到 OpenCV 实践
人工智能·opencv·计算机视觉
Wilber的技术分享1 小时前
【机器学习实战笔记 14】集成学习:XGBoost算法(一) 原理简介与快速应用
人工智能·笔记·算法·随机森林·机器学习·集成学习·xgboost
19891 小时前
【零基础学AI】第26讲:循环神经网络(RNN)与LSTM - 文本生成
人工智能·python·rnn·神经网络·机器学习·tensorflow·lstm
burg_xun1 小时前
【Vibe Coding 实战】我如何用 AI 把一张草图变成了能跑的应用
人工智能
酌沧2 小时前
AI做美观PPT:3步流程+工具测评+避坑指南
人工智能·powerpoint
狂师2 小时前
啥是AI Agent!2025年值得推荐入坑AI Agent的五大工具框架!(新手科普篇)
人工智能·后端·程序员