可插拔训练加速trick-Scaling PyTorch Model Training With Minimal Code Changes

依赖:

shell 复制代码
pip install lightning

插拔改动:

python 复制代码
from lightning.fabric import Fabric

#...

# 实例化
fabric = Fabric(accelerator='cuda')  
# 混精度用这个,加速明显
#fabric = Fabric(accelerator="cuda", precision="bf16-mixed")
fabric.launch()

#...

# 插拔接入
model, optimizer = fabric.setup(model, optimizer) 
train_dataloader = fabric.setup_dataloaders(train_dataloader)

#...

def train(num_epochs, model, optimizer, train_loader, val_loader, fabric):
    for epoch in range(num_epochs):
        train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device)

        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            model.train()  
            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            fabric.backward(loss)  # 插拔接入,原反向传播:loss.backward()
            optimizer.step()
            #...

参考文献

CVPR 2023 Talk:Scaling PyTorch Model Training With Minimal Code Changes

相关推荐
AI_小站3 小时前
6个GitHub爆火的免费大模型教程,助你快速进阶AI编程
人工智能·langchain·github·知识图谱·agent·llama·rag
xindoo3 小时前
GitHub Trending霸榜!深度解析AI Coding辅助神器 Superpowers
人工智能·github
时间之里3 小时前
【深度学习】:RF-DETR与yolo对比
人工智能·深度学习·yolo
北京阿法龙科技有限公司3 小时前
数智化升级:AR 智能眼镜驱动工业运维效能革新
人工智能
风落无尘3 小时前
《智能重生:从垃圾堆到AI工程师》——第二章 概率与生存
大数据·人工智能
j_xxx404_3 小时前
Linux:静态链接与动态链接深度解析
linux·运维·服务器·c++·人工智能
收获不止数据库3 小时前
达梦9发布会归来:AI 时代,我们需要一款什么样的数据库?
数据库·人工智能·ai·语言模型·数据分析
hhb_6184 小时前
AI全栈编程生存指南
人工智能
AI-Frontiers4 小时前
transformer进阶之路:#2 工作原理详解
人工智能·深度学习·transformer
科研前沿4 小时前
2026 数字孪生前沿科技:全景迭代报告 —— 镜像视界生成式孪生(Generative DT)技术白皮书
大数据·人工智能·科技·算法·音视频·空间计算