突破内存瓶颈,使用 GaLore 一张4090消费级显卡也能预训练LLaMA-7B

对于大模型训练此前有不少降低内存的方法,如梯度检查点、混合精度训练(BF16、FP16)等,本文将介绍前几天刚出的内存高效训练方法 GaLore,看一下其实测效果。

另外,我撰写的大模型相关的博客及配套代码 均整理放置在Github:llm-action,有需要的朋友自取。

GaLore 核心思想

梯度低秩投影(GaLore)是一种全量参数学习的训练策略,但比常见的低秩自适应方法(例如:LoRA)更节省内存。其关键思想是利用权重矩阵 W 的梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> G ∈ R m × n G \in \mathbb{R}^{m\times n} </math>G∈Rm×n 缓慢变化的低秩结构,而不是试图将权重矩阵本身近似为低秩。

作为一种梯度投影方法,GaLore 与优化器的选择无关,只需两行代码即可轻松插入现有优化器,如下面的算法所示。GaLore目前实现了GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor 三种优化器。

准备工作

代码下载

bash 复制代码
git clone git@github.com:jiaweizzhao/GaLore.git
cd GaLore
git checkout a6bc1650
  • configs 目录下存放各种模型配置
  • scripts/benchmark_c4 目录下存放不同模型大小的启动脚本

环境安装

基础环境配置如下:

  • 操作系统: CentOS 7
  • CPUs: 单个节点具有 1TB 内存的 Intel CPU,物理CPU个数为64,每颗CPU核数为16
  • GPUs: 8 卡 A800 80GB GPUs
  • Python: 3.10 (需要先升级OpenSSL到1.1.1t版本(点击下载OpenSSL),然后再编译安装Python),点击下载Python
  • NVIDIA驱动程序版本: 515.125.06,根据不同型号选择不同的驱动程序,点击下载
  • CUDA工具包: 11.8,点击下载

安装依赖包。

pip install -r requirements.txt

其中,requirements.txt 文件为:

ini 复制代码
torch==2.1.0
transformers==4.31.0
tokenizers
datasets==2.14.6
peft
wandb
loguru
nvitop
lion-pytorch
matplotlib
bitsandbytes
scipy
scikit-learn
evaluate

注意:Pytorch 需确保2.1.0以上,不然会报错。

wandb 启用离线模式

启用离线模式后,wandb 将不会上传数据,但仍然会记录实验过程中的数据和结果。

bash 复制代码
wandb  offline
# W&B offline. Running your script from this directory will only write metadata locally. Use wandb disabled to completely turn off W&B.

数据集准备

本文使用 C4 数据集进行训练,C4 数据集是由 Google 提供的一个大型预训练数据集,用于训练语言模型。C4 数据集包含了数十亿个网页的内容,经过处理和清洗后形成了一个适合用于训练大型语言模型的数据集。这个数据集可以用于训练各种自然语言处理任务,如文本生成、文本分类、语言建模等。

由于数据集太大,这里只下载了一个文件大约356317条数据。

下载 T5 tokenizer

因为我们是从头开始预训练,所以使用哪个 tokenizer 并不重要,T5 tokenizer是在C4上训练的,而我们也在C4上训练,因此,它是一个不错的选择。

模型训练

为了不影响阅读体验,详细的代码放置在GitHub:llm-action 项目 galore 目录下 torchrun_main.py 文件。

GaLore 初体验,预训练 LLaMA-60M

首先,在 C4 数据集 上训练 60m 模型。

执行脚本:

css 复制代码
# LLaMA-60M, GaLore-Adam, 1 A100, 1 Node
CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node 1 --master_port=24033 torchrun_main.py \
    --model_config configs/llama_60m.json \
    --lr 0.01 \
    --galore_scale 0.25 \
    --rank 128 \
    --update_proj_gap 200 \
    --batch_size 256 \
    --total_batch_size 512 \
    --num_training_steps 10000 \
    --warmup_steps 1000 \
    --weight_decay 0 \
    --dtype bfloat16 \
    --eval_every 1000 \
    --optimizer galore_adamw 

运行过程:

