PyTorch 中detach 和no_grad的应用:以 Llama 3 冻结参数为例

PyTorch 中 detachno_grad 的应用:以 Llama 3 冻结参数为例

在深度学习中,特别是处理大型预训练模型(如 Hugging Face 的 Llama 3)时,我们经常需要"冻结"某些层的参数,使其在训练中保持不变。这种操作通常用于迁移学习(Transfer Learning),以减少计算开销或保留预训练模型的知识。本文将通过一个实际代码示例,结合 detachno_grad 的使用,详细介绍如何冻结 Hugging Face 大模型的参数,并实现自定义计算。


1. 冻结参数的背景与方法

1.1 冻结参数的意义

冻结模型参数指的是在训练过程中固定一部分参数,使其不随梯度更新。这在以下场景非常常见:

  • 迁移学习:保留预训练模型的知识,仅训练新的下游任务层。
  • 减少计算负担:在大型模型中冻结底层特征提取层,避免重复计算。
1.2 冻结参数的方法

在 PyTorch 中,冻结参数通常有以下两种方式:

  1. 设置 requires_grad=False,让冻结的参数不参与梯度计算。
  2. 使用 torch.no_grad()detach() 分离计算图,防止某些操作对参数产生影响。

2. 示例代码:使用 detach 冻结 Llama 3 模型的参数

2.1 安装 Hugging Face 和依赖

确保安装最新的 transformers 库以支持 Llama 3:

bash 复制代码
pip install transformers torch
2.2 加载 Llama 3 并冻结参数

以下代码展示了如何加载 Llama 3 模型,冻结其参数,并使用 detach() 获取冻结输出。

python 复制代码
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

# 加载预训练的 Llama 3 模型和分词器
model_name = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# 冻结模型的所有参数
for param in model.parameters():
    param.requires_grad = False

# 输入样本文本
input_text = "Hugging Face's Llama 3 is amazing for transfer learning!"
inputs = tokenizer(input_text, return_tensors="pt")

# 使用 torch.no_grad() 和 detach 冻结模型的输出
with torch.no_grad():
    frozen_output = model(**inputs).logits.detach()

# 假设我们添加一个新的操作
# 比如,添加一个偏置
custom_bias = torch.ones_like(frozen_output)
output_with_bias = frozen_output + custom_bias

# 打印结果
print("冻结后的输出(添加偏置):", output_with_bias)

3. 代码解析

3.1 冻结参数

在代码中,以下部分将模型的所有参数设置为不可训练:

python 复制代码
for param in model.parameters():
    param.requires_grad = False

这样,参数就不会在反向传播中更新,且显存占用减少。

3.2 使用 detach()

以下代码通过 detach() 分离计算图:

python 复制代码
frozen_output = model(**inputs).logits.detach()
  • 作用 :断开梯度传播路径,使得 frozen_output 不与原始计算图关联。
  • 优点 :相比于 no_grad()detach() 仅作用于特定张量,灵活性更强。
3.3 添加自定义计算

在冻结模型的输出后,我们可以继续对输出执行新的操作。例如:

python 复制代码
custom_bias = torch.ones_like(frozen_output)
output_with_bias = frozen_output + custom_bias

这些操作不会影响原始模型的参数或计算图。


4. 为什么选择 detach

4.1 detachno_grad 的区别
功能 detach torch.no_grad()
作用范围 对单个张量生效 作用于整个上下文环境
是否保留计算图 保留原始计算图,但分离当前张量的梯度路径 不会记录任何计算图
适用场景 需要冻结部分张量的梯度传播 推理阶段或临时禁用梯度计算
4.2 detach 的优势

在冻结模型时,我们通常需要对部分输出张量进行自定义计算。使用 detach(),既可以保留原始计算图(用于后续梯度计算),又可以灵活操作冻结的部分张量。


5. 输出示例

运行上述代码时,打印的输出如下(值为随机生成):

c 复制代码
冻结后的输出(添加偏置): tensor([[...]])

frozen_output 是从 Llama 3 模型中提取的冻结输出,并在其基础上添加了自定义偏置。


6. 注意事项

  1. 显存管理

    冻结参数可以显著减少显存使用,但 detach() 只分离了梯度,不会减少前向传播的内存开销。如果需要更高效的推理,应配合 torch.no_grad()

  2. 冻结部分参数

    如果只想冻结部分参数,可以有选择地设置 requires_grad。例如:

    python 复制代码
    for name, param in model.named_parameters():
        if "layer1" in name:
            param.requires_grad = False
  3. 保存与加载

    冻结参数后,保存模型时应注意区分冻结的部分和可训练的部分。


总结

通过本文的介绍,你应该对 PyTorch 中 detachno_grad 的使用有了更深的理解。我们以 Hugging Face 的 Llama 3 为例,演示了如何冻结参数、分离计算图,并在冻结输出的基础上进行自定义计算。这种方法在迁移学习和模型优化中非常实用,是深度学习开发者必须掌握的技能之一。

后记

2024年12月13日10点38分于上海,在GPT4o大模型辅助下完成。

相关推荐
Hoper.J16 分钟前
GPT 系列论文精读:从 GPT-1 到 GPT-4
人工智能·gpt·深度学习·ai·自然语言处理·llm
pzx_00118 分钟前
【集成学习】Stacking算法详解
人工智能·算法·leetcode·机器学习·职场和发展·集成学习
余胜辉19 分钟前
期望最大化算法:机器学习中的隐变量与参数估计的艺术
人工智能·机器学习·高斯混合模型·隐马尔可夫模型·期望最大化算法·em 算法
睿深渊39 分钟前
【2025最新】Poe保姆级订阅指南,Poe订阅看这一篇就够了!最方便使用各类AI!
人工智能
Eric.Lee20211 小时前
数据集-目标检测系列- 电话 测数据集 call_phone >> DataBall
人工智能·计算机视觉
脚踏实地的大梦想家1 小时前
【自然语言处理】P1 自然语言处理概述
人工智能·自然语言处理
香菜烤面包1 小时前
大语言模型LLM推理框架简单总结
人工智能·语言模型·自然语言处理
XianxinMao1 小时前
《语言模型的新型推理范式:基于链式思考与强化学习的突破》
人工智能·语言模型
不二青衣1 小时前
使用gtsam添加OrientedPlane3Factor平面约束因子
人工智能·算法·平面