[CLIP-VIT-L + Qwen] 多模态大模型源码阅读 - 语言模型篇(4)

[CLIP-VIT-L + Qwen] 多模态大模型学习笔记 - 语言模型篇(4)

参考repo:WatchTower-Liu/VLM-learning; url: VLLM-BASE

前情提要

有关MQwenModel的代码请看(多模态大模型源码阅读 - 1多模态大模型源码阅读 - 2多模态大模型源码阅读 - 3

本节中将接着看MQwen.py中的剩余源码,即MQwenLMHeadModel和main函数源码,MQwen.py重构了Qwen大模型中QwenModel的前向传播代码和QwenLMHeadModel的部分代码,以适配视觉编码器CLIP-VIT-L和语言模型Qwen的多模态架构,QwenModel类作为基座模型,QwenLMHeadModel 是基于 QwenModel 的一个扩展,加入了针对特定下游任务的头,在后续我们主要使用重写后的MQwenLMHeadModel作为多模态架构中的语言模型。

源码解读(MQwenLMHeadModel类)

init函数

python 复制代码
  class MQWenLMHeadModel(QWenLMHeadModel):  
    def __init__(self, config, otherConfig):
        super().__init__(config)

        self.transformer = MQWenModel(config, otherConfig)

        if config.bf16:
            self.transformer.bfloat16()
    
        if config.fp16:
            self.transformer.half()

总体含义

初始化一个使用MQwenModel的类的实例,以便后续使用MQwenModel进行前向传播。

逐行解读

config 和 otherconfig一个作为初始化模型的通用配置参数,一个是用户自定义的额外参数传入。

python 复制代码
    def __init__(self, config, otherConfig):

使用通用配置参数初始化父类。

python 复制代码
 super().__init__(config)

初始化基座模型MQwenModel用于前向传播,传递入通用配置参数和自定义参数,赋值给成员变量self.transformer

python 复制代码
self.transformer = MQWenModel(config, otherConfig)

确定将模型的权重转换为bf16(brain float 16)还是单精度浮点数(float16)数据格式,bf16和双精度浮点数(float32)有相同的动态范围,但是只需要16位的存储空间,可以看做单精度浮点数(float16)的变体。这两种精度都支持自动混合精度训练,可以减少内存占用,提高性能。

python 复制代码
        if config.bf16:
            self.transformer.bfloat16()
    
        if config.fp16:
            self.transformer.half()

prepare_inputs_for_generation函数

python 复制代码
    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values:
            input_ids = input_ids[:, -1].unsqueeze(-1)

        if input_ids.size(0) == 1:
            attention_mask = None
        else:
            attention_mask = kwargs.get("attention_mask", None)

        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "images": kwargs.get("images")
            }
        )
        return model_inputs

整体含义

这段代码主要用于准备模型输入部分,对Input_ids,past_key_values等参数进行预处理,并返回一个包含这些值预处理结果的字典

逐行解读

input_ids: 由分词后的token被映射为词汇表中的唯一数字索引。例如'hello, world'分词后被映射为{111,222}(这里仅举例,不代表真实索引)

past_key_values:过去时间步中处理的序列输入数据的键值对缓存结果,通常为一个元组。

inputs_embeds:input_ids经过word_embedding层处理为的词嵌入向量。

一般来说,只需要传入input_ids和input_embeds中的其中一个即可、

python 复制代码
    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
    ):

判断当前是否为推理、训练的第一步(如果启用use_cache的话)。input_ids的size通常为(batch_size,seq_len),这里我们只取每一批次input_ids的最后一个索引,这是因为如果我们有缓存的键值对,那么就无需重复计算先前缓存的键值对。unqueeze(-1)是因为当我们只取一个索引的时候,Input_ids会降维成(batch_size,),因此我们需要重新将其扩充为二维张量,size为(batch_size,1)。

python 复制代码
        if past_key_values:
            input_ids = input_ids[:, -1].unsqueeze(-1)

如果当前批次为1,说明输入只有单个样本,因此不需要使用注意力掩码。否则我们就从关键词参数中获取,如果提供的话,否则置为None。

python 复制代码
        if input_ids.size(0) == 1:
            attention_mask = None
        else:
            attention_mask = kwargs.get("attention_mask", None)

如果我们传入了inputs_embeds并且有缓存的键值对,这代表我们处于推理或训练的中间步骤,我们初始化一个包含input_embeds的字典,否则初始化一个包含input_ids的字典。

python 复制代码
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

我们用update函数更新字典,update函数接受一个字典参数,并且覆盖原字典中键相同元素的值。最后将字典作为返回值返回。

python 复制代码
        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "images": kwargs.get("images")
            }
        )
        return model_inputs

forward函数

python 复制代码
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        images: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )
        transformer_outputs = self.transformer(
            input_ids,
            images=images,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]

        lm_logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            labels = labels.to(lm_logits.device)
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

