可插拔训练加速trick-Scaling PyTorch Model Training With Minimal Code Changes

依赖:

shell 复制代码
pip install lightning

插拔改动:

python 复制代码
from lightning.fabric import Fabric

#...

# 实例化
fabric = Fabric(accelerator='cuda')  
# 混精度用这个,加速明显
#fabric = Fabric(accelerator="cuda", precision="bf16-mixed")
fabric.launch()

#...

# 插拔接入
model, optimizer = fabric.setup(model, optimizer) 
train_dataloader = fabric.setup_dataloaders(train_dataloader)

#...

def train(num_epochs, model, optimizer, train_loader, val_loader, fabric):
    for epoch in range(num_epochs):
        train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device)

        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            model.train()  
            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            fabric.backward(loss)  # 插拔接入,原反向传播:loss.backward()
            optimizer.step()
            #...

参考文献

CVPR 2023 Talk:Scaling PyTorch Model Training With Minimal Code Changes

相关推荐
三寸3374 分钟前
硬刚GPT 5.1,Grok 4.1来了,所有用户免费使用!
人工智能·ai·ai编程
苍何8 分钟前
Gemini3 强势来袭,这次前端真的死了。。。
人工智能
悟空CRM服务15 分钟前
我用一条命令部署了完整CRM系统!
java·人工智能·开源·开源软件
组合缺一16 分钟前
Solon AI 开发学习 - 1导引
java·人工智能·学习·ai·openai·solon
A-刘晨阳19 分钟前
《华为数据之道》发行五周年暨《数据空间探索与实践》新书发布会召开,共探AI时代数据治理新路径
人工智能·华为
人工小情绪27 分钟前
大模型运行的基本机制
人工智能
brave and determined32 分钟前
可编程逻辑器件学习(day24):异构计算:突破算力瓶颈的未来之路
人工智能·嵌入式硬件·深度学习·学习·算法·fpga·asic
南山安36 分钟前
让 LLM 与外界对话:使用 Function Calling 实现天气查询工具
人工智能·后端·python
用户51914958484538 分钟前
信号、Shell与Docker:层层嵌套的陷阱剖析
人工智能·aigc
文心快码BaiduComate43 分钟前
Comate Figma2Code智能体升级,畅享Figma2Code不受限
人工智能·程序员·前端框架