文章目录
- [`image_tokens` 和 `image_ids` 的详细获取过程](#
image_tokens和image_ids的详细获取过程) -
- [`image_tokens` 的获取过程](#
image_tokens的获取过程) -
- [`_pack_latents` 方法详解](#
_pack_latents方法详解)
- [`_pack_latents` 方法详解](#
- [`image_ids` 的获取过程](#
image_ids的获取过程) -
- [`_prepare_latent_image_ids` 方法详解](#
_prepare_latent_image_ids方法详解)
- [`_prepare_latent_image_ids` 方法详解](#
- 总结
- [`image_tokens` 的获取过程](#
Flux 中的 RoPE 代码解析
image_tokens 和 image_ids 的详细获取过程
image_tokens 和 image_ids 是在扩散模型中处理图像潜在表示的两个关键组件。下面我将详细解析它们的获取过程。
image_tokens 的获取过程
image_tokens 是通过 _pack_latents 方法从VAE编码后的潜在表示中获得的:
python
images_tokens = pipeline._pack_latents(images, *images.shape)
_pack_latents 方法详解
python
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
这个方法执行以下步骤:
-
重塑潜在表示:
pythonlatents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)- 将形状为
[batch_size, num_channels_latents, height, width]的潜在表示重塑为[batch_size, num_channels_latents, height//2, 2, width//2, 2] - 这实际上是将每个2×2的区域分组在一起
- 将形状为
-
重排维度:
pythonlatents = latents.permute(0, 2, 4, 1, 3, 5)- 将维度从
[batch, channels, h//2, 2, w//2, 2]重排为[batch, h//2, w//2, channels, 2, 2] - 这将空间位置(h//2, w//2)放在前面,将通道和2×2块放在后面
- 将维度从
-
最终重塑:
pythonlatents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)- 将张量重塑为
[batch_size, (h//2)*(w//2), channels*4] - 这将每个空间位置的所有通道和2×2块合并为一个特征向量
- 将张量重塑为
结果是一个形状为 [batch_size, (height//2)*(width//2), num_channels_latents*4] 的张量,其中:
- 第一维是批次大小
- 第二维是空间位置的数量(每个2×2块作为一个位置)
- 第三维是每个位置的特征维度(原始通道数×4)
image_ids 的获取过程
image_ids 是通过 _prepare_latent_image_ids 方法生成的:
python
images_ids = pipeline._prepare_latent_image_ids(
images.shape[0],
images.shape[2],
images.shape[3],
pipeline.device,
pipeline.dtype,
)
_prepare_latent_image_ids 方法详解
python
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
这个方法执行以下步骤:
-
创建基础张量:
pythonlatent_image_ids = torch.zeros(height // 2, width // 2, 3)- 创建一个形状为
[height//2, width//2, 3]的全零张量 - 这个张量将用于存储每个空间位置的坐标信息
- 创建一个形状为
-
填充行坐标:
pythonlatent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]- 在第二个通道(索引1)填充行坐标,范围从0到height//2-1
[:, None]使得行坐标在垂直方向上重复
-
填充列坐标:
pythonlatent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]- 在第三个通道(索引2)填充列坐标,范围从0到width//2-1
[None, :]使得列坐标在水平方向上重复
-
重塑张量:
pythonlatent_image_ids = latent_image_ids.reshape( latent_image_id_height * latent_image_id_width, latent_image_id_channels )- 将张量从
[height//2, width//2, 3]重塑为[(height//2)*(width//2), 3] - 这将二维空间位置展平为一维序列
- 将张量从
-
转换设备和数据类型:
pythonreturn latent_image_ids.to(device=device, dtype=dtype)- 将张量移动到指定设备并转换为指定数据类型
结果是一个形状为 [(height//2)*(width//2), 3] 的张量,其中:
- 第一维是展平后的空间位置数量
- 第二维包含3个值,第一个值始终为0,第二个值是行坐标,第三个值是列坐标
总结
-
image_tokens:- 形状:
[batch_size, (height//2)*(width//2), num_channels_latents*4] - 内容:每个空间位置的特征向量,通过重组原始潜在表示获得
- 作用:包含图像的语义和视觉信息,用于模型的条件生成
- 形状:
-
image_ids:- 形状:
[(height//2)*(width//2), 3] - 内容:每个空间位置的坐标信息(0, 行, 列)
- 作用:帮助模型定位和组织特征,维持空间结构信息
- 形状:
这两个组件共同工作,使模型能够理解和处理图像的潜在表示,同时保持空间结构信息,这对于图像生成和编辑任务至关重要。