可插拔训练加速trick-Scaling PyTorch Model Training With Minimal Code Changes

依赖:

shell 复制代码
pip install lightning

插拔改动:

python 复制代码
from lightning.fabric import Fabric

#...

# 实例化
fabric = Fabric(accelerator='cuda')  
# 混精度用这个,加速明显
#fabric = Fabric(accelerator="cuda", precision="bf16-mixed")
fabric.launch()

#...

# 插拔接入
model, optimizer = fabric.setup(model, optimizer) 
train_dataloader = fabric.setup_dataloaders(train_dataloader)

#...

def train(num_epochs, model, optimizer, train_loader, val_loader, fabric):
    for epoch in range(num_epochs):
        train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device)

        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            model.train()  
            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            fabric.backward(loss)  # 插拔接入,原反向传播:loss.backward()
            optimizer.step()
            #...

参考文献

CVPR 2023 Talk:Scaling PyTorch Model Training With Minimal Code Changes

相关推荐
min1811234561 分钟前
PC端零基础跨职能流程图制作教程
大数据·人工智能·信息可视化·架构·流程图
愚公搬代码14 分钟前
【愚公系列】《AI+直播营销》015-直播的选品策略(设计直播产品矩阵)
人工智能·线性代数·矩阵
静听松涛13319 分钟前
中文PC端多人协作泳道图制作平台
大数据·论文阅读·人工智能·搜索引擎·架构·流程图·软件工程
学历真的很重要39 分钟前
LangChain V1.0 Context Engineering(上下文工程)详细指南
人工智能·后端·学习·语言模型·面试·职场和发展·langchain
IT=>小脑虎40 分钟前
Python零基础衔接进阶知识点【详解版】
开发语言·人工智能·python
UnderTurrets1 小时前
A_Survey_on_3D_object_Affordance
pytorch·深度学习·计算机视觉·3d
koo3641 小时前
pytorch深度学习笔记13
pytorch·笔记·深度学习
黄焖鸡能干四碗1 小时前
智能制造工业大数据应用及探索方案(PPT文件)
大数据·运维·人工智能·制造·需求分析
高洁011 小时前
CLIP 的双编码器架构是如何优化图文关联的?(3)
深度学习·算法·机器学习·transformer·知识图谱
世岩清上1 小时前
乡村振兴主题展厅本土化材料运用与地域文化施工表达
大数据·人工智能·乡村振兴·展厅