【LLaVA-NeXT】LLaVATrainer说明

LLaVATrainer

python 复制代码
class llava.train.llava_trainer.LLaVATrainer(Trainer)

用于训练 LLaVA (Large Language and Vision Assistant) 多模态模型的训练器类,继承自 transformers.Trainer

该类在标准 Transformer Trainer 基础上扩展了以下功能:

  • 支持 MeZO (Memory-efficient Zeroth-Order Optimization) 零阶优化训练模式
  • 提供多种基于长度和模态的数据采样策略
  • 支持 DeepSpeedFSDP 分布式训练
  • 提供针对多模态适配器 (MM Adapter) 的特定检查点保存功能

参数

该类接受所有 transformers.Trainer 支持的关键字参数,同时支持以下额外参数(通过 args 传入):

参数 类型 默认值 描述
trainer_mode str "regular" 训练模式。可选 "regular"(常规反向传播训练)或 "zo"(MeZO 零阶优化训练)。
zo_eps float 1e-3 MeZO 超参数 epsilon,控制参数扰动的幅度。
zo_num_directions int 1 MeZO 优化中使用的随机方向数量。
group_by_length bool False 是否按序列长度分组采样。
group_by_modality_length bool False 是否按模态长度分组采样。
group_by_modality_length_auto bool False 是否使用自动模态长度分组采样。
group_by_varlen bool False 是否使用可变长度分组采样。
mm_projector_lr float, optional None 多模态投影层的独立学习率。
mm_vision_tower_lr float, optional None 视觉编码器的独立学习率。

属性

属性 类型 描述
trainer_mode str 当前训练模式("regular""zo")。
zo_eps float MeZO epsilon 超参数。
zo_num_directions int MeZO 随机方向数量。
trainable_params List[Tuple[str, Parameter]] 可训练参数列表,包含参数名称和参数本身。
mezo_update_history List[Dict] MeZO 更新历史记录,用于检查点恢复。

方法

zo_perturb_parameters

python 复制代码
zo_perturb_parameters(scaling_factor: float = 1.0) -> None

使用随机向量 z z z 扰动模型参数。

参数:

  • scaling_factor (float) -- 扰动的缩放因子。正值表示正向扰动,负值表示反向扰动。

示例:

python 复制代码
# 正向扰动
trainer.zo_perturb_parameters(scaling_factor=1.0)

# 反向扰动(恢复原始参数后再扰动)
trainer.zo_perturb_parameters(scaling_factor=-2.0)

zo_forward

python 复制代码
zo_forward(model: nn.Module, inputs: Dict) -> torch.Tensor

在推理模式下计算前向传播损失。

参数:

  • model (nn.Module) -- 需要计算损失的模型。
  • inputs (Dict) -- 输入批次数据。

返回:

  • torch.Tensor -- 计算得到的损失值(已 detach)。

zo_step

python 复制代码
zo_step(model: nn.Module, inputs: Dict) -> torch.Tensor

使用 MeZO 算法执行单步梯度估计。通过正向和反向扰动的损失差来近似梯度。

参数:

  • model (nn.Module) -- 模型实例。
  • inputs (Dict) -- 输入批次数据。

返回:

  • torch.Tensor -- 归一化后的损失值。

注意事项:

该方法在 gradient_accumulation_steps 期间累积多个方向的梯度估计,在 zo_update 中统一应用。


zo_update

python 复制代码
zo_update(learning_rate: float) -> None

根据累积的梯度估计更新模型参数。

参数:

  • learning_rate (float) -- 当前学习率。

注意事项:

  • 该方法自动处理 weight decay
  • biaslayer_normlayernorm 参数不会应用 weight decay
  • 调用后会清空累积的梯度估计

save_model

python 复制代码
save_model(output_dir: Optional[str] = None, _internal_call: bool = False)

保存模型检查点。当使用 MeZO 模式时,会额外保存轻量级的 MeZO 状态检查点。

参数:

  • output_dir (str, optional ) -- 保存路径。默认使用 args.output_dir
  • _internal_call (bool) -- 是否为内部调用。

_save_checkpoint

