`transformers` 的 `per_device_train_batch_size` 不支持小于 1 的浮点数值,llamafactory 支持

transformersper_device_train_batch_size 不支持小于 1 的浮点数值(比如 0.5)。这个参数必须是正整数int 类型),代表每个设备上的样本数量。

但你可以通过以下方式实现等效于 batch_size < 1 的效果:

方案 1:梯度累积(Gradient Accumulation)⭐推荐

这是标准做法,通过 gradient_accumulation_steps 来模拟小 batch:

python 复制代码
from transformers import TrainingArguments

training_args = TrainingArguments(
    per_device_train_batch_size=1,           # 最小为1
    gradient_accumulation_steps=4,           # 每4步才更新一次参数
    # 等效 batch_size = 1 * 4 = 4(如果是多卡,再乘以卡数)
)

如果你想要等效 batch_size = 0.5,可以反过来:

  • per_device_train_batch_size=1
  • 每 2 个 step 处理一个样本(需要自定义训练循环)

方案 2:使用 DataLoaderbatch_sampler

如果你确实需要物理上每个 batch 只有 0.5 个样本(比如一条数据拆两次 forward),需要自定义 Trainer 或重写 compute_loss

python 复制代码
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # 把 input_ids 拆成两半,分别 forward
        input_ids = inputs["input_ids"]
        half_len = input_ids.shape[1] // 2
        
        # 前半部分
        outputs1 = model(input_ids=input_ids[:, :half_len])
        # 后半部分  
        outputs2 = model(input_ids=input_ids[:, half_len:])
        
        loss = (outputs1.loss + outputs2.loss) / 2
        return (loss, outputs) if return_outputs else loss

方案 3:DeepSpeed / FSDP 的自动微批处理

如果你用 DeepSpeed ZeRO-3FSDP ,它们支持 train_micro_batch_size_per_gpu 配合 gradient_accumulation_steps

json 复制代码
{
    "train_batch_size": 4,
    "train_micro_batch_size_per_gpu": 1,
    "gradient_accumulation_steps": 4
}

常见误区 ❌

python 复制代码
# 错误!会报错或自动转为1
per_device_train_batch_size=0.5  # 不允许!

总结

需求 正确做法
想要更小的"等效batch size" gradient_accumulation_steps 调大
显存不够但想 batch_size=1 gradient_checkpointing + 8-bit 优化器
物理上必须拆分单个样本 自定义 Trainer.compute_loss

LLaMA Factory 中,通过增大梯度累积步数和减小批量大小,可以有效控制显存使用

。这个组合策略的目的是在保持总批量大小(per_device_train_batch_size * gradient_accumulation_steps)不变或近似的前提下,降低每次前向-反向传播所需的峰值显存。

策略解析:

• per_device_train_batch_size 降到 1/2:这意味着每次前向传播处理的样本数减半,能直接降低存储激活值(activation)所需显存,因为激活值与批次大小成近似正比。

• gradient_accumulation_steps 设为 4/8:这表示累积 4 次或 8 次小批量计算(每次使用减半的批量)的梯度后,再执行一次参数更新。这相当于模拟了一个更大的"有效批量大小"。

复制代码
◦   有效批量大小 = per_device_train_batch_size * gradient_accumulation_steps

◦   (调整前)有效批量大小 = 原大小 * 1

◦   (调整后)有效批量大小 = (原大小/2)  4 = 原大小  2

◦   (调整后)有效批量大小 = (原大小/2)  8 = 原大小  4

• 最终影响:

复制代码
◦   总显存消耗降低:通过减小单次处理的批量,显著降低了存储激活和中间变量的峰值显存。

◦   收敛行为可能变化:有效批量大小会影响训练的稳定性和收敛性。批量越大,通常噪声越小,但可能收敛到更尖锐的极小点,泛化性有时会稍差。将 gradient_accumulation_steps 设为 4 或 8 会增加有效批量大小,可能加快收敛速度,但需要留意学习率与批量大小的关系(如线性缩放规则)。