整体含义

这段代码是MQwenLMHeadModel的前向传播函数,包括输入处理、损失计算和输出格式化。

逐行解读

对于传递的参数不在赘述,在前几期的笔记中都有详细记载,除了labels参数,这一参数用于有监督训练,作为标签计算损失。

python 复制代码
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        images: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

确定返回值的输出格式,如果return_dict不为None,则以字典形式返回,否则获取通用配置参数中的默认值。

python 复制代码
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

使用MQwenModel的前向传播函数获取输出,输出通常包括最后一层的隐藏状态,缓存的键值对,注意力分数,每一层的隐藏状态等等,具体可以参考repo中的MQwenModel的前向传播函数,或者前几期的学习笔记。

python 复制代码
        transformer_outputs = self.transformer(
            input_ids,
            images=images,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

获取last_hidden_state,使用语言头来处理并生成预测用的对数概率。这里的self.lm_head继承自QwenLMHeadModel,具体可以参考Qwen模型的源码。

初始化loss为None,以便后续更新loss值。

python 复制代码
        hidden_states = transformer_outputs[0]

        lm_logits = self.lm_head(hidden_states)

        loss = None

如果提供了训练用的标签,首先将它转移到对数概率挂载的设备(gpu或者cpu)上,取对数概率除了最后一个时间步的所有元素,labels取除了第一个标签外的所有标签,这样做是为了让预测结果和实际标签错开,让每个时间步的输出都有对应的下一个词的标签。例如我们的输出'ABCD',标签也为'ABCD',经过处理后输出为'ABC',标签为'BCD',这样A对应B,C对应D,每个时间步的输出结果都有对应的下一个词作为标签。contiguous()函数目的是让变量在内存中连续。

损失函数设定为交叉熵损失函数,shift_logits原本的size为(batch_size, seq_len - 1,vocab_size),shift_labels的size为(batc_size,seq_len - 1),将这两个变量除了最后一个维度,其余维度展平,以便进行损失计算。

python 复制代码
        if labels is not None:
            labels = labels.to(lm_logits.device)
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )

如果不设定返回值类型为字典。则初始化output为对数概率加上前向传播输出除了last_hidden_state外的所有输出。

如果损失值不为空,将其与output一起返回,否则只返回output

python 复制代码
        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

反之,我们返回一个CausalLMOutputWithPast类型的输出结果,这个类用于封装因果模型的前向传播输出结果

python 复制代码
        return CausalLMOutputWithPast(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

main函数

python 复制代码
def main():
    MQ = MQWenLMHeadModel.from_pretrained("huggingface_model/qwen/Qwen-1_8B/", torch_dtype = torch.bfloat16, trust_remote_code = True)

if __name__ == "__main__":
    main()

逐行解读

使用MQWenLMHeadModel从huggingface加载预训练模型的权重配置,在保留原有模型的基础功能和预训练权重的同时,添加新的功能或改进现有功能。frompretrained方法通常继承自huggingface的pretrainedmodel类

python 复制代码
MQ = MQWenLMHeadModel.from_pretrained("huggingface_model/qwen/Qwen-1_8B/", torch_dtype = torch.bfloat16, trust_remote_code = True)

至此,MQwen.py讲解完毕,后续会讲解repo中的其余部分,并动手训练实现一个多模态大模型

相关推荐
m0_7431064642 分钟前
【论文笔记】MV-DUSt3R+:两秒重建一个3D场景
论文阅读·深度学习·计算机视觉·3d·几何学
m0_7431064644 分钟前
【论文笔记】TranSplat:深度refine的camera-required可泛化稀疏方法
论文阅读·深度学习·计算机视觉·3d·几何学
Coovally AI模型快速验证4 小时前
MMYOLO:打破单一模式限制,多模态目标检测的革命性突破!
人工智能·算法·yolo·目标检测·机器学习·计算机视觉·目标跟踪
AI浩4 小时前
【面试总结】FFN(前馈神经网络)在Transformer模型中先升维再降维的原因
人工智能·深度学习·计算机视觉·transformer
可为测控4 小时前
图像处理基础(4):高斯滤波器详解
人工智能·算法·计算机视觉
佛州小李哥6 小时前
Agent群舞,在亚马逊云科技搭建数字营销多代理(Multi-Agent)(下篇)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
bohu837 小时前
OpenCV笔记3-图像修复
笔记·opencv·图像修复·亮度增强·图片磨皮
大丈夫立于天地间8 小时前
ISIS基础知识
网络·网络协议·学习·智能路由器·信息与通信
old_power8 小时前
【PCL】Segmentation 模块—— 基于图割算法的点云分割(Min-Cut Based Segmentation)
c++·算法·计算机视觉·3d
doubt。8 小时前
【BUUCTF】[RCTF2015]EasySQL1
网络·数据库·笔记·mysql·安全·web安全