【扩散模型(七)】IP-Adapter 与 IP-Adapter Plus 的具体区别是什么?

系列文章目录


文章目录

  • 系列文章目录
  • [一、Image Encoder 的使用区别?](#一、Image Encoder 的使用区别?)
    • [1.1 Image Encoder 组成](#1.1 Image Encoder 组成)
    • [1.2 .hidden_states[-2] 表示什么?](#1.2 .hidden_states[-2] 表示什么?)
  • [二、ImageProjModel 和 Resampler 的区别?](#二、ImageProjModel 和 Resampler 的区别?)
    • [2.1 ImageProjModel 代码](#2.1 ImageProjModel 代码)
    • [2.2 Resampler 代码](#2.2 Resampler 代码)

从下图中可以很直观地看出有两处不同,第一是使用 image encoder 的方式不同、得到了不同的图像特征,第二是将原有的简单 ImageProjModel 替换成了更加复杂的 Resampler 以提取更多的图像信息。

一、Image Encoder 的使用区别?

1.1 Image Encoder 组成

Image Encoder 是 CLIPVisionModelWithProjection 类(位于 /path/lib/python3.12/site-packages/transformers/models/clip/modeling_clip.py)

根据其构造函数,可见分为两块(1)vision_model【CLIPVisionTransformer】 和 (2)visual_projection【Linear】

python 复制代码
class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
    config_class = CLIPVisionConfig
    main_input_name = "pixel_values"

    def __init__(self, config: CLIPVisionConfig):
        super().__init__(config)

        self.vision_model = CLIPVisionTransformer(config)

        self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

1.视觉模型(vision_model

  • 这通常是一个处理输入图像的视觉转换器(ViT)。
  • 它从图像中提取特征并输出表示,通常包括总结整个图像的"合并"输出。
  • 视觉模型处理了理解图像内容的繁重任务。

2.视觉投影(visual_projection

  • 这是一个线性层,将视觉模型的高维输出映射到低维空间。
  • 在 CLIP 这样的多模态模型中,投影会将图像表示与文本表示对齐
  • 它确保图像嵌入与文本嵌入位于同一空间,便于比较或组合。

1.2 .hidden_states[-2] 表示什么?

我们仔细对比 IP-Adapter 和 IP-Adapter Plus 的细节,会发现采用 Image Encoder 的方式不一样

python 复制代码
# IP-Adapter
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds

# IP-Adapter Plus
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]

关键区别在于:

(1) .image_embeds属性来获取图像的嵌入表示,是 经过编码器最后一层(visual_projection)处理后的结果。

(2) .hidden_states[-2]:这行代码调用self.image_encoder时,通过 output_hidden_states=True 参数指示编码器输出除了最终输出之外的所有隐藏状态

  • self.image_encoder返回一个对象,其中hidden_states是一个包含所有隐藏层输出的列表。
  • 然后,通过索引 [-2] 访问这个列表的倒数第二个元素,即倒数第二个隐藏层的输出。
  • 未经过编码器最后一层(visual_projection)处理后的结果。

二、ImageProjModel 和 Resampler 的区别?

  • ImageProjModelResampler 都是用于将图像嵌入(image_embeds)投影到一个更高维度的空间,以便作为后续的生成引导。通过对 2.1 和 2.2 的两段代码,可以总结出差异:
  1. 网络结构

    • ImageProjModel:包含一个线性层self.proj用于投影,以及一个层归一化self.norm
    • Resampler:包含位置嵌入(如果apply_pos_embTrue)、输入投影self.proj_in、输出投影self.proj_out和层归一化self.norm_out。此外,它还包含一个由多个注意力和前馈网络层组成的模块列表self.layers,这些层用于处理输入数据。
  2. 注意力机制

    • ImageProjModel:没有使用注意力机制。
    • Resampler:使用自定义的PerceiverAttention模块进行注意力计算。
  3. 前馈网络

    • ImageProjModel:没有前馈网络。
    • Resampler:使用FeedForward模块,这是一个标准的前馈网络,通常用于Transformer架构中。
  4. 序列处理

    • ImageProjModel:没有特别处理序列数据。
    • Resampler:设计用于序列数据,包括可选的通过self.to_latents_from_mean_pooled_seq从平均池化序列生成额外的潜在表示。
  5. 可学习的参数

    • ImageProjModel:主要参数是线性层的权重。
    • Resampler:除了线性层的权重外,还包括可学习的潜在表示self.latents
  6. 输出

    • ImageProjModel:输出经过投影和归一化的图像嵌入。
    • Resampler:输出经过多层处理和归一化的序列特征。
  7. 特殊函数

    • Resampler中使用了masked_mean函数,这表明它可能用于处理带有掩码的序列数据,例如在处理变长序列时。

总结来说,ImageProjModel是一个简单的投影模型,可能用于将图像特征投影到一个多维空间以便于与其他类型的数据结合。而Resampler是一个更复杂的模型 (主要来源于论文^1^),设计用于处理序列数据,并通过注意力和前馈网络层进行特征转换。

2.1 ImageProjModel 代码

python 复制代码
class ImageProjModel(torch.nn.Module):
    """Projection Model"""

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens

2.2 Resampler 代码

Flamingo 论文中的图像,可与代码对照理解。

python 复制代码
class Resampler(nn.Module):
    def __init__(
        self,
        dim=1024,
        depth=8,
        dim_head=64,
        heads=16,
        num_queries=8,
        embedding_dim=768,
        output_dim=1024,
        ff_mult=4,
        max_seq_len: int = 257,  # CLIP tokens + CLS token
        apply_pos_emb: bool = False,
        num_latents_mean_pooled: int = 0,  # number of latents derived from mean pooled representation of the sequence
    ):
        super().__init__()
        self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None

        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
        
        print(embedding_dim, dim)
        self.proj_in = nn.Linear(embedding_dim, dim)

        self.proj_out = nn.Linear(dim, output_dim)
        self.norm_out = nn.LayerNorm(output_dim)

        self.to_latents_from_mean_pooled_seq = (
            nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
            )
            if num_latents_mean_pooled > 0
            else None
        )

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

    def forward(self, x):
        if self.pos_emb is not None:
            n, device = x.shape[1], x.device
            pos_emb = self.pos_emb(torch.arange(n, device=device))
            x = x + pos_emb

        latents = self.latents.repeat(x.size(0), 1, 1)

        x = self.proj_in(x)

        if self.to_latents_from_mean_pooled_seq:
            meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim=-2)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        latents = self.proj_out(latents)
        return self.norm_out(latents)


def masked_mean(t, *, dim, mask=None):
    if mask is None:
        return t.mean(dim=dim)

    denom = mask.sum(dim=dim, keepdim=True)
    mask = rearrange(mask, "b n -> b n 1")
    masked_t = t.masked_fill(~mask, 0.0)

    return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)

  1. Flamingo: a Visual Language Model for Few-Shot Learning ↩︎
相关推荐
DisonTangor13 分钟前
阿里通义千问开源Qwen2.5系列模型:Qwen2-VL-72B媲美GPT-4
人工智能·计算机视觉
豆浩宇13 分钟前
Halcon OCR检测 免训练版
c++·人工智能·opencv·算法·计算机视觉·ocr
Narutolxy15 分钟前
Python 单元测试:深入理解与实战应用20240919
python·单元测试·log4j
LLSU1318 分钟前
聚星文社AI软件小说推文软件
人工智能
JackieZhengChina20 分钟前
吴泳铭:AI最大的想象力不在手机屏幕,而是改变物理世界
人工智能·智能手机
ShuQiHere21 分钟前
【ShuQiHere】 探索数据挖掘的世界:从概念到应用
人工智能·数据挖掘
嵌入式杂谈22 分钟前
OpenCV计算机视觉:探索图片处理的多种操作
人工智能·opencv·计算机视觉
时光追逐者23 分钟前
分享6个.NET开源的AI和LLM相关项目框架
人工智能·microsoft·ai·c#·.net·.netcore
东隆科技23 分钟前
PicoQuant公司:探索铜铟镓硒(CIGS)太阳能电池技术,引领绿色能源革新
人工智能·能源
红米煮粥24 分钟前
图像处理-掩码
图像处理·opencv·计算机视觉