LLM训练中batchsize与过拟合和泛化的关系

一般来说batch size越大训练会越稳定,然而实际情况下,结论可能违反直觉。

更大的batch size有可能导致过拟合,而小batch size却有利于提升泛化性能。

这里尝试从多个角度分析这个问题,所有结论参考和收集自网络资料。

1 较小batch size有助于泛化?

这里从梯度估计的角度分析小batch_size和大batch_size对模型训练的影响。

1.1 梯度估计噪声

当batch_size较小时,每次更新,基于少量样本的梯度估计,此时噪声较大。

这种噪声类似于一种隐式的正则化,阻止模型过于精确地拟合训练数据。

其作用类似于给梯度下降添加随机扰动,帮助逃离局部极小值,泛化能力通常更好,原因如下:

1)小batch_size找到的解,往往位于平坦的极小值盆地

2)平坦极小值对参数扰动不敏感,此时泛化能力更强

3)大批量容易收敛到尖锐的极小值,对训练数据过拟合,测试时性能下降

1.2 训练动态差异

小批量即小batch_size,此时在相同epoch下,需要更多更新次数,从而会进行更多的优化探索

大批量即大batch_size,此时梯度估计更准确,但可能过早收敛到次优点。

2 较大batch size可能导致过拟合?

2.1 梯度探索

大批量梯度估计能更快速的接近整个训练集的梯度,但因为训练过程缺乏噪声,容易过拟合训练数据,当然包括噪声和异常。

从梯度探索的角度看,大批量训练路径更直接、更确定,减少了对损失函数的更多探索,可能错过更好的泛化解。

2.2 正则化探索

从正则化的角度看,小批量训练相当于通过随机采样在每个批次上进行了数据增强,即正则化。

大批量减弱了这种随机性带来的正则化效果,虽然模型能很快收敛,但可能收敛到次优点。

3 研究证据支持与实际权衡

3.1 研究支撑

Keskar等人(2017)明确展示大批量训练导致泛化差距,使用大批量训练深度神经网络,虽然能加速训练,但往往会损害模型的最终泛化性能。

另外,Hoffer等人(2017)提出"训练长度假设",认为大批量需要更多迭代。大批量训练导致模型泛化能力下降的根本原因,并非批量大小本身,而是因为其减少了参数更新的次数(即训练长度不足)。只要通过增加训练周期(Epoch)来补足更新次数,就可以消除所谓的"泛化鸿沟"。

通常在实践中,Batch Size = 32/64通常是很好的起点。

3.2 小批量的优势

小批量训练,有更好的泛化性能,更少的内存需求,更频繁的模型更新,对梯度噪声更鲁棒。

但小批量训练也有缺点,由于迭代次数增多,训练速度可能较慢,但每次迭代快较快。

另外,小批量训练时梯度噪声可能使训练过程不稳定,需要仔细调整学习率。

所以,资源受限时,优先小批量(如32、64)。

相反,大批量时GPU利用率高,训练更稳定,梯度方差小,能更快拟合训练集高。

大批量的缺点,就是需要更多显存,同时由于小批量带来的正则化效应,训练后的模型可能泛化能力差,另外也需要调整学习率策略。

如果追求快速迭代原型,可以中等批量(128、256)或更大批量,但需要配合正则化和采用带更多随机扰动的优化器,以及采用多种学习率调整策略,比如线性缩放、余弦退火等调度器、学习率预热,帮助大批量训练逃离尖锐极小值。

3.3 渐进式调整策略

另外,也可以同时采用小批量和大批量,但采用 渐进式调整策略。

1)开始时使用小批量(如32)进行超参数搜索

2)找到最佳配置后,逐步增大批量

3)同步调整学习率(线性缩放)

4)监控验证集性能,必要时增加正则化

3.4 特殊情况与例外

极小的批量(<8)会使BatchNorm统计估计不准确。

此时可考虑GroupNorm或LayerNorm,批量归一化(BatchNorm)。

对比学习、自监督学习,通常需要大批量才能获得好的表示,因为小批量随机扰动太大,导致模型训练不稳定很难收敛。采用大批量是,需要通过动量编码器、大量负样本等技术弥补。

reference


On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima

https://arxiv.org/abs/1609.04836v1

Train longer, generalize better: closing the generalization gap in large batch training of neural networks

https://arxiv.org/abs/1705.08741

The Limit of the Batch Size

https://ar5iv.labs.arxiv.org/html/2006.08517

相关推荐
Lsx_5 分钟前
前端视角下认识 AI Agent 和 LangChain
前端·人工智能·agent
aiguangyuan6 分钟前
使用LSTM进行情感分类:原理与实现剖析
人工智能·python·nlp
季明洵7 分钟前
C语言实现单链表
c语言·开发语言·数据结构·算法·链表
shandianchengzi12 分钟前
【小白向】错位排列|图文解释公考常见题目错位排列的递推式Dn=(n-1)(Dn-2+Dn-1)推导方式
笔记·算法·公考·递推·排列·考公
I_LPL12 分钟前
day26 代码随想录算法训练营 回溯专题5
算法·回溯·hot100·求职面试·n皇后·解数独
Yeats_Liao13 分钟前
评估体系构建:基于自动化指标与人工打分的双重验证
运维·人工智能·深度学习·算法·机器学习·自动化
cpp_250117 分钟前
P9586 「MXOI Round 2」游戏
数据结构·c++·算法·题解·洛谷
深圳市恒星物联科技有限公司18 分钟前
水质流量监测仪:复合指标监测的管网智能感知设备
大数据·网络·人工智能
浅念-21 分钟前
C语言编译与链接全流程:从源码到可执行程序的幕后之旅
c语言·开发语言·数据结构·经验分享·笔记·学习·算法
断眉的派大星30 分钟前
均值为0,方差为1:数据的“标准校服”
人工智能·机器学习·均值算法