可插拔训练加速trick-Scaling PyTorch Model Training With Minimal Code Changes

依赖:

shell 复制代码
pip install lightning

插拔改动:

python 复制代码
from lightning.fabric import Fabric

#...

# 实例化
fabric = Fabric(accelerator='cuda')  
# 混精度用这个,加速明显
#fabric = Fabric(accelerator="cuda", precision="bf16-mixed")
fabric.launch()

#...

# 插拔接入
model, optimizer = fabric.setup(model, optimizer) 
train_dataloader = fabric.setup_dataloaders(train_dataloader)

#...

def train(num_epochs, model, optimizer, train_loader, val_loader, fabric):
    for epoch in range(num_epochs):
        train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device)

        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            model.train()  
            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            fabric.backward(loss)  # 插拔接入,原反向传播:loss.backward()
            optimizer.step()
            #...

参考文献

CVPR 2023 Talk:Scaling PyTorch Model Training With Minimal Code Changes

相关推荐
不老刘14 分钟前
新一代图像生成工具:Nano Banana Pro 带来更自然的创作体验
人工智能·google·gemini·nano banana pro
袁庭新26 分钟前
人人都能学AI,人人都要学AI
人工智能·aigc
Tzarevich27 分钟前
前端调用大语言模型:基于 Vite 的工程化实践与 HTTP 请求详解
人工智能
Soonyang Zhang37 分钟前
MoeDistributeDispatch算子代码阅读
人工智能·算子·ascendc
sanggou40 分钟前
Windsurf AI IDE 完全使用指南
ide·人工智能
2501_941870562 小时前
人工智能与未来的工作:自动化与人类协作的新时代
大数据·人工智能
Blurpath2 小时前
2025 年用ChatGPT+代理构建AI驱动的智能爬虫
人工智能·爬虫·chatgpt·ip代理·住宅ip·动态住宅代理·轮换ip
啦啦啦在冲冲冲2 小时前
lora矩阵的初始化为啥B矩阵为0呢,为啥不是A呢
深度学习·机器学习·矩阵
极客BIM工作室2 小时前
大模型中的Scaling Law:AI的“增长密码“
人工智能