可插拔训练加速trick-Scaling PyTorch Model Training With Minimal Code Changes

依赖:

shell 复制代码
pip install lightning

插拔改动:

python 复制代码
from lightning.fabric import Fabric

#...

# 实例化
fabric = Fabric(accelerator='cuda')  
# 混精度用这个,加速明显
#fabric = Fabric(accelerator="cuda", precision="bf16-mixed")
fabric.launch()

#...

# 插拔接入
model, optimizer = fabric.setup(model, optimizer) 
train_dataloader = fabric.setup_dataloaders(train_dataloader)

#...

def train(num_epochs, model, optimizer, train_loader, val_loader, fabric):
    for epoch in range(num_epochs):
        train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device)

        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            model.train()  
            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            fabric.backward(loss)  # 插拔接入,原反向传播:loss.backward()
            optimizer.step()
            #...

参考文献

CVPR 2023 Talk:Scaling PyTorch Model Training With Minimal Code Changes

相关推荐
CCPC不拿奖不改名3 分钟前
循环神经网络RNN:整数索引→稠密向量(嵌入层 / Embedding)详解
人工智能·python·rnn·深度学习·神经网络·自然语言处理·embedding
学好statistics和DS14 分钟前
感知机的对偶形式是怎么来的
深度学习·神经网络·机器学习
石去皿19 分钟前
大模型面试常见问答
人工智能·面试·职场和发展
Java后端的Ai之路34 分钟前
【AI大模型开发】-RAG 技术详解
人工智能·rag
墨香幽梦客34 分钟前
家具ERP口碑榜单,物料配套专用工具推荐
大数据·人工智能
Coder_Boy_42 分钟前
基于SpringAI的在线考试系统-考试系统DDD(领域驱动设计)实现步骤详解
java·数据库·人工智能·spring boot
敏叔V5871 小时前
从人类反馈到直接偏好优化:AI对齐技术的实战演进
人工智能
琅琊榜首20201 小时前
AI赋能短剧创作:从Prompt设计到API落地的全技术指南
人工智能·prompt
测试者家园1 小时前
Prompt、Agent、测试智能体:测试的新机会,还是新焦虑?
人工智能·prompt·智能体·职业和发展·质量效能·智能化测试·软件开发和测试
嗷嗷哦润橘_1 小时前
从萝卜纸巾猫到桌游:“蒸蚌大开门”的设计平衡之旅
人工智能·算法·游戏·概率论·桌游