【多模态】Flamingo模型技术学习

Flamingo模型技术学习

前言

最近多模态模型特别火,从头开始学习!在前面写的几篇里面学习了MiniCPM-V、ViT、CLIP和BLIP/BLIP-2之后,今天学习一下Flamingo模型,记录学习过程,欢迎批评指正,一起学习~~

Flamingo------支持上下文学习的多模态模型

  • Flamingo出自Deepmind在NeurIPS 2022的论文Flamingo: a Visual Language Model for Few-Shot Learning
  • 相比于之前的方案,之前的方案通常只支持VQA,Flamingo支持图片和文本内容混合输入,支持提供图-文示例进行In-context learning,支持提供多图多轮对话

模型架构

  • Flamingo模型和大部分VLM模型类似,由视觉编码器+图文对齐模块+LLM构成,Flamingo并不更新视觉编码器和LLM的参数
  • 预训练过程中,参数会更新的部分是resampler模块和插入到LLM的cross-attention层

模型架构------Resampler

  • 这个Resampler是一个Q-Former模块,训练过程中会更新query和transformer层,前面的BLIP-2的核心就是这个模块
  • 如果输入很长,例如是一个视频,Q-Former可以实现对信息的压缩【输出大小只和query大小有关】

模型架构------插入到LLM的cross-attention层

  • 预训练过程中,参数会更新的部分是resampler模块和插入到LLM的cross-attention层,这些层的权重参数初始化为0,在Flamingo的训练过程中更新

    *值得注意的是,这些交叉注意力层前面的权重系数,层数越深,权重系数的绝对值越大

代码查看------masked cross-attention

  • 为了让模型能够除了对图片进行描述之外,还能处理图片和文本交替混合输入,Flamingo的做法是在cross-attention层使用掩码
  • 在当前文本token下,只看这个文本token前一个图片的视觉token
  • media_locations = input_ids == self.media_token_id这里面media_token_id是<image>
  • 推理时有图片的时候,use_cache这个为false
python 复制代码
media_locations = input_ids == self.media_token_id

# if there are media already cached and we're generating and there are no media tokens in the input,
# we'll assume that ALL input tokens should attend to the last previous media that is cached.
# this is especially important for HF generate() compatibility, since generate() calls forward()
# repeatedly one token at a time (with no media tokens).
# without this check, the model would not attend to any images when generating (after the first token)
use_cached_media_locations = (
    self._use_cached_vision_x
    and self.is_conditioned()
    and not media_locations.any()
)

for layer in self._get_decoder_layers():
    if not use_cached_media_locations:
        layer.condition_media_locations(media_locations)
    layer.condition_use_cached_media(use_cached_media_locations)

这个掩码机制实现如下,注意输入了图片的时候use_cached=False

python 复制代码
media = rearrange(media, "b t n d -> b (t n) d")
sim = einsum("... i d, ... j d -> ... i j", q, k)

if exists(media_locations):
    media_time = torch.arange(T_img, device=x.device) + 1  # T_img是存图片数量的维度

    if not use_cached_media:
        # at each boolean of True, increment the time counter (relative to media time)
        text_time = media_locations.cumsum(dim=-1)  # 前缀和,如果第1张是图,得到[0,1,1,1]

    # text time must equal media time if only attending to most immediate image
    # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
    mask_op = torch.eq if self.only_attend_immediate_media else torch.ge

    text_to_media_mask = mask_op(
        rearrange(text_time, "b i -> b 1 i 1"),
        repeat(media_time, "j -> 1 1 1 (j n)", n=n),
    )
    sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)

sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)

直接看代码比较抽象,看一个例子会比较清晰:假设B = 2,文本长度T_txt = 4,图片数T_img = 3,Resampler的query向量长度n = 2,那么

python 复制代码
media_locations = torch.tensor([
[False, True, False, False],
[False, False, True, False]])
text_time = torch.tensor([
    [0, 1, 1, 1],
    [0, 0, 1, 1]])
media_time = torch.tensor([1, 2, 3])

广播机制实现的结果如下:

最后实现的效果是:得到的mask,每一行对应一个文本token,每n列代表一张图像

note

到这里,MiniCPM-V的前面需要了解的一些内容应该都看完了,过段时间再仔细看一下MiniCPM-V的源代码,偏好对齐那部分也很强,有效降低了幻觉,学习一下DPO是如何实现的

相关推荐
woshihonghonga1 天前
Deepseek在它擅长的AI数据处理领域还有是有低级错误【k折交叉验证中每折样本数计算】
人工智能·python·深度学习·机器学习
乌恩大侠1 天前
以 NVIDIA Sionna Research Kit 赋能 AI 原生 6G 科研
人工智能·usrp
三掌柜6661 天前
借助 Kiro:实现《晚间手机免打扰》应用,破解深夜刷屏困境
人工智能·aws
飞雁科技1 天前
CRM客户管理系统定制开发:如何精准满足企业需求并提升效率?
大数据·运维·人工智能·devops·驻场开发
飞雁科技1 天前
上位机软件定制开发技巧:如何打造专属工业解决方案?
大数据·人工智能·软件开发·devops·驻场开发
这张生成的图像能检测吗1 天前
SAMWISE:为文本驱动的视频分割注入SAM2的智慧
人工智能·图像分割·视频·时序
哥布林学者1 天前
吴恩达深度学习课程二: 改善深层神经网络 第一周:深度学习的实践 课后作业和代码实践
深度学习·ai
antonytyler1 天前
机器学习实践项目(二)- 房价预测增强篇 - 特征工程一
人工智能·机器学习
N 年 后1 天前
cursor和传统idea的区别是什么?
java·人工智能·intellij-idea
AI Echoes1 天前
LangChain 使用语义路由选择不同的Prompt模板
人工智能·python·langchain·prompt·agent