可插拔训练加速trick-Scaling PyTorch Model Training With Minimal Code Changes

依赖:

shell 复制代码
pip install lightning

插拔改动:

python 复制代码
from lightning.fabric import Fabric

#...

# 实例化
fabric = Fabric(accelerator='cuda')  
# 混精度用这个,加速明显
#fabric = Fabric(accelerator="cuda", precision="bf16-mixed")
fabric.launch()

#...

# 插拔接入
model, optimizer = fabric.setup(model, optimizer) 
train_dataloader = fabric.setup_dataloaders(train_dataloader)

#...

def train(num_epochs, model, optimizer, train_loader, val_loader, fabric):
    for epoch in range(num_epochs):
        train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device)

        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            model.train()  
            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            fabric.backward(loss)  # 插拔接入,原反向传播:loss.backward()
            optimizer.step()
            #...

参考文献

CVPR 2023 Talk:Scaling PyTorch Model Training With Minimal Code Changes

相关推荐
云卓SKYDROID几秒前
无人机电压模块技术剖析
人工智能·无人机·电压·高科技·云卓科技
Codebee7 分钟前
使用Qoder 改造前端UI/UE升级改造实践:从传统界面到现代化体验的华丽蜕变
前端·人工智能
用户51914958484511 分钟前
Apache服务器自动化运维与安全加固脚本详解
人工智能·aigc
yintele17 分钟前
智能AI汽车电子行业,EMS应用相关问题
人工智能·汽车
却道天凉_好个秋24 分钟前
深度学习(四):数据集划分
人工智能·深度学习·数据集
数字冰雹28 分钟前
“图观”端渲染场景编辑器
人工智能·编辑器
里昆28 分钟前
【AI】Tensorflow在jupyterlab中运行要注意的问题
人工智能·python·tensorflow
荼蘼1 小时前
OpenCV 高阶 图像金字塔 用法解析及案例实现
人工智能·opencv·计算机视觉
Clownseven1 小时前
2025云计算趋势:Serverless与AI大模型如何赋能中小企业
人工智能·serverless·云计算
2401_828890641 小时前
使用 BERT 实现意图理解和实体识别
人工智能·python·自然语言处理·bert·transformer