机器学习/huggingface笔记:Transformer内存占用刨析 和高效训练

1 模型操作的解剖

  • Transformers架构包括3组主要的操作:
    • 张量收缩
      • 线性层和多头注意力的组件都进行批量矩阵-矩阵乘法。
      • 这些操作是训练transformer时计算最密集的部分。
    • 统计归一化
      • Softmax和层归一化比张量收缩的计算密集度要低一些
      • 涉及一个或多个归约操作
      • 其结果通过映射应用。
    • 元素级操作
      • 剩余的操作:偏差、dropout、激活和残差连接。
      • 这些是计算最不密集的操作

2 模型内存刨析

2.1 模型内存

2.1.1 根据模型名称中的数字推断模型大小

  • 默认情况下,Hugging Face 的类如 TextGenerationPipelineAutoModelForCausalLM 会以 float32 精度加载模型。
  • 这意味着每个参数需要 4 字节(32 位),
    • ------>一个具有 80 亿参数的"8B"模型将需要大约 32GB 的内存。
    • 这可能有些浪费!
    • ------>大多数现代语言模型都是在"bfloat16"精度下训练的,这种精度只需要每个参数 2 字节

2.2 训练模型内存

  • 训练模型使用了比将模型放在GPU上多得多的内存
    • ------>在训练过程中有许多组件使用GPU内存
      • 模型权重
      • 优化器状态
      • 梯度
      • 保存的前向激活用于梯度计算
      • 临时缓冲区
      • 特定功能的内存

2.1 模型权重

  • fp32训练:4字节 * 参数数量
  • 混合精度训练:6字节 * 参数数量(在内存中保留一个fp32和一个fp16的模型)

2.2 优化器状态:

  • 普通AdamW:8字节 * 参数数量(保持2个状态)
  • 8位AdamW优化器如bitsandbytes:2字节 * 参数数量
  • 带动量的SGD等优化器:4字节 * 参数数量(保持1个状态)

2.3 梯度:

  • fp32或混合精度训练:4字节 * 参数数量(梯度始终保持在fp32)

2.4 前向激活

  • 大小取决于许多因素,关键因素是序列长度、隐藏大小和批量大小。
  • 有输入和输出通过前向和后向函数传递,并为梯度计算保存前向激活。

3 高效训练总览

|----------------|--------|--------|
| 方法/工具 | 提高训练速度 | 优化内存使用 |
| 批量大小选择 | 是 | 是 |
| 梯度累积 | 否 | 是 |
| 梯度检查点 | 否 | 是 |
| 混合精度训练 | 是 | 否 |
| 优化器选择 | 是 | 是 |
| 数据预加载 | 是 | 否 |
| DeepSpeed Zero | 否 | 是 |
| torch.compile | 是 | 否 |
| 参数高效微调(PEFT) | 否 | 是 |

3.1 选择批量大小

|--------|--------|
| 提高训练速度 | 优化内存使用 |
| 是 | 是 |

  • 为了达到最佳性能,首先要确定适当的批量大小。
    • 推荐使用2^N大小的批量大小和输入/输出神经元数量。
    • 通常是8的倍数,但根据使用的硬件和模型的数据类型,这个数值可以更高。
  • Tensor Core要求根据数据类型和硬件定义乘数。
    • 例如,对于fp16数据类型,推荐使用8的倍数
    • 除非是A100 GPU,在这种情况下使用64的倍数。

3.2 梯度累积

|--------|--------|
| 提高训练速度 | 优化内存使用 |
| 否 | 是 |

  • 梯度累积方法旨在以较小的增量计算梯度,而不是一次性为整个批次计算

    • 这种方法涉及通过对模型进行前向和后向传播,并在此过程中累积梯度,以较小的批次迭代计算梯度
    • 一旦累积了足够数量的梯度,就执行模型的优化步骤
  • 通过使用梯度累积,可以增加有效的批量大小,超出GPU内存容量的限制

  • 然而,重要的是要注意,梯度累积引入的额外前向和后向传播可能会减慢训练过程

  • 通过添加gradient_accumulation_steps参数到TrainingArguments来启用梯度累积:

