(10-5-03)大模型时代的人形机器人感知:基于RoboBrain大模型的人形机器人通用智能感知系统(3)模型训练

10.5.3 模型训练

文件train/train.py是基于本实例的多模态(图文/视频)大语言模型训练脚本,核心功能包括定义模型、数据、训练三大类参数配置类,适配多模态模型不同组件(视觉塔、MLP适配器、语言模型等)的调优参数;提供了零冗余优化(Zero)兼容的参数处理、PEFT/LoRA权重提取、模型安全保存等工具函数,以及分词器与嵌入层自适应调整的功能;针对Llama2、Gemma、Qwen、Llama3、MPT等不同架构模型实现了对应的对话数据预处理逻辑,涵盖了多模态token(图片/视频)处理、对话prompt模板适配、训练标签掩码(仅保留模型回复部分为训练目标,忽略人类输入)等关键步骤,最终支撑多模态语言模型的训练全流程。

文件train/train.py的主要代码如下所示。

(1)下面代码的功能是定义多模态模型训练的核心参数配置类,统一管理模型基础信息、多模态组件(视觉塔、MLP 适配器、视觉重采样器等)的微调 / 冻结策略、模型版本及特殊格式配置(如图片token样式)等参数,为多模态模型的初始化和训练策略控制提供标准化参数支撑。

python 复制代码
@dataclass
class ModelArguments:
    # 模型名称或预训练权重路径(示例值为facebook/opt-125m)
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    # 模型类名称,用于初始化模型类,格式为XXXXForCausalLM。例如目前XXXX可选LlavaLlama、LlavaMixtral、LlavaMistral、Llama
    model_class_name: Optional[str] = field(default=None, metadata={"help": "用于初始化模型类,格式为XXXXForCausalLM。例如目前XXXX可选LlavaLlama、LlavaMixtral、LlavaMistral、Llama"})

    # 多模态模型可微调的部分,会覆盖之前的微调设置,可选值如"mm_mlp_adapter"、"mm_vision_tower,mm_mlp_adapter"等
    mm_tunable_parts: Optional[str] = field(
        default=None, metadata={"help": '可选值为"mm_mlp_adapter"、"mm_vision_resampler"、"mm_vision_tower,mm_mlp_adapter,mm_language_model"等,用于指定多模态模型要微调的部分'}
    )
    # 模型版本
    version: Optional[str] = field(default="v0")
    # 是否冻结模型主干网络
    freeze_backbone: bool = field(default=False)
    # 是否微调多模态MLP适配器
    tune_mm_mlp_adapter: bool = field(default=False)
    # 是否微调多模态视觉重采样器
    tune_mm_vision_resampler: bool = field(default=False)
    # 视觉塔模型的路径
    vision_tower: Optional[str] = field(default=None)
    # 视觉塔预训练权重路径(默认使用最后一层的权重)
    vision_tower_pretrained: Optional[str] = field(default=None)  # 默认使用最后一层

    # 是否解冻多模态视觉塔
    unfreeze_mm_vision_tower: bool = field(default=False)
    # 是否解冻语言模型
    unfreeze_language_model: bool = field(default=False)
    # 选择视觉塔的特征输出层(默认选最后一层)
    mm_vision_select_layer: Optional[int] = field(default=-1)  # 默认使用最后一层
    # 多模态MLP适配器的预训练权重路径
    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
    # 多模态投影器的类型(默认线性投影)
    mm_projector_type: Optional[str] = field(default="linear")
    # 是否使用图片开始/结束特殊token
    mm_use_im_start_end: bool = field(default=False)

(2)下面代码的功能是为HuggingFace Trainer提供安全的模型保存逻辑,适配多模态模型的差异化保存需求(如仅保存适配器权重),兼容DeepSpeed Zero优化策略避免权重保存异常,同时按训练场景分类保存完整模型或仅多模态适配器权重,降低存储开销。

