【代码模板】Pytorch AMP 混合精度训练

背景

当使用AMP混合精度训练时,可以提升训练速度,并降低对显存的占用。下面提供一个使用AMP训练的代码demo。

Demo

python 复制代码
use_amp = True

net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

start_timer()
for epoch in range(epochs):
    for input, target in zip(data, targets):
        with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
            output = net(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        opt.zero_grad() # set_to_none=True here can modestly improve performance
end_timer_and_print("Mixed precision:")

参考

Automatic Mixed Precision

相关推荐
冬奇Lab4 小时前
Workflow 系列(04):Multi-Agent 协调——编排器边界、并发控制与上下文隔离
人工智能·工作流引擎
冬奇Lab4 小时前
每日一个开源项目(第147篇):HyperGraphRAG - 用超图表示 N 元关系,RAG 的第三代范式
人工智能·开源·graphql
甲维斯5 小时前
Github + 阿里云oss实现类似codex的自动更新!
人工智能
阿里云大数据AI技术7 小时前
光轮智能 × 阿里云:共建 Physical AI 云上数据、评测与持续学习基础设施
人工智能·机器学习
机器之心7 小时前
实锤了:Claude Code偷查用户,时区、中国AI实验室全是关键词
人工智能·openai
网易云信7 小时前
Cursor点燃个人开发者,企业级AI为何频频受挫?Agent工厂从提效工具到AI员工的跃迁
人工智能·开源
网易云信7 小时前
解锁触手可及的温暖:网易智企 x Wander Puffs AI 云游泡芙
人工智能
转转技术团队7 小时前
从 PRD 到可验证代码:AI 需求开发闭环实践
人工智能