【扩散模型(十)】IP-Adapter 源码详解 4 - 训练细节、具体训了哪些层?

系列文章目录


文章目录


通过前面的系列文章,很清楚要训练的就是 image_proj_model(或者对于 plus 来说是 resampler) 和 adapter_modules 两块。

而 image_proj_model 这块比较简单,原码如下所示

python 复制代码
    # freeze parameters of models to save more memory
    unet.requires_grad_(False)
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    image_encoder.requires_grad_(False)
    
    #ip-adapter
    image_proj_model = ImageProjModel(
        cross_attention_dim=unet.config.cross_attention_dim,
        clip_embeddings_dim=image_encoder.config.projection_dim,
        clip_extra_context_tokens=4,
    )

adapter_modules 分为两类

  1. AttnProcessor 对应 self attention
  2. IPAttnProcessor 对应 cross attention

按理说 self attention 对应的 AttnProcessor 应该不会被训练,但是 training = True,便让人非常费解。

进一步查看 AttnProcessor2_0 和 IPAttnProcessor2_0 后,就清楚了,因为从 AttnProcessor2_0 的构造函数(init)中并没有参数,就算是 trianing = True 也并不影响训练,实际训练的模块还是 IPAttnProcessor2_0 构造函数中的 to_k_ip 和 to_v_ip 两层 linear!

python 复制代码
class AttnProcessor2_0(torch.nn.Module):
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
...

class IPAttnProcessor2_0(torch.nn.Module):
    r"""
    Attention processor for IP-Adapater for PyTorch 2.0.
    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        scale (`float`, defaults to 1.0):
            the weight scale of image prompt.
        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
            The context length of the image features.
    """

    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
        super().__init__()

        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.num_tokens = num_tokens

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

    def __call__(

总结

  1. IP-Adapter 训的就是 image_proj_model(或者对于 plus 来说是 resampler) 和 adapter_modules 两块。
  2. 在 adapter_modules 中,实际只训了 IPAttnProcessor2_0 的 to_k_ip 和 to_v_ip。
  3. adapter_modules 是在每个有含有 cross attention 的 unet block 里进行的替换,如下图所示。
相关推荐
checkcheckck15 分钟前
spring ai 适配 流式回答、mcp、milvus向量数据库、rag、聊天会话记忆
人工智能
Microvision维视智造16 分钟前
从“人工眼”到‘智能眼’:EZ-Vision视觉系统如何重构生产线视觉检测精度?
图像处理·人工智能·重构·视觉检测
巫婆理发22226 分钟前
神经网络(多层感知机)(第二课第二周)
人工智能·深度学习·神经网络
lxmyzzs30 分钟前
【打怪升级 - 03】YOLO11/YOLO12/YOLOv10/YOLOv8 完全指南:从理论到代码实战,新手入门必看教程
人工智能·神经网络·yolo·目标检测·计算机视觉
SEO_juper31 分钟前
企业级 AI 工具选型报告:9 个技术平台的 ROI 对比与部署策略
人工智能·搜索引擎·百度·llm·工具·geo·数字营销
Coovally AI模型快速验证43 分钟前
数据集分享 | 智慧农业实战数据集精选
人工智能·算法·目标检测·机器学习·计算机视觉·目标跟踪·无人机
xw337340956444 分钟前
彩色转灰度的核心逻辑:三种经典方法及原理对比
人工智能·python·深度学习·opencv·计算机视觉
蓝桉8021 小时前
opencv学习(图像金字塔)
人工智能·opencv·学习
倔强青铜三1 小时前
为什么 self 与 super() 成了 Python 的永恒痛点?
人工智能·python·面试
墨尘游子1 小时前
目标导向的强化学习:问题定义与 HER 算法详解—强化学习(19)
人工智能·python·算法