系列文章目录
- 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
- 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
- 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
- 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。
- 【可控图像生成系列论文(四)】IP-Adapter 具体是如何训练的?1公式篇
- 【扩散模型(六)】IP-Adapter 是如何训练的?2 源码篇(IP-Adapter Plus)
文章目录
- 系列文章目录
- [一、Image Encoder 的使用区别?](#一、Image Encoder 的使用区别?)
-
- [1.1 Image Encoder 组成](#1.1 Image Encoder 组成)
- [1.2 .hidden_states[-2] 表示什么?](#1.2 .hidden_states[-2] 表示什么?)
- [二、ImageProjModel 和 Resampler 的区别?](#二、ImageProjModel 和 Resampler 的区别?)
-
- [2.1 ImageProjModel 代码](#2.1 ImageProjModel 代码)
- [2.2 Resampler 代码](#2.2 Resampler 代码)
从下图中可以很直观地看出有两处不同,第一是使用 image encoder 的方式不同、得到了不同的图像特征,第二是将原有的简单 ImageProjModel 替换成了更加复杂的 Resampler 以提取更多的图像信息。
一、Image Encoder 的使用区别?
1.1 Image Encoder 组成
Image Encoder 是 CLIPVisionModelWithProjection 类(位于 /path/lib/python3.12/site-packages/transformers/models/clip/modeling_clip.py)
根据其构造函数,可见分为两块(1)vision_model【CLIPVisionTransformer】 和 (2)visual_projection【Linear】
python
class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
def __init__(self, config: CLIPVisionConfig):
super().__init__(config)
self.vision_model = CLIPVisionTransformer(config)
self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
# Initialize weights and apply final processing
self.post_init()
1.视觉模型(vision_model
):
- 这通常是一个处理输入图像的视觉转换器(ViT)。
- 它从图像中提取特征并输出表示,通常包括总结整个图像的"合并"输出。
- 视觉模型处理了理解图像内容的繁重任务。
2.视觉投影(visual_projection
):
- 这是一个线性层,将视觉模型的高维输出映射到低维空间。
- 在 CLIP 这样的多模态模型中,
投影会将图像表示与文本表示对齐
。 - 它确保图像嵌入与文本嵌入位于同一空间,便于比较或组合。
1.2 .hidden_states[-2] 表示什么?
我们仔细对比 IP-Adapter 和 IP-Adapter Plus 的细节,会发现采用 Image Encoder 的方式不一样
python
# IP-Adapter
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
# IP-Adapter Plus
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
关键区别在于:
(1) .image_embeds属性来获取图像的嵌入表示,是 经过编码器最后一层(visual_projection
)处理后的结果。
(2) .hidden_states[-2]:这行代码调用self.image_encoder时,通过 output_hidden_states=True 参数指示编码器输出除了最终输出之外的所有隐藏状态。
- self.image_encoder返回一个对象,其中hidden_states是一个包含所有隐藏层输出的列表。
- 然后,通过索引 [-2] 访问这个列表的倒数第二个元素,即倒数第二个隐藏层的输出。
- 即未经过编码器最后一层(
visual_projection
)处理后的结果。
二、ImageProjModel 和 Resampler 的区别?
ImageProjModel
和Resampler
都是用于将图像嵌入(image_embeds
)投影到一个更高维度的空间,以便作为后续的生成引导。通过对 2.1 和 2.2 的两段代码,可以总结出差异:
-
网络结构:
ImageProjModel
:包含一个线性层self.proj
用于投影,以及一个层归一化self.norm
。Resampler
:包含位置嵌入(如果apply_pos_emb
为True
)、输入投影self.proj_in
、输出投影self.proj_out
和层归一化self.norm_out
。此外,它还包含一个由多个注意力和前馈网络层组成的模块列表self.layers
,这些层用于处理输入数据。
-
注意力机制:
ImageProjModel
:没有使用注意力机制。Resampler
:使用自定义的PerceiverAttention
模块进行注意力计算。
-
前馈网络:
ImageProjModel
:没有前馈网络。Resampler
:使用FeedForward
模块,这是一个标准的前馈网络,通常用于Transformer架构中。
-
序列处理:
ImageProjModel
:没有特别处理序列数据。Resampler
:设计用于序列数据,包括可选的通过self.to_latents_from_mean_pooled_seq
从平均池化序列生成额外的潜在表示。
-
可学习的参数:
ImageProjModel
:主要参数是线性层的权重。Resampler
:除了线性层的权重外,还包括可学习的潜在表示self.latents
。
-
输出:
ImageProjModel
:输出经过投影和归一化的图像嵌入。Resampler
:输出经过多层处理和归一化的序列特征。
-
特殊函数:
Resampler
中使用了masked_mean
函数,这表明它可能用于处理带有掩码的序列数据,例如在处理变长序列时。
总结来说,ImageProjModel
是一个简单的投影模型,可能用于将图像特征投影到一个多维空间以便于与其他类型的数据结合。而Resampler
是一个更复杂的模型 (主要来源于论文^1^),设计用于处理序列数据,并通过注意力和前馈网络层进行特征转换。
2.1 ImageProjModel 代码
python
class ImageProjModel(torch.nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
2.2 Resampler 代码
Flamingo 论文中的图像,可与代码对照理解。
python
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
max_seq_len: int = 257, # CLIP tokens + CLS token
apply_pos_emb: bool = False,
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
):
super().__init__()
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
print(embedding_dim, dim)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.to_latents_from_mean_pooled_seq = (
nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * num_latents_mean_pooled),
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
)
if num_latents_mean_pooled > 0
else None
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, x):
if self.pos_emb is not None:
n, device = x.shape[1], x.device
pos_emb = self.pos_emb(torch.arange(n, device=device))
x = x + pos_emb
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
if self.to_latents_from_mean_pooled_seq:
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
latents = torch.cat((meanpooled_latents, latents), dim=-2)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
def masked_mean(t, *, dim, mask=None):
if mask is None:
return t.mean(dim=dim)
denom = mask.sum(dim=dim, keepdim=True)
mask = rearrange(mask, "b n -> b n 1")
masked_t = t.masked_fill(~mask, 0.0)
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
- Flamingo: a Visual Language Model for Few-Shot Learning ↩︎