python 复制代码
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
    """收集模型状态字典并保存到磁盘,适配多模态模型的差异化保存需求"""
    # 判断是否仅需要保存多模态适配器相关权重
    if hasattr(trainer.args, "tune_mm_mlp_adapter") and trainer.args.tune_mm_mlp_adapter:
        check_only_save_mm_adapter_tunnable = True
    # 仅当可微调部分为mm_mlp_adapter或mm_vision_resampler时,仅保存适配器
    elif hasattr(trainer.args, "mm_tunable_parts") and (len(trainer.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in trainer.args.mm_tunable_parts or "mm_vision_resampler" in trainer.args.mm_tunable_parts)):
        check_only_save_mm_adapter_tunnable = True
    else:
        check_only_save_mm_adapter_tunnable = False

    # 等待所有进程同步,保证多卡训练时保存一致
    trainer.accelerator.wait_for_everyone()
    torch.cuda.synchronize()
    rank0_print(f"仅保存投影器权重: {check_only_save_mm_adapter_tunnable}")
    if check_only_save_mm_adapter_tunnable:
        # 仅保存适配器权重,匹配的关键词包括mm_projector和vision_resampler
        keys_to_match = ["mm_projector", "vision_resampler"]
        if getattr(trainer.args, "use_im_start_end", False):
            keys_to_match.extend(["embed_tokens", "embed_in"])

        # 获取Zero优化兼容的适配器权重(处理分布式权重分片)
        weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
        # 保存模型配置文件
        trainer.model.config.save_pretrained(output_dir)

        # 拆分保存路径,按checkpoint分类保存适配器权重
        current_folder = output_dir.split("/")[-1]
        parent_folder = os.path.dirname(output_dir)
        if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
            if current_folder.startswith("checkpoint-"):
                mm_projector_folder = os.path.join(parent_folder, "mm_projector")
                os.makedirs(mm_projector_folder, exist_ok=True)
                torch.save(weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin"))
            else:
                torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
        return

    # 如果使用DeepSpeed训练,直接调用内置保存方法
    if trainer.deepspeed:
        trainer.save_model(output_dir)
        return

    # 常规保存:收集模型状态字典并转CPU保存(避免GPU显存占用)
    state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa

(3)下面代码的功能是针对Llama2架构的对话数据进行标准化预处理,适配多模态(含图片token)场景,通过Llama2专属prompt模板生成训练用对话文本,对文本token化后掩码掉人类输入部分的标签(仅保留模型回复作为训练目标),确保训练聚焦于模型生成逻辑。

python 复制代码
def preprocess_llama_2(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
    # 复制默认对话模板,适配Llama2的prompt格式规范
    conv = conversation_lib.default_conversation.copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

    # 应用prompt模板,标准化多轮对话的格式
    conversations = []
    for i, source in enumerate(sources):
        # 过滤第一条非人类发送的消息(保证对话以人类提问开头)
        if roles[source[0]["from"]] != conv.roles[0]:
            source = source[1:]

        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            # 校验对话角色是否符合交替规则(人类-模型-人类...)
            assert role == conv.roles[j % 2], f"第{i}条数据的角色匹配异常"
            conv.append_message(role, sentence["value"])
        # 生成Llama2格式的标准化prompt文本
        conversations.append(conv.get_prompt())

    # 对对话文本进行Token化,适配多模态(图片token)场景
    if has_image:
        # 处理含图片token的prompt,生成input_ids(适配多模态token解析)
        input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
    else:
        # 纯文本prompt的token化(按最长序列填充、超长截断)
        input_ids = tokenizer(
            conversations,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        ).input_ids

    # 复制input_ids作为标签,后续掩码非训练目标部分
    targets = input_ids.clone()

    # 校验对话分隔符样式为Llama2专属格式
    assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2

    # 掩码标签:仅保留模型回复部分,人类输入设为IGNORE_INDEX(不参与损失计算)
    sep = "[/INST] "
    for conversation, target in zip(conversations, targets):
        # 计算有效序列长度(排除padding token)
        total_len = int(target.ne(tokenizer.pad_token_id).sum())

        # 按Llama2分隔符拆分多轮对话
        rounds = conversation.split(conv.sep2)
        cur_len = 1
        # 掩码bos_token所在位置
        target[:cur_len] = IGNORE_INDEX
        for i, rou in enumerate(rounds):
            if rou == "":
                break

            parts = rou.split(sep)
            if len(parts) != 2:
                break
            parts[0] += sep

            # 计算单轮对话长度和人类输入部分长度(适配多模态token)
            if has_image:
                round_len = len(tokenizer_image_token(rou, tokenizer))
                instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
            else:
                round_len = len(tokenizer(rou).input_ids)
                instruction_len = len(tokenizer(parts[0]).input_ids) - 2

            # 掩码人类输入部分的标签(不计算损失)
            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

            cur_len += round_len
        # 掩码超出有效长度的部分
        target[cur_len:] = IGNORE_INDEX

        # 校验token化长度,不匹配则全掩码(避免训练异常)
        if cur_len < tokenizer.model_max_length:
            if cur_len != total_len:
                target[:] = IGNORE_INDEX
                print(f"警告:token化长度不匹配:{cur_len} vs. {total_len},已忽略该条数据")

    return dict(
        input_ids=input_ids,
        labels=targets,
    )

(4)下面代码的功能是动态调整分词器和模型嵌入层的大小,适配新增的特殊token(如图片/视频 token);对新增token的嵌入权重采用原有嵌入权重的均值初始化,避免新增token导致的训练不稳定,保证嵌入层参数的合理性。

python 复制代码
def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
):
    """调整分词器和嵌入层大小(非优化版本,可能导致嵌入层大小无法被64整除)"""
    # 向分词器添加新的特殊token,返回新增token的数量
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    # 调整模型嵌入层的大小,适配新增后的分词器词汇表
    model.resize_token_embeddings(len(tokenizer))

    # 若有新增token,对其嵌入权重进行均值初始化(基于原有权重)
    if num_new_tokens > 0:
        # 获取模型输入嵌入层和输出嵌入层的权重数据
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        # 计算原有嵌入权重的均值(排除新增token对应的权重位置)
        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

        # 用原有权重的均值初始化新增token的嵌入权重
        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg

(5)下面代码的功能是对多模态训练数据进行格式标准化预处理,修复图片token的非规范位置(保证token在句首),适配im_start/im_end特殊token的使用场景,同时清理数据中的冗余噪声文本,确保多模态输入(图片/视频)的token格式统一,避免训练时格式异常。

python 复制代码
def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
    # 判断当前训练是否为多模态场景
    is_multimodal = data_args.is_multimodal
    if not is_multimodal:
        return sources

    # 遍历每条对话数据,标准化多模态token格式
    for source in sources:
        for sentence in source:
            # 统计单条语句中的图片token数量(仅处理单图片场景)
            num_im = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
            # 修复图片token不在句首的问题(规范格式:图片token在前,文本在后)
            if num_im == 1 and DEFAULT_IMAGE_TOKEN in sentence["value"] and not sentence["value"].startswith(DEFAULT_IMAGE_TOKEN):
                # 移除原有图片token并清理首尾空格
                sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
                # 将图片token置于句首,保证多模态输入格式统一
                sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
                sentence["value"] = sentence["value"].strip()
                # 适配带mmtag的对话版本,用<Image>标签包裹图片token
                if "mmtag" in conversation_lib.default_conversation.version:
                    sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>")
            
            # 替换图片token为带im_start/im_end的格式(若开启该配置)
            replace_token = DEFAULT_IMAGE_TOKEN
            if data_args.mm_use_im_start_end:
                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
            sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)

            # 清理VideoInstruct-100k数据集的冗余噪声文本(临时数据清洗方案)
            sentence["value"] = sentence["value"].replace("QA_GT_caption_based_noisy", "")

    return sources