◦   训练时间略有增加:梯度累积需要更多次的前向/后向计算才更新一次权重,可能会轻微增加每个 epoch 的训练时间,但通常对最终训练时间影响不大。

操作建议:

  1. 调整原则:保持 per_device_train_batch_size * gradient_accumulation_steps 的乘积(即有效批量大小)在合理范围内。通常,你可以先确定一个可接受的有效批量大小(例如 16、32、64),再根据显存限制调整两者比例。
  2. 学习率调整:当你增加有效批量大小(例如从 8 增大到 16 或 32)时,通常需要相应增大学习率(例如按有效批量大小的平方根或线性关系缩放,常见做法是按线性比例放大),以保持相似的收敛特性。在 LLaMA Factory 的配置中,注意调整 learning_rate 参数。
  3. 显存监控:在调整后,务必监控显存使用情况,确保不超出显卡容量,并保留一定余量。
  4. 性能平衡:这是一个显存与训练速度/稳定性之间的权衡。如果显存非常紧张,可以进一步降低 per_device_train_batch_size 并增加 gradient_accumulation_steps,反之亦然。

示例:

若原配置为 per_device_train_batch_size=8, gradient_accumulation_steps=1(有效批量=8),显存不足。可调整为:

• per_device_train_batch_size=4, gradient_accumulation_steps=2(有效批量=8),显存减半,更新频率不变。

• per_device_train_batch_size=2, gradient_accumulation_steps=4(有效批量=8),显存降至约1/4,更新频率不变。

• 如果想尝试更大批量以稳定训练,也可设为 per_device_train_batch_size=4, gradient_accumulation_steps=4(有效批量=16),显存减半,但有效批量加倍,此时可能需要适当提高学习率。

在 LLaMA Factory 的配置文件中,你可以在 ds_train.json 或 train_args.json 等配置文件中找到并修改这些参数。

对于使用极小batch来避免梯度爆炸的情况,你可以考虑以下几种方法:

  1. 减小学习率:由于batch size减小,每次更新的梯度方差会增大,适当降低学习率可以让训练更稳定。

  2. 使用梯度裁剪:对梯度进行裁剪,例如限制梯度的范数(如使用 torch.nn.utils.clip_grad_norm_),防止更新步长过大。

  3. 调整优化器参数:使用带有自适应学习率的优化器(如 Adam、RMSprop),它们对不同参数有适应性,可以在小batch下更稳定。

  4. 归一化技巧:确保数据经过适当的归一化或标准化,有助于稳定梯度。

  5. 增加梯度累积:如果因为内存限制不得不使用小batch,可以通过梯度累积(gradient accumulation)来模拟更大batch的效果。比如每多个step更新一次,将多个小batch的梯度累积后再应用更新。

  6. 监控与调试:在训练过程中监控梯度范数(gradient norm),若发现异常增大,及时调整策略。

如果你有具体场景或框架,我可以给出更针对性的建议。

相关推荐
zl_vslam2 小时前
SLAM中的非线性优-3D图优化之绝对位姿SE3约束四元数形式(十九)
人工智能·算法·计算机视觉·3d
Predestination王瀞潞2 小时前
1.3.1 AI->Tesseract OCR Engine标准(HP、Google):Tesseract OCR Engine
人工智能·ocr
Fleshy数模2 小时前
基于PyTorch的食品图像分类:数据增强与调优实战
人工智能·pytorch·分类
岁岁种桃花儿2 小时前
AI超级智能开发系列从入门到上天第十篇:SpringAI+云知识库服务
linux·运维·数据库·人工智能·oracle·llm
小马_xiaoen2 小时前
2026 AI 开发新风向:Skills 安装量 Top 10 深度解析
人工智能·skill
Surmon2 小时前
AI 代替不了这样的你
人工智能·ai编程
j_xxx404_2 小时前
蓝桥杯基础--时间复杂度
数据结构·c++·算法·蓝桥杯·排序算法
ZGi.ai2 小时前
一个 LLM 网关需要做哪些事? 多模型统一接入的工程设计
人工智能
FL16238631292 小时前
C#版winform实现FaceFusion人脸替换
人工智能