可插拔训练加速trick-Scaling PyTorch Model Training With Minimal Code Changes

依赖:

shell 复制代码
pip install lightning

插拔改动:

python 复制代码
from lightning.fabric import Fabric

#...

# 实例化
fabric = Fabric(accelerator='cuda')  
# 混精度用这个,加速明显
#fabric = Fabric(accelerator="cuda", precision="bf16-mixed")
fabric.launch()

#...

# 插拔接入
model, optimizer = fabric.setup(model, optimizer) 
train_dataloader = fabric.setup_dataloaders(train_dataloader)

#...

def train(num_epochs, model, optimizer, train_loader, val_loader, fabric):
    for epoch in range(num_epochs):
        train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device)

        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            model.train()  
            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            fabric.backward(loss)  # 插拔接入,原反向传播:loss.backward()
            optimizer.step()
            #...

参考文献

CVPR 2023 Talk:Scaling PyTorch Model Training With Minimal Code Changes

相关推荐
短视频矩阵源码定制3 分钟前
矩阵系统源头厂家
大数据·人工智能·矩阵
老赵聊算法、大模型备案4 分钟前
《人工智能拟人化互动服务管理暂行办法(征求意见稿)》深度解读:AI“拟人”时代迎来首个专项监管框架
人工智能·算法·安全·aigc
亚马逊云开发者14 分钟前
使用 Kiro AI IDE 开发 Amazon CDK 部署架构:从模糊需求到三层堆栈的协作实战
人工智能
心无旁骛~15 分钟前
ModelEngine Nexent 智能体从创建到部署全流程深度体验:自动化利器让 AI 开发效率拉满!
运维·人工智能·自动化
老徐电商数据笔记19 分钟前
数据仓库工程师在AI时代的走向探究
数据仓库·人工智能
小鸡吃米…24 分钟前
机器学习——生命周期
人工智能·python·机器学习
hzp66627 分钟前
GhostCache 的新型缓存侧信道攻击
人工智能·黑客·网络攻击·ghostcache
mubei-12328 分钟前
TF-IDF / BM25:经典的传统信息检索算法
人工智能·检索算法
databook33 分钟前
回归分析全家桶(16种回归模型实现方式总结)
人工智能·python·机器学习
天竺鼠不该去劝架34 分钟前
传统财务管理瓶颈:财务机器人如何提升效率
大数据·数据库·人工智能