可插拔训练加速trick-Scaling PyTorch Model Training With Minimal Code Changes

依赖:

shell 复制代码
pip install lightning

插拔改动:

python 复制代码
from lightning.fabric import Fabric

#...

# 实例化
fabric = Fabric(accelerator='cuda')  
# 混精度用这个,加速明显
#fabric = Fabric(accelerator="cuda", precision="bf16-mixed")
fabric.launch()

#...

# 插拔接入
model, optimizer = fabric.setup(model, optimizer) 
train_dataloader = fabric.setup_dataloaders(train_dataloader)

#...

def train(num_epochs, model, optimizer, train_loader, val_loader, fabric):
    for epoch in range(num_epochs):
        train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device)

        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            model.train()  
            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            fabric.backward(loss)  # 插拔接入,原反向传播:loss.backward()
            optimizer.step()
            #...

参考文献

CVPR 2023 Talk:Scaling PyTorch Model Training With Minimal Code Changes

相关推荐
格林威1 分钟前
工业相机图像高速存储(C#版):内存映射文件方法,附堡盟相机C#实战代码!
开发语言·人工智能·数码相机·计算机视觉·c#·工业相机·堡盟相机
人工智能训练5 分钟前
Qwen3.5 开源全解析:从 0.8B 到 397B,代际升级 + 全场景选型指南
linux·运维·服务器·人工智能·开源·ai编程
南滑散修5 分钟前
机器学习(一)-数学基础
人工智能·机器学习
Lab_AI6 分钟前
iLabPower LES与SDH科学数据基因组平台赋能光电材料研发与生产,鼎材科技与创腾科技进一步深化合作
大数据·人工智能·oled·材料设计·光电材料研发·材料创新·材料研发
prince_zxill7 分钟前
Raspberry Pi边缘AI:运行轻量级机器学习模型
人工智能·机器学习
放下华子我只抽RuiKe57 分钟前
机器学习全景指南-基石篇——预测连续值的线性回归
人工智能·深度学习·神经网络·算法·机器学习·自然语言处理·线性回归
前端技术9 分钟前
【鸿蒙实战】从零打造智能物联网家居控制系统:HarmonyOS Next分布式能力的完美诠释
java·前端·人工智能·分布式·物联网·前端框架·harmonyos
草莓熊Lotso9 分钟前
MySQL 数据库基础入门:从概念到实战
linux·运维·服务器·数据库·c++·人工智能·mysql
Thomas.Sir11 分钟前
DeepSeek:开源AI的破局者
人工智能·gpt-4·deepseek
小圣贤君12 分钟前
从「选中一段」到「整章润色」:编辑器里的 AI 润色是怎么做出来的
人工智能·electron·编辑器·vue3·ai写作·deepseek·写小说