【RoPE】Flux 中的 Image Tokenization

文章目录

  • [`image_tokens` 和 `image_ids` 的详细获取过程](#image_tokensimage_ids 的详细获取过程)
    • [`image_tokens` 的获取过程](#image_tokens 的获取过程)
      • [`_pack_latents` 方法详解](#_pack_latents 方法详解)
    • [`image_ids` 的获取过程](#image_ids 的获取过程)
      • [`_prepare_latent_image_ids` 方法详解](#_prepare_latent_image_ids 方法详解)
    • 总结

Flux 中的 RoPE 代码解析

image_tokensimage_ids 的详细获取过程

image_tokensimage_ids 是在扩散模型中处理图像潜在表示的两个关键组件。下面我将详细解析它们的获取过程。

image_tokens 的获取过程

image_tokens 是通过 _pack_latents 方法从VAE编码后的潜在表示中获得的:

python 复制代码
images_tokens = pipeline._pack_latents(images, *images.shape)

_pack_latents 方法详解

python 复制代码
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
    latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
    latents = latents.permute(0, 2, 4, 1, 3, 5)
    latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)

    return latents

这个方法执行以下步骤:

  1. 重塑潜在表示

    python 复制代码
    latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
    • 将形状为 [batch_size, num_channels_latents, height, width] 的潜在表示重塑为 [batch_size, num_channels_latents, height//2, 2, width//2, 2]
    • 这实际上是将每个2×2的区域分组在一起
  2. 重排维度

    python 复制代码
    latents = latents.permute(0, 2, 4, 1, 3, 5)
    • 将维度从 [batch, channels, h//2, 2, w//2, 2] 重排为 [batch, h//2, w//2, channels, 2, 2]
    • 这将空间位置(h//2, w//2)放在前面,将通道和2×2块放在后面
  3. 最终重塑

    python 复制代码
    latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
    • 将张量重塑为 [batch_size, (h//2)*(w//2), channels*4]
    • 这将每个空间位置的所有通道和2×2块合并为一个特征向量

结果是一个形状为 [batch_size, (height//2)*(width//2), num_channels_latents*4] 的张量,其中:

  • 第一维是批次大小
  • 第二维是空间位置的数量(每个2×2块作为一个位置)
  • 第三维是每个位置的特征维度(原始通道数×4)

image_ids 的获取过程

image_ids 是通过 _prepare_latent_image_ids 方法生成的:

python 复制代码
images_ids = pipeline._prepare_latent_image_ids(
    images.shape[0],
    images.shape[2],
    images.shape[3],
    pipeline.device,
    pipeline.dtype,
)

_prepare_latent_image_ids 方法详解

python 复制代码
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
    latent_image_ids = torch.zeros(height // 2, width // 2, 3)
    latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
    latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]

    latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

    latent_image_ids = latent_image_ids.reshape(
        latent_image_id_height * latent_image_id_width, latent_image_id_channels
    )

    return latent_image_ids.to(device=device, dtype=dtype)

这个方法执行以下步骤:

  1. 创建基础张量

    python 复制代码
    latent_image_ids = torch.zeros(height // 2, width // 2, 3)
    • 创建一个形状为 [height//2, width//2, 3] 的全零张量
    • 这个张量将用于存储每个空间位置的坐标信息
  2. 填充行坐标

    python 复制代码
    latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
    • 在第二个通道(索引1)填充行坐标,范围从0到height//2-1
    • [:, None] 使得行坐标在垂直方向上重复
  3. 填充列坐标

    python 复制代码
    latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
    • 在第三个通道(索引2)填充列坐标,范围从0到width//2-1
    • [None, :] 使得列坐标在水平方向上重复
  4. 重塑张量

    python 复制代码
    latent_image_ids = latent_image_ids.reshape(
        latent_image_id_height * latent_image_id_width, latent_image_id_channels
    )
    • 将张量从 [height//2, width//2, 3] 重塑为 [(height//2)*(width//2), 3]
    • 这将二维空间位置展平为一维序列
  5. 转换设备和数据类型

    python 复制代码
    return latent_image_ids.to(device=device, dtype=dtype)
    • 将张量移动到指定设备并转换为指定数据类型

结果是一个形状为 [(height//2)*(width//2), 3] 的张量,其中:

  • 第一维是展平后的空间位置数量
  • 第二维包含3个值,第一个值始终为0,第二个值是行坐标,第三个值是列坐标

总结

  1. image_tokens

    • 形状:[batch_size, (height//2)*(width//2), num_channels_latents*4]
    • 内容:每个空间位置的特征向量,通过重组原始潜在表示获得
    • 作用:包含图像的语义和视觉信息,用于模型的条件生成
  2. image_ids

    • 形状:[(height//2)*(width//2), 3]
    • 内容:每个空间位置的坐标信息(0, 行, 列)
    • 作用:帮助模型定位和组织特征,维持空间结构信息

这两个组件共同工作,使模型能够理解和处理图像的潜在表示,同时保持空间结构信息,这对于图像生成和编辑任务至关重要。

相关推荐
NAGNIP1 天前
轻松搞懂全连接神经网络结构!
人工智能·算法·面试
moshuying1 天前
别让AI焦虑,偷走你本该有的底气
前端·人工智能
董董灿是个攻城狮1 天前
零基础带你用 AI 搞定命令行
人工智能
喝拿铁写前端1 天前
Dify 构建 FE 工作流:前端团队可复用 AI 工作流实战
前端·人工智能
阿里云大数据AI技术1 天前
阿里云 EMR Serverless Spark + DataWorks 技术实践:引领企业 Data+AI 一体化转型
人工智能
billhan20161 天前
MCP 深入理解:协议原理与自定义开发
人工智能
用户8356290780511 天前
无需 Office:Python 批量转换 PPT 为图片
后端·python
Jahzo1 天前
openclaw桌面端体验--ClawX
人工智能·github
billhan20161 天前
Agent 开发全流程:从概念到生产
人工智能