冻结语言模型中的 自注意力层,使其参数不参与训练(梯度不会更新)。 对于跨注意力层,则解冻参数,使这些层可以进行梯度更新,从而参与训练。

python 复制代码
def freeze_LLM_only(model):
    """
    Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned
    """
    for name, param in model.language_model.named_parameters():
                param.requires_grad = False
    for i, layer in enumerate(model.language_model.model.layers):
        if i in model.language_model.model.cross_attention_layers:
            for param in layer.parameters():
                param.requires_grad = True

这段代码的作用是:

  1. 冻结语言模型中的 自注意力层,使其参数不参与训练(梯度不会更新)。
  2. 对于跨注意力层,则解冻参数,使这些层可以进行梯度更新,从而参与训练。

逐步拆解

1. 函数签名
python 复制代码
def freeze_LLM_only(model):
  • 目的:在多模态模型中,仅冻结语言模型(LLM)的自注意力层(Self-Attention),而保留跨注意力层(Cross-Attention)和其他部分(如视觉模型、投影模块)的可训练性。

2. 冻结语言模型参数
python 复制代码
for name, param in model.language_model.named_parameters():
    param.requires_grad = False
  • 遍历语言模型 model.language_model 中的所有参数。
  • 操作 :将所有参数的 requires_grad 属性设置为 False,使它们在训练中不会被更新。

3. 解冻跨注意力层参数
python 复制代码
for i, layer in enumerate(model.language_model.model.layers):
    if i in model.language_model.model.cross_attention_layers:
        for param in layer.parameters():
            param.requires_grad = True
  • 遍历语言模型中的每一层(假设 model.language_model.model.layers 是存储所有 Transformer 层的列表)。
  • 判断当前层是否属于跨注意力层:
    • 如果当前层的索引 i 属于 cross_attention_layers(一个存储跨注意力层索引的列表),解冻该层的参数。
    • 操作 :设置 requires_grad=True,使这些层在训练中可以更新。

举例说明

假设模型结构:
  • 一个多模态模型 model 包含以下部分:
    1. language_model(语言模型)
    2. vision_model(视觉模型)
    3. multi_modal_projector(多模态投影模块)
  • 语言模型 language_model
    • 有 6 层 Transformer 层存储在 model.language_model.model.layers
    • 跨注意力层的索引存储在 model.language_model.model.cross_attention_layers = [2, 4]

使用示例
python 复制代码
# 模拟一个模型对象
class DummyModel:
    def __init__(self):
        self.language_model = self.LanguageModel()

    class LanguageModel:
        def __init__(self):
            self.model = self.Model()

        class Model:
            def __init__(self):
                # 假设有 6 层 Transformer
                self.layers = [nn.Linear(10, 10) for _ in range(6)]
                self.cross_attention_layers = [2, 4]

# 创建模型实例
model = DummyModel()

# 冻结自注意力层,解冻跨注意力层
freeze_LLM_only(model)

验证参数的状态
python 复制代码
for i, layer in enumerate(model.language_model.model.layers):
    print(f"Layer {i}: requires_grad = {any(param.requires_grad for param in layer.parameters())}")
输出:
复制代码
Layer 0: requires_grad = False
Layer 1: requires_grad = False
Layer 2: requires_grad = True
Layer 3: requires_grad = False
Layer 4: requires_grad = True
Layer 5: requires_grad = False

总结

  1. 目标:冻结语言模型中的自注意力层,仅训练跨注意力层。
  2. 适用场景
    • 在多模态任务中,只需要调整跨注意力层以实现语言与其他模态(如视觉)的交互,而保持语言模型自注意力层的知识不被破坏。
  3. 灵活性 :可以通过调整 cross_attention_layers 的索引选择要解冻的层。

这里是通过索引i判断出是不是 属于 cross_attention_layer,可以对这段代码进行优化。

相关推荐
Antonio9151 小时前
【图像处理】图像的基础几何变换
图像处理·人工智能·计算机视觉
新加坡内哥谈技术2 小时前
Perplexity AI 的 RAG 架构全解析:幕后技术详解
人工智能
武子康2 小时前
AI研究-119 DeepSeek-OCR PyTorch FlashAttn 2.7.3 推理与部署 模型规模与资源详细分析
人工智能·深度学习·机器学习·ai·ocr·deepseek·deepseek-ocr
Sirius Wu3 小时前
深入浅出:Tongyi DeepResearch技术解读
人工智能·语言模型·langchain·aigc
忙碌5443 小时前
AI大模型时代下的全栈技术架构:从深度学习到云原生部署实战
人工智能·深度学习·架构
LZ_Keep_Running3 小时前
智能变电巡检:AI检测新突破
人工智能
InfiSight智睿视界4 小时前
AI 技术助力汽车美容行业实现精细化运营管理
大数据·人工智能
没有钱的钱仔5 小时前
机器学习笔记
人工智能·笔记·机器学习
听风吹等浪起5 小时前
基于改进TransUNet的港口船只图像分割系统研究
人工智能·深度学习·cnn·transformer
化作星辰5 小时前
深度学习_原理和进阶_PyTorch入门(2)后续语法3
人工智能·pytorch·深度学习