可插拔训练加速trick-Scaling PyTorch Model Training With Minimal Code Changes

依赖:

shell 复制代码
pip install lightning

插拔改动:

python 复制代码
from lightning.fabric import Fabric

#...

# 实例化
fabric = Fabric(accelerator='cuda')  
# 混精度用这个,加速明显
#fabric = Fabric(accelerator="cuda", precision="bf16-mixed")
fabric.launch()

#...

# 插拔接入
model, optimizer = fabric.setup(model, optimizer) 
train_dataloader = fabric.setup_dataloaders(train_dataloader)

#...

def train(num_epochs, model, optimizer, train_loader, val_loader, fabric):
    for epoch in range(num_epochs):
        train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device)

        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            model.train()  
            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            fabric.backward(loss)  # 插拔接入,原反向传播:loss.backward()
            optimizer.step()
            #...

参考文献

CVPR 2023 Talk:Scaling PyTorch Model Training With Minimal Code Changes

相关推荐
AI松子6661 分钟前
PyTorch-混合精度训练(amp)
人工智能·pytorch·python
MDLZH1 分钟前
Pytorch性能调优简单总结
人工智能·pytorch·python
GIS数据转换器1 小时前
基于GIS的智慧旅游调度指挥平台
运维·人工智能·物联网·无人机·旅游·1024程序员节
沧澜sincerely2 小时前
数据挖掘概述
人工智能·数据挖掘
数数科技的数据干货3 小时前
从爆款到厂牌:解读游戏工业化的业务持续增长道路
运维·数据库·人工智能
amhjdx6 小时前
星巽短剧以科技赋能影视创新,构建全球短剧新生态!
人工智能·科技
听风南巷6 小时前
机器人全身控制WBC理论及零空间原理解析(数学原理解析版)
人工智能·数学建模·机器人
美林数据Tempodata7 小时前
“双新”指引,AI驱动:工业数智应用生产性实践创新
大数据·人工智能·物联网·实践中心建设·金基地建设
电科_银尘7 小时前
【大语言模型】-- 私有化部署
人工智能·语言模型·自然语言处理
翔云 OCR API9 小时前
人工智能驱动下的OCR API技术演进与实践应用
人工智能·ocr