less 复制代码
master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.
===================================BUG REPORT===================================
Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
================================================================================
bin /usr/local/conda/envs/torch201-venv/lib/python3.10/site-packages/bitsandbytes-0.39.1-py3.10.egg/bitsandbytes/libbitsandby                                                               tes_cuda118.so
/usr/local/conda/envs/torch201-venv/lib/python3.10/site-packages/bitsandbytes-0.39.1-py3.10.egg/bitsandbytes/cuda_setup/main.                                                               py:149: UserWarning: /usr/local/conda/envs/torch201-venv did not contain ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.                                                               12.0'] as expected! Searching further paths...
  warn(msg)
/usr/local/conda/envs/torch201-venv/lib/python3.10/site-packages/bitsandbytes-0.39.1-py3.10.egg/bitsandbytes/cuda_setup/main.                                                               py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/usr/l                                                               ocal/nvidia/lib64'), PosixPath('/usr/local/cuda/extras/CUPTI/lib64'), PosixPath('/opt/rh/devtoolset-9/root/usr/lib/dyninst'),                                                                PosixPath('/usr/local/nvidia/lib')}
  warn(msg)
CUDA SETUP: CUDA runtime path found: /usr/local/cuda-11.8/lib64/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 8.0
CUDA SETUP: Detected CUDA version 118
CUDA SETUP: Loading binary /usr/local/conda/envs/torch201-venv/lib/python3.10/site-packages/bitsandbytes-0.39.1-py3.10.egg/bi                                                               tsandbytes/libbitsandbytes_cuda118.so...
[2024-03-11 10:07:17,776] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Starting script
2024-03-11 10:07:18.599 | INFO     | __main__:main:145 - Global rank 0, local rank 0, device: 0
2024-03-11 10:07:18.600 | INFO     | __main__:main:149 - Process group initialized
wandb: Tracking run with wandb version 0.16.4
wandb: W&B syncing is set to `offline` in this directory.
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
2024-03-11 10:07:19.597 | INFO     | __main__:main:168 - Using dist with rank 0 (only rank 0 will log)
2024-03-11 10:07:19.598 | INFO     | __main__:main:169 - ****************************************
2024-03-11 10:07:19.598 | INFO     | __main__:main:170 - Starting training with the arguments
2024-03-11 10:07:19.598 | INFO     | __main__:main:172 - model_config                   configs/llama_60m.json
2024-03-11 10:07:19.598 | INFO     | __main__:main:172 - use_hf_model                   False
2024-03-11 10:07:19.598 | INFO     | __main__:main:172 - continue_from                  None
2024-03-11 10:07:19.598 | INFO     | __main__:main:172 - batch_size                     256
2024-03-11 10:07:19.598 | INFO     | __main__:main:172 - gradient_accumulation          2
2024-03-11 10:07:19.599 | INFO     | __main__:main:172 - total_batch_size               512
2024-03-11 10:07:19.599 | INFO     | __main__:main:172 - max_length                     256
2024-03-11 10:07:19.599 | INFO     | __main__:main:172 - optimizer                      galore_adamw
2024-03-11 10:07:19.599 | INFO     | __main__:main:172 - lr                             0.01
2024-03-11 10:07:19.599 | INFO     | __main__:main:172 - scheduler                      cosine
2024-03-11 10:07:19.599 | INFO     | __main__:main:172 - min_lr_ratio                   0.1
2024-03-11 10:07:19.599 | INFO     | __main__:main:172 - activation_checkpointing       False
2024-03-11 10:07:19.599 | INFO     | __main__:main:172 - weight_decay                   0.0
2024-03-11 10:07:19.599 | INFO     | __main__:main:172 - warmup_steps                   1000
2024-03-11 10:07:19.600 | INFO     | __main__:main:172 - eval_every                     1000
2024-03-11 10:07:19.600 | INFO     | __main__:main:172 - num_training_steps             10000
2024-03-11 10:07:19.600 | INFO     | __main__:main:172 - max_train_tokens               None
2024-03-11 10:07:19.600 | INFO     | __main__:main:172 - save_every                     10000
2024-03-11 10:07:19.600 | INFO     | __main__:main:172 - save_dir                       checkpoints/llama_60m-2024-03-11-10-0                                                               7-18
2024-03-11 10:07:19.600 | INFO     | __main__:main:172 - tags                           None
2024-03-11 10:07:19.600 | INFO     | __main__:main:172 - dtype                          bfloat16
2024-03-11 10:07:19.600 | INFO     | __main__:main:172 - workers                        8
2024-03-11 10:07:19.600 | INFO     | __main__:main:172 - seed                           0
2024-03-11 10:07:19.601 | INFO     | __main__:main:172 - name                           test
2024-03-11 10:07:19.601 | INFO     | __main__:main:172 - grad_clipping                  0.0
2024-03-11 10:07:19.601 | INFO     | __main__:main:172 - beta1                          0.0
2024-03-11 10:07:19.601 | INFO     | __main__:main:172 - rank                           128
2024-03-11 10:07:19.601 | INFO     | __main__:main:172 - update_proj_gap                200
2024-03-11 10:07:19.601 | INFO     | __main__:main:172 - galore_scale                   0.25
2024-03-11 10:07:19.601 | INFO     | __main__:main:172 - proj_type                      std
2024-03-11 10:07:19.601 | INFO     | __main__:main:172 - single_gpu                     False
2024-03-11 10:07:19.601 | INFO     | __main__:main:173 - ****************************************
Downloading data files: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 3708.49it/s]
Extracting data files: 100%|██████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 919.60it/s]
Generating train split: 356317 examples [00:02, 125738.07 examples/s]
/usr/local/conda/envs/torch201-venv/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been supe                                                               rseded by promote_options='default'.
  table = cls._concat_blocks(blocks, axis=0)
