QLoRA 使用 4 比特量化来压缩预训练的语言模型。然后冻结基础模型的参数,并将相对少量的可训练参数以低秩适配器的形式添加到模型中。在微调过程中,QLoRA 通过冻结的 4 比特量化预训练语言模型将梯度反向传播到低秩适配器中。LoRA 层的权重是训练期间唯一可更新的参数。
huggingface 官方博客:huggingface.co/blog/zh/4bi...
官方代码:gist.github.com/younesbelka...
模型量化,微调代码如下:
python
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
dataset = load_dataset(script_args.dataset_name, split="train[:1%]")
# We load the model
if script_args.use_multi_gpu:
device_map = "auto"
else:
device_map = {"":get_current_device()}
if script_args.use_8_bit and script_args.use_4_bit:
raise ValueError(
"You can't use 8 bit and 4 bit precision at the same time"
)
if script_args.use_4_bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type=script_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=script_args.use_bnb_nested_quant,
)
else:
bnb_config = None
transformers_class = AutoModelForSeq2SeqLM if script_args.use_seq2seq_lm else AutoModelForCausalLM
model = transformers_class.from_pretrained(
script_args.model_name,
load_in_8bit=script_args.use_8_bit,
load_in_4bit=script_args.use_4_bit,
device_map=device_map if (script_args.use_8_bit or script_args.use_4_bit) else None,
quantization_config=bnb_config,
torch_dtype=torch.float16,
)
if script_args.use_adapters:
peft_config = LoraConfig(
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM" if not script_args.use_seq2seq_lm else "SEQ_2_SEQ_LM",
)
else:
peft_config = None
if script_args.use_8_bit:
raise ValueError(
"You need to use adapters to use 8 bit precision"
)
if "llama" in script_args.model_name:
tokenizer = LlamaTokenizer.from_pretrained(script_args.model_name)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
else:
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)
with tempfile.TemporaryDirectory() as tmp_dir:
training_arguments = TrainingArguments(
per_device_train_batch_size=script_args.batch_size,
max_steps=10,
gradient_accumulation_steps=4,
per_device_eval_batch_size=script_args.batch_size,
output_dir=tmp_dir,
report_to=["none"],
optim=script_args.optimizer_name,
fp16=True,
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="messages",
peft_config=peft_config,
max_seq_length=script_args.max_seq_length,
args=training_arguments,
)
trainer.train()
本文主要讲述peft微调的源代码,不对训练代码解读。
前言:模型量化技术
目前大模型有4种主流的量化方案
-
bitsandbytes
-
GPTQ
-
AWQ
-
llama.cpp
其中,llama.cpp适用于arm等cpu运行平台,GPTQ、AWQ适用于CUDA平台。
bitsandbytes 更适合微。根据这一观察,获得最佳合并模型的一种方法是:
- (1) 使用 bitsandbytes 量化基础模型 (零样本量化)
- (2) 添加并微调适配器
- (3) 将训练后的适配器合并到基础模型或 反量化模型 之中!
- (4) 使用 GPTQ 量化合并后的模型并将其用于部署
peft源码解读
peft代码之间层层调用,这里就不根据程序调用顺序讲解,直接按照模型初始化,运行的顺序。即主要讲解下面两个方法:
- init()
- forward()
init方法
模型初始化首先调用的是inject_adapter
python
def inject_adapter(self, model: nn.Module, adapter_name: str):
peft_config = self.peft_config[adapter_name]
# Note: If possible, all checks should be performed *at the start of this method*.
# This way, we can raise early if something goes wrong, without leaving the model
# in a bad (half-initialized) state.
self._check_new_adapter_config(peft_config)
is_target_modules_in_base_model = False
key_list = [key for key, _ in model.named_modules()]
_check_for_modules_to_save = getattr(peft_config, "modules_to_save", None) is not None
_has_modules_to_save = False
model_config = getattr(model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
peft_config = self._prepare_adapter_config(peft_config, model_config)
for key in key_list:
# Check for modules_to_save in case
if _check_for_modules_to_save and any(
key.endswith(f"{module_to_save}") for module_to_save in peft_config.modules_to_save
):
# Optionally set the modules to save
parent, target, target_name = _get_submodules(model, key)
if not isinstance(target, ModulesToSaveWrapper):
new_module = ModulesToSaveWrapper(target, adapter_name)
setattr(parent, target_name, new_module)
else:
target.update(adapter_name)
_has_modules_to_save = True
continue
if not self._check_target_module_exists(peft_config, key):
continue
is_target_modules_in_base_model = True
parent, target, target_name = _get_submodules(model, key)
optional_kwargs = {
"loaded_in_8bit": getattr(model, "is_loaded_in_8bit", False),
"loaded_in_4bit": getattr(model, "is_loaded_in_4bit", False),
"current_key": key,
}
self._create_and_replace(peft_config, adapter_name, target, target_name, parent, **optional_kwargs)
if not is_target_modules_in_base_model:
raise ValueError(
f"Target modules {peft_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)
self._mark_only_adapters_as_trainable()
if self.peft_config[adapter_name].inference_mode:
for n, p in self.model.named_parameters():
if adapter_name in n:
p.requires_grad = False
if _has_modules_to_save:
if not hasattr(model, "modules_to_save"):
model.modules_to_save = set(peft_config.modules_to_save)
else:
model.modules_to_save.update(set(peft_config.modules_to_save))
主要进行了两种模块的微调:
- module_to_save
- target_module
module_to_save使用了包装器ModulesToSaveWrapper
target_model使用了以下方式
python
def _create_and_replace(
self,
lora_config,
adapter_name,
target,
target_name,
parent,
current_key,
**optional_kwargs,
):
if current_key is None:
raise ValueError("Current Key shouldn't be `None`")
# Regexp matching - Find key which matches current target_name in patterns provided
pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys()))
target_name_key = next(filter(lambda key: re.match(f".*\.{key}$", current_key), pattern_keys), current_key)
r = lora_config.rank_pattern.get(target_name_key, lora_config.r)
alpha = lora_config.alpha_pattern.get(target_name_key, lora_config.lora_alpha)
bias = hasattr(target, "bias") and target.bias is not None
kwargs = {
"r": r,
"lora_alpha": alpha,
"lora_dropout": lora_config.lora_dropout,
"fan_in_fan_out": lora_config.fan_in_fan_out,
"init_lora_weights": lora_config.init_lora_weights,
}
kwargs["loaded_in_8bit"] = optional_kwargs.pop("loaded_in_8bit", False)
kwargs["loaded_in_4bit"] = optional_kwargs.pop("loaded_in_4bit", False)
kwargs["bias"] = bias
quantization_config = get_quantization_config(self.model, method="gptq")
if quantization_config is not None:
kwargs["gptq_quantization_config"] = quantization_config
# TODO: better deal with that
if isinstance(target, Conv2d):
target.update_layer_conv2d(
adapter_name,
r,
alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
)
elif isinstance(target, Embedding):
target.update_layer_embedding(
adapter_name,
r,
alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
)
elif isinstance(target, Linear):
target.update_layer(
adapter_name,
r,
alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
)
else:
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
if adapter_name != self.active_adapter:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
python
def _create_new_module(lora_config, adapter_name, target, **kwargs):
gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
megatron_core = None
if lora_config.megatron_config:
megatron_core = importlib.import_module(lora_config.megatron_core)
if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
eightbit_kwargs = kwargs.copy()
eightbit_kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
"memory_efficient_backward": target.state.memory_efficient_backward,
"threshold": target.state.threshold,
"index": target.index,
}
)
new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs)
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update(
{
"compute_dtype": target.compute_dtype,
"compress_statistics": target.weight.compress_statistics,
"quant_type": target.weight.quant_type,
}
)
new_module = Linear4bit(target, adapter_name, **fourbit_kwargs)
elif AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear):
new_module = QuantLinear(target, adapter_name, **kwargs)
target.weight = target.qweight
elif isinstance(target_base_layer, torch.nn.Embedding):
embedding_kwargs = kwargs.copy()
embedding_kwargs.pop("fan_in_fan_out", None)
embedding_kwargs.update(lora_config.loftq_config)
new_module = Embedding(target, adapter_name, **embedding_kwargs)
elif isinstance(target_base_layer, torch.nn.Conv2d):
kwargs.update(lora_config.loftq_config)
new_module = Conv2d(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, torch.nn.Linear):
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
kwargs.update(lora_config.loftq_config)
new_module = Linear(target, adapter_name, **kwargs)
elif megatron_core and isinstance(
target_base_layer,
(megatron_core.tensor_parallel.ColumnParallelLinear, megatron_core.tensor_parallel.RowParallelLinear),
):
from .tp_layer import LoraParallelLinear
megatron_kwargs = kwargs.copy()
megatron_config = lora_config.megatron_config
if isinstance(megatron_config, dict):
transformer_config_class = megatron_core.transformer.transformer_config.TransformerConfig
megatron_config = transformer_config_class(**lora_config.megatron_config)
megatron_kwargs["megatron_config"] = megatron_config
if megatron_kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `ColumnParallelLinear` "
"or `RowParallelLinear`. "
"Setting fan_in_fan_out to False."
)
megatron_kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
new_module = LoraParallelLinear(
base_layer=target, adapter_name=adapter_name, backend=megatron_core.tensor_parallel, **megatron_kwargs
)
elif isinstance(target_base_layer, Conv1D):
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
"Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
kwargs.update(lora_config.loftq_config)
new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs)
else:
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`."
)
return new_module
_create_new_module,本来可以只展示部分核心代码,这里贴出了全部代码,是因为,我们需要了解peft支持的模块类型,一般是将Linear模块用于微调,仅支持以下两种量化模型
- bitsandbytes
- gptq
请注意不支持任何其他格式的量化模型进行peft微调
python
class Linear4bit(torch.nn.Module, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
base_layer: torch.nn.Module,
adapter_name: str,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
if self.merged:
warnings.warn(
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
f"You are now additionally merging {','.join(self.active_adapters)}."
)
if adapter_names is None:
adapter_names = self.active_adapters
for active_adapter in adapter_names:
if active_adapter not in self.lora_A.keys():
continue
warnings.warn(
"Merge lora module to 4-bit linear may get different generations due to rounding errors."
)
# Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
weight = self.get_base_layer().weight
kwargs = weight.__dict__
lora_data = self.get_delta_weight(active_adapter)
w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) + lora_data
if safe_merge and not torch.isfinite(w_data).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
weight.device
)
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter not in self.lora_A.keys():
continue
warnings.warn(
"Unmerge lora module to 4-bit linear may get different generations due to rounding errors."
)
weight = self.get_base_layer().weight
kwargs = weight.__dict__
lora_data = self.get_delta_weight(active_adapter)
w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - lora_data
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
weight.device
)
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
# The reason is that in some cases, an error can occur that backprop
# does not work on a manipulated view. This issue may be solved with
# newer PyTorch versions but this would need extensive testing to be
# sure.
result = result.clone()
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
output = lora_B(lora_A(dropout(x)))
if requires_conversion:
output = output.to(expected_dtype)
output = output * scaling
result += output
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep
重点是Linear4bit,他是对LoraModel的一个继承,其forward方法是对lora的一个具体实现
请注意merged和unmerge方法
- merged用于将模型和lora模块融合
- unmerge用于将模型和lora模块分离
其关键在于,一个基座模型可以有微调和不微调,在ppo、dpo训练中很重要,一个模型根据使用情况可以代表不同的模型,避免重复占用显存。
请注意以下代码
python
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
output = lora_B(lora_A(dropout(x)))
if requires_conversion:
output = output.to(expected_dtype)
因为loraA、loraB的参数类型未知,但大模型训练是使用bf16或者fp16,所以需要转化,一般来说,使用默认的设置,不手动设置参数为其他类型就没问题,建议不要自定义lora参数层的数据类型是fp32、或者int4、int8。可以查看
ModulesToSaveWrapper没有数据类型转换的代码,因此自定义参数类型会造成运行失败