python
def freeze_LLM_only(model):
"""
Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned
"""
for name, param in model.language_model.named_parameters():
param.requires_grad = False
for i, layer in enumerate(model.language_model.model.layers):
if i in model.language_model.model.cross_attention_layers:
for param in layer.parameters():
param.requires_grad = True
这段代码的作用是:
- 冻结语言模型中的 自注意力层,使其参数不参与训练(梯度不会更新)。
- 对于跨注意力层,则解冻参数,使这些层可以进行梯度更新,从而参与训练。
逐步拆解
1. 函数签名
python
def freeze_LLM_only(model):
- 目的:在多模态模型中,仅冻结语言模型(LLM)的自注意力层(Self-Attention),而保留跨注意力层(Cross-Attention)和其他部分(如视觉模型、投影模块)的可训练性。
2. 冻结语言模型参数
python
for name, param in model.language_model.named_parameters():
param.requires_grad = False
- 遍历语言模型
model.language_model
中的所有参数。 - 操作 :将所有参数的
requires_grad
属性设置为False
,使它们在训练中不会被更新。
3. 解冻跨注意力层参数
python
for i, layer in enumerate(model.language_model.model.layers):
if i in model.language_model.model.cross_attention_layers:
for param in layer.parameters():
param.requires_grad = True
- 遍历语言模型中的每一层(假设
model.language_model.model.layers
是存储所有 Transformer 层的列表)。 - 判断当前层是否属于跨注意力层:
- 如果当前层的索引
i
属于cross_attention_layers
(一个存储跨注意力层索引的列表),解冻该层的参数。 - 操作 :设置
requires_grad=True
,使这些层在训练中可以更新。
- 如果当前层的索引
举例说明
假设模型结构:
- 一个多模态模型
model
包含以下部分:language_model
(语言模型)vision_model
(视觉模型)multi_modal_projector
(多模态投影模块)
- 语言模型
language_model
:- 有 6 层
Transformer
层存储在model.language_model.model.layers
。 - 跨注意力层的索引存储在
model.language_model.model.cross_attention_layers = [2, 4]
。
- 有 6 层
使用示例
python
# 模拟一个模型对象
class DummyModel:
def __init__(self):
self.language_model = self.LanguageModel()
class LanguageModel:
def __init__(self):
self.model = self.Model()
class Model:
def __init__(self):
# 假设有 6 层 Transformer
self.layers = [nn.Linear(10, 10) for _ in range(6)]
self.cross_attention_layers = [2, 4]
# 创建模型实例
model = DummyModel()
# 冻结自注意力层,解冻跨注意力层
freeze_LLM_only(model)
验证参数的状态
python
for i, layer in enumerate(model.language_model.model.layers):
print(f"Layer {i}: requires_grad = {any(param.requires_grad for param in layer.parameters())}")
输出:
Layer 0: requires_grad = False
Layer 1: requires_grad = False
Layer 2: requires_grad = True
Layer 3: requires_grad = False
Layer 4: requires_grad = True
Layer 5: requires_grad = False
总结
- 目标:冻结语言模型中的自注意力层,仅训练跨注意力层。
- 适用场景 :
- 在多模态任务中,只需要调整跨注意力层以实现语言与其他模态(如视觉)的交互,而保持语言模型自注意力层的知识不被破坏。
- 灵活性 :可以通过调整
cross_attention_layers
的索引选择要解冻的层。