PyTorch 中detach 和no_grad的应用:以 Llama 3 冻结参数为例

PyTorch 中 detachno_grad 的应用:以 Llama 3 冻结参数为例

在深度学习中,特别是处理大型预训练模型(如 Hugging Face 的 Llama 3)时,我们经常需要"冻结"某些层的参数,使其在训练中保持不变。这种操作通常用于迁移学习(Transfer Learning),以减少计算开销或保留预训练模型的知识。本文将通过一个实际代码示例,结合 detachno_grad 的使用,详细介绍如何冻结 Hugging Face 大模型的参数,并实现自定义计算。


1. 冻结参数的背景与方法

1.1 冻结参数的意义

冻结模型参数指的是在训练过程中固定一部分参数,使其不随梯度更新。这在以下场景非常常见:

  • 迁移学习:保留预训练模型的知识,仅训练新的下游任务层。
  • 减少计算负担:在大型模型中冻结底层特征提取层,避免重复计算。
1.2 冻结参数的方法

在 PyTorch 中,冻结参数通常有以下两种方式:

  1. 设置 requires_grad=False,让冻结的参数不参与梯度计算。
  2. 使用 torch.no_grad()detach() 分离计算图,防止某些操作对参数产生影响。

2. 示例代码:使用 detach 冻结 Llama 3 模型的参数

2.1 安装 Hugging Face 和依赖

确保安装最新的 transformers 库以支持 Llama 3:

bash 复制代码
pip install transformers torch
2.2 加载 Llama 3 并冻结参数

以下代码展示了如何加载 Llama 3 模型,冻结其参数,并使用 detach() 获取冻结输出。

python 复制代码
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

# 加载预训练的 Llama 3 模型和分词器
model_name = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# 冻结模型的所有参数
for param in model.parameters():
    param.requires_grad = False

# 输入样本文本
input_text = "Hugging Face's Llama 3 is amazing for transfer learning!"
inputs = tokenizer(input_text, return_tensors="pt")

# 使用 torch.no_grad() 和 detach 冻结模型的输出
with torch.no_grad():
    frozen_output = model(**inputs).logits.detach()

# 假设我们添加一个新的操作
# 比如,添加一个偏置
custom_bias = torch.ones_like(frozen_output)
output_with_bias = frozen_output + custom_bias

# 打印结果
print("冻结后的输出(添加偏置):", output_with_bias)

3. 代码解析

3.1 冻结参数

在代码中,以下部分将模型的所有参数设置为不可训练:

python 复制代码
for param in model.parameters():
    param.requires_grad = False

这样,参数就不会在反向传播中更新,且显存占用减少。

3.2 使用 detach()

以下代码通过 detach() 分离计算图:

python 复制代码
frozen_output = model(**inputs).logits.detach()
  • 作用 :断开梯度传播路径,使得 frozen_output 不与原始计算图关联。
  • 优点 :相比于 no_grad()detach() 仅作用于特定张量,灵活性更强。
3.3 添加自定义计算

在冻结模型的输出后,我们可以继续对输出执行新的操作。例如:

python 复制代码
custom_bias = torch.ones_like(frozen_output)
output_with_bias = frozen_output + custom_bias

这些操作不会影响原始模型的参数或计算图。


4. 为什么选择 detach

4.1 detachno_grad 的区别
功能 detach torch.no_grad()
作用范围 对单个张量生效 作用于整个上下文环境
是否保留计算图 保留原始计算图,但分离当前张量的梯度路径 不会记录任何计算图
适用场景 需要冻结部分张量的梯度传播 推理阶段或临时禁用梯度计算
4.2 detach 的优势

在冻结模型时,我们通常需要对部分输出张量进行自定义计算。使用 detach(),既可以保留原始计算图(用于后续梯度计算),又可以灵活操作冻结的部分张量。


5. 输出示例

运行上述代码时,打印的输出如下(值为随机生成):

c 复制代码
冻结后的输出(添加偏置): tensor([[...]])

frozen_output 是从 Llama 3 模型中提取的冻结输出,并在其基础上添加了自定义偏置。


6. 注意事项

  1. 显存管理

    冻结参数可以显著减少显存使用,但 detach() 只分离了梯度,不会减少前向传播的内存开销。如果需要更高效的推理,应配合 torch.no_grad()

  2. 冻结部分参数

    如果只想冻结部分参数,可以有选择地设置 requires_grad。例如:

    python 复制代码
    for name, param in model.named_parameters():
        if "layer1" in name:
            param.requires_grad = False
  3. 保存与加载

    冻结参数后,保存模型时应注意区分冻结的部分和可训练的部分。


总结

通过本文的介绍,你应该对 PyTorch 中 detachno_grad 的使用有了更深的理解。我们以 Hugging Face 的 Llama 3 为例,演示了如何冻结参数、分离计算图,并在冻结输出的基础上进行自定义计算。这种方法在迁移学习和模型优化中非常实用,是深度学习开发者必须掌握的技能之一。

后记

2024年12月13日10点38分于上海,在GPT4o大模型辅助下完成。

相关推荐
人机与认知实验室32 分钟前
生物神经网络与人工神经网络都有自组织临界
人工智能·深度学习·神经网络·机器学习
微臣愚钝1 小时前
【实验16】基于双向LSTM模型完成文本分类任务
人工智能·rnn·lstm
Funny_AI_LAB1 小时前
超越DFINE最新目标检测SOTA模型DEIM
人工智能·目标检测·计算机视觉·目标跟踪
小众AI1 小时前
supervision - 好用的计算机视觉 AI 工具库
人工智能·计算机视觉
WeeJot嵌入式1 小时前
深度学习中的多通道卷积与偏置过程详解
人工智能·深度学习
独泪了无痕2 小时前
【IntelliJ IDEA 集成工具】TalkX - AI编程助手
人工智能·个人开发·intellij idea
z千鑫2 小时前
【人工智能】ChatGPT 4的潜力:AI文案、绘画、视频与GPTs平台详解
人工智能·chatgpt·音视频
小熊科研路(同名GZH)2 小时前
【电力负荷预测实例】采用新英格兰2024年最新电力负荷数据的BPNN神经网络电力负荷预测模型
人工智能·神经网络·机器学习
安全方案2 小时前
免费下载 | 2024算网融合技术与产业白皮书
人工智能
星夜Zn2 小时前
斯坦福大学发布最新AI形势报告(2024)第七章:Policy and Governance
论文阅读·人工智能·形势报告