LLM训练中batchsize与过拟合和泛化的关系

一般来说batch size越大训练会越稳定,然而实际情况下,结论可能违反直觉。

更大的batch size有可能导致过拟合,而小batch size却有利于提升泛化性能。

这里尝试从多个角度分析这个问题,所有结论参考和收集自网络资料。

1 较小batch size有助于泛化?

这里从梯度估计的角度分析小batch_size和大batch_size对模型训练的影响。

1.1 梯度估计噪声

当batch_size较小时,每次更新,基于少量样本的梯度估计,此时噪声较大。

这种噪声类似于一种隐式的正则化,阻止模型过于精确地拟合训练数据。

其作用类似于给梯度下降添加随机扰动,帮助逃离局部极小值,泛化能力通常更好,原因如下:

1)小batch_size找到的解,往往位于平坦的极小值盆地

2)平坦极小值对参数扰动不敏感,此时泛化能力更强

3)大批量容易收敛到尖锐的极小值,对训练数据过拟合,测试时性能下降

1.2 训练动态差异

小批量即小batch_size,此时在相同epoch下,需要更多更新次数,从而会进行更多的优化探索

大批量即大batch_size,此时梯度估计更准确,但可能过早收敛到次优点。

2 较大batch size可能导致过拟合?

2.1 梯度探索

大批量梯度估计能更快速的接近整个训练集的梯度,但因为训练过程缺乏噪声,容易过拟合训练数据,当然包括噪声和异常。

从梯度探索的角度看,大批量训练路径更直接、更确定,减少了对损失函数的更多探索,可能错过更好的泛化解。

2.2 正则化探索

从正则化的角度看,小批量训练相当于通过随机采样在每个批次上进行了数据增强,即正则化。

大批量减弱了这种随机性带来的正则化效果,虽然模型能很快收敛,但可能收敛到次优点。

3 研究证据支持与实际权衡

3.1 研究支撑

Keskar等人(2017)明确展示大批量训练导致泛化差距,使用大批量训练深度神经网络,虽然能加速训练,但往往会损害模型的最终泛化性能。

另外,Hoffer等人(2017)提出"训练长度假设",认为大批量需要更多迭代。大批量训练导致模型泛化能力下降的根本原因,并非批量大小本身,而是因为其减少了参数更新的次数(即训练长度不足)。只要通过增加训练周期(Epoch)来补足更新次数,就可以消除所谓的"泛化鸿沟"。

通常在实践中,Batch Size = 32/64通常是很好的起点。

3.2 小批量的优势

小批量训练,有更好的泛化性能,更少的内存需求,更频繁的模型更新,对梯度噪声更鲁棒。

但小批量训练也有缺点,由于迭代次数增多,训练速度可能较慢,但每次迭代快较快。

另外,小批量训练时梯度噪声可能使训练过程不稳定,需要仔细调整学习率。

所以,资源受限时,优先小批量(如32、64)。

相反,大批量时GPU利用率高,训练更稳定,梯度方差小,能更快拟合训练集高。

大批量的缺点,就是需要更多显存,同时由于小批量带来的正则化效应,训练后的模型可能泛化能力差,另外也需要调整学习率策略。

如果追求快速迭代原型,可以中等批量(128、256)或更大批量,但需要配合正则化和采用带更多随机扰动的优化器,以及采用多种学习率调整策略,比如线性缩放、余弦退火等调度器、学习率预热,帮助大批量训练逃离尖锐极小值。

3.3 渐进式调整策略

另外,也可以同时采用小批量和大批量,但采用 渐进式调整策略。

1)开始时使用小批量(如32)进行超参数搜索

2)找到最佳配置后,逐步增大批量

3)同步调整学习率(线性缩放)

4)监控验证集性能,必要时增加正则化

3.4 特殊情况与例外

极小的批量(<8)会使BatchNorm统计估计不准确。

此时可考虑GroupNorm或LayerNorm,批量归一化(BatchNorm)。

对比学习、自监督学习,通常需要大批量才能获得好的表示,因为小批量随机扰动太大,导致模型训练不稳定很难收敛。采用大批量是,需要通过动量编码器、大量负样本等技术弥补。

reference


On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima

https://arxiv.org/abs/1609.04836v1

Train longer, generalize better: closing the generalization gap in large batch training of neural networks

https://arxiv.org/abs/1705.08741

The Limit of the Batch Size

https://ar5iv.labs.arxiv.org/html/2006.08517

相关推荐
LJ97951119 小时前
当AI遇上媒体发布:企业传播的下一站
大数据·人工智能
大傻^9 小时前
LangChain4j 核心抽象:ChatMessage、UserMessage 与模型无关设计
人工智能·rag·langchain4j
智算菩萨9 小时前
基于多模态基础模型迈向通用人工智能:BriVL模型深度解析
论文阅读·人工智能·ai·语言模型·agi
小鹿软件办公9 小时前
OpenAI 补齐产品线:GPT-5.4 Mini 与 Nano 正式发布
人工智能·openai
Jordannnnnnnn9 小时前
复试打卡day30
算法
qq_233772719 小时前
元——人工智能
人工智能
大傻^9 小时前
SpringAI 2.0 可观测性体系:AI 操作追踪、指标监控与评估框架
人工智能·springai·指标监控·评估框架
郝学胜-神的一滴9 小时前
贪心策略实战Leetcode 860题:柠檬水找零问题的优雅解法
数据结构·c++·算法·leetcode·职场和发展
GIS数据转换器9 小时前
小龙虾(OpenClaw) 在低空经济领域的应用
大数据·人工智能·无人机·智慧城市·制造
用户69371750013849 小时前
OS级AI Agent:手机操作系统的下一个战场
android·前端·人工智能