PyTorch 中detach 和no_grad的应用:以 Llama 3 冻结参数为例

PyTorch 中 detachno_grad 的应用:以 Llama 3 冻结参数为例

在深度学习中,特别是处理大型预训练模型(如 Hugging Face 的 Llama 3)时,我们经常需要"冻结"某些层的参数,使其在训练中保持不变。这种操作通常用于迁移学习(Transfer Learning),以减少计算开销或保留预训练模型的知识。本文将通过一个实际代码示例,结合 detachno_grad 的使用,详细介绍如何冻结 Hugging Face 大模型的参数,并实现自定义计算。


1. 冻结参数的背景与方法

1.1 冻结参数的意义

冻结模型参数指的是在训练过程中固定一部分参数,使其不随梯度更新。这在以下场景非常常见:

  • 迁移学习:保留预训练模型的知识,仅训练新的下游任务层。
  • 减少计算负担:在大型模型中冻结底层特征提取层,避免重复计算。
1.2 冻结参数的方法

在 PyTorch 中,冻结参数通常有以下两种方式:

  1. 设置 requires_grad=False,让冻结的参数不参与梯度计算。
  2. 使用 torch.no_grad()detach() 分离计算图,防止某些操作对参数产生影响。

2. 示例代码:使用 detach 冻结 Llama 3 模型的参数

2.1 安装 Hugging Face 和依赖

确保安装最新的 transformers 库以支持 Llama 3:

bash 复制代码
pip install transformers torch
2.2 加载 Llama 3 并冻结参数

以下代码展示了如何加载 Llama 3 模型,冻结其参数,并使用 detach() 获取冻结输出。

python 复制代码
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

# 加载预训练的 Llama 3 模型和分词器
model_name = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# 冻结模型的所有参数
for param in model.parameters():
    param.requires_grad = False

# 输入样本文本
input_text = "Hugging Face's Llama 3 is amazing for transfer learning!"
inputs = tokenizer(input_text, return_tensors="pt")

# 使用 torch.no_grad() 和 detach 冻结模型的输出
with torch.no_grad():
    frozen_output = model(**inputs).logits.detach()

# 假设我们添加一个新的操作
# 比如,添加一个偏置
custom_bias = torch.ones_like(frozen_output)
output_with_bias = frozen_output + custom_bias

# 打印结果
print("冻结后的输出(添加偏置):", output_with_bias)

3. 代码解析

3.1 冻结参数

在代码中,以下部分将模型的所有参数设置为不可训练:

python 复制代码
for param in model.parameters():
    param.requires_grad = False

这样,参数就不会在反向传播中更新,且显存占用减少。

3.2 使用 detach()

以下代码通过 detach() 分离计算图:

python 复制代码
frozen_output = model(**inputs).logits.detach()
  • 作用 :断开梯度传播路径,使得 frozen_output 不与原始计算图关联。
  • 优点 :相比于 no_grad()detach() 仅作用于特定张量,灵活性更强。
3.3 添加自定义计算

在冻结模型的输出后,我们可以继续对输出执行新的操作。例如:

python 复制代码
custom_bias = torch.ones_like(frozen_output)
output_with_bias = frozen_output + custom_bias

这些操作不会影响原始模型的参数或计算图。


4. 为什么选择 detach

4.1 detachno_grad 的区别
功能 detach torch.no_grad()
作用范围 对单个张量生效 作用于整个上下文环境
是否保留计算图 保留原始计算图,但分离当前张量的梯度路径 不会记录任何计算图
适用场景 需要冻结部分张量的梯度传播 推理阶段或临时禁用梯度计算
4.2 detach 的优势

在冻结模型时,我们通常需要对部分输出张量进行自定义计算。使用 detach(),既可以保留原始计算图(用于后续梯度计算),又可以灵活操作冻结的部分张量。


5. 输出示例

运行上述代码时,打印的输出如下(值为随机生成):

c 复制代码
冻结后的输出(添加偏置): tensor([[...]])

frozen_output 是从 Llama 3 模型中提取的冻结输出,并在其基础上添加了自定义偏置。


6. 注意事项

  1. 显存管理

    冻结参数可以显著减少显存使用,但 detach() 只分离了梯度,不会减少前向传播的内存开销。如果需要更高效的推理,应配合 torch.no_grad()

  2. 冻结部分参数

    如果只想冻结部分参数,可以有选择地设置 requires_grad。例如:

    python 复制代码
    for name, param in model.named_parameters():
        if "layer1" in name:
            param.requires_grad = False
  3. 保存与加载

    冻结参数后,保存模型时应注意区分冻结的部分和可训练的部分。


总结

通过本文的介绍,你应该对 PyTorch 中 detachno_grad 的使用有了更深的理解。我们以 Hugging Face 的 Llama 3 为例,演示了如何冻结参数、分离计算图,并在冻结输出的基础上进行自定义计算。这种方法在迁移学习和模型优化中非常实用,是深度学习开发者必须掌握的技能之一。

后记

2024年12月13日10点38分于上海,在GPT4o大模型辅助下完成。

相关推荐
美狐美颜sdk2 小时前
跨平台直播美颜SDK集成实录:Android/iOS如何适配贴纸功能
android·人工智能·ios·架构·音视频·美颜sdk·第三方美颜sdk
DeepSeek-大模型系统教程2 小时前
推荐 7 个本周 yyds 的 GitHub 项目。
人工智能·ai·语言模型·大模型·github·ai大模型·大模型学习
郭庆汝2 小时前
pytorch、torchvision与python版本对应关系
人工智能·pytorch·python
小雷FansUnion4 小时前
深入理解MCP架构:智能服务编排、上下文管理与动态路由实战
人工智能·架构·大模型·mcp
资讯分享周4 小时前
扣子空间PPT生产力升级:AI智能生成与多模态创作新时代
人工智能·powerpoint
叶子爱分享5 小时前
计算机视觉与图像处理的关系
图像处理·人工智能·计算机视觉
鱼摆摆拜拜5 小时前
第 3 章:神经网络如何学习
人工智能·神经网络·学习
一只鹿鹿鹿5 小时前
信息化项目验收,软件工程评审和检查表单
大数据·人工智能·后端·智慧城市·软件工程
张较瘦_6 小时前
[论文阅读] 人工智能 | 深度学习系统崩溃恢复新方案:DaiFu框架的原位修复技术
论文阅读·人工智能·深度学习
cver1236 小时前
野生动物检测数据集介绍-5,138张图片 野生动物保护监测 智能狩猎相机系统 生态研究与调查
人工智能·pytorch·深度学习·目标检测·计算机视觉·目标跟踪