【Stable Diffusion 1.5 】在 Unet 中每个 Cross Attention 块中的张量变化过程

系列文章目录


文章目录


前言

特征图 (Latent) 尺寸和注意力图(attention map)尺寸在扩散模型中有差异,是由于模型架构和注意力机制的特性决定的。

特征图和注意力图的尺寸差异原因

  1. 不同的功能目的

    • 特征图(Feature Maps):承载图像的语义和视觉特征,维持空间结构
    • 注意力图(Attention Maps):表示不同位置之间的关联强度,是一种关系矩阵
  2. UNet架构中的特征图尺寸

    在U-Net中,特征图的尺寸在不同层级有变化:

    • 输入图像通常是 512×512 或 256×256
    • 下采样路径(Encoder):尺寸逐渐缩小 (512→256→128→64→32→16...)
    • 上采样路径(Decoder):尺寸逐渐增大 (16→32→64→128→256→512...)

    在Break-a-Scene代码中,我们看到特征图尺寸被下采样到64×64:

    python 复制代码
    downsampled_mask = F.interpolate(input=max_masks, size=(64, 64))
  3. 注意力机制中的尺寸计算

    注意力机制处理的是"token"之间的关系,其中:

    • 自注意力(Self-Attention):特征图中的每个位置视为一个token
    • 交叉注意力(Cross-Attention):文本序列中的token与特征图中的位置建立关联

    如果特征图尺寸是h×w,则自注意力矩阵的尺寸是(hw)×(hw),这是一个平方关系

    在代码中,注意力图通常被下采样到16×16:

    python 复制代码
    GT_masks = F.interpolate(input=batch["instance_masks"][batch_idx], size=(16, 16))
  4. 计算效率考虑

    • 注意力计算的复杂度是O(n²),其中n是token数量
    • 对于64×64的特征图,如果直接计算自注意力,需要处理4096×4096的矩阵
    • 为了降低计算量,通常在较低分辨率(如16×16)的特征图上计算注意力,这样只需处理256×256的矩阵

在Break-a-Scene中的具体实现

在Break-a-Scene中,这些尺寸差异体现在:

  1. 两种不同的损失计算

    a. 掩码损失(Masked Loss):应用在64×64的 Latent 上

    python 复制代码
    max_masks = torch.max(batch["instance_masks"], axis=1).values
    downsampled_mask = F.interpolate(input=max_masks, size=(64, 64))
    model_pred = model_pred * downsampled_mask
    target = target * downsampled_mask

    b. 注意力损失(Attention Loss):应用在16×16的注意力图上

    python 复制代码
    GT_masks = F.interpolate(input=batch["instance_masks"][batch_idx], size=(16, 16))
    agg_attn = self.aggregate_attention(res=16, from_where=("up", "down"), is_cross=True, select=batch_idx)
  2. 注意力存储的筛选

    在存储注意力图时,只保留小尺寸的注意力图:

    python 复制代码
    def forward(self, attn, is_cross: bool, place_in_unet: str):
        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
        if attn.shape[1] <= 32**2:  # 只保存小于或等于32×32的注意力图
            self.step_store[key].append(attn)
        return attn
  3. 注意力聚合

    在聚合不同层的注意力时,确保只使用匹配目标分辨率的注意力图:

    python 复制代码
    def aggregate_attention(self, res: int, from_where: List[str], is_cross: bool, select: int):
        # ...
        num_pixels = res**2
        for location in from_where:
            for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
                if item.shape[1] == num_pixels:  # 只选择匹配分辨率的注意力图
                    cross_maps = item.reshape(self.args.train_batch_size, -1, res, res, item.shape[-1])[select]
                    out.append(cross_maps)
        # ...

总结

特征图和注意力图尺寸的差异主要是因为:

  1. 它们在模型中的功能不同
  2. 注意力计算的计算复杂度要求在较低分辨率上进行
  3. UNet架构中的不同层级有不同的特征图尺寸
  4. 为了平衡精度和计算效率,Break-a-Scene使用不同分辨率的特征图和注意力图来计算不同类型的损失

这种设计使得Break-a-Scene能够有效地学习token与图像区域之间的对应关系,同时保持计算效率。

相关推荐
love530love15 小时前
【避坑指南】提示词“闹鬼”?Stable Diffusion 自动注入神秘词汇 xiao yi xian 排查全记录
人工智能·windows·stable diffusion·model keyword
世界尽头与你15 小时前
Stable Diffusion web UI 未授权访问漏洞
安全·网络安全·stable diffusion·渗透测试
love530love15 小时前
【故障解析】Stable Diffusion WebUI 更换主题后启动报 JSONDecodeError?可能是“主题加载”惹的祸
人工智能·windows·stable diffusion·大模型·json·stablediffusion·gradio 主题
ai_xiaogui5 天前
Stable Diffusion Web UI 绘世版 v4.6.1 整合包:一键极速部署,深度解决 AI 绘画环境配置与 CUDA 依赖难题
人工智能·stable diffusion·环境零配置·高性能内核优化·全功能插件集成·极速部署体验
微学AI6 天前
金仓数据库的新格局:以多模融合开创文档数据库
人工智能·stable diffusion
我的golang之路果然有问题6 天前
开源绘画大模型简单了解
人工智能·ai作画·stable diffusion·人工智能作画
我的golang之路果然有问题7 天前
comfyUI中的动作提取分享
人工智能·stable diffusion·ai绘画·人工智能作画·comfy
stephen one10 天前
2026 AI深度伪造危机:实测 Midjourney v7 与 Flux 2 Max 识别,谁才是 AI 检测的天花板?
人工智能·ai作画·stable diffusion·aigc·midjourney
长不大的蜡笔小新13 天前
基于Stable Diffusion的多模态图像生成与识别系统
stable diffusion
米汤爱学习14 天前
stable-diffusion-webui【笔记】
笔记·stable diffusion