
10.5.3 模型训练
文件train/train.py是基于本实例的多模态(图文/视频)大语言模型训练脚本,核心功能包括定义模型、数据、训练三大类参数配置类,适配多模态模型不同组件(视觉塔、MLP适配器、语言模型等)的调优参数;提供了零冗余优化(Zero)兼容的参数处理、PEFT/LoRA权重提取、模型安全保存等工具函数,以及分词器与嵌入层自适应调整的功能;针对Llama2、Gemma、Qwen、Llama3、MPT等不同架构模型实现了对应的对话数据预处理逻辑,涵盖了多模态token(图片/视频)处理、对话prompt模板适配、训练标签掩码(仅保留模型回复部分为训练目标,忽略人类输入)等关键步骤,最终支撑多模态语言模型的训练全流程。
文件train/train.py的主要代码如下所示。
(1)下面代码的功能是定义多模态模型训练的核心参数配置类,统一管理模型基础信息、多模态组件(视觉塔、MLP 适配器、视觉重采样器等)的微调 / 冻结策略、模型版本及特殊格式配置(如图片token样式)等参数,为多模态模型的初始化和训练策略控制提供标准化参数支撑。
python
@dataclass
class ModelArguments:
# 模型名称或预训练权重路径(示例值为facebook/opt-125m)
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
# 模型类名称,用于初始化模型类,格式为XXXXForCausalLM。例如目前XXXX可选LlavaLlama、LlavaMixtral、LlavaMistral、Llama
model_class_name: Optional[str] = field(default=None, metadata={"help": "用于初始化模型类,格式为XXXXForCausalLM。例如目前XXXX可选LlavaLlama、LlavaMixtral、LlavaMistral、Llama"})
# 多模态模型可微调的部分,会覆盖之前的微调设置,可选值如"mm_mlp_adapter"、"mm_vision_tower,mm_mlp_adapter"等
mm_tunable_parts: Optional[str] = field(
default=None, metadata={"help": '可选值为"mm_mlp_adapter"、"mm_vision_resampler"、"mm_vision_tower,mm_mlp_adapter,mm_language_model"等,用于指定多模态模型要微调的部分'}
)
# 模型版本
version: Optional[str] = field(default="v0")
# 是否冻结模型主干网络
freeze_backbone: bool = field(default=False)
# 是否微调多模态MLP适配器
tune_mm_mlp_adapter: bool = field(default=False)
# 是否微调多模态视觉重采样器
tune_mm_vision_resampler: bool = field(default=False)
# 视觉塔模型的路径
vision_tower: Optional[str] = field(default=None)
# 视觉塔预训练权重路径(默认使用最后一层的权重)
vision_tower_pretrained: Optional[str] = field(default=None) # 默认使用最后一层
# 是否解冻多模态视觉塔
unfreeze_mm_vision_tower: bool = field(default=False)
# 是否解冻语言模型
unfreeze_language_model: bool = field(default=False)
# 选择视觉塔的特征输出层(默认选最后一层)
mm_vision_select_layer: Optional[int] = field(default=-1) # 默认使用最后一层
# 多模态MLP适配器的预训练权重路径
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
# 多模态投影器的类型(默认线性投影)
mm_projector_type: Optional[str] = field(default="linear")
# 是否使用图片开始/结束特殊token
mm_use_im_start_end: bool = field(default=False)
(2)下面代码的功能是为HuggingFace Trainer提供安全的模型保存逻辑,适配多模态模型的差异化保存需求(如仅保存适配器权重),兼容DeepSpeed Zero优化策略避免权重保存异常,同时按训练场景分类保存完整模型或仅多模态适配器权重,降低存储开销。
python
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""收集模型状态字典并保存到磁盘,适配多模态模型的差异化保存需求"""
# 判断是否仅需要保存多模态适配器相关权重
if hasattr(trainer.args, "tune_mm_mlp_adapter") and trainer.args.tune_mm_mlp_adapter:
check_only_save_mm_adapter_tunnable = True
# 仅当可微调部分为mm_mlp_adapter或mm_vision_resampler时,仅保存适配器
elif hasattr(trainer.args, "mm_tunable_parts") and (len(trainer.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in trainer.args.mm_tunable_parts or "mm_vision_resampler" in trainer.args.mm_tunable_parts)):
check_only_save_mm_adapter_tunnable = True
else:
check_only_save_mm_adapter_tunnable = False
# 等待所有进程同步,保证多卡训练时保存一致
trainer.accelerator.wait_for_everyone()
torch.cuda.synchronize()
rank0_print(f"仅保存投影器权重: {check_only_save_mm_adapter_tunnable}")
if check_only_save_mm_adapter_tunnable:
# 仅保存适配器权重,匹配的关键词包括mm_projector和vision_resampler
keys_to_match = ["mm_projector", "vision_resampler"]
if getattr(trainer.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
# 获取Zero优化兼容的适配器权重(处理分布式权重分片)
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
# 保存模型配置文件
trainer.model.config.save_pretrained(output_dir)
# 拆分保存路径,按checkpoint分类保存适配器权重
current_folder = output_dir.split("/")[-1]
parent_folder = os.path.dirname(output_dir)
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
if current_folder.startswith("checkpoint-"):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin"))
else:
torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
return
# 如果使用DeepSpeed训练,直接调用内置保存方法
if trainer.deepspeed:
trainer.save_model(output_dir)
return
# 常规保存:收集模型状态字典并转CPU保存(避免GPU显存占用)
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
(3)下面代码的功能是针对Llama2架构的对话数据进行标准化预处理,适配多模态(含图片token)场景,通过Llama2专属prompt模板生成训练用对话文本,对文本token化后掩码掉人类输入部分的标签(仅保留模型回复作为训练目标),确保训练聚焦于模型生成逻辑。
python
def preprocess_llama_2(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
# 复制默认对话模板,适配Llama2的prompt格式规范
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# 应用prompt模板,标准化多轮对话的格式
conversations = []
for i, source in enumerate(sources):
# 过滤第一条非人类发送的消息(保证对话以人类提问开头)
if roles[source[0]["from"]] != conv.roles[0]:
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
# 校验对话角色是否符合交替规则(人类-模型-人类...)
assert role == conv.roles[j % 2], f"第{i}条数据的角色匹配异常"
conv.append_message(role, sentence["value"])
# 生成Llama2格式的标准化prompt文本
conversations.append(conv.get_prompt())
# 对对话文本进行Token化,适配多模态(图片token)场景
if has_image:
# 处理含图片token的prompt,生成input_ids(适配多模态token解析)
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
else:
# 纯文本prompt的token化(按最长序列填充、超长截断)
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
# 复制input_ids作为标签,后续掩码非训练目标部分
targets = input_ids.clone()
# 校验对话分隔符样式为Llama2专属格式
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
# 掩码标签:仅保留模型回复部分,人类输入设为IGNORE_INDEX(不参与损失计算)
sep = "[/INST] "
for conversation, target in zip(conversations, targets):
# 计算有效序列长度(排除padding token)
total_len = int(target.ne(tokenizer.pad_token_id).sum())
# 按Llama2分隔符拆分多轮对话
rounds = conversation.split(conv.sep2)
cur_len = 1
# 掩码bos_token所在位置
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
# 计算单轮对话长度和人类输入部分长度(适配多模态token)
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
# 掩码人类输入部分的标签(不计算损失)
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
# 掩码超出有效长度的部分
target[cur_len:] = IGNORE_INDEX
# 校验token化长度,不匹配则全掩码(避免训练异常)
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"警告:token化长度不匹配:{cur_len} vs. {total_len},已忽略该条数据")
return dict(
input_ids=input_ids,
labels=targets,
)
(4)下面代码的功能是动态调整分词器和模型嵌入层的大小,适配新增的特殊token(如图片/视频 token);对新增token的嵌入权重采用原有嵌入权重的均值初始化,避免新增token导致的训练不稳定,保证嵌入层参数的合理性。
python
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""调整分词器和嵌入层大小(非优化版本,可能导致嵌入层大小无法被64整除)"""
# 向分词器添加新的特殊token,返回新增token的数量
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
# 调整模型嵌入层的大小,适配新增后的分词器词汇表
model.resize_token_embeddings(len(tokenizer))
# 若有新增token,对其嵌入权重进行均值初始化(基于原有权重)
if num_new_tokens > 0:
# 获取模型输入嵌入层和输出嵌入层的权重数据
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
# 计算原有嵌入权重的均值(排除新增token对应的权重位置)
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
# 用原有权重的均值初始化新增token的嵌入权重
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
(5)下面代码的功能是对多模态训练数据进行格式标准化预处理,修复图片token的非规范位置(保证token在句首),适配im_start/im_end特殊token的使用场景,同时清理数据中的冗余噪声文本,确保多模态输入(图片/视频)的token格式统一,避免训练时格式异常。
python
def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
# 判断当前训练是否为多模态场景
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
# 遍历每条对话数据,标准化多模态token格式
for source in sources:
for sentence in source:
# 统计单条语句中的图片token数量(仅处理单图片场景)
num_im = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
# 修复图片token不在句首的问题(规范格式:图片token在前,文本在后)
if num_im == 1 and DEFAULT_IMAGE_TOKEN in sentence["value"] and not sentence["value"].startswith(DEFAULT_IMAGE_TOKEN):
# 移除原有图片token并清理首尾空格
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
# 将图片token置于句首,保证多模态输入格式统一
sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
sentence["value"] = sentence["value"].strip()
# 适配带mmtag的对话版本,用<Image>标签包裹图片token
if "mmtag" in conversation_lib.default_conversation.version:
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>")
# 替换图片token为带im_start/im_end的格式(若开启该配置)
replace_token = DEFAULT_IMAGE_TOKEN
if data_args.mm_use_im_start_end:
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
# 清理VideoInstruct-100k数据集的冗余噪声文本(临时数据清洗方案)
sentence["value"] = sentence["value"].replace("QA_GT_caption_based_noisy", "")
return sources
到此为止,本实例的核心功能介绍完毕,机器人可以调用本实例大模型进行推理测试,例如图10-8展示了机器人操作任务轨迹的示意图,分别展示了"拿起刀(pick up the knife)"与"够到香蕉(reach for the banana)"两个任务场景下的动作规划:图中蓝色路径(标注"reason")是机器人的动作轨迹,星星标记("s")代表机器人操作的起点,粉色方块("e")对应任务目标的终点,直观体现了机器人接收语言任务指令后,从初始位置到目标物体的动作路径规划过程。

图10-8 机器人操作任务轨迹的示意图