对于大模型训练此前有不少降低内存的方法,如梯度检查点、混合精度训练(BF16、FP16)等,本文将介绍前几天刚出的内存高效训练方法 GaLore,看一下其实测效果。
另外,我撰写的大模型相关的博客及配套代码 均整理放置在Github:llm-action,有需要的朋友自取。
GaLore 核心思想
梯度低秩投影(GaLore)是一种全量参数学习的训练策略,但比常见的低秩自适应方法(例如:LoRA)更节省内存。其关键思想是利用权重矩阵 W 的梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> G ∈ R m × n G \in \mathbb{R}^{m\times n} </math>G∈Rm×n 缓慢变化的低秩结构,而不是试图将权重矩阵本身近似为低秩。
作为一种梯度投影方法,GaLore 与优化器的选择无关,只需两行代码即可轻松插入现有优化器,如下面的算法所示。GaLore目前实现了GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor 三种优化器。
准备工作
代码下载
bash
git clone git@github.com:jiaweizzhao/GaLore.git
cd GaLore
git checkout a6bc1650
- configs 目录下存放各种模型配置
- scripts/benchmark_c4 目录下存放不同模型大小的启动脚本
环境安装
基础环境配置如下:
- 操作系统: CentOS 7
- CPUs: 单个节点具有 1TB 内存的 Intel CPU,物理CPU个数为64,每颗CPU核数为16
- GPUs: 8 卡 A800 80GB GPUs
- Python: 3.10 (需要先升级OpenSSL到1.1.1t版本(点击下载OpenSSL),然后再编译安装Python),点击下载Python
- NVIDIA驱动程序版本: 515.125.06,根据不同型号选择不同的驱动程序,点击下载。
- CUDA工具包: 11.8,点击下载
安装依赖包。
pip install -r requirements.txt
其中,requirements.txt 文件为:
ini
torch==2.1.0
transformers==4.31.0
tokenizers
datasets==2.14.6
peft
wandb
loguru
nvitop
lion-pytorch
matplotlib
bitsandbytes
scipy
scikit-learn
evaluate
注意:Pytorch 需确保2.1.0以上,不然会报错。
wandb 启用离线模式
启用离线模式后,wandb
将不会上传数据,但仍然会记录实验过程中的数据和结果。
bash
wandb offline
# W&B offline. Running your script from this directory will only write metadata locally. Use wandb disabled to completely turn off W&B.
数据集准备
本文使用 C4 数据集进行训练,C4 数据集是由 Google 提供的一个大型预训练数据集,用于训练语言模型。C4 数据集包含了数十亿个网页的内容,经过处理和清洗后形成了一个适合用于训练大型语言模型的数据集。这个数据集可以用于训练各种自然语言处理任务,如文本生成、文本分类、语言建模等。
由于数据集太大,这里只下载了一个文件大约356317条数据。
下载 T5 tokenizer
因为我们是从头开始预训练,所以使用哪个 tokenizer 并不重要,T5 tokenizer是在C4上训练的,而我们也在C4上训练,因此,它是一个不错的选择。
模型训练
为了不影响阅读体验,详细的代码放置在GitHub:llm-action 项目 galore 目录下 torchrun_main.py 文件。
GaLore 初体验,预训练 LLaMA-60M
首先,在 C4 数据集 上训练 60m 模型。
执行脚本:
css
# LLaMA-60M, GaLore-Adam, 1 A100, 1 Node
CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node 1 --master_port=24033 torchrun_main.py \
--model_config configs/llama_60m.json \
--lr 0.01 \
--galore_scale 0.25 \
--rank 128 \
--update_proj_gap 200 \
--batch_size 256 \
--total_batch_size 512 \
--num_training_steps 10000 \
--warmup_steps 1000 \
--weight_decay 0 \
--dtype bfloat16 \
--eval_every 1000 \
--optimizer galore_adamw
运行过程:
less
master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.
===================================BUG REPORT===================================
Welcome to bitsandbytes. For bug reports, please run
python -m bitsandbytes
and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
================================================================================
bin /usr/local/conda/envs/torch201-venv/lib/python3.10/site-packages/bitsandbytes-0.39.1-py3.10.egg/bitsandbytes/libbitsandby tes_cuda118.so
/usr/local/conda/envs/torch201-venv/lib/python3.10/site-packages/bitsandbytes-0.39.1-py3.10.egg/bitsandbytes/cuda_setup/main. py:149: UserWarning: /usr/local/conda/envs/torch201-venv did not contain ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so. 12.0'] as expected! Searching further paths...
warn(msg)
/usr/local/conda/envs/torch201-venv/lib/python3.10/site-packages/bitsandbytes-0.39.1-py3.10.egg/bitsandbytes/cuda_setup/main. py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/usr/l ocal/nvidia/lib64'), PosixPath('/usr/local/cuda/extras/CUPTI/lib64'), PosixPath('/opt/rh/devtoolset-9/root/usr/lib/dyninst'), PosixPath('/usr/local/nvidia/lib')}
warn(msg)
CUDA SETUP: CUDA runtime path found: /usr/local/cuda-11.8/lib64/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 8.0
CUDA SETUP: Detected CUDA version 118
CUDA SETUP: Loading binary /usr/local/conda/envs/torch201-venv/lib/python3.10/site-packages/bitsandbytes-0.39.1-py3.10.egg/bi tsandbytes/libbitsandbytes_cuda118.so...
[2024-03-11 10:07:17,776] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Starting script
2024-03-11 10:07:18.599 | INFO | __main__:main:145 - Global rank 0, local rank 0, device: 0
2024-03-11 10:07:18.600 | INFO | __main__:main:149 - Process group initialized
wandb: Tracking run with wandb version 0.16.4
wandb: W&B syncing is set to `offline` in this directory.
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
2024-03-11 10:07:19.597 | INFO | __main__:main:168 - Using dist with rank 0 (only rank 0 will log)
2024-03-11 10:07:19.598 | INFO | __main__:main:169 - ****************************************
2024-03-11 10:07:19.598 | INFO | __main__:main:170 - Starting training with the arguments
2024-03-11 10:07:19.598 | INFO | __main__:main:172 - model_config configs/llama_60m.json
2024-03-11 10:07:19.598 | INFO | __main__:main:172 - use_hf_model False
2024-03-11 10:07:19.598 | INFO | __main__:main:172 - continue_from None
2024-03-11 10:07:19.598 | INFO | __main__:main:172 - batch_size 256
2024-03-11 10:07:19.598 | INFO | __main__:main:172 - gradient_accumulation 2
2024-03-11 10:07:19.599 | INFO | __main__:main:172 - total_batch_size 512
2024-03-11 10:07:19.599 | INFO | __main__:main:172 - max_length 256
2024-03-11 10:07:19.599 | INFO | __main__:main:172 - optimizer galore_adamw
2024-03-11 10:07:19.599 | INFO | __main__:main:172 - lr 0.01
2024-03-11 10:07:19.599 | INFO | __main__:main:172 - scheduler cosine
2024-03-11 10:07:19.599 | INFO | __main__:main:172 - min_lr_ratio 0.1
2024-03-11 10:07:19.599 | INFO | __main__:main:172 - activation_checkpointing False
2024-03-11 10:07:19.599 | INFO | __main__:main:172 - weight_decay 0.0
2024-03-11 10:07:19.599 | INFO | __main__:main:172 - warmup_steps 1000
2024-03-11 10:07:19.600 | INFO | __main__:main:172 - eval_every 1000
2024-03-11 10:07:19.600 | INFO | __main__:main:172 - num_training_steps 10000
2024-03-11 10:07:19.600 | INFO | __main__:main:172 - max_train_tokens None
2024-03-11 10:07:19.600 | INFO | __main__:main:172 - save_every 10000
2024-03-11 10:07:19.600 | INFO | __main__:main:172 - save_dir checkpoints/llama_60m-2024-03-11-10-0 7-18
2024-03-11 10:07:19.600 | INFO | __main__:main:172 - tags None
2024-03-11 10:07:19.600 | INFO | __main__:main:172 - dtype bfloat16
2024-03-11 10:07:19.600 | INFO | __main__:main:172 - workers 8
2024-03-11 10:07:19.600 | INFO | __main__:main:172 - seed 0
2024-03-11 10:07:19.601 | INFO | __main__:main:172 - name test
2024-03-11 10:07:19.601 | INFO | __main__:main:172 - grad_clipping 0.0
2024-03-11 10:07:19.601 | INFO | __main__:main:172 - beta1 0.0
2024-03-11 10:07:19.601 | INFO | __main__:main:172 - rank 128
2024-03-11 10:07:19.601 | INFO | __main__:main:172 - update_proj_gap 200
2024-03-11 10:07:19.601 | INFO | __main__:main:172 - galore_scale 0.25
2024-03-11 10:07:19.601 | INFO | __main__:main:172 - proj_type std
2024-03-11 10:07:19.601 | INFO | __main__:main:172 - single_gpu False
2024-03-11 10:07:19.601 | INFO | __main__:main:173 - ****************************************
Downloading data files: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 3708.49it/s]
Extracting data files: 100%|██████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 919.60it/s]
Generating train split: 356317 examples [00:02, 125738.07 examples/s]
/usr/local/conda/envs/torch201-venv/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been supe rseded by promote_options='default'.
table = cls._concat_blocks(blocks, axis=0)
DatasetDict({
train: Dataset({
features: ['text', 'timestamp', 'url'],
num_rows: 356317
})
})
2024-03-11 10:08:42.623 | INFO | __main__:main:182 - Shuffling data with seed 42
wandb: WARNING Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb .save("/mnt/folder/file.h5", base_path="/mnt")
Update steps: 0%| | 0/10000 [00:00<?, ?it/s]
enable GaLore for weights in module: model.layers.0.self_attn.q_proj
enable GaLore for weights in module: model.layers.0.self_attn.k_proj
enable GaLore for weights in module: model.layers.0.self_attn.v_proj
enable GaLore for weights in module: model.layers.0.self_attn.o_proj
enable GaLore for weights in module: model.layers.0.mlp.gate_proj
enable GaLore for weights in module: model.layers.0.mlp.down_proj
enable GaLore for weights in module: model.layers.0.mlp.up_proj
...
enable GaLore for weights in module: model.layers.7.self_attn.q_proj
enable GaLore for weights in module: model.layers.7.self_attn.k_proj
enable GaLore for weights in module: model.layers.7.self_attn.v_proj
enable GaLore for weights in module: model.layers.7.self_attn.o_proj
enable GaLore for weights in module: model.layers.7.mlp.gate_proj
enable GaLore for weights in module: model.layers.7.mlp.down_proj
enable GaLore for weights in module: model.layers.7.mlp.up_proj
2024-03-11 10:08:45.614 | INFO | __main__:main:294 -
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32000, 512, padding_idx=31999)
(layers): ModuleList(
(0-7): 8 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=512, out_features=512, bias=False)
(k_proj): Linear(in_features=512, out_features=512, bias=False)
(v_proj): Linear(in_features=512, out_features=512, bias=False)
(o_proj): Linear(in_features=512, out_features=512, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=512, out_features=1376, bias=False)
(down_proj): Linear(in_features=1376, out_features=512, bias=False)
(up_proj): Linear(in_features=512, out_features=1376, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=512, out_features=32000, bias=False)
)
2024-03-11 10:08:45.614 | INFO | __main__:main:295 - Total params: 58.07M
2024-03-11 10:08:45.615 | INFO | __main__:main:296 - Trainable params: 58.07M
2024-03-11 10:08:45.615 | INFO | __main__:main:297 - Total params with GaLore enabled: 25.30M
2024-03-11 10:08:45.615 | INFO | __main__:main:298 - Saving model to checkpoints/llama_60m-2024-03-11-10-07-18 every 1000 0 update steps
/home/guodong.li/GaLore/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be remov ed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to dis able this warning
warnings.warn(
Update steps: 7%|█▌ | 696/10000 [06:19<1:21:46, 1.90it/s]2024-03-11 10:15:05.670 | INFO | __main__ :main:523 - Training finished
Update steps: 7%|█▌ | 696/10000 [06:20<1:24:40, 1.83it/s]
2024-03-11 10:15:05.672 | INFO | __main__:main:528 - Saving model and optimizer to checkpoints/llama_60m-2024-03-11-10-07 -18/model_696, update step 696
2024-03-11 10:15:05.952 | INFO | __main__:main:554 - Running final evaluation
/usr/local/conda/envs/torch201-venv/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been supe rseded by promote_options='default'.
table = cls._concat_blocks(blocks, axis=0)
DatasetDict({
train: Dataset({
features: ['text', 'timestamp', 'url'],
num_rows: 356317
})
})
2024-03-11 10:16:26.307 | INFO | __main__:evaluate_model:93 - Loaded validation dataset in 80.13 seconds
Map: 100%|██████████████████████████████████████████████████████████████████| 356317/356317 [02:10<00:00, 2731.25 examples/s]
2024-03-11 10:18:36.788 | INFO | __main__:evaluate_model:109 - Eval set prepared in 210.61 seconds
2024-03-11 10:19:07.928 | INFO | __main__:main:571 - Final eval loss: 4.445544719696045
2024-03-11 10:19:07.929 | INFO | __main__:main:573 - Script finished successfully
Rank 0 finished successfully
wandb:
wandb: Run history:
wandb: final_eval_loss ▁
wandb: final_eval_tokens ▁
wandb: loss █▇▄▄▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: lr ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb: throughput_batches ▄▄▅▆▃▁▅▃▂▃▆▂▂▃▄▂▃▂▃▃█▃▃▃▃▂▂▃▃▃▃▂▁▃▃▂▂▂▂▆
wandb: throughput_examples ▄▄▅▆▃▁▅▃▂▃▆▂▂▃▄▂▃▂▃▃█▃▃▃▃▂▂▃▃▃▃▂▁▃▃▂▂▂▂▆
wandb: throughput_tokens ▅██▄▅▆▆▅▆▄▇▅▇▆▃▇▆▄▄▅█▄▅▄▅▄▆▆▅▅▅▇▁▇▆▇▇▆▄▇
wandb: tokens_seen ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb: update_step ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:
wandb: Run summary:
wandb: final_eval_loss 4.44554
wandb: final_eval_tokens 10032866
wandb: loss 4.53125
wandb: lr 0.00696
wandb: throughput_batches 3.8724
wandb: throughput_examples 991.33511
wandb: throughput_tokens 192032.45359
wandb: tokens_seen 69593341
wandb: update_step 696
wandb:
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /home/guodong.li/GaLore/wandb/offline-run-20240311_100719-qlkq3zcd
wandb: Find logs at: ./wandb/offline-run-20240311_100719-qlkq3zcd/logs
单张 4090 消费级显卡预训练 LLaMA-7B
接下来,使用单个 GPU(例如:NVIDIA RTX 4090)训练 7B 模型,您所需要做的就是指定 --optimizer=galore_adamw8bit_per_layer
,这会启用 GaLoreAdamW8bit
并进行每层权重更新。通过激活(梯度)检查点(activation checkpointing),您可以将在 NVIDIA RTX 4090 上测试的批量大小保持为 16。
我懒得去找机器了,就在 A800 上面勉强测试一下吧。
执行命令:
css
CUDA_VISIBLE_DEVICES=7 torchrun --standalone --nproc_per_node 1 torchrun_main.py \
--model_config configs/llama_7b.json \
--lr 0.005 \
--galore_scale 0.25 \
--rank 1024 \
--update_proj_gap 500 \
--batch_size 16 \
--total_batch_size 512 \
--activation_checkpointing \
--num_training_steps 150000 \
--warmup_steps 15000 \
--weight_decay 0 \
--grad_clipping 1.0 \
--dtype bfloat16 \
--eval_every 1000 \
--single_gpu \
--optimizer galore_adamw8bit_per_layer
训练时长:
bash
Update steps: 0%| | 31/150000 [37:14<1842:28:16, 44.23s/it]
显存占用:
diff
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 7 N/A N/A 13343 C .../torch210-venv/bin/python 24192MiB |
+-----------------------------------------------------------------------------+
目前,使用每层权重更新 技术仅支持单 GPU 训练 ( --single_gpu ),而不能使用 nn.parallel.DistributedDataParallel
。
接下来指定优化器参数--optimizer=galore_adamw
,同时,启动梯度检查点进行多机多卡训练。
css
CUDA_VISIBLE_DEVICES=4,7 torchrun --standalone --nproc_per_node 2 --master_port=24033 torchrun_main.py \
--model_config configs/llama_7b.json \
--lr 0.01 \
--galore_scale 0.25 \
--rank 1024 \
--update_proj_gap 200 \
--batch_size 2 \
--total_batch_size 8 \
--activation_checkpointing \
--num_training_steps 10000 \
--warmup_steps 1000 \
--weight_decay 0 \
--dtype bfloat16 \
--eval_every 1000 \
--optimizer galore_adamw
训练时长:
bash
Update steps: 1%|▏ | 64/10000 [10:31<4:50:05, 1.75s/it]
显存占用:
bash
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 4 N/A N/A 33406 C .../torch210-venv/bin/python 76368MiB |
| 7 N/A N/A 33407 C .../torch210-venv/bin/python 76960MiB |
+-----------------------------------------------------------------------------+
微调 RoBERTa
运行脚本:
css
python run_glue.py \
--model_name_or_path roberta-base \
--task_name mrpc \
--enable_galore \
--lora_all_modules \
--max_length 512 \
--seed=1234 \
--lora_r 4 \
--galore_scale 4 \
--per_device_train_batch_size 16 \
--update_proj_gap 500 \
--learning_rate 3e-5 \
--num_train_epochs 30 \
--output_dir results/ft/roberta_base/mrpc
结语
本文简要介绍了GaLore,同时演示了使用GaLore进行预训练。虽然 GaLore 使用每层权重更新 技术,同时,使用激活检查点,能够在单张消费级显卡上面预训练LLaMA-7B,但是使用每层权重更新技术仅支持单 GPU 训练,因此,仅是理论上可行,如果大家数据量不大,想使用其进行LLaMA-7B全量微调还是不错的,给只有3090/4090消费级显卡的朋友,除了使用ZeRO3 Offload之外,又多了一个选择。大家可以再观望一下 GaLore 后续的一些工作。
码字不易,如果觉得我的文章能够能够给您带来帮助,期待您的点赞收藏加关注~~