python 复制代码
training_args = TrainingArguments(per_device_train_batch_size=1, 
                                  gradient_accumulation_steps=4, 
                                  **default_args)


#------------------>有效批量大小变为4
  • 较高数量的梯度累积步骤可能导致更明显的训练减慢

    • 假设,不使用梯度累积时per_device_train_batch_size=4已达到GPU的限制
    • 如果希望以大小为64的批次进行训练,不要将per_device_train_batch_size设置为1并且gradient_accumulation_steps设置为64。相反,保持per_device_train_batch_size=4并设置gradient_accumulation_steps=16
      • ------>这样可以在更好地利用可用GPU资源的同时,获得相同的有效批量大小。
  • 或者,使用Accelerate完全控制训练循环

3.3 梯度检查点

|--------|--------|
| 提高训练速度 | 优化内存使用 |
| 否 | 是 |

  • 即使将批量大小设置为1并使用梯度累积,一些大型模型仍可能面临内存问题
    • 这是因为还有其他组件也需要内存存储。
  • 一种方法是:在向后传递过程中计算梯度时保存前向传递的所有激活
    • ------>会导致显著的内存开销
  • 另一种方法是:丢弃激活,并在向后传递时需要时重新计算它们
    • ------>会引入相当大的计算开销并减慢训练过程
  • 梯度检查点提供了这两种方法之间的折中方案
    • 在计算图中策略性地保存选定的激活
    • ------>只需要重新计算一小部分激活以获取梯度

要在Trainer中启用梯度检查点,请向TrainingArguments传递相应的标志:

python 复制代码
training_args = TrainingArguments(
    per_device_train_batch_size=1, 
    gradient_accumulation_steps=4, 
    gradient_checkpointing=True, 
    **default_args
)

3.4 混合精度训练

|--------|--------|
| 提高训练速度 | 优化内存使用 |
| 是 | 否 |

  • 混合精度训练是一种旨在通过使用较低精度的数值格式来优化模型训练的计算效率的技术。
  • 传统上,大多数模型使用32位浮点精度(fp32或float32)来表示和处理变量。
  • 然而,并非所有变量都需要这种高精度级别来实现准确的结果。
  • 通过将某些变量的精度降低到如16位浮点(fp16或float16)等较低的数值格式,我们可以加速计算。
  • 由于这种方法中的一些计算以半精度进行,而有些仍以全精度进行,因此这种方法被称为混合精度训练。

3.4.1 fp16

  • 尽管梯度也以半精度计算,但它们在优化步骤中转换回全精度,因此在这里不会节省内存
  • 虽然混合精度训练结果计算更快,但也可能导致使用更多GPU内存,尤其是在小批量大小的情况下。
  • 这是因为模型现在以16位和32位精度同时存在于GPU上(GPU上的原始模型大小的1.5倍)。

要启用混合精度训练,将fp16标志设置为True

python 复制代码
training_args = TrainingArguments(per_device_train_batch_size=4, 
                                  fp16=True, 
                                  **default_args)

3.5 优化器选择

|--------|--------|
| 提高训练速度 | 优化内存使用 |
| 是 | 是 |

  • 对于Transformer,最常用的优化器是Adam或AdamW(带权重衰减的Adam)。

  • Adam通过存储之前梯度的滚动平均来实现良好的收敛性

  • 然而,它会增加与模型参数数量相当的额外内存占用。

    • 例如,对于一个具有30亿参数的模型,如"google-t5/t5-3b":

      • 标准的AdamW优化器将需要24GB的GPU内存,因为它为每个参数使用8字节(8*3 => 24GB)

      • Adafactor优化器将需要超过12GB。它每个参数略多于4字节,所以是4*3再加上一些额外的。

      • 8位BNB量化优化器只会使用6GB(2*3),如果所有优化器状态都被量化。

  • ------>为了解决这个问题,可以使用其他优化器。

