PyTorch 中 detach
和 no_grad
的应用:以 Llama 3 冻结参数为例
在深度学习中,特别是处理大型预训练模型(如 Hugging Face 的 Llama 3)时,我们经常需要"冻结"某些层的参数,使其在训练中保持不变。这种操作通常用于迁移学习(Transfer Learning),以减少计算开销或保留预训练模型的知识。本文将通过一个实际代码示例,结合 detach
和 no_grad
的使用,详细介绍如何冻结 Hugging Face 大模型的参数,并实现自定义计算。
1. 冻结参数的背景与方法
1.1 冻结参数的意义
冻结模型参数指的是在训练过程中固定一部分参数,使其不随梯度更新。这在以下场景非常常见:
- 迁移学习:保留预训练模型的知识,仅训练新的下游任务层。
- 减少计算负担:在大型模型中冻结底层特征提取层,避免重复计算。
1.2 冻结参数的方法
在 PyTorch 中,冻结参数通常有以下两种方式:
- 设置
requires_grad=False
,让冻结的参数不参与梯度计算。 - 使用
torch.no_grad()
或detach()
分离计算图,防止某些操作对参数产生影响。
2. 示例代码:使用 detach
冻结 Llama 3 模型的参数
2.1 安装 Hugging Face 和依赖
确保安装最新的 transformers
库以支持 Llama 3:
bash
pip install transformers torch
2.2 加载 Llama 3 并冻结参数
以下代码展示了如何加载 Llama 3 模型,冻结其参数,并使用 detach()
获取冻结输出。
python
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载预训练的 Llama 3 模型和分词器
model_name = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# 冻结模型的所有参数
for param in model.parameters():
param.requires_grad = False
# 输入样本文本
input_text = "Hugging Face's Llama 3 is amazing for transfer learning!"
inputs = tokenizer(input_text, return_tensors="pt")
# 使用 torch.no_grad() 和 detach 冻结模型的输出
with torch.no_grad():
frozen_output = model(**inputs).logits.detach()
# 假设我们添加一个新的操作
# 比如,添加一个偏置
custom_bias = torch.ones_like(frozen_output)
output_with_bias = frozen_output + custom_bias
# 打印结果
print("冻结后的输出(添加偏置):", output_with_bias)
3. 代码解析
3.1 冻结参数
在代码中,以下部分将模型的所有参数设置为不可训练:
python
for param in model.parameters():
param.requires_grad = False
这样,参数就不会在反向传播中更新,且显存占用减少。
3.2 使用 detach()
以下代码通过 detach()
分离计算图:
python
frozen_output = model(**inputs).logits.detach()
- 作用 :断开梯度传播路径,使得
frozen_output
不与原始计算图关联。 - 优点 :相比于
no_grad()
,detach()
仅作用于特定张量,灵活性更强。
3.3 添加自定义计算
在冻结模型的输出后,我们可以继续对输出执行新的操作。例如:
python
custom_bias = torch.ones_like(frozen_output)
output_with_bias = frozen_output + custom_bias
这些操作不会影响原始模型的参数或计算图。
4. 为什么选择 detach
?
4.1 detach
与 no_grad
的区别
功能 | detach |
torch.no_grad() |
---|---|---|
作用范围 | 对单个张量生效 | 作用于整个上下文环境 |
是否保留计算图 | 保留原始计算图,但分离当前张量的梯度路径 | 不会记录任何计算图 |
适用场景 | 需要冻结部分张量的梯度传播 | 推理阶段或临时禁用梯度计算 |
4.2 detach
的优势
在冻结模型时,我们通常需要对部分输出张量进行自定义计算。使用 detach()
,既可以保留原始计算图(用于后续梯度计算),又可以灵活操作冻结的部分张量。
5. 输出示例
运行上述代码时,打印的输出如下(值为随机生成):
c
冻结后的输出(添加偏置): tensor([[...]])
frozen_output
是从 Llama 3 模型中提取的冻结输出,并在其基础上添加了自定义偏置。
6. 注意事项
-
显存管理
冻结参数可以显著减少显存使用,但
detach()
只分离了梯度,不会减少前向传播的内存开销。如果需要更高效的推理,应配合torch.no_grad()
。 -
冻结部分参数
如果只想冻结部分参数,可以有选择地设置
requires_grad
。例如:pythonfor name, param in model.named_parameters(): if "layer1" in name: param.requires_grad = False
-
保存与加载
冻结参数后,保存模型时应注意区分冻结的部分和可训练的部分。
总结
通过本文的介绍,你应该对 PyTorch 中 detach
和 no_grad
的使用有了更深的理解。我们以 Hugging Face 的 Llama 3 为例,演示了如何冻结参数、分离计算图,并在冻结输出的基础上进行自定义计算。这种方法在迁移学习和模型优化中非常实用,是深度学习开发者必须掌握的技能之一。
后记
2024年12月13日10点38分于上海,在GPT4o大模型辅助下完成。