【扩散模型(三)】IP-Adapter 源码详解1-输入篇

系列文章目录

文章目录


整体结构图+代码中的变量名

IP-Adapter 源码:https://github.com/tencent-ailab/IP-Adapter

本文就基于 SD1.5 的 IP-Adapter 训练代码 tutorial_train.py 为例,进行代码和结构图的解释。


一、IP-Adapter 做了什么?

如上图所示,插入了图中的最上面一条分支(图像输入条件分支):

  1. 蓝色的(无需训练的) Image Encoder
  2. 红色的(需训练的)Linear + LN(LayerNorm)
  3. 红色的(需训练的)、针对图像(Image Prompt)的 Cross Attention。

在论文中也提到,具体分别是:

  1. Image Encoder 是 pretrained CLIP image encoder
  2. 线性层和层归一化 Linear + LN(LayerNorm^1^):
    • 为了有效地分解全局图像嵌入,作者使用一个小的可训练投影网络(projection network)将图像嵌入投影到长度为N的特征序列中(在本研究中使用N=4),图像特征的维数与预训练的扩散模型中文本特征的维数相同。使用的投影网络由线性层和层归一化组成。
  3. Decoupled Cross-Attention 中,做法是在原来的 UNet 的 Cross-Attention 中加了一层 Cross-Attention。
    • 如原文提到 "we add a new cross-attention layer for each cross-attention layer in the original UNet model to insert image features."

二、对应的代码实现

1.模型输入

先简单看下模型的训练时的输入,即 /path/IP-Adapter/tutorial_train.py 中 main() 函数内的 dataloader 部分,下面代码通过调用 MyDataset 类来实现了 train_dataloader 的构建。

python 复制代码
    # dataloader
    train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_fn,
        batch_size=args.train_batch_size,
        num_workers=args.dataloader_num_workers,
    )

对于实际训练使用的数据则为从 train_dataloader 中取的:

  1. batch["images"]
    • 用来得到形状后,生成随机噪声。
    • 具体如下代码所示,通过 vae.encoder 得到 latents后
    • 通过 torch.randn_like(latents) 按照 latents 张量的形状生成一个随机的噪声张量 noise
  2. batch["clip_images"]
    • 通过 image_encoder 得到 image_embeds 图像特征
  3. batch["drop_image_embeds"]
    • 文中有提到会随机通过随机丢弃条件信息(如文本或图像嵌入),使得模型会学会在有条件和无条件的情况下进行预测(生成图像)
  4. batch["text_input_ids"] 是文本输入,通过一个 text_encoder 后得到文本特征 encoder_hidden_states
python 复制代码
  for step, batch in enumerate(train_dataloader):
      load_data_time = time.perf_counter() - begin

        with torch.no_grad():
            latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()
            latents = latents * vae.config.scaling_factor

        # Sample noise that we'll add to the latents
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
    
        with torch.no_grad():
            image_embeds = image_encoder(batch["clip_images"].to(accelerator.device, dtype=weight_dtype)).image_embeds
        image_embeds_ = []
        for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
            if drop_image_embed == 1:
                image_embeds_.append(torch.zeros_like(image_embed))
            else:
                image_embeds_.append(image_embed)
        image_embeds = torch.stack(image_embeds_)
    
        with torch.no_grad():
            encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0] # pooled_prompt_embeds?

2.Linear 和 LN(LayerNorm)

以 SD1.5 + IP-Adapter 的训练代码为例:

下方代码为 /path/IP-Adapter/tutorial_train.py 中 main() 函数内,调用了定义好的 ImageProjModel 类

python 复制代码
#ip-adapter
    image_proj_model = ImageProjModel(
        cross_attention_dim=unet.config.cross_attention_dim,
        clip_embeddings_dim=image_encoder.config.projection_dim,
        clip_extra_context_tokens=4,
    )

下方代码为 /path/IP-Adapter/ip_adapter/ip_adapter.py 被调用的 ImageProjModel 类,在构造函数 __init__ 中可以看到有前文提到的 Linear 和 LayerNorm。

python 复制代码
class ImageProjModel(torch.nn.Module):
    """Projection Model"""

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens

总结

本文详解了IP-Adapter 训练源码中的输入部分,下篇则详解核心部分,针对图像输入的 Cross-Attention。


  1. Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprint arXiv:1607.06450, 2016 ↩︎
相关推荐
十三画者16 分钟前
【工具】survex一个解释机器学习生存模型的R包
人工智能·机器学习·数据挖掘·数据分析·r语言·数据可视化
GIS数据转换器21 分钟前
智能化水利监管:无人机视频在违章行为识别中的应用
大数据·人工智能·物联网·无人机·智慧城市
努力进修26 分钟前
通义万相 2.1 × 蓝耘智算:AIGC 界的「黄金搭档」如何重塑创作未来?
人工智能·aigc·deepseek·蓝耘·通义万相2.1
测试者家园31 分钟前
AI在网络安全中的新角色:智能检测与预测防御
软件测试·人工智能·安全·web安全·网络安全·质量效能
伊织code32 分钟前
python-docx - 读写更新 .docx 文件(Microsoft Word 2007+ )
python·microsoft·word·docx·python-docx·openxml
徐礼昭|商派软件市场负责人35 分钟前
如何搭建一套工业品跨境出海B2B商城平台?|商派B2B系统解决方案
人工智能·跨境出海·工业品b2b·b2b平台
离开地球表面_9938 分钟前
LLM、Prompt、AI Agent、RAG... 一网打尽大模型热门概念
人工智能·llm
羊小猪~~44 分钟前
深度学习项目--基于DenseNet网络的“乳腺癌图像识别”,准确率090%+,pytorch复现
网络·人工智能·pytorch·python·深度学习·机器学习·分类
Eagle_Clark1 小时前
提示词工程(Prompt Engineering)
人工智能·aigc·openai
蚝油菜花1 小时前
Deep Research Web UI:开源版Deep Research!接入DeepSeek一键生成深度研究报告,可视化检索过程
人工智能·开源