冻结语言模型中的 自注意力层,使其参数不参与训练(梯度不会更新)。 对于跨注意力层,则解冻参数,使这些层可以进行梯度更新,从而参与训练。

python 复制代码
def freeze_LLM_only(model):
    """
    Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned
    """
    for name, param in model.language_model.named_parameters():
                param.requires_grad = False
    for i, layer in enumerate(model.language_model.model.layers):
        if i in model.language_model.model.cross_attention_layers:
            for param in layer.parameters():
                param.requires_grad = True

这段代码的作用是:

  1. 冻结语言模型中的 自注意力层,使其参数不参与训练(梯度不会更新)。
  2. 对于跨注意力层,则解冻参数,使这些层可以进行梯度更新,从而参与训练。

逐步拆解

1. 函数签名
python 复制代码
def freeze_LLM_only(model):
  • 目的:在多模态模型中,仅冻结语言模型(LLM)的自注意力层(Self-Attention),而保留跨注意力层(Cross-Attention)和其他部分(如视觉模型、投影模块)的可训练性。

2. 冻结语言模型参数
python 复制代码
for name, param in model.language_model.named_parameters():
    param.requires_grad = False
  • 遍历语言模型 model.language_model 中的所有参数。
  • 操作 :将所有参数的 requires_grad 属性设置为 False,使它们在训练中不会被更新。

3. 解冻跨注意力层参数
python 复制代码
for i, layer in enumerate(model.language_model.model.layers):
    if i in model.language_model.model.cross_attention_layers:
        for param in layer.parameters():
            param.requires_grad = True
  • 遍历语言模型中的每一层(假设 model.language_model.model.layers 是存储所有 Transformer 层的列表)。
  • 判断当前层是否属于跨注意力层:
    • 如果当前层的索引 i 属于 cross_attention_layers(一个存储跨注意力层索引的列表),解冻该层的参数。
    • 操作 :设置 requires_grad=True,使这些层在训练中可以更新。

举例说明

假设模型结构:
  • 一个多模态模型 model 包含以下部分:
    1. language_model(语言模型)
    2. vision_model(视觉模型)
    3. multi_modal_projector(多模态投影模块)
  • 语言模型 language_model
    • 有 6 层 Transformer 层存储在 model.language_model.model.layers
    • 跨注意力层的索引存储在 model.language_model.model.cross_attention_layers = [2, 4]

使用示例
python 复制代码
# 模拟一个模型对象
class DummyModel:
    def __init__(self):
        self.language_model = self.LanguageModel()

    class LanguageModel:
        def __init__(self):
            self.model = self.Model()

        class Model:
            def __init__(self):
                # 假设有 6 层 Transformer
                self.layers = [nn.Linear(10, 10) for _ in range(6)]
                self.cross_attention_layers = [2, 4]

# 创建模型实例
model = DummyModel()

# 冻结自注意力层,解冻跨注意力层
freeze_LLM_only(model)

验证参数的状态
python 复制代码
for i, layer in enumerate(model.language_model.model.layers):
    print(f"Layer {i}: requires_grad = {any(param.requires_grad for param in layer.parameters())}")
输出:
复制代码
Layer 0: requires_grad = False
Layer 1: requires_grad = False
Layer 2: requires_grad = True
Layer 3: requires_grad = False
Layer 4: requires_grad = True
Layer 5: requires_grad = False

总结

  1. 目标:冻结语言模型中的自注意力层,仅训练跨注意力层。
  2. 适用场景
    • 在多模态任务中,只需要调整跨注意力层以实现语言与其他模态(如视觉)的交互,而保持语言模型自注意力层的知识不被破坏。
  3. 灵活性 :可以通过调整 cross_attention_layers 的索引选择要解冻的层。

这里是通过索引i判断出是不是 属于 cross_attention_layer,可以对这段代码进行优化。

相关推荐
Cedric1113几秒前
机器学习中的距离总结
人工智能·机器学习
大模型真好玩5 分钟前
深入浅出LangGraph AI Agent智能体开发教程(五)—LangGraph 数据分析助手智能体项目实战
人工智能·python·mcp
IT_陈寒18 分钟前
React性能优化:这5个Hook技巧让我的组件渲染效率提升50%(附代码对比)
前端·人工智能·后端
Captaincc20 分钟前
9 月 20 日,TRAE Meetup@Guangzhou 相聚羊城
人工智能·后端
霍格沃兹软件测试开发35 分钟前
快速掌握Dify+Chrome MCP:打造网页操控AI助手
人工智能·chrome·dify·mcp
张子夜 iiii1 小时前
4步OpenCV-----扫秒身份证号
人工智能·python·opencv·计算机视觉
华新嘉华DTC创新营销3 小时前
华新嘉华:AI搜索优化重塑本地生活行业:智能推荐正取代“关键词匹配”
人工智能·百度·生活
第七序章3 小时前
【C++STL】list的详细用法和底层实现
c语言·c++·自然语言处理·list
SmartBrain4 小时前
DeerFlow 实践:华为IPD流程的评审智能体设计
人工智能·语言模型·架构
l1t5 小时前
利用DeepSeek实现服务器客户端模式的DuckDB原型
服务器·c语言·数据库·人工智能·postgresql·协议·duckdb