可插拔训练加速trick-Scaling PyTorch Model Training With Minimal Code Changes

依赖:

shell 复制代码
pip install lightning

插拔改动:

python 复制代码
from lightning.fabric import Fabric

#...

# 实例化
fabric = Fabric(accelerator='cuda')  
# 混精度用这个,加速明显
#fabric = Fabric(accelerator="cuda", precision="bf16-mixed")
fabric.launch()

#...

# 插拔接入
model, optimizer = fabric.setup(model, optimizer) 
train_dataloader = fabric.setup_dataloaders(train_dataloader)

#...

def train(num_epochs, model, optimizer, train_loader, val_loader, fabric):
    for epoch in range(num_epochs):
        train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device)

        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            model.train()  
            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            fabric.backward(loss)  # 插拔接入,原反向传播:loss.backward()
            optimizer.step()
            #...

参考文献

CVPR 2023 Talk:Scaling PyTorch Model Training With Minimal Code Changes

相关推荐
反向跟单策略4 分钟前
期货反向跟单-2025年回顾及2026年展望
大数据·人工智能·学习·数据分析·区块链
yunhuibin25 分钟前
GoogLeNet学习
人工智能·python·深度学习·神经网络·学习
luoganttcc41 分钟前
Taalas 将人工智能模型蚀刻到晶体管上,以提升推理能力
人工智能·fpga开发
冬奇Lab1 小时前
一天一个开源项目(第33篇):MyCodeAgent - 面向学习的 Claude Code 风格代码代理框架
人工智能·开源·资讯
deephub1 小时前
RAG 中分块重叠的 8 个隐性开销与权衡策略
人工智能·大语言模型·rag·检索
Ethan Hunt丶1 小时前
MSVTNet: 基于多尺度视觉Transformer的运动想象EEG分类模型
人工智能·深度学习·算法·transformer·脑机接口
康康的AI博客2 小时前
智能情感分析与品牌策略优化:如何通过AI洞察提升企业市场响应力
大数据·数据库·人工智能
亚古数据2 小时前
法国公司的类型:探索法国企业的多样形态
大数据·人工智能·亚古数据·法国公司
星爷AG I2 小时前
12-11 印象加工(AGI基础理论)
人工智能·agi
gs801402 小时前
赋予 AI 大模型“联网”超能力:Serper (Google Search API) 深度解析与实战
人工智能