3.5.1 Adafactor

  • Adafactor不会为权重矩阵中的每个元素存储滚动平均,而是保留汇总信息(按行和按列的滚动平均的总和),显著减少了其占用空间。
  • 然而,与Adam相比,Adafactor在某些情况下可能收敛速度较慢。
python 复制代码
training_args = TrainingArguments(per_device_train_batch_size=4, 
                                  optim="adafactor", 
                                  **default_args)

3.5.2 8位Adam

  • 与Adafactor聚合优化器状态不同,8位Adam保留完整状态并对其进行量化。
  • 量化意味着它以较低的精度存储状态,并且只在优化时解量化。
  • 这类似于混合精度训练背后的思想。
python 复制代码
training_args = TrainingArguments(per_device_train_batch_size=4, 
                                  optim="adamw_bnb_8bit", 
                                  **default_args)

3.6 数据预加载

|--------|--------|
| 提高训练速度 | 优化内存使用 |
| 是 | 否 |

  • 为了达到极高的训练速度,能够以GPU能够处理的最大速度喂入数据是一个重要的要求。

  • 默认情况下,所有操作都发生在主进程中,这可能无法足够快速地从磁盘读取数据,从而创建瓶颈,导致GPU未充分利用

  • 配置以下参数以减少瓶颈:

    • DataLoader(pin_memory=True, ...) - 确保数据被预加载到CPU上的固定内存中,通常会导致从CPU到GPU内存的传输速度大大加快。

    • DataLoader(num_workers=4, ...) - 启动多个工作器以更快地预加载数据。

      • 在训练过程中,观察GPU使用情况;如果远低于100%,试验增加工作器的数量。

      • 当然,问题可能在其他地方,所以增加工作器数量并不一定能带来更好的性能。

3.7 参数高效微调(PEFT)

|--------|--------|
| 提高训练速度 | 优化内存使用 |
| 否 | 是 |

  • 在微调过程中冻结预训练模型参数,并在其上添加少量可训练参数(适配器)
  • 结果是优化器状态和梯度相关的内存大大减少
    • 例如,使用普通的AdamW,优化器状态的内存需求为:

      • fp32参数副本:4字节/参数

      • Momentum:4字节/参数

      • 方差:4字节/参数

    • 假设一个模型有70亿参数,并注入了2亿参数的低秩适配器。

      • 纯模型的优化器状态的内存需求将是12 * 7 = 84 GB(假设有70亿可训练参数)。

      • 添加Lora会稍微增加模型权重相关的内存,并显著减少优化器状态的内存需求至12 * 0.2 = 2.4 GB

参考内容:

Model training anatomy (huggingface.co)

相关推荐
Yawesh_best4 小时前
告别系统壁垒!WSL+cpolar 让跨平台开发效率翻倍
运维·服务器·数据库·笔记·web安全
Ccjf酷儿6 小时前
操作系统 蒋炎岩 3.硬件视角的操作系统
笔记
习习.y7 小时前
python笔记梳理以及一些题目整理
开发语言·笔记·python
在逃热干面7 小时前
(笔记)自定义 systemd 服务
笔记
DKPT9 小时前
ZGC和G1收集器相比哪个更好?
java·jvm·笔记·学习·spring
QT 小鲜肉10 小时前
【孙子兵法之上篇】001. 孙子兵法·计篇
笔记·读书·孙子兵法
星轨初途11 小时前
数据结构排序算法详解(5)——非比较函数:计数排序(鸽巢原理)及排序算法复杂度和稳定性分析
c语言·开发语言·数据结构·经验分享·笔记·算法·排序算法
QT 小鲜肉11 小时前
【孙子兵法之上篇】001. 孙子兵法·计篇深度解析与现代应用
笔记·读书·孙子兵法
love530love14 小时前
【笔记】ComfUI RIFEInterpolation 节点缺失问题(cupy CUDA 安装)解决方案
人工智能·windows·笔记·python·插件·comfyui
愚戏师14 小时前
MySQL 数据导出
数据库·笔记·mysql