【代码模板】Pytorch AMP 混合精度训练

背景

当使用AMP混合精度训练时,可以提升训练速度,并降低对显存的占用。下面提供一个使用AMP训练的代码demo。

Demo

python 复制代码
use_amp = True

net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

start_timer()
for epoch in range(epochs):
    for input, target in zip(data, targets):
        with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
            output = net(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        opt.zero_grad() # set_to_none=True here can modestly improve performance
end_timer_and_print("Mixed precision:")

参考

Automatic Mixed Precision

相关推荐
福运常在4 分钟前
股票数据API(19)次新股池数据
java·python·maven
多看书少吃饭8 分钟前
Vue3 + Java + Python 打造企业级大模型知识库(含 SSE 流式对话完整源码)
java·python·状态模式
胖祥11 分钟前
onnx之优化器
人工智能·深度学习
AI服务老曹11 分钟前
源码交付与低代码重构:企业级 AI 视频管理平台的二次开发实战
人工智能·低代码·重构
L-影12 分钟前
下篇:一棵树能长成多少种样子?——AI中决策树的类型与作用,以及它凭什么活了六十年还没过气
人工智能·算法·决策树·ai
jovi_AI电报13 分钟前
你还把 ChatGPT 当白月光,别人已经让它出来上班了
人工智能
Z.风止14 分钟前
Large Model-learning(2)
开发语言·笔记·python·leetcode
蓝天守卫者联盟114 分钟前
玩具喷涂废气治理厂家:行业现状、技术路径与选型指南
大数据·运维·人工智能·python
m0_7381207215 分钟前
我的创作纪念日0328
java·网络·windows·python·web安全·php
智慧化智能化数字化方案21 分钟前
架构进阶——解读企业数字化转型L1-L5数据架构设计方法论及案例【附全文阅读】
人工智能·企业数字化转型·l1-l5数据架构设计方法论