【扩散模型(七)】IP-Adapter 与 IP-Adapter Plus 的具体区别是什么?

系列文章目录


文章目录

  • 系列文章目录
  • [一、Image Encoder 的使用区别?](#一、Image Encoder 的使用区别?)
    • [1.1 Image Encoder 组成](#1.1 Image Encoder 组成)
    • [1.2 .hidden_states[-2] 表示什么?](#1.2 .hidden_states[-2] 表示什么?)
  • [二、ImageProjModel 和 Resampler 的区别?](#二、ImageProjModel 和 Resampler 的区别?)
    • [2.1 ImageProjModel 代码](#2.1 ImageProjModel 代码)
    • [2.2 Resampler 代码](#2.2 Resampler 代码)

从下图中可以很直观地看出有两处不同,第一是使用 image encoder 的方式不同、得到了不同的图像特征,第二是将原有的简单 ImageProjModel 替换成了更加复杂的 Resampler 以提取更多的图像信息。

一、Image Encoder 的使用区别?

1.1 Image Encoder 组成

Image Encoder 是 CLIPVisionModelWithProjection 类(位于 /path/lib/python3.12/site-packages/transformers/models/clip/modeling_clip.py)

根据其构造函数,可见分为两块(1)vision_model【CLIPVisionTransformer】 和 (2)visual_projection【Linear】

python 复制代码
class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
    config_class = CLIPVisionConfig
    main_input_name = "pixel_values"

    def __init__(self, config: CLIPVisionConfig):
        super().__init__(config)

        self.vision_model = CLIPVisionTransformer(config)

        self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

1.视觉模型(vision_model

  • 这通常是一个处理输入图像的视觉转换器(ViT)。
  • 它从图像中提取特征并输出表示,通常包括总结整个图像的"合并"输出。
  • 视觉模型处理了理解图像内容的繁重任务。

2.视觉投影(visual_projection

  • 这是一个线性层,将视觉模型的高维输出映射到低维空间。
  • 在 CLIP 这样的多模态模型中,投影会将图像表示与文本表示对齐
  • 它确保图像嵌入与文本嵌入位于同一空间,便于比较或组合。

1.2 .hidden_states[-2] 表示什么?

我们仔细对比 IP-Adapter 和 IP-Adapter Plus 的细节,会发现采用 Image Encoder 的方式不一样

python 复制代码
# IP-Adapter
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds

# IP-Adapter Plus
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]

关键区别在于:

(1) .image_embeds属性来获取图像的嵌入表示,是 经过编码器最后一层(visual_projection)处理后的结果。

(2) .hidden_states[-2]:这行代码调用self.image_encoder时,通过 output_hidden_states=True 参数指示编码器输出除了最终输出之外的所有隐藏状态

  • self.image_encoder返回一个对象,其中hidden_states是一个包含所有隐藏层输出的列表。
  • 然后,通过索引 [-2] 访问这个列表的倒数第二个元素,即倒数第二个隐藏层的输出。
  • 未经过编码器最后一层(visual_projection)处理后的结果。

二、ImageProjModel 和 Resampler 的区别?

  • ImageProjModelResampler 都是用于将图像嵌入(image_embeds)投影到一个更高维度的空间,以便作为后续的生成引导。通过对 2.1 和 2.2 的两段代码,可以总结出差异:
  1. 网络结构

    • ImageProjModel:包含一个线性层self.proj用于投影,以及一个层归一化self.norm
    • Resampler:包含位置嵌入(如果apply_pos_embTrue)、输入投影self.proj_in、输出投影self.proj_out和层归一化self.norm_out。此外,它还包含一个由多个注意力和前馈网络层组成的模块列表self.layers,这些层用于处理输入数据。
  2. 注意力机制

    • ImageProjModel:没有使用注意力机制。
    • Resampler:使用自定义的PerceiverAttention模块进行注意力计算。
  3. 前馈网络

    • ImageProjModel:没有前馈网络。
    • Resampler:使用FeedForward模块,这是一个标准的前馈网络,通常用于Transformer架构中。
  4. 序列处理

    • ImageProjModel:没有特别处理序列数据。
    • Resampler:设计用于序列数据,包括可选的通过self.to_latents_from_mean_pooled_seq从平均池化序列生成额外的潜在表示。
  5. 可学习的参数

    • ImageProjModel:主要参数是线性层的权重。
    • Resampler:除了线性层的权重外,还包括可学习的潜在表示self.latents
  6. 输出

    • ImageProjModel:输出经过投影和归一化的图像嵌入。
    • Resampler:输出经过多层处理和归一化的序列特征。
  7. 特殊函数

    • Resampler中使用了masked_mean函数,这表明它可能用于处理带有掩码的序列数据,例如在处理变长序列时。

总结来说,ImageProjModel是一个简单的投影模型,可能用于将图像特征投影到一个多维空间以便于与其他类型的数据结合。而Resampler是一个更复杂的模型 (主要来源于论文^1^),设计用于处理序列数据,并通过注意力和前馈网络层进行特征转换。

2.1 ImageProjModel 代码

python 复制代码
class ImageProjModel(torch.nn.Module):
    """Projection Model"""

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens

2.2 Resampler 代码

Flamingo 论文中的图像,可与代码对照理解。

python 复制代码
class Resampler(nn.Module):
    def __init__(
        self,
        dim=1024,
        depth=8,
        dim_head=64,
        heads=16,
        num_queries=8,
        embedding_dim=768,
        output_dim=1024,
        ff_mult=4,
        max_seq_len: int = 257,  # CLIP tokens + CLS token
        apply_pos_emb: bool = False,
        num_latents_mean_pooled: int = 0,  # number of latents derived from mean pooled representation of the sequence
    ):
        super().__init__()
        self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None

        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
        
        print(embedding_dim, dim)
        self.proj_in = nn.Linear(embedding_dim, dim)

        self.proj_out = nn.Linear(dim, output_dim)
        self.norm_out = nn.LayerNorm(output_dim)

        self.to_latents_from_mean_pooled_seq = (
            nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
            )
            if num_latents_mean_pooled > 0
            else None
        )

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

    def forward(self, x):
        if self.pos_emb is not None:
            n, device = x.shape[1], x.device
            pos_emb = self.pos_emb(torch.arange(n, device=device))
            x = x + pos_emb

        latents = self.latents.repeat(x.size(0), 1, 1)

        x = self.proj_in(x)

        if self.to_latents_from_mean_pooled_seq:
            meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim=-2)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        latents = self.proj_out(latents)
        return self.norm_out(latents)


def masked_mean(t, *, dim, mask=None):
    if mask is None:
        return t.mean(dim=dim)

    denom = mask.sum(dim=dim, keepdim=True)
    mask = rearrange(mask, "b n -> b n 1")
    masked_t = t.masked_fill(~mask, 0.0)

    return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)

  1. Flamingo: a Visual Language Model for Few-Shot Learning ↩︎
相关推荐
q5673152322 分钟前
在 Bash 中获取 Python 模块变量列
开发语言·python·bash
是萝卜干呀23 分钟前
Backend - Python 爬取网页数据并保存在Excel文件中
python·excel·table·xlwt·爬取网页数据
代码欢乐豆24 分钟前
数据采集之selenium模拟登录
python·selenium·测试工具
喵~来学编程啦31 分钟前
【论文精读】LPT: Long-tailed prompt tuning for image classification
人工智能·深度学习·机器学习·计算机视觉·论文笔记
深圳市青牛科技实业有限公司44 分钟前
【青牛科技】应用方案|D2587A高压大电流DC-DC
人工智能·科技·单片机·嵌入式硬件·机器人·安防监控
狂奔solar1 小时前
yelp数据集上识别潜在的热门商家
开发语言·python
Tassel_YUE1 小时前
网络自动化04:python实现ACL匹配信息(主机与主机信息)
网络·python·自动化
水豚AI课代表1 小时前
分析报告、调研报告、工作方案等的提示词
大数据·人工智能·学习·chatgpt·aigc
几两春秋梦_1 小时前
符号回归概念
人工智能·数据挖掘·回归
聪明的墨菲特i1 小时前
Python爬虫学习
爬虫·python·学习