python 复制代码
_save_checkpoint(model, trial, metrics=None) -> None

保存训练检查点。该方法重写了父类的检查点保存逻辑,以支持仅保存多模态适配器 (MM Adapter) 权重的场景。

参数:

  • model -- 需要保存的模型实例。
  • trial -- 超参数搜索试验对象(用于确定输出目录)。
  • metrics (Dict, optional) -- 评估指标字典。

行为说明:

当满足以下任一条件时,仅保存适配器权重:

  • args.tune_mm_mlp_adapter=True
  • args.mm_tunable_parts 仅包含 "mm_mlp_adapter""mm_vision_resampler"

在这种情况下,会保存:

  • 模型配置文件 (config.json)
  • 适配器权重文件 (mm_projector.bin)

保存的权重包括:

  • mm_projector 相关参数
  • vision_resampler 相关参数
  • 如果 use_im_start_end=True,还包括 embed_tokensembed_in

其他情况下,调用父类 Trainer._save_checkpoint() 进行完整模型保存。

注意事项:

  • 该方法支持 DeepSpeed ZeRO-3 模式,会正确收集分布在多个 GPU 上的参数
  • 仅在主进程(local_rank == 0local_rank == -1)上执行实际的保存操作

示例:

python 复制代码
# 仅微调 MM Adapter 时的配置
training_args.tune_mm_mlp_adapter = True

# 或者通过 mm_tunable_parts 指定
training_args.mm_tunable_parts = "mm_mlp_adapter"

# 训练过程中的检查点将只包含适配器权重
# 保存路径示例: output_dir/checkpoint-1000/mm_projector.bin

create_optimizer

python 复制代码
create_optimizer() -> torch.optim.Optimizer

创建优化器。支持为不同模块设置独立学习率(如 mm_projectorvision_tower)。

返回:

  • torch.optim.Optimizer -- 配置好的优化器实例。

注意事项:

在 MeZO 模式下,会创建一个虚拟优化器(dummy optimizer),实际参数更新由 zo_update 方法执行。


get_train_dataloader

python 复制代码
get_train_dataloader() -> DataLoader

创建并返回训练数据加载器。

返回:

  • torch.utils.data.DataLoader -- 训练数据加载器。

示例

基本使用

python 复制代码
from llava.train.llava_trainer import LLaVATrainer
from transformers import TrainingArguments

# 配置训练参数
training_args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    num_train_epochs=1,
    group_by_modality_length=True,  # 启用模态长度分组
)

# 创建训练器
trainer = LLaVATrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
)

# 开始训练
trainer.train()

# 保存模型
trainer.save_model("./final_model")

使用 MeZO 训练模式

python 复制代码
from llava.train.llava_trainer import LLaVATrainer

# 配置 MeZO 相关参数
training_args.trainer_mode = "zo"
training_args.zo_eps = 1e-3
training_args.zo_num_directions = 1

# 创建训练器
trainer = LLaVATrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
)

# MeZO 模式训练
trainer.train()

设置模块独立学习率

python 复制代码
# 为多模态投影层和视觉编码器设置独立学习率
training_args.mm_projector_lr = 1e-4
training_args.mm_vision_tower_lr = 2e-6

trainer = LLaVATrainer(
    model=model,
    args=training_args,
    ...
)

参见

相关推荐
郑小路1 年前
LLaVA模型讲解与总结
llava
万里鹏程转瞬至1 年前
论文阅读:LLaVA-OneVision: Easy Visual Task Transfer
论文阅读·多模态·llava
William.csj1 年前
大模型——LLaVA和LLaMA的介绍和区别
llama·llava
alxe_made2 年前
VLM系列文章1-LLaVA
llm·vllm·llava
自律版光追2 年前
【书生·浦语大模型实战营第二期】XTuner微调LLM:1.8B、多模态、Agent——学习笔记4
笔记·学习·微调·internlm·llava·书生·浦语·xtuner
代码讲故事2 年前
LLaVA:GPT-4V(ision) 的新开源替代品
chatgpt·aigc·gpt4·llama·模型·gpt-4v·llava