DatasetDict({
    train: Dataset({
        features: ['text', 'timestamp', 'url'],
        num_rows: 356317
    })
})
2024-03-11 10:08:42.623 | INFO     | __main__:main:182 - Shuffling data with seed 42
wandb: WARNING Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb                                                               .save("/mnt/folder/file.h5", base_path="/mnt")
Update steps:   0%|                                   | 0/10000 [00:00<?, ?it/s]
enable GaLore for weights in module:  model.layers.0.self_attn.q_proj         
enable GaLore for weights in module:  model.layers.0.self_attn.k_proj
enable GaLore for weights in module:  model.layers.0.self_attn.v_proj
enable GaLore for weights in module:  model.layers.0.self_attn.o_proj
enable GaLore for weights in module:  model.layers.0.mlp.gate_proj
enable GaLore for weights in module:  model.layers.0.mlp.down_proj
enable GaLore for weights in module:  model.layers.0.mlp.up_proj
...
enable GaLore for weights in module:  model.layers.7.self_attn.q_proj
enable GaLore for weights in module:  model.layers.7.self_attn.k_proj
enable GaLore for weights in module:  model.layers.7.self_attn.v_proj
enable GaLore for weights in module:  model.layers.7.self_attn.o_proj
enable GaLore for weights in module:  model.layers.7.mlp.gate_proj
enable GaLore for weights in module:  model.layers.7.mlp.down_proj
enable GaLore for weights in module:  model.layers.7.mlp.up_proj
2024-03-11 10:08:45.614 | INFO     | __main__:main:294 -
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 512, padding_idx=31999)
    (layers): ModuleList(
      (0-7): 8 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=512, out_features=512, bias=False)
          (k_proj): Linear(in_features=512, out_features=512, bias=False)
          (v_proj): Linear(in_features=512, out_features=512, bias=False)
          (o_proj): Linear(in_features=512, out_features=512, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=512, out_features=1376, bias=False)
          (down_proj): Linear(in_features=1376, out_features=512, bias=False)
          (up_proj): Linear(in_features=512, out_features=1376, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=512, out_features=32000, bias=False)
)

