【扩散模型(六)】Stable Diffusion 3 diffusers 源码详解1-推理代码-文本处理部分

系列文章目录


文章目录

  • 系列文章目录
  • 前言
  • 一、文本处理的整体流程
  • [二、Text Encoder 1、2(CLIP)](#二、Text Encoder 1、2(CLIP))
    • [1. 模型部分](#1. 模型部分)
    • [2. 两个 Text Encoder 的输入和输出](#2. 两个 Text Encoder 的输入和输出)
  • [三、Text Encoder 3(T5)](#三、Text Encoder 3(T5))
  • 其他

前言

下图为《Scaling Rectified Flow Transformers for High-Resolution Image Synthesis》 (ICML 2024 )中的 SD3 架构图。


一、文本处理的整体流程

下面流程图只对正向提示词进行了梳理,负向提示词的流程并无差异。

本文分析的源代码为 diffusers 包中的 SD3 pipeline (位置在/path/to/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py),文本处理部分主要为 其中 __call__() 函数调用的 self.encode_prompt() 函数,主要涉及了 3 个 text encoder 以及对应的 3 个 tokenizer。

其输入输出如下:

python 复制代码
 (
     prompt_embeds,
     negative_prompt_embeds,
     pooled_prompt_embeds,
     negative_pooled_prompt_embeds,
 ) = self.encode_prompt(
     prompt=prompt,
     prompt_2=prompt_2,
     prompt_3=prompt_3,
     negative_prompt=negative_prompt,
     negative_prompt_2=negative_prompt_2,
     negative_prompt_3=negative_prompt_3,
     do_classifier_free_guidance=self.do_classifier_free_guidance,
     prompt_embeds=prompt_embeds,
     negative_prompt_embeds=negative_prompt_embeds,
     pooled_prompt_embeds=pooled_prompt_embeds,
     negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
     device=device,
     clip_skip=self.clip_skip,
     num_images_per_prompt=num_images_per_prompt,
     max_sequence_length=max_sequence_length,
 )

输入:

  • 其中 prompt 和 negative_prompt 为输入的字符串
  • 其他的 prompt_2、 prompt_3、 negative_prompt_2、 negative_prompt_3、prompt_embeds、 negative_prompt_embeds、pooled_prompt_embeds、negative_pooled_prompt_embeds 均为 None
  • do_classifier_free_guidance 一般都是 True
  • max_sequence_length = 256

具体而言是在 encode_prompt 函数中,通过两次 _get_clip_prompt_embeds_get_t5_prompt_embeds 来调用 3 个 Text Encoder。

python 复制代码
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
    prompt=prompt,
    device=device,
    num_images_per_prompt=num_images_per_prompt,
    clip_skip=clip_skip,
    clip_model_index=0,
)
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
    prompt=prompt_2,
    device=device,
    num_images_per_prompt=num_images_per_prompt,
    clip_skip=clip_skip,
    clip_model_index=1,
)
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)

t5_prompt_embed = self._get_t5_prompt_embeds(
    prompt=prompt_3,
    num_images_per_prompt=num_images_per_prompt,
    max_sequence_length=max_sequence_length,
    device=device,
)

二、Text Encoder 1、2(CLIP)

1. 模型部分

  • 根据输入的 clip_tokenizers、clip_text_encoders 序号分别选择 text_encoder (CLIP L/14^1^) 或者 text_encoder_2 (OpenCLIP bigG/14^2^)。
  • 从下面初始化代码可以看出,二者 text_encodertext_encoder_2 采用的类一致,所以二者的区别主要是模型权重以及 config 不同。
python 复制代码
...
def __init__(...
     text_encoder: CLIPTextModelWithProjection,
     tokenizer: CLIPTokenizer,
     text_encoder_2: CLIPTextModelWithProjection,
     tokenizer_2: CLIPTokenizer,
...

 def _get_clip_prompt_embeds(
     self,
     prompt: Union[str, List[str]],
     num_images_per_prompt: int = 1,
     device: Optional[torch.device] = None,
     clip_skip: Optional[int] = None,
     clip_model_index: int = 0,
 ):
     device = device or self._execution_device

     clip_tokenizers = [self.tokenizer, self.tokenizer_2]
     clip_text_encoders = [self.text_encoder, self.text_encoder_2]

     tokenizer = clip_tokenizers[clip_model_index]
     text_encoder = clip_text_encoders[clip_model_index]

在下载的 SD3 模型权重文件中,/path/to/stable-diffusion-3-medium-diffusers 可以找到 text_encodertext_encoder_2 子目录,对比其中的 config(下图中左边为 text_encoder ,右边为 text_encoder_2 ),可以知道二者更具体的不同之处:

  1. hidden_size 不同:768 vs 1280
  2. hidden_act: quick_gelu vs gelu
  3. intermediate_size 不同:3072 vs 5120
  4. "num_attention_heads" 和 "num_hidden_layers":12/12 vs 20/32
  5. projection_dim 不同:768 vs 1280
  • 从以上 config 中,可以明显看出 text_encoder_2 (OpenCLIP bigG/14) 确实更加 big。
  • 两个 Text Encoder 最终的输出也和上文 "一、文本处理的整体流程 " 中的流程图一致,分别输出 [n, 77, 768 ] 和 [n, 77, 1280]。
    • n 为推理时的 num_images_per_prompt,每个 prompt 的出图数量。

2. 两个 Text Encoder 的输入和输出

  • 二者的输入是相同的 prompt,得到输出为不同的两对 prompt_embed, pooled_prompt_embedprompt_2_embed, pooled_prompt_2_embed
  • 其中,
    • prompt_embed [n, 77, 768 ] 和 prompt_2_embed [n, 77, 1280]为主要的 prompt 特征,并在后续 cat 到一起,得到 clip_prompt_embeds [n, 77, 2048]。
    • pooled_prompt_embed 和 pooled_prompt_2_embed 也一样 cat,
    • 两种特质的区别:prompt_embed(prompt_2_embed)是更主要/细粒度的文本特征 、而 pooled_prompt_embed(pooled_prompt_2_embed)是更粗粒度的文本特征
    • 原文:However, as the pooled text representation retains only coarse-grained information about the text input ^3^, the network also requires information from the sequence representation c t x t c_{txt} ctxt.
python 复制代码
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
    prompt=prompt,
    device=device,
    num_images_per_prompt=num_images_per_prompt,
    clip_skip=clip_skip,
    clip_model_index=0,
)
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
    prompt=prompt_2,
    device=device,
    num_images_per_prompt=num_images_per_prompt,
    clip_skip=clip_skip,
    clip_model_index=1,
)
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
...
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)

