源码看MGM

众所周不知,mgm 源自于llava,那么它们在代码上有什么区别呢?

它们的投影层一致

llava_llama与mgm_llama

主要是forward函数的区别

llava

python 复制代码
def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[List[List[int]]] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        if inputs_embeds is None:
            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                labels,
                images,
                image_sizes
            )

        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
)

mgm_llama

python 复制代码
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        images_aux: Optional[torch.FloatTensor] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if inputs_embeds is None:
            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                labels,
                images,
                images_aux
            )

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        if self.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

关于输入,只有这点不同,将image_size改为了image_aux 改变了输出

mgm调用了self.model的前向传播方法,并且在返回时进行了更复杂的处理,包括对输出进行处理生成logits,并计算交叉熵损失函数,区分了训练和推理的模式返回不同输出。

对于返回值:

return_dict 传递自self.config.use_return_dict

如果返回字典,则与llava相同,返回自回归模型输出基类,分类概率logits,以及其他可选的输出,包括

  • loss语言建模损失(用于下一个标记预测);
  • past_key_values 包含预先计算的隐藏状态(self-attention块中的k和v),可用于(参见 past_key_values 输入)加速顺序解码;
  • hidden_states模型每层输出的隐藏状态,加上可选的初始嵌入输出;
  • attentions 在softmax后的注意力权重,用于计算自我注意力头部的加权平均值。

如果返回的并不是字典:

python 复制代码
output = (logits,) + outputs[1:] # `outputs`变量预计是模型返回的元组或另一个序列,其中可能包含各种元素,例如隐藏状态、注意力图或past_key_values,具体取决于模型的配置
return (loss,) + output if loss is not None else output 

则返回由loss和其他输出组成的元组

  • 如果loss不是None,则将其添加到output元组前面,使返回元组的第一个元素成为损失值。这在您想要跟踪损失以进行优化的训练场景中很常见。
  • 如果lossNone,则在推理过程中可能会出现这种情况,即没有提供实际目标(标签)来计算损失,则该函数仅返回output元组。

processor

mgm在vision tower中eva encoder中引入clip-p14的processor

同时,还定义了一个新的processor:video processor适配视频输入(作为图像序列)

eva encoder

eva 其实就是eva clip,使用了对比学习的encoder-only模型,在mgm中它直接设定了网络。

相关推荐
ModestCoder_14 分钟前
强化学习 Policy 的 Tracking 能力全解析,以Legged_gym为例解说Policy的训练流程
人工智能·算法·自然语言处理·机器人·具身智能
小白程序员成长日记35 分钟前
2025.12.02 力扣每日一题
数据结构·算法·leetcode
永远都不秃头的程序员(互关)39 分钟前
在vscodeC语言多文件编译实战指南
c语言·数据结构·算法
立志成为大牛的小牛1 小时前
数据结构——五十三、处理冲突的方法——拉链法(王道408)
数据结构·学习·考研·算法
吃着火锅x唱着歌1 小时前
LeetCode 3583.统计特殊三元组
算法·leetcode·职场和发展
FPGA_无线通信1 小时前
OFDM 频偏补偿和相位跟踪(2)
算法·fpga开发
SHOJYS1 小时前
思维难度较大 贪心优化背包 [USACO22DEC] Bribing Friends G
数据结构·算法·深度优先
啊董dong1 小时前
课后作业-2025年12月07号作业
数据结构·c++·算法·深度优先·noi
无限进步_2 小时前
C语言宏的魔法:探索offsetof与位交换的奇妙世界
c语言·开发语言·windows·后端·算法·visual studio
Lucky“经营分析”2 小时前
经营分析师-《经营分析能力》
算法