【RoPE】Flux 中的 Image Tokenization

文章目录

  • [`image_tokens` 和 `image_ids` 的详细获取过程](#image_tokensimage_ids 的详细获取过程)
    • [`image_tokens` 的获取过程](#image_tokens 的获取过程)
      • [`_pack_latents` 方法详解](#_pack_latents 方法详解)
    • [`image_ids` 的获取过程](#image_ids 的获取过程)
      • [`_prepare_latent_image_ids` 方法详解](#_prepare_latent_image_ids 方法详解)
    • 总结

Flux 中的 RoPE 代码解析

image_tokensimage_ids 的详细获取过程

image_tokensimage_ids 是在扩散模型中处理图像潜在表示的两个关键组件。下面我将详细解析它们的获取过程。

image_tokens 的获取过程

image_tokens 是通过 _pack_latents 方法从VAE编码后的潜在表示中获得的:

python 复制代码
images_tokens = pipeline._pack_latents(images, *images.shape)

_pack_latents 方法详解

python 复制代码
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
    latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
    latents = latents.permute(0, 2, 4, 1, 3, 5)
    latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)

    return latents

这个方法执行以下步骤:

  1. 重塑潜在表示

    python 复制代码
    latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
    • 将形状为 [batch_size, num_channels_latents, height, width] 的潜在表示重塑为 [batch_size, num_channels_latents, height//2, 2, width//2, 2]
    • 这实际上是将每个2×2的区域分组在一起
  2. 重排维度

    python 复制代码
    latents = latents.permute(0, 2, 4, 1, 3, 5)
    • 将维度从 [batch, channels, h//2, 2, w//2, 2] 重排为 [batch, h//2, w//2, channels, 2, 2]
    • 这将空间位置(h//2, w//2)放在前面,将通道和2×2块放在后面
  3. 最终重塑

    python 复制代码
    latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
    • 将张量重塑为 [batch_size, (h//2)*(w//2), channels*4]
    • 这将每个空间位置的所有通道和2×2块合并为一个特征向量

结果是一个形状为 [batch_size, (height//2)*(width//2), num_channels_latents*4] 的张量,其中:

  • 第一维是批次大小
  • 第二维是空间位置的数量(每个2×2块作为一个位置)
  • 第三维是每个位置的特征维度(原始通道数×4)

image_ids 的获取过程

image_ids 是通过 _prepare_latent_image_ids 方法生成的:

python 复制代码
images_ids = pipeline._prepare_latent_image_ids(
    images.shape[0],
    images.shape[2],
    images.shape[3],
    pipeline.device,
    pipeline.dtype,
)

_prepare_latent_image_ids 方法详解

python 复制代码
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
    latent_image_ids = torch.zeros(height // 2, width // 2, 3)
    latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
    latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]

    latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

    latent_image_ids = latent_image_ids.reshape(
        latent_image_id_height * latent_image_id_width, latent_image_id_channels
    )

    return latent_image_ids.to(device=device, dtype=dtype)

这个方法执行以下步骤:

  1. 创建基础张量

    python 复制代码
    latent_image_ids = torch.zeros(height // 2, width // 2, 3)
    • 创建一个形状为 [height//2, width//2, 3] 的全零张量
    • 这个张量将用于存储每个空间位置的坐标信息
  2. 填充行坐标

    python 复制代码
    latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
    • 在第二个通道(索引1)填充行坐标,范围从0到height//2-1
    • [:, None] 使得行坐标在垂直方向上重复
  3. 填充列坐标

    python 复制代码
    latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
    • 在第三个通道(索引2)填充列坐标,范围从0到width//2-1
    • [None, :] 使得列坐标在水平方向上重复
  4. 重塑张量

    python 复制代码
    latent_image_ids = latent_image_ids.reshape(
        latent_image_id_height * latent_image_id_width, latent_image_id_channels
    )
    • 将张量从 [height//2, width//2, 3] 重塑为 [(height//2)*(width//2), 3]
    • 这将二维空间位置展平为一维序列
  5. 转换设备和数据类型

    python 复制代码
    return latent_image_ids.to(device=device, dtype=dtype)
    • 将张量移动到指定设备并转换为指定数据类型

结果是一个形状为 [(height//2)*(width//2), 3] 的张量,其中:

  • 第一维是展平后的空间位置数量
  • 第二维包含3个值,第一个值始终为0,第二个值是行坐标,第三个值是列坐标

总结

  1. image_tokens

    • 形状:[batch_size, (height//2)*(width//2), num_channels_latents*4]
    • 内容:每个空间位置的特征向量,通过重组原始潜在表示获得
    • 作用:包含图像的语义和视觉信息,用于模型的条件生成
  2. image_ids

    • 形状:[(height//2)*(width//2), 3]
    • 内容:每个空间位置的坐标信息(0, 行, 列)
    • 作用:帮助模型定位和组织特征,维持空间结构信息

这两个组件共同工作,使模型能够理解和处理图像的潜在表示,同时保持空间结构信息,这对于图像生成和编辑任务至关重要。

相关推荐
callJJ3 小时前
Spring AI ImageModel 完全指南:用 OpenAI DALL-E 生成图像
大数据·人工智能·spring·openai·springai·图像模型
李日灐3 小时前
C++进阶必备:红黑树从 0 到 1: 手撕底层,带你搞懂平衡二叉树的平衡逻辑与黑高检验
开发语言·数据结构·c++·后端·面试·红黑树·自平衡二叉搜索树
铁蛋AI编程实战3 小时前
2026 大模型推理框架测评:vLLM 0.5/TGI 2.0/TensorRT-LLM 1.8/DeepSpeed-MII 0.9 性能与成本防线对比
人工智能·机器学习·vllm
23遇见3 小时前
CANN ops-nn 仓库高效开发指南:从入门到精通
人工智能
SAP工博科技3 小时前
SAP 公有云 ERP 多工厂多生产线数据统一管理技术实现解析
大数据·运维·人工智能
芷栀夏3 小时前
CANN ops-math:异构计算场景下基础数学算子的深度优化与硬件亲和设计解析
人工智能·cann
爱吃泡芙的小白白3 小时前
深入解析CNN中的BN层:从稳定训练到前沿演进
人工智能·神经网络·cnn·梯度爆炸·bn·稳定模型
Risehuxyc3 小时前
备份三个PHP程序
android·开发语言·php
聆风吟º3 小时前
CANN runtime 性能优化:异构计算下运行时组件的效率提升与资源利用策略
人工智能·深度学习·神经网络·cann