三、Text Encoder 3(T5)

T5EncoderModel 的调用则更简洁一点,输入同样是 prompt,并且只有一个输出。

python 复制代码
def __init__(...
	text_encoder_3: T5EncoderModel,
       tokenizer_3: T5TokenizerFast,
...

t5_prompt_embed = self._get_t5_prompt_embeds(
    prompt=prompt_3,
    num_images_per_prompt=num_images_per_prompt,
    max_sequence_length=max_sequence_length,
    device=device,
)

# 实际为 clip_prompt_embeds = torch.nn.functional.pad(
#    clip_prompt_embeds, (0, 4096-2048)
#),即在后面 2048 个维度上 pad 全 0. 
clip_prompt_embeds = torch.nn.functional.pad(
    clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
)

# 在序列长度的维度(-2)上 cat 到一起,得到 77+256 = 333 的长度
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
  • 作用:增强对复杂文本的生成能力。
  • 原文:T5 对于复杂的提示词很重要,例如涉及高度细节或拼写较长的文本(第2行和第3行)。然而,对于大多数提示,作者发现在推理时删除T5仍然可以获得具有竞争力的性能。

其他

强烈安利另外一位博主的文章:

  1. Stable Diffusion1.5网络结构-超详细原创
  2. Stable Diffusion XL网络结构-超详细原创

  1. Learning transferable visual models from natural language supervision, 2021. ↩︎

  2. Reproducible scaling laws for contrastive language-image learning. In 2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2023. doi: 10.1109/cvpr52729.2023.00276. URL http://dx.doi.org/10.1109/CVPR52729.2 023.00276. ↩︎

  3. Sdxl: Improving latent diffusion models for high-resolution image synthesis, 2023. ↩︎

相关推荐
亚图跨际1 分钟前
Python和R荧光分光光度法
开发语言·python·r语言·荧光分光光度法
飞凌嵌入式3 分钟前
飞凌嵌入式T113-i开发板RISC-V核的实时应用方案
人工智能·嵌入式硬件·嵌入式·risc-v·飞凌嵌入式
sinovoip5 分钟前
Banana Pi BPI-CanMV-K230D-Zero 采用嘉楠科技 K230D RISC-V芯片设计
人工智能·科技·物联网·开源·risc-v
谢眠18 分钟前
深度学习day3-自动微分
python·深度学习·机器学习
搏博27 分钟前
神经网络问题之一:梯度消失(Vanishing Gradient)
人工智能·机器学习
z千鑫27 分钟前
【人工智能】深入理解PyTorch:从0开始完整教程!全文注解
人工智能·pytorch·python·gpt·深度学习·ai编程
YRr YRr35 分钟前
深度学习:神经网络的搭建
人工智能·深度学习·神经网络
威桑38 分钟前
CMake + mingw + opencv
人工智能·opencv·计算机视觉
爱喝热水的呀哈喽41 分钟前
torch张量与函数表达式写法
人工智能·pytorch·深度学习
MessiGo1 小时前
Python 爬虫 (1)基础 | 基础操作
开发语言·python