冻结语言模型中的 自注意力层,使其参数不参与训练(梯度不会更新)。 对于跨注意力层,则解冻参数,使这些层可以进行梯度更新,从而参与训练。

python 复制代码
def freeze_LLM_only(model):
    """
    Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned
    """
    for name, param in model.language_model.named_parameters():
                param.requires_grad = False
    for i, layer in enumerate(model.language_model.model.layers):
        if i in model.language_model.model.cross_attention_layers:
            for param in layer.parameters():
                param.requires_grad = True

这段代码的作用是:

  1. 冻结语言模型中的 自注意力层,使其参数不参与训练(梯度不会更新)。
  2. 对于跨注意力层,则解冻参数,使这些层可以进行梯度更新,从而参与训练。

逐步拆解

1. 函数签名
python 复制代码
def freeze_LLM_only(model):
  • 目的:在多模态模型中,仅冻结语言模型(LLM)的自注意力层(Self-Attention),而保留跨注意力层(Cross-Attention)和其他部分(如视觉模型、投影模块)的可训练性。

2. 冻结语言模型参数
python 复制代码
for name, param in model.language_model.named_parameters():
    param.requires_grad = False
  • 遍历语言模型 model.language_model 中的所有参数。
  • 操作 :将所有参数的 requires_grad 属性设置为 False,使它们在训练中不会被更新。

3. 解冻跨注意力层参数
python 复制代码
for i, layer in enumerate(model.language_model.model.layers):
    if i in model.language_model.model.cross_attention_layers:
        for param in layer.parameters():
            param.requires_grad = True
  • 遍历语言模型中的每一层(假设 model.language_model.model.layers 是存储所有 Transformer 层的列表)。
  • 判断当前层是否属于跨注意力层:
    • 如果当前层的索引 i 属于 cross_attention_layers(一个存储跨注意力层索引的列表),解冻该层的参数。
    • 操作 :设置 requires_grad=True,使这些层在训练中可以更新。

举例说明

假设模型结构:
  • 一个多模态模型 model 包含以下部分:
    1. language_model(语言模型)
    2. vision_model(视觉模型)
    3. multi_modal_projector(多模态投影模块)
  • 语言模型 language_model
    • 有 6 层 Transformer 层存储在 model.language_model.model.layers
    • 跨注意力层的索引存储在 model.language_model.model.cross_attention_layers = [2, 4]

使用示例
python 复制代码
# 模拟一个模型对象
class DummyModel:
    def __init__(self):
        self.language_model = self.LanguageModel()

    class LanguageModel:
        def __init__(self):
            self.model = self.Model()

        class Model:
            def __init__(self):
                # 假设有 6 层 Transformer
                self.layers = [nn.Linear(10, 10) for _ in range(6)]
                self.cross_attention_layers = [2, 4]

# 创建模型实例
model = DummyModel()

# 冻结自注意力层,解冻跨注意力层
freeze_LLM_only(model)

验证参数的状态
python 复制代码
for i, layer in enumerate(model.language_model.model.layers):
    print(f"Layer {i}: requires_grad = {any(param.requires_grad for param in layer.parameters())}")
输出:
复制代码
Layer 0: requires_grad = False
Layer 1: requires_grad = False
Layer 2: requires_grad = True
Layer 3: requires_grad = False
Layer 4: requires_grad = True
Layer 5: requires_grad = False

总结

  1. 目标:冻结语言模型中的自注意力层,仅训练跨注意力层。
  2. 适用场景
    • 在多模态任务中,只需要调整跨注意力层以实现语言与其他模态(如视觉)的交互,而保持语言模型自注意力层的知识不被破坏。
  3. 灵活性 :可以通过调整 cross_attention_layers 的索引选择要解冻的层。

这里是通过索引i判断出是不是 属于 cross_attention_layer,可以对这段代码进行优化。

相关推荐
焦耳热科技前沿4 小时前
北京科技大学/理化所ACS Nano:混合价态Cu₂Sb金属间化合物实现高效尿素电合成
大数据·人工智能·自动化·能源·材料工程
C+-C资深大佬4 小时前
Creo 11.0 全功能解析:多体设计 + 仿真制造,机械设计效率翻倍下载安装
人工智能
浔川python社4 小时前
【维护期间重要提醒】请勿使用浔川 AI 翻译 v6.0 翻译违规内容
人工智能
CS创新实验室5 小时前
AI 与编程
人工智能·编程·编程语言
min1811234565 小时前
深度伪造内容的检测与溯源技术
大数据·网络·人工智能
_codemonster5 小时前
高斯卷积的可加性定理
人工智能·计算机视觉
数据智研6 小时前
【数据分享】(2005–2016年)基于水资源承载力的华北地区降水与地下水要素数据
大数据·人工智能·信息可视化·数据分析
likuolei6 小时前
Spring AI框架完整指南
人工智能·python·spring
梵得儿SHI6 小时前
(第四篇)Spring AI 核心技术攻坚:多轮对话与记忆机制,打造有上下文的 AI
java·人工智能·spring·springai生态·上下文丢失问题·三类记忆·智能客服实战案
二哈喇子!6 小时前
PyTorch生态与昇腾平台适配:环境搭建与详细安装指南
人工智能·pytorch·python