pytorch梯度累积

梯度累加其实是为了变相扩大batch_size,用来解决显存受限问题。

常规训练方式,每次从train_loader读取出一个batch的数据:

python 复制代码
for x,y in train_loader:
	pred = model(x)
	loss = criterion(pred, label)
	# 反向传播
	loss.backward()
	# 根据新的梯度更新网络参数
	optimizer.step()
	# 清空以往梯度,通过下面反向传播重新计算梯度
	optimizer.zero_grad()

pytorch每次forward完都会得到一个用于梯度回传的计算图,pytorch构建的计算图是动态的,其实在每次backward后计算图都会从内存中释放掉,但是梯度不会清空的。所以若不显示的进行optimizer.zero_grad()清空过往梯度这一步操作,backward()的时候就会累加过往梯度。

梯度累加的做法:

python 复制代码
accumulation_steps = 4
for i,(x,y) in enumerate(train_loader):
	pred = model(x)
	loss = criterion(pred, label)
	
	# 相当于对累加后的梯度取平均
	loss = loss/accumulation_steps
	# 反向传播
	loss.backward()

	if (i+1) % accumulation_steps == 0:
		# 根据新的梯度更新网络参数
		optimizer.step()
		# 清空以往梯度,通过下面反向传播重新计算梯度
		optimizer.zero_grad()

代码中设置accumulation_steps = 4,意思就是变相扩大batch_size四倍。因为代码中每隔4次迭代才清空梯度,更新参数。

至于为啥loss = loss/accumulation_steps,因为梯度累加了四次呀,那就要取平均,除以4。那我每次loss取4,其实就相当于最后将累加后的梯度除4咯。同时,因为累计了4个batch,那学习率也应该扩大4倍,让更新的步子跨大点。

看网上的帖子有讨论对BN层是否有影响,因为BN的估算阶段(计算batch内均值、方差)是在forward阶段完成的,那真实的batch_size放大4倍效果肯定是比通过梯度累加放大4倍效果好的,毕竟计算真实的大batch_size内的均值、方差肯定更精确。

还有讨论说通过调低BN参数momentum可以得到更长序列的统计信息,应该意思是能够记忆更久远的统计信息(均值、方差),以逼近真实的扩大batch_size的效果。

参考

pytorch骚操作之梯度累加,变相增大batch size

相关推荐
专注VB编程开发20年2 分钟前
Activex dll创建调用-Python,Node.js, JAVA主流编程语言操作COM对象
java·开发语言·python·node.js·activex dll
moonsheeper2 分钟前
Prompt优化策略
人工智能·机器学习
海岸线科技2 分钟前
离散制造,工单级成本管控的必然
大数据·人工智能·制造
亚控科技3 分钟前
亚控信创SCADA以全栈国产化方案,筑牢航空燃油安全供应生命线
运维·人工智能·安全·kingscada·亚控科技
权泽谦4 分钟前
用大语言模型实现一个离线翻译小程序(无网络也能用)
开发语言·人工智能·语言模型·小程序·php
论文小助手W6856 分钟前
【SAE出版,EI检索】第六届智慧城市工程与公共交通国际学术会议(SCEPT 2026)
人工智能·智慧城市·交通物流
上不如老下不如小1 小时前
2025年第七届全国高校计算机能力挑战赛 决赛 Python组 编程题汇总
开发语言·python
User_芊芊君子2 小时前
AI Ping 深度评测:大模型 API 选型的 “理性决策中枢”,终结经验主义选型时代
人工智能
smile_Iris2 小时前
Day 32 类的定义和方法
开发语言·python
reasonsummer2 小时前
【教学类-89-11】20251209新年篇07——灰色姓名对联(名字做对联,姓氏做横批,福字贴(通义万相AI福字空心字))
python·通义万相