众所周不知,mgm 源自于llava,那么它们在代码上有什么区别呢?
它们的投影层一致
llava_llama与mgm_llama
主要是forward函数的区别
llava
python
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
image_sizes
)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
mgm_llama
python
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
images_aux: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
images_aux
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
关于输入,只有这点不同,将image_size改为了image_aux 改变了输出
mgm调用了
self.model
的前向传播方法,并且在返回时进行了更复杂的处理,包括对输出进行处理生成logits
,并计算交叉熵损失函数,区分了训练和推理的模式返回不同输出。
对于返回值:
return_dict 传递自self.config.use_return_dict
如果返回字典,则与llava相同,返回自回归模型输出基类,分类概率logits,以及其他可选的输出,包括
- loss语言建模损失(用于下一个标记预测);
- past_key_values 包含预先计算的隐藏状态(self-attention块中的k和v),可用于(参见
past_key_values
输入)加速顺序解码; - hidden_states模型每层输出的隐藏状态,加上可选的初始嵌入输出;
- attentions 在softmax后的注意力权重,用于计算自我注意力头部的加权平均值。
如果返回的并不是字典:
python
output = (logits,) + outputs[1:] # `outputs`变量预计是模型返回的元组或另一个序列,其中可能包含各种元素,例如隐藏状态、注意力图或past_key_values,具体取决于模型的配置
return (loss,) + output if loss is not None else output
则返回由loss
和其他输出组成的元组
- 如果
loss
不是None
,则将其添加到output
元组前面,使返回元组的第一个元素成为损失值。这在您想要跟踪损失以进行优化的训练场景中很常见。 - 如果
loss
是None
,则在推理过程中可能会出现这种情况,即没有提供实际目标(标签)来计算损失,则该函数仅返回output
元组。
processor
mgm在vision tower中eva encoder中引入clip-p14的processor
同时,还定义了一个新的processor:video processor
适配视频输入(作为图像序列)
eva encoder
eva 其实就是eva clip,使用了对比学习的encoder-only模型,在mgm中它直接设定了网络。