【扩散模型(七)】IP-Adapter 与 IP-Adapter Plus 的具体区别是什么?

系列文章目录


文章目录

  • 系列文章目录
  • [一、Image Encoder 的使用区别?](#一、Image Encoder 的使用区别?)
    • [1.1 Image Encoder 组成](#1.1 Image Encoder 组成)
    • [1.2 .hidden_states[-2] 表示什么?](#1.2 .hidden_states[-2] 表示什么?)
  • [二、ImageProjModel 和 Resampler 的区别?](#二、ImageProjModel 和 Resampler 的区别?)
    • [2.1 ImageProjModel 代码](#2.1 ImageProjModel 代码)
    • [2.2 Resampler 代码](#2.2 Resampler 代码)

从下图中可以很直观地看出有两处不同,第一是使用 image encoder 的方式不同、得到了不同的图像特征,第二是将原有的简单 ImageProjModel 替换成了更加复杂的 Resampler 以提取更多的图像信息。

一、Image Encoder 的使用区别?

1.1 Image Encoder 组成

Image Encoder 是 CLIPVisionModelWithProjection 类(位于 /path/lib/python3.12/site-packages/transformers/models/clip/modeling_clip.py)

根据其构造函数,可见分为两块(1)vision_model【CLIPVisionTransformer】 和 (2)visual_projection【Linear】

python 复制代码
class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
    config_class = CLIPVisionConfig
    main_input_name = "pixel_values"

    def __init__(self, config: CLIPVisionConfig):
        super().__init__(config)

        self.vision_model = CLIPVisionTransformer(config)

        self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

1.视觉模型(vision_model

  • 这通常是一个处理输入图像的视觉转换器(ViT)。
  • 它从图像中提取特征并输出表示,通常包括总结整个图像的"合并"输出。
  • 视觉模型处理了理解图像内容的繁重任务。

2.视觉投影(visual_projection

  • 这是一个线性层,将视觉模型的高维输出映射到低维空间。
  • 在 CLIP 这样的多模态模型中,投影会将图像表示与文本表示对齐
  • 它确保图像嵌入与文本嵌入位于同一空间,便于比较或组合。

1.2 .hidden_states[-2] 表示什么?

我们仔细对比 IP-Adapter 和 IP-Adapter Plus 的细节,会发现采用 Image Encoder 的方式不一样

python 复制代码
# IP-Adapter
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds

# IP-Adapter Plus
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]

关键区别在于:

(1) .image_embeds属性来获取图像的嵌入表示,是 经过编码器最后一层(visual_projection)处理后的结果。

(2) .hidden_states[-2]:这行代码调用self.image_encoder时,通过 output_hidden_states=True 参数指示编码器输出除了最终输出之外的所有隐藏状态

  • self.image_encoder返回一个对象,其中hidden_states是一个包含所有隐藏层输出的列表。
  • 然后,通过索引 [-2] 访问这个列表的倒数第二个元素,即倒数第二个隐藏层的输出。
  • 未经过编码器最后一层(visual_projection)处理后的结果。

二、ImageProjModel 和 Resampler 的区别?

  • ImageProjModelResampler 都是用于将图像嵌入(image_embeds)投影到一个更高维度的空间,以便作为后续的生成引导。通过对 2.1 和 2.2 的两段代码,可以总结出差异:
  1. 网络结构

    • ImageProjModel:包含一个线性层self.proj用于投影,以及一个层归一化self.norm
    • Resampler:包含位置嵌入(如果apply_pos_embTrue)、输入投影self.proj_in、输出投影self.proj_out和层归一化self.norm_out。此外,它还包含一个由多个注意力和前馈网络层组成的模块列表self.layers,这些层用于处理输入数据。
  2. 注意力机制

    • ImageProjModel:没有使用注意力机制。
    • Resampler:使用自定义的PerceiverAttention模块进行注意力计算。
  3. 前馈网络

    • ImageProjModel:没有前馈网络。
    • Resampler:使用FeedForward模块,这是一个标准的前馈网络,通常用于Transformer架构中。
  4. 序列处理

    • ImageProjModel:没有特别处理序列数据。
    • Resampler:设计用于序列数据,包括可选的通过self.to_latents_from_mean_pooled_seq从平均池化序列生成额外的潜在表示。
  5. 可学习的参数

    • ImageProjModel:主要参数是线性层的权重。
    • Resampler:除了线性层的权重外,还包括可学习的潜在表示self.latents
  6. 输出

    • ImageProjModel:输出经过投影和归一化的图像嵌入。
    • Resampler:输出经过多层处理和归一化的序列特征。
  7. 特殊函数

    • Resampler中使用了masked_mean函数,这表明它可能用于处理带有掩码的序列数据,例如在处理变长序列时。

总结来说,ImageProjModel是一个简单的投影模型,可能用于将图像特征投影到一个多维空间以便于与其他类型的数据结合。而Resampler是一个更复杂的模型 (主要来源于论文^1^),设计用于处理序列数据,并通过注意力和前馈网络层进行特征转换。

2.1 ImageProjModel 代码

python 复制代码
class ImageProjModel(torch.nn.Module):
    """Projection Model"""

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens

2.2 Resampler 代码

Flamingo 论文中的图像,可与代码对照理解。

python 复制代码
class Resampler(nn.Module):
    def __init__(
        self,
        dim=1024,
        depth=8,
        dim_head=64,
        heads=16,
        num_queries=8,
        embedding_dim=768,
        output_dim=1024,
        ff_mult=4,
        max_seq_len: int = 257,  # CLIP tokens + CLS token
        apply_pos_emb: bool = False,
        num_latents_mean_pooled: int = 0,  # number of latents derived from mean pooled representation of the sequence
    ):
        super().__init__()
        self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None

        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
        
        print(embedding_dim, dim)
        self.proj_in = nn.Linear(embedding_dim, dim)

        self.proj_out = nn.Linear(dim, output_dim)
        self.norm_out = nn.LayerNorm(output_dim)

        self.to_latents_from_mean_pooled_seq = (
            nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
            )
            if num_latents_mean_pooled > 0
            else None
        )

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

    def forward(self, x):
        if self.pos_emb is not None:
            n, device = x.shape[1], x.device
            pos_emb = self.pos_emb(torch.arange(n, device=device))
            x = x + pos_emb

        latents = self.latents.repeat(x.size(0), 1, 1)

        x = self.proj_in(x)

        if self.to_latents_from_mean_pooled_seq:
            meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim=-2)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        latents = self.proj_out(latents)
        return self.norm_out(latents)


def masked_mean(t, *, dim, mask=None):
    if mask is None:
        return t.mean(dim=dim)

    denom = mask.sum(dim=dim, keepdim=True)
    mask = rearrange(mask, "b n -> b n 1")
    masked_t = t.masked_fill(~mask, 0.0)

    return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)

  1. Flamingo: a Visual Language Model for Few-Shot Learning ↩︎
相关推荐
shansjqun3 分钟前
教学内容全覆盖:航拍杂草检测与分类
人工智能·分类·数据挖掘
狸克先生5 分钟前
如何用AI写小说(二):Gradio 超简单的网页前端交互
前端·人工智能·chatgpt·交互
baiduopenmap20 分钟前
百度世界2024精选公开课:基于地图智能体的导航出行AI应用创新实践
前端·人工智能·百度地图
小任同学Alex23 分钟前
浦语提示词工程实践(LangGPT版,服务器上部署internlm2-chat-1_8b,踩坑很多才完成的详细教程,)
人工智能·自然语言处理·大模型
新加坡内哥谈技术29 分钟前
微软 Ignite 2024 大会
人工智能
nuclear201142 分钟前
使用Python 在Excel中创建和取消数据分组 - 详解
python·excel数据分组·创建excel分组·excel分类汇总·excel嵌套分组·excel大纲级别·取消excel分组
江瀚视野1 小时前
Q3净利增长超预期,文心大模型调用量大增,百度未来如何分析?
人工智能
Lucky小小吴1 小时前
有关django、python版本、sqlite3版本冲突问题
python·django·sqlite
陪学1 小时前
百度遭初创企业指控抄袭,维权还是碰瓷?
人工智能·百度·面试·职场和发展·产品运营
QCN_1 小时前
湘潭大学人工智能考试复习1(软件工程)
人工智能