LLaVATrainer
python
class llava.train.llava_trainer.LLaVATrainer(Trainer)
用于训练 LLaVA (Large Language and Vision Assistant) 多模态模型的训练器类,继承自 transformers.Trainer。
该类在标准 Transformer Trainer 基础上扩展了以下功能:
- 支持 MeZO (Memory-efficient Zeroth-Order Optimization) 零阶优化训练模式
- 提供多种基于长度和模态的数据采样策略
- 支持 DeepSpeed 和 FSDP 分布式训练
- 提供针对多模态适配器 (MM Adapter) 的特定检查点保存功能
参数
该类接受所有 transformers.Trainer 支持的关键字参数,同时支持以下额外参数(通过 args 传入):
| 参数 | 类型 | 默认值 | 描述 |
|---|---|---|---|
trainer_mode |
str |
"regular" |
训练模式。可选 "regular"(常规反向传播训练)或 "zo"(MeZO 零阶优化训练)。 |
zo_eps |
float |
1e-3 |
MeZO 超参数 epsilon,控制参数扰动的幅度。 |
zo_num_directions |
int |
1 |
MeZO 优化中使用的随机方向数量。 |
group_by_length |
bool |
False |
是否按序列长度分组采样。 |
group_by_modality_length |
bool |
False |
是否按模态长度分组采样。 |
group_by_modality_length_auto |
bool |
False |
是否使用自动模态长度分组采样。 |
group_by_varlen |
bool |
False |
是否使用可变长度分组采样。 |
mm_projector_lr |
float, optional |
None |
多模态投影层的独立学习率。 |
mm_vision_tower_lr |
float, optional |
None |
视觉编码器的独立学习率。 |
属性
| 属性 | 类型 | 描述 |
|---|---|---|
trainer_mode |
str |
当前训练模式("regular" 或 "zo")。 |
zo_eps |
float |
MeZO epsilon 超参数。 |
zo_num_directions |
int |
MeZO 随机方向数量。 |
trainable_params |
List[Tuple[str, Parameter]] |
可训练参数列表,包含参数名称和参数本身。 |
mezo_update_history |
List[Dict] |
MeZO 更新历史记录,用于检查点恢复。 |
方法
zo_perturb_parameters
python
zo_perturb_parameters(scaling_factor: float = 1.0) -> None
使用随机向量 z z z 扰动模型参数。
参数:
- scaling_factor (
float) -- 扰动的缩放因子。正值表示正向扰动,负值表示反向扰动。
示例:
python
# 正向扰动
trainer.zo_perturb_parameters(scaling_factor=1.0)
# 反向扰动(恢复原始参数后再扰动)
trainer.zo_perturb_parameters(scaling_factor=-2.0)
zo_forward
python
zo_forward(model: nn.Module, inputs: Dict) -> torch.Tensor
在推理模式下计算前向传播损失。
参数:
- model (
nn.Module) -- 需要计算损失的模型。 - inputs (
Dict) -- 输入批次数据。
返回:
torch.Tensor-- 计算得到的损失值(已 detach)。
zo_step
python
zo_step(model: nn.Module, inputs: Dict) -> torch.Tensor
使用 MeZO 算法执行单步梯度估计。通过正向和反向扰动的损失差来近似梯度。
参数:
- model (
nn.Module) -- 模型实例。 - inputs (
Dict) -- 输入批次数据。
返回:
torch.Tensor-- 归一化后的损失值。
注意事项:
该方法在
gradient_accumulation_steps期间累积多个方向的梯度估计,在zo_update中统一应用。
zo_update
python
zo_update(learning_rate: float) -> None
根据累积的梯度估计更新模型参数。
参数:
- learning_rate (
float) -- 当前学习率。
注意事项:
- 该方法自动处理 weight decay
bias、layer_norm和layernorm参数不会应用 weight decay- 调用后会清空累积的梯度估计
save_model
python
save_model(output_dir: Optional[str] = None, _internal_call: bool = False)
保存模型检查点。当使用 MeZO 模式时,会额外保存轻量级的 MeZO 状态检查点。
参数:
- output_dir (
str, optional ) -- 保存路径。默认使用args.output_dir。 - _internal_call (
bool) -- 是否为内部调用。
_save_checkpoint
python
_save_checkpoint(model, trial, metrics=None) -> None
保存训练检查点。该方法重写了父类的检查点保存逻辑,以支持仅保存多模态适配器 (MM Adapter) 权重的场景。
参数:
- model -- 需要保存的模型实例。
- trial -- 超参数搜索试验对象(用于确定输出目录)。
- metrics (
Dict, optional) -- 评估指标字典。
行为说明:
当满足以下任一条件时,仅保存适配器权重:
args.tune_mm_mlp_adapter=Trueargs.mm_tunable_parts仅包含"mm_mlp_adapter"或"mm_vision_resampler"
在这种情况下,会保存:
- 模型配置文件 (
config.json) - 适配器权重文件 (
mm_projector.bin)
保存的权重包括:
mm_projector相关参数vision_resampler相关参数- 如果
use_im_start_end=True,还包括embed_tokens和embed_in
其他情况下,调用父类 Trainer._save_checkpoint() 进行完整模型保存。
注意事项:
- 该方法支持 DeepSpeed ZeRO-3 模式,会正确收集分布在多个 GPU 上的参数
- 仅在主进程(
local_rank == 0或local_rank == -1)上执行实际的保存操作
示例:
python
# 仅微调 MM Adapter 时的配置
training_args.tune_mm_mlp_adapter = True
# 或者通过 mm_tunable_parts 指定
training_args.mm_tunable_parts = "mm_mlp_adapter"
# 训练过程中的检查点将只包含适配器权重
# 保存路径示例: output_dir/checkpoint-1000/mm_projector.bin
create_optimizer
python
create_optimizer() -> torch.optim.Optimizer
创建优化器。支持为不同模块设置独立学习率(如 mm_projector 和 vision_tower)。
返回:
torch.optim.Optimizer-- 配置好的优化器实例。
注意事项:
在 MeZO 模式下,会创建一个虚拟优化器(dummy optimizer),实际参数更新由
zo_update方法执行。
get_train_dataloader
python
get_train_dataloader() -> DataLoader
创建并返回训练数据加载器。
返回:
torch.utils.data.DataLoader-- 训练数据加载器。
示例
基本使用
python
from llava.train.llava_trainer import LLaVATrainer
from transformers import TrainingArguments
# 配置训练参数
training_args = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
learning_rate=2e-5,
num_train_epochs=1,
group_by_modality_length=True, # 启用模态长度分组
)
# 创建训练器
trainer = LLaVATrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
)
# 开始训练
trainer.train()
# 保存模型
trainer.save_model("./final_model")
使用 MeZO 训练模式
python
from llava.train.llava_trainer import LLaVATrainer
# 配置 MeZO 相关参数
training_args.trainer_mode = "zo"
training_args.zo_eps = 1e-3
training_args.zo_num_directions = 1
# 创建训练器
trainer = LLaVATrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
)
# MeZO 模式训练
trainer.train()
设置模块独立学习率
python
# 为多模态投影层和视觉编码器设置独立学习率
training_args.mm_projector_lr = 1e-4
training_args.mm_vision_tower_lr = 2e-6
trainer = LLaVATrainer(
model=model,
args=training_args,
...
)
参见
transformers.Trainer-- 基类文档LengthGroupedSampler-- 长度分组采样器LLaVADPOTrainer-- 用于 DPO (Direct Preference Optimization) 训练的变体