2024-03-11 10:08:45.614 | INFO     | __main__:main:295 - Total params: 58.07M
2024-03-11 10:08:45.615 | INFO     | __main__:main:296 - Trainable params: 58.07M
2024-03-11 10:08:45.615 | INFO     | __main__:main:297 - Total params with GaLore enabled: 25.30M
2024-03-11 10:08:45.615 | INFO     | __main__:main:298 - Saving model to checkpoints/llama_60m-2024-03-11-10-07-18 every 1000                                                               0 update steps
/home/guodong.li/GaLore/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be remov                                                               ed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to dis                                                               able this warning
  warnings.warn(
Update steps:   7%|█▌                     | 696/10000 [06:19<1:21:46,  1.90it/s]2024-03-11 10:15:05.670 | INFO     | __main__                                                               :main:523 - Training finished
Update steps:   7%|█▌                     | 696/10000 [06:20<1:24:40,  1.83it/s]
2024-03-11 10:15:05.672 | INFO     | __main__:main:528 - Saving model and optimizer to checkpoints/llama_60m-2024-03-11-10-07                                                               -18/model_696, update step 696
2024-03-11 10:15:05.952 | INFO     | __main__:main:554 - Running final evaluation
/usr/local/conda/envs/torch201-venv/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been supe                                                               rseded by promote_options='default'.
  table = cls._concat_blocks(blocks, axis=0)
DatasetDict({
    train: Dataset({
        features: ['text', 'timestamp', 'url'],
        num_rows: 356317
    })
})
2024-03-11 10:16:26.307 | INFO     | __main__:evaluate_model:93 - Loaded validation dataset in 80.13 seconds
Map: 100%|██████████████████████████████████████████████████████████████████| 356317/356317 [02:10<00:00, 2731.25 examples/s]
2024-03-11 10:18:36.788 | INFO     | __main__:evaluate_model:109 - Eval set prepared in 210.61 seconds
2024-03-11 10:19:07.928 | INFO     | __main__:main:571 - Final eval loss: 4.445544719696045
2024-03-11 10:19:07.929 | INFO     | __main__:main:573 - Script finished successfully
Rank 0 finished successfully
wandb:
wandb: Run history:
wandb:     final_eval_loss ▁
wandb:   final_eval_tokens ▁
wandb:                loss █▇▄▄▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:                  lr ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:  throughput_batches ▄▄▅▆▃▁▅▃▂▃▆▂▂▃▄▂▃▂▃▃█▃▃▃▃▂▂▃▃▃▃▂▁▃▃▂▂▂▂▆
wandb: throughput_examples ▄▄▅▆▃▁▅▃▂▃▆▂▂▃▄▂▃▂▃▃█▃▃▃▃▂▂▃▃▃▃▂▁▃▃▂▂▂▂▆
wandb:   throughput_tokens ▅██▄▅▆▆▅▆▄▇▅▇▆▃▇▆▄▄▅█▄▅▄▅▄▆▆▅▅▅▇▁▇▆▇▇▆▄▇
wandb:         tokens_seen ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:         update_step ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:
wandb: Run summary:
wandb:     final_eval_loss 4.44554
wandb:   final_eval_tokens 10032866
wandb:                loss 4.53125
wandb:                  lr 0.00696
wandb:  throughput_batches 3.8724
wandb: throughput_examples 991.33511
wandb:   throughput_tokens 192032.45359
wandb:         tokens_seen 69593341
wandb:         update_step 696
wandb:
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /home/guodong.li/GaLore/wandb/offline-run-20240311_100719-qlkq3zcd
wandb: Find logs at: ./wandb/offline-run-20240311_100719-qlkq3zcd/logs

单张 4090 消费级显卡预训练 LLaMA-7B

接下来,使用单个 GPU(例如:NVIDIA RTX 4090)训练 7B 模型,您所需要做的就是指定 --optimizer=galore_adamw8bit_per_layer ,这会启用 GaLoreAdamW8bit 并进行每层权重更新。通过激活(梯度)检查点(activation checkpointing),您可以将在 NVIDIA RTX 4090 上测试的批量大小保持为 16。

我懒得去找机器了,就在 A800 上面勉强测试一下吧。

执行命令:

css 复制代码
CUDA_VISIBLE_DEVICES=7 torchrun --standalone --nproc_per_node 1 torchrun_main.py \
    --model_config configs/llama_7b.json \
    --lr 0.005 \
    --galore_scale 0.25 \
    --rank 1024 \
    --update_proj_gap 500 \
    --batch_size 16 \
    --total_batch_size 512 \
    --activation_checkpointing \
    --num_training_steps 150000 \
    --warmup_steps 15000 \
    --weight_decay 0 \
    --grad_clipping 1.0 \
    --dtype bfloat16 \
    --eval_every 1000 \
    --single_gpu \
    --optimizer galore_adamw8bit_per_layer

训练时长:

bash 复制代码
Update steps:   0%|                    | 31/150000 [37:14<1842:28:16, 44.23s/it]

显存占用:

diff 复制代码
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    7   N/A  N/A     13343      C   .../torch210-venv/bin/python    24192MiB |
+-----------------------------------------------------------------------------+

目前,使用每层权重更新 技术仅支持单 GPU 训练 ( --single_gpu ),而不能使用 nn.parallel.DistributedDataParallel

接下来指定优化器参数--optimizer=galore_adamw,同时,启动梯度检查点进行多机多卡训练。

css 复制代码
CUDA_VISIBLE_DEVICES=4,7  torchrun --standalone --nproc_per_node 2 --master_port=24033 torchrun_main.py \
    --model_config configs/llama_7b.json \
    --lr 0.01 \
    --galore_scale 0.25 \
    --rank 1024 \
    --update_proj_gap 200 \
    --batch_size 2 \
    --total_batch_size 8 \
    --activation_checkpointing \
    --num_training_steps 10000 \
    --warmup_steps 1000 \
    --weight_decay 0 \
    --dtype bfloat16 \
    --eval_every 1000 \
    --optimizer galore_adamw

训练时长:

bash 复制代码
Update steps:   1%|▏                       | 64/10000 [10:31<4:50:05,  1.75s/it]

显存占用:

bash 复制代码
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    4   N/A  N/A     33406      C   .../torch210-venv/bin/python    76368MiB |
|    7   N/A  N/A     33407      C   .../torch210-venv/bin/python    76960MiB |
+-----------------------------------------------------------------------------+

微调 RoBERTa

运行脚本:

css 复制代码
python run_glue.py \
    --model_name_or_path roberta-base \
    --task_name mrpc \
    --enable_galore \
    --lora_all_modules \
    --max_length 512 \
    --seed=1234 \
    --lora_r 4 \
    --galore_scale 4 \
    --per_device_train_batch_size 16 \
    --update_proj_gap 500 \
    --learning_rate 3e-5 \
    --num_train_epochs 30 \
    --output_dir results/ft/roberta_base/mrpc

结语

本文简要介绍了GaLore,同时演示了使用GaLore进行预训练。虽然 GaLore 使用每层权重更新 技术,同时,使用激活检查点,能够在单张消费级显卡上面预训练LLaMA-7B,但是使用每层权重更新技术仅支持单 GPU 训练,因此,仅是理论上可行,如果大家数据量不大,想使用其进行LLaMA-7B全量微调还是不错的,给只有3090/4090消费级显卡的朋友,除了使用ZeRO3 Offload之外,又多了一个选择。大家可以再观望一下 GaLore 后续的一些工作。

码字不易,如果觉得我的文章能够能够给您带来帮助,期待您的点赞收藏加关注~~

相关推荐
bastgia2 天前
Tokenformer: 下一代Transformer架构
人工智能·机器学习·llm
新智元2 天前
李飞飞谢赛宁:多模态 LLM「空间大脑」觉醒,惊现世界模型雏形!
人工智能·llm
RWKV元始智能2 天前
RWKV-7:极先进的大模型架构,长文本能力极强
人工智能·llm
zaim13 天前
计算机的错误计算(一百八十七)
人工智能·ai·大模型·llm·错误·正弦/sin·误差/error
张拭心3 天前
Google 提供的 Android 端上大模型组件:MediaPipe LLM 介绍
android·人工智能·llm
带电的小王3 天前
whisper.cpp: Android端测试 -- Android端手机部署音频大模型
android·智能手机·llm·whisper·音频大模型·whisper.cpp
带电的小王3 天前
whisper.cpp: PC端测试 -- 电脑端部署音频大模型
llm·whisper·音视频·音频大模型
Ambition_LAO3 天前
LLaMA-Factory QuickStart 流程详解
llm·llama
宇梵文书C4 天前
在CFFF云平台使用llama-factory部署及微调Qwen2.5-7B-Instruct
llm·llama·cfff
zaim14 天前
计算机的错误计算(一百八十六)
人工智能·python·ai·大模型·llm·误差·decimal