LLM训练中batchsize与过拟合和泛化的关系

一般来说batch size越大训练会越稳定,然而实际情况下,结论可能违反直觉。

更大的batch size有可能导致过拟合,而小batch size却有利于提升泛化性能。

这里尝试从多个角度分析这个问题,所有结论参考和收集自网络资料。

1 较小batch size有助于泛化?

这里从梯度估计的角度分析小batch_size和大batch_size对模型训练的影响。

1.1 梯度估计噪声

当batch_size较小时,每次更新,基于少量样本的梯度估计,此时噪声较大。

这种噪声类似于一种隐式的正则化,阻止模型过于精确地拟合训练数据。

其作用类似于给梯度下降添加随机扰动,帮助逃离局部极小值,泛化能力通常更好,原因如下:

1)小batch_size找到的解,往往位于平坦的极小值盆地

2)平坦极小值对参数扰动不敏感,此时泛化能力更强

3)大批量容易收敛到尖锐的极小值,对训练数据过拟合,测试时性能下降

1.2 训练动态差异

小批量即小batch_size,此时在相同epoch下,需要更多更新次数,从而会进行更多的优化探索

大批量即大batch_size,此时梯度估计更准确,但可能过早收敛到次优点。

2 较大batch size可能导致过拟合?

2.1 梯度探索

大批量梯度估计能更快速的接近整个训练集的梯度,但因为训练过程缺乏噪声,容易过拟合训练数据,当然包括噪声和异常。

从梯度探索的角度看,大批量训练路径更直接、更确定,减少了对损失函数的更多探索,可能错过更好的泛化解。

2.2 正则化探索

从正则化的角度看,小批量训练相当于通过随机采样在每个批次上进行了数据增强,即正则化。

大批量减弱了这种随机性带来的正则化效果,虽然模型能很快收敛,但可能收敛到次优点。

3 研究证据支持与实际权衡

3.1 研究支撑

Keskar等人(2017)明确展示大批量训练导致泛化差距,使用大批量训练深度神经网络,虽然能加速训练,但往往会损害模型的最终泛化性能。

另外,Hoffer等人(2017)提出"训练长度假设",认为大批量需要更多迭代。大批量训练导致模型泛化能力下降的根本原因,并非批量大小本身,而是因为其减少了参数更新的次数(即训练长度不足)。只要通过增加训练周期(Epoch)来补足更新次数,就可以消除所谓的"泛化鸿沟"。

通常在实践中,Batch Size = 32/64通常是很好的起点。

3.2 小批量的优势

小批量训练,有更好的泛化性能,更少的内存需求,更频繁的模型更新,对梯度噪声更鲁棒。

但小批量训练也有缺点,由于迭代次数增多,训练速度可能较慢,但每次迭代快较快。

另外,小批量训练时梯度噪声可能使训练过程不稳定,需要仔细调整学习率。

所以,资源受限时,优先小批量(如32、64)。

相反,大批量时GPU利用率高,训练更稳定,梯度方差小,能更快拟合训练集高。

大批量的缺点,就是需要更多显存,同时由于小批量带来的正则化效应,训练后的模型可能泛化能力差,另外也需要调整学习率策略。

如果追求快速迭代原型,可以中等批量(128、256)或更大批量,但需要配合正则化和采用带更多随机扰动的优化器,以及采用多种学习率调整策略,比如线性缩放、余弦退火等调度器、学习率预热,帮助大批量训练逃离尖锐极小值。

3.3 渐进式调整策略

另外,也可以同时采用小批量和大批量,但采用 渐进式调整策略。

1)开始时使用小批量(如32)进行超参数搜索

2)找到最佳配置后,逐步增大批量

3)同步调整学习率(线性缩放)

4)监控验证集性能,必要时增加正则化

3.4 特殊情况与例外

极小的批量(<8)会使BatchNorm统计估计不准确。

此时可考虑GroupNorm或LayerNorm,批量归一化(BatchNorm)。

对比学习、自监督学习,通常需要大批量才能获得好的表示,因为小批量随机扰动太大,导致模型训练不稳定很难收敛。采用大批量是,需要通过动量编码器、大量负样本等技术弥补。

reference


On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima

https://arxiv.org/abs/1609.04836v1

Train longer, generalize better: closing the generalization gap in large batch training of neural networks

https://arxiv.org/abs/1705.08741

The Limit of the Batch Size

https://ar5iv.labs.arxiv.org/html/2006.08517

相关推荐
ccLianLian2 小时前
Segment Anything Model
人工智能·深度学习·计算机视觉
week_泽2 小时前
第10课:从零构建生产级AI Agent服务技术方案 - 学习笔记_10
人工智能·笔记·学习·ai agent
lynnlovemin2 小时前
AI时代信息安全:从挑战突围到智能防御体系构建
人工智能·信息安全
西柚小萌新2 小时前
【计算机视觉CV:标注工具】--labelimg+labelme
人工智能·计算机视觉
躺平的赶海人2 小时前
PyTorch 安装指南:快速开启深度学习之旅
人工智能·pytorch·深度学习
muddjsv2 小时前
什么是算法?——现代视角下的一次凝视
算法
IT_陈寒2 小时前
Vue3性能优化实战:5个被低估的API让我减少了40%的代码量
前端·人工智能·后端
laplace01232 小时前
智能体经典范式构建
算法·langchain·大模型·agent
小雨下雨的雨2 小时前
Flutter鸿蒙共赢——色彩的流变:流体梯度网格与现代视觉重构
算法·flutter·华为·重构·交互·harmonyos·鸿蒙