使用accumulate step节省显卡内存

使用前提:

单卡,模型+batch=1的数据能跑起来

使用accumulate step的意思就是,每次forward较小的batch,如batch=4,每4steps再更新一次参数,训练结果等效于batch=16

先跑一次原先的模型

复制代码
python NLinear_exp_full.py --accu_step 1 --batch 16 
epoch: 0
time comsuming: 1.8598144054412842
training epoch:0:0.0%
time comsuming: 2.137087106704712
training epoch:0:80.64516129032258%
time comsuming: 2.2242424488067627
time comsuming: 2.294013500213623
test epoch:0:0.0%
episode 0 mae 23.900234 rmse 66.41403 smape 0.934281
epoch: 1
time comsuming: 3.2021634578704834
training epoch:1:0.0%
time comsuming: 3.477159261703491
training epoch:1:80.64516129032258%
time comsuming: 3.560976505279541
time comsuming: 3.624363422393799
test epoch:1:0.0%
episode 1 mae 22.137833 rmse 64.748055 smape 0.79881644
epoch: 2
time comsuming: 3.982663869857788
training epoch:2:0.0%
time comsuming: 4.26115345954895
training epoch:2:80.64516129032258%
time comsuming: 4.350359678268433
time comsuming: 4.427008628845215
test epoch:2:0.0%
episode 2 mae 21.542023 rmse 64.10915 smape 0.68798375
epoch: 3
time comsuming: 4.786099910736084
training epoch:3:0.0%
time comsuming: 5.036171913146973
training epoch:3:80.64516129032258%
time comsuming: 5.121201038360596
time comsuming: 5.197283744812012
test epoch:3:0.0%
episode 3 mae 21.322206 rmse 64.079384 smape 0.6753313
epoch: 4
time comsuming: 5.5672008991241455
training epoch:4:0.0%
time comsuming: 5.830775260925293
training epoch:4:80.64516129032258%
time comsuming: 5.919378757476807
time comsuming: 5.9778666496276855

再跑一次batch设置为4,且accumulate step为4的情况

复制代码
python NLinear_exp_full.py --accu_step 4 --batch 4 
time comsuming: 1.9860742092132568
training epoch:0:0.0%
time comsuming: 2.221600294113159
training epoch:0:20.161290322580644%
time comsuming: 2.453077554702759
training epoch:0:40.32258064516129%
time comsuming: 2.675966262817383
training epoch:0:60.483870967741936%
time comsuming: 2.832383394241333
training epoch:0:80.64516129032258%
time comsuming: 3.0732641220092773
time comsuming: 3.1844491958618164
test epoch:0:0.0%
time comsuming: 3.4134249687194824
test epoch:0:72.99270072992701%
episode 0 mae 23.900234 rmse 66.41403 smape 0.934281
epoch: 1
time comsuming: 4.225269079208374
training epoch:1:0.0%
time comsuming: 4.442946434020996
training epoch:1:20.161290322580644%
time comsuming: 4.611685752868652
training epoch:1:40.32258064516129%
time comsuming: 4.845811367034912
training epoch:1:60.483870967741936%
time comsuming: 5.074229001998901
training epoch:1:80.64516129032258%
time comsuming: 5.326176166534424
time comsuming: 5.397624492645264
test epoch:1:0.0%
time comsuming: 5.633365869522095
test epoch:1:72.99270072992701%
episode 1 mae 22.137833 rmse 64.748055 smape 0.79881644
epoch: 2
time comsuming: 5.991377592086792
training epoch:2:0.0%
time comsuming: 6.217101097106934
training epoch:2:20.161290322580644%
time comsuming: 6.363693714141846
training epoch:2:40.32258064516129%
time comsuming: 6.590087175369263
training epoch:2:60.483870967741936%
time comsuming: 6.823684215545654
training epoch:2:80.64516129032258%
time comsuming: 7.081570625305176
time comsuming: 7.148298978805542
test epoch:2:0.0%
time comsuming: 7.377046823501587
test epoch:2:72.99270072992701%
episode 2 mae 21.542023 rmse 64.10915 smape 0.68798375
epoch: 3
time comsuming: 7.766062021255493
training epoch:3:0.0%
time comsuming: 7.996231317520142
training epoch:3:20.161290322580644%
time comsuming: 8.161593675613403
training epoch:3:40.32258064516129%
time comsuming: 8.388957738876343
training epoch:3:60.483870967741936%
time comsuming: 8.618509769439697
training epoch:3:80.64516129032258%
time comsuming: 8.876739978790283
time comsuming: 8.95041275024414
test epoch:3:0.0%
time comsuming: 9.18027663230896

显存占比: 514MB VS 494MB

相关推荐
小彭律师1 分钟前
电动汽车充电设施可调能力聚合评估与预测
人工智能·深度学习·机器学习
_waylau5 分钟前
【HarmonyOS NEXT+AI】问答05:ArkTS和仓颉编程语言怎么选?
人工智能·华为·harmonyos·arkts·鸿蒙·仓颉
老实人y9 分钟前
TIME - MoE 模型代码 3.2——Time-MoE-main/time_moe/datasets/time_moe_dataset.py
人工智能·python·机器学习·icl·icp
极客智谷13 分钟前
Spring AI 系列——使用大模型对文本内容分类归纳并标签化输出
人工智能·spring·分类
夏子曦24 分钟前
AI——认知建模工具:ACT-R
人工智能·机器学习·ai
studyer_domi37 分钟前
Matlab 基于Hough变换的人眼虹膜定位方法
人工智能·计算机视觉
qq_436962181 小时前
AI数据分析中的伪需求场景:现状、挑战与突破路径
人工智能·数据挖掘·数据分析·ai数据分析
豌豆花下猫1 小时前
Python 3.14 新特性盘点,更新了些什么?
后端·python·ai
flying_13141 小时前
面试常问系列(一)-神经网络参数初始化-之-softmax
深度学习·神经网络·算法·机器学习·面试
Python私教1 小时前
Python函数:从基础到进阶的完整指南
java·服务器·python