【扩散模型(七)】IP-Adapter 与 IP-Adapter Plus 的具体区别是什么?

系列文章目录


文章目录

  • 系列文章目录
  • [一、Image Encoder 的使用区别?](#一、Image Encoder 的使用区别?)
    • [1.1 Image Encoder 组成](#1.1 Image Encoder 组成)
    • [1.2 .hidden_states[-2] 表示什么?](#1.2 .hidden_states[-2] 表示什么?)
  • [二、ImageProjModel 和 Resampler 的区别?](#二、ImageProjModel 和 Resampler 的区别?)
    • [2.1 ImageProjModel 代码](#2.1 ImageProjModel 代码)
    • [2.2 Resampler 代码](#2.2 Resampler 代码)

从下图中可以很直观地看出有两处不同,第一是使用 image encoder 的方式不同、得到了不同的图像特征,第二是将原有的简单 ImageProjModel 替换成了更加复杂的 Resampler 以提取更多的图像信息。

一、Image Encoder 的使用区别?

1.1 Image Encoder 组成

Image Encoder 是 CLIPVisionModelWithProjection 类(位于 /path/lib/python3.12/site-packages/transformers/models/clip/modeling_clip.py)

根据其构造函数,可见分为两块(1)vision_model【CLIPVisionTransformer】 和 (2)visual_projection【Linear】

python 复制代码
class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
    config_class = CLIPVisionConfig
    main_input_name = "pixel_values"

    def __init__(self, config: CLIPVisionConfig):
        super().__init__(config)

        self.vision_model = CLIPVisionTransformer(config)

        self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

1.视觉模型(vision_model

  • 这通常是一个处理输入图像的视觉转换器(ViT)。
  • 它从图像中提取特征并输出表示,通常包括总结整个图像的"合并"输出。
  • 视觉模型处理了理解图像内容的繁重任务。

2.视觉投影(visual_projection

  • 这是一个线性层,将视觉模型的高维输出映射到低维空间。
  • 在 CLIP 这样的多模态模型中,投影会将图像表示与文本表示对齐
  • 它确保图像嵌入与文本嵌入位于同一空间,便于比较或组合。

1.2 .hidden_states[-2] 表示什么?

我们仔细对比 IP-Adapter 和 IP-Adapter Plus 的细节,会发现采用 Image Encoder 的方式不一样

python 复制代码
# IP-Adapter
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds

# IP-Adapter Plus
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]

关键区别在于:

(1) .image_embeds属性来获取图像的嵌入表示,是 经过编码器最后一层(visual_projection)处理后的结果。

(2) .hidden_states[-2]:这行代码调用self.image_encoder时,通过 output_hidden_states=True 参数指示编码器输出除了最终输出之外的所有隐藏状态

  • self.image_encoder返回一个对象,其中hidden_states是一个包含所有隐藏层输出的列表。
  • 然后,通过索引 [-2] 访问这个列表的倒数第二个元素,即倒数第二个隐藏层的输出。
  • 未经过编码器最后一层(visual_projection)处理后的结果。

二、ImageProjModel 和 Resampler 的区别?

  • ImageProjModelResampler 都是用于将图像嵌入(image_embeds)投影到一个更高维度的空间,以便作为后续的生成引导。通过对 2.1 和 2.2 的两段代码,可以总结出差异:
  1. 网络结构

    • ImageProjModel:包含一个线性层self.proj用于投影,以及一个层归一化self.norm
    • Resampler:包含位置嵌入(如果apply_pos_embTrue)、输入投影self.proj_in、输出投影self.proj_out和层归一化self.norm_out。此外,它还包含一个由多个注意力和前馈网络层组成的模块列表self.layers,这些层用于处理输入数据。
  2. 注意力机制

    • ImageProjModel:没有使用注意力机制。
    • Resampler:使用自定义的PerceiverAttention模块进行注意力计算。
  3. 前馈网络

    • ImageProjModel:没有前馈网络。
    • Resampler:使用FeedForward模块,这是一个标准的前馈网络,通常用于Transformer架构中。
  4. 序列处理

    • ImageProjModel:没有特别处理序列数据。
    • Resampler:设计用于序列数据,包括可选的通过self.to_latents_from_mean_pooled_seq从平均池化序列生成额外的潜在表示。
  5. 可学习的参数

    • ImageProjModel:主要参数是线性层的权重。
    • Resampler:除了线性层的权重外,还包括可学习的潜在表示self.latents
  6. 输出

    • ImageProjModel:输出经过投影和归一化的图像嵌入。
    • Resampler:输出经过多层处理和归一化的序列特征。
  7. 特殊函数

    • Resampler中使用了masked_mean函数,这表明它可能用于处理带有掩码的序列数据,例如在处理变长序列时。

总结来说,ImageProjModel是一个简单的投影模型,可能用于将图像特征投影到一个多维空间以便于与其他类型的数据结合。而Resampler是一个更复杂的模型 (主要来源于论文^1^),设计用于处理序列数据,并通过注意力和前馈网络层进行特征转换。

2.1 ImageProjModel 代码

python 复制代码
class ImageProjModel(torch.nn.Module):
    """Projection Model"""

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens

2.2 Resampler 代码

Flamingo 论文中的图像,可与代码对照理解。

python 复制代码
class Resampler(nn.Module):
    def __init__(
        self,
        dim=1024,
        depth=8,
        dim_head=64,
        heads=16,
        num_queries=8,
        embedding_dim=768,
        output_dim=1024,
        ff_mult=4,
        max_seq_len: int = 257,  # CLIP tokens + CLS token
        apply_pos_emb: bool = False,
        num_latents_mean_pooled: int = 0,  # number of latents derived from mean pooled representation of the sequence
    ):
        super().__init__()
        self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None

        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
        
        print(embedding_dim, dim)
        self.proj_in = nn.Linear(embedding_dim, dim)

        self.proj_out = nn.Linear(dim, output_dim)
        self.norm_out = nn.LayerNorm(output_dim)

        self.to_latents_from_mean_pooled_seq = (
            nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
            )
            if num_latents_mean_pooled > 0
            else None
        )

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

    def forward(self, x):
        if self.pos_emb is not None:
            n, device = x.shape[1], x.device
            pos_emb = self.pos_emb(torch.arange(n, device=device))
            x = x + pos_emb

        latents = self.latents.repeat(x.size(0), 1, 1)

        x = self.proj_in(x)

        if self.to_latents_from_mean_pooled_seq:
            meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim=-2)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        latents = self.proj_out(latents)
        return self.norm_out(latents)


def masked_mean(t, *, dim, mask=None):
    if mask is None:
        return t.mean(dim=dim)

    denom = mask.sum(dim=dim, keepdim=True)
    mask = rearrange(mask, "b n -> b n 1")
    masked_t = t.masked_fill(~mask, 0.0)

    return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)

  1. Flamingo: a Visual Language Model for Few-Shot Learning ↩︎
相关推荐
普密斯科技6 分钟前
手机外观边框缺陷视觉检测智慧方案
人工智能·计算机视觉·智能手机·自动化·视觉检测·集成测试
四口鲸鱼爱吃盐19 分钟前
Pytorch | 利用AI-FGTM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
lishanlu13621 分钟前
Pytorch分布式训练
人工智能·ddp·pytorch并行训练
是娜个二叉树!26 分钟前
图像处理基础 | 格式转换.rgb转.jpg 灰度图 python
开发语言·python
互联网杂货铺30 分钟前
Postman接口测试:全局变量/接口关联/加密/解密
自动化测试·软件测试·python·测试工具·职场和发展·测试用例·postman
日出等日落34 分钟前
从零开始使用MaxKB打造本地大语言模型智能问答系统与远程交互
人工智能·语言模型·自然语言处理
三木吧43 分钟前
开发微信小程序的过程与心得
人工智能·微信小程序·小程序
云边有个稻草人1 小时前
AIGC与娱乐产业:颠覆创意与生产的新力量
aigc·娱乐
whaosoft-1431 小时前
w~视觉~3D~合集5
人工智能