到此为止,本实例的核心功能介绍完毕,机器人可以调用本实例大模型进行推理测试,例如图10-8展示了机器人操作任务轨迹的示意图,分别展示了"拿起刀(pick up the knife)"与"够到香蕉(reach for the banana)"两个任务场景下的动作规划:图中蓝色路径(标注"reason")是机器人的动作轨迹,星星标记("s")代表机器人操作的起点,粉色方块("e")对应任务目标的终点,直观体现了机器人接收语言任务指令后,从初始位置到目标物体的动作路径规划过程。

图10-8 机器人操作任务轨迹的示意图

相关推荐
陈天伟教授1 小时前
人工智能应用- 预测新冠病毒传染性:04. 中国:强力措施遏制疫情
前端·人工智能·安全·xss·csrf
双星系统2 小时前
OpenClaw本地部署完全指南:2026年让AI真正“动手干活”
人工智能
火山引擎开发者社区2 小时前
ArkClaw“虾塘”再进化,开启“无忧养虾”
人工智能
火山引擎开发者社区2 小时前
从 Vibe Coding 到 Agentic Engineering:ArkClaw + Supabase,打造你的私有化 Agent 工厂
人工智能
七牛云行业应用2 小时前
GPT-5.4 mini 与 nano 深度评测:核心差异、API 成本实测与选型指南
人工智能·openai·api调用·gpt-5.4·大模型降本
cxr8282 小时前
PaperclipAI 组织关系与智能体协作指南
数据库·人工智能·架构·ai智能体·openclaw
大傻^3 小时前
Spring AI Alibaba RAG实战:基于向量存储的检索增强生成
java·人工智能·spring
Physicist in Geophy.3 小时前
claude code workflow
人工智能