用 bitsandbytes、4 比特量化和 QLoRA 训练 LLM,peft源码解读

QLoRA 使用 4 比特量化来压缩预训练的语言模型。然后冻结基础模型的参数,并将相对少量的可训练参数以低秩适配器的形式添加到模型中。在微调过程中,QLoRA 通过冻结的 4 比特量化预训练语言模型将梯度反向传播到低秩适配器中。LoRA 层的权重是训练期间唯一可更新的参数。

huggingface 官方博客:huggingface.co/blog/zh/4bi...

官方代码:gist.github.com/younesbelka...

模型量化,微调代码如下:

python 复制代码
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

dataset = load_dataset(script_args.dataset_name, split="train[:1%]")

# We load the model
if script_args.use_multi_gpu:
    device_map = "auto"
else:
    device_map = {"":get_current_device()}

if script_args.use_8_bit and script_args.use_4_bit:
    raise ValueError(
        "You can't use 8 bit and 4 bit precision at the same time"
    )

if script_args.use_4_bit:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type=script_args.bnb_4bit_quant_type,
        bnb_4bit_use_double_quant=script_args.use_bnb_nested_quant,
    )   
else:
    bnb_config = None

transformers_class = AutoModelForSeq2SeqLM if script_args.use_seq2seq_lm else AutoModelForCausalLM

model = transformers_class.from_pretrained(
    script_args.model_name, 
    load_in_8bit=script_args.use_8_bit, 
    load_in_4bit=script_args.use_4_bit,
    device_map=device_map if (script_args.use_8_bit or script_args.use_4_bit) else None,
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
)

if script_args.use_adapters:
    peft_config = LoraConfig(
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM" if not script_args.use_seq2seq_lm else "SEQ_2_SEQ_LM",
    )
else:
    peft_config = None
    if script_args.use_8_bit:
        raise ValueError(
            "You need to use adapters to use 8 bit precision"
        )

if "llama" in script_args.model_name:
    tokenizer = LlamaTokenizer.from_pretrained(script_args.model_name)
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
else:
    tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)

with tempfile.TemporaryDirectory() as tmp_dir:
    training_arguments = TrainingArguments(
        per_device_train_batch_size=script_args.batch_size,
        max_steps=10,
        gradient_accumulation_steps=4,
        per_device_eval_batch_size=script_args.batch_size,
        output_dir=tmp_dir,
        report_to=["none"],
        optim=script_args.optimizer_name,
        fp16=True,
    )

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=dataset,
        dataset_text_field="messages",
        peft_config=peft_config,
        max_seq_length=script_args.max_seq_length,
        args=training_arguments,
    )

    trainer.train()

本文主要讲述peft微调的源代码,不对训练代码解读。

前言:模型量化技术

目前大模型有4种主流的量化方案

  1. bitsandbytes

  2. GPTQ

  3. AWQ

  4. llama.cpp

    其中,llama.cpp适用于arm等cpu运行平台,GPTQ、AWQ适用于CUDA平台。

bitsandbytes 更适合微。根据这一观察,获得最佳合并模型的一种方法是:

  • (1) 使用 bitsandbytes 量化基础模型 (零样本量化)
  • (2) 添加并微调适配器
  • (3) 将训练后的适配器合并到基础模型或 反量化模型 之中!
  • (4) 使用 GPTQ 量化合并后的模型并将其用于部署

peft源码解读

peft代码之间层层调用,这里就不根据程序调用顺序讲解,直接按照模型初始化,运行的顺序。即主要讲解下面两个方法:

  1. init()
  2. forward()

init方法

模型初始化首先调用的是inject_adapter

python 复制代码
def inject_adapter(self, model: nn.Module, adapter_name: str):

    peft_config = self.peft_config[adapter_name]
    # Note: If possible, all checks should be performed *at the start of this method*.
    # This way, we can raise early if something goes wrong, without leaving the model
    # in a bad (half-initialized) state.
    self._check_new_adapter_config(peft_config)

    is_target_modules_in_base_model = False
    key_list = [key for key, _ in model.named_modules()]

    _check_for_modules_to_save = getattr(peft_config, "modules_to_save", None) is not None
    _has_modules_to_save = False

    model_config = getattr(model, "config", {"model_type": "custom"})
    if hasattr(model_config, "to_dict"):
        model_config = model_config.to_dict()

    peft_config = self._prepare_adapter_config(peft_config, model_config)

    for key in key_list:
        # Check for modules_to_save in case
        if _check_for_modules_to_save and any(
            key.endswith(f"{module_to_save}") for module_to_save in peft_config.modules_to_save
        ):
            # Optionally set the modules to save
            parent, target, target_name = _get_submodules(model, key)

            if not isinstance(target, ModulesToSaveWrapper):
                new_module = ModulesToSaveWrapper(target, adapter_name)
                setattr(parent, target_name, new_module)
            else:
                target.update(adapter_name)

            _has_modules_to_save = True
            continue

        if not self._check_target_module_exists(peft_config, key):
            continue

        is_target_modules_in_base_model = True
        parent, target, target_name = _get_submodules(model, key)

        optional_kwargs = {
            "loaded_in_8bit": getattr(model, "is_loaded_in_8bit", False),
            "loaded_in_4bit": getattr(model, "is_loaded_in_4bit", False),
            "current_key": key,
        }
        self._create_and_replace(peft_config, adapter_name, target, target_name, parent, **optional_kwargs)

    if not is_target_modules_in_base_model:
        raise ValueError(
            f"Target modules {peft_config.target_modules} not found in the base model. "
            f"Please check the target modules and try again."
        )

    self._mark_only_adapters_as_trainable()

    if self.peft_config[adapter_name].inference_mode:
        for n, p in self.model.named_parameters():
            if adapter_name in n:
                p.requires_grad = False

    if _has_modules_to_save:
        if not hasattr(model, "modules_to_save"):
            model.modules_to_save = set(peft_config.modules_to_save)
        else:
            model.modules_to_save.update(set(peft_config.modules_to_save))

主要进行了两种模块的微调:

  1. module_to_save
  2. target_module

module_to_save使用了包装器ModulesToSaveWrapper

target_model使用了以下方式

python 复制代码
def _create_and_replace(
    self,
    lora_config,
    adapter_name,
    target,
    target_name,
    parent,
    current_key,
    **optional_kwargs,
):
    if current_key is None:
        raise ValueError("Current Key shouldn't be `None`")
    # Regexp matching - Find key which matches current target_name in patterns provided
    pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys()))
    target_name_key = next(filter(lambda key: re.match(f".*\.{key}$", current_key), pattern_keys), current_key)

    r = lora_config.rank_pattern.get(target_name_key, lora_config.r)
    alpha = lora_config.alpha_pattern.get(target_name_key, lora_config.lora_alpha)
    bias = hasattr(target, "bias") and target.bias is not None
    kwargs = {
        "r": r,
        "lora_alpha": alpha,
        "lora_dropout": lora_config.lora_dropout,
        "fan_in_fan_out": lora_config.fan_in_fan_out,
        "init_lora_weights": lora_config.init_lora_weights,
    }
    kwargs["loaded_in_8bit"] = optional_kwargs.pop("loaded_in_8bit", False)
    kwargs["loaded_in_4bit"] = optional_kwargs.pop("loaded_in_4bit", False)
    kwargs["bias"] = bias

    quantization_config = get_quantization_config(self.model, method="gptq")
    if quantization_config is not None:
        kwargs["gptq_quantization_config"] = quantization_config

    # TODO: better deal with that
    if isinstance(target, Conv2d):
        target.update_layer_conv2d(
            adapter_name,
            r,
            alpha,
            lora_config.lora_dropout,
            lora_config.init_lora_weights,
        )
    elif isinstance(target, Embedding):
        target.update_layer_embedding(
            adapter_name,
            r,
            alpha,
            lora_config.lora_dropout,
            lora_config.init_lora_weights,
        )
    elif isinstance(target, Linear):
        target.update_layer(
            adapter_name,
            r,
            alpha,
            lora_config.lora_dropout,
            lora_config.init_lora_weights,
        )
    else:
        new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
        if adapter_name != self.active_adapter:
            # adding an additional adapter: it is not automatically trainable
            new_module.requires_grad_(False)
        self._replace_module(parent, target_name, new_module, target)
python 复制代码
def _create_new_module(lora_config, adapter_name, target, **kwargs):
    gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
    AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)

    loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
    loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)

    if isinstance(target, BaseTunerLayer):
        target_base_layer = target.get_base_layer()
    else:
        target_base_layer = target

    megatron_core = None
    if lora_config.megatron_config:
        megatron_core = importlib.import_module(lora_config.megatron_core)

    if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
        eightbit_kwargs = kwargs.copy()
        eightbit_kwargs.update(
            {
                "has_fp16_weights": target.state.has_fp16_weights,
                "memory_efficient_backward": target.state.memory_efficient_backward,
                "threshold": target.state.threshold,
                "index": target.index,
            }
        )
        new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs)
    elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
        fourbit_kwargs = kwargs.copy()
        fourbit_kwargs.update(
            {
                "compute_dtype": target.compute_dtype,
                "compress_statistics": target.weight.compress_statistics,
                "quant_type": target.weight.quant_type,
            }
        )
        new_module = Linear4bit(target, adapter_name, **fourbit_kwargs)
    elif AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear):
        new_module = QuantLinear(target, adapter_name, **kwargs)
        target.weight = target.qweight
    elif isinstance(target_base_layer, torch.nn.Embedding):
        embedding_kwargs = kwargs.copy()
        embedding_kwargs.pop("fan_in_fan_out", None)
        embedding_kwargs.update(lora_config.loftq_config)
        new_module = Embedding(target, adapter_name, **embedding_kwargs)
    elif isinstance(target_base_layer, torch.nn.Conv2d):
        kwargs.update(lora_config.loftq_config)
        new_module = Conv2d(target, adapter_name, **kwargs)
    elif isinstance(target_base_layer, torch.nn.Linear):
        if kwargs["fan_in_fan_out"]:
            warnings.warn(
                "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
                "Setting fan_in_fan_out to False."
            )
            kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
        kwargs.update(lora_config.loftq_config)
        new_module = Linear(target, adapter_name, **kwargs)
    elif megatron_core and isinstance(
        target_base_layer,
        (megatron_core.tensor_parallel.ColumnParallelLinear, megatron_core.tensor_parallel.RowParallelLinear),
    ):
        from .tp_layer import LoraParallelLinear

        megatron_kwargs = kwargs.copy()
        megatron_config = lora_config.megatron_config
        if isinstance(megatron_config, dict):
            transformer_config_class = megatron_core.transformer.transformer_config.TransformerConfig
            megatron_config = transformer_config_class(**lora_config.megatron_config)
        megatron_kwargs["megatron_config"] = megatron_config
        if megatron_kwargs["fan_in_fan_out"]:
            warnings.warn(
                "fan_in_fan_out is set to True but the target module is `ColumnParallelLinear` "
                "or `RowParallelLinear`. "
                "Setting fan_in_fan_out to False."
            )
            megatron_kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
        new_module = LoraParallelLinear(
            base_layer=target, adapter_name=adapter_name, backend=megatron_core.tensor_parallel, **megatron_kwargs
        )
    elif isinstance(target_base_layer, Conv1D):
        if not kwargs["fan_in_fan_out"]:
            warnings.warn(
                "fan_in_fan_out is set to False but the target module is `Conv1D`. "
                "Setting fan_in_fan_out to True."
            )
            kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
        kwargs.update(lora_config.loftq_config)
        new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs)
    else:
        raise ValueError(
            f"Target module {target} is not supported. Currently, only the following modules are supported: "
            "`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`."
        )

    return new_module

_create_new_module,本来可以只展示部分核心代码,这里贴出了全部代码,是因为,我们需要了解peft支持的模块类型,一般是将Linear模块用于微调,仅支持以下两种量化模型

  1. bitsandbytes
  2. gptq

请注意不支持任何其他格式的量化模型进行peft微调

python 复制代码
class Linear4bit(torch.nn.Module, LoraLayer):
# Lora implemented in a dense layer
def __init__(
    self,
    base_layer: torch.nn.Module,
    adapter_name: str,
    r: int = 0,
    lora_alpha: int = 1,
    lora_dropout: float = 0.0,
    init_lora_weights: bool = True,
    **kwargs,
) -> None:
    super().__init__()
    LoraLayer.__init__(self, base_layer)

    self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:

    if self.merged:
        warnings.warn(
            f"Already following adapters were merged {','.join(self.merged_adapters)}. "
            f"You are now additionally merging {','.join(self.active_adapters)}."
        )

    if adapter_names is None:
        adapter_names = self.active_adapters

    for active_adapter in adapter_names:
        if active_adapter not in self.lora_A.keys():
            continue
        warnings.warn(
            "Merge lora module to 4-bit linear may get different generations due to rounding errors."
        )
        # Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
        weight = self.get_base_layer().weight
        kwargs = weight.__dict__
        lora_data = self.get_delta_weight(active_adapter)

        w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) + lora_data
        if safe_merge and not torch.isfinite(w_data).all():
            raise ValueError(
                f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
            )

        self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
            weight.device
        )
        self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
    """
    This method unmerges all merged adapter layers from the base weights.
    """
    if not self.merged:
        warnings.warn("Already unmerged. Nothing to do.")
        return

    while len(self.merged_adapters) > 0:
        active_adapter = self.merged_adapters.pop()
        if active_adapter not in self.lora_A.keys():
            continue
        warnings.warn(
            "Unmerge lora module to 4-bit linear may get different generations due to rounding errors."
        )
        weight = self.get_base_layer().weight
        kwargs = weight.__dict__
        lora_data = self.get_delta_weight(active_adapter)
        w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - lora_data
        self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
            weight.device
        )

def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
    if self.disable_adapters:
        if self.merged:
            self.unmerge()
        result = self.base_layer(x, *args, **kwargs)
    elif self.merged:
        result = self.base_layer(x, *args, **kwargs)
    else:
        result = self.base_layer(x, *args, **kwargs)
        # As per Tim Dettmers, for 4bit, we need to defensively clone here.
        # The reason is that in some cases, an error can occur that backprop
        # does not work on a manipulated view. This issue may be solved with
        # newer PyTorch versions but this would need extensive testing to be
        # sure.
        result = result.clone()

        for active_adapter in self.active_adapters:
            if active_adapter not in self.lora_A.keys():
                continue
            lora_A = self.lora_A[active_adapter]
            lora_B = self.lora_B[active_adapter]
            dropout = self.lora_dropout[active_adapter]
            scaling = self.scaling[active_adapter]

            requires_conversion = not torch.is_autocast_enabled()
            if requires_conversion:
                expected_dtype = result.dtype
                x = x.to(lora_A.weight.dtype)

            output = lora_B(lora_A(dropout(x)))
            if requires_conversion:
                output = output.to(expected_dtype)
            output = output * scaling
            result += output

    return result

def __repr__(self) -> str:
    rep = super().__repr__()
    return "lora." + rep

重点是Linear4bit,他是对LoraModel的一个继承,其forward方法是对lora的一个具体实现

请注意merged和unmerge方法

  • merged用于将模型和lora模块融合
  • unmerge用于将模型和lora模块分离

其关键在于,一个基座模型可以有微调和不微调,在ppo、dpo训练中很重要,一个模型根据使用情况可以代表不同的模型,避免重复占用显存。

请注意以下代码

python 复制代码
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
    expected_dtype = result.dtype
    x = x.to(lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
    output = output.to(expected_dtype)

因为loraA、loraB的参数类型未知,但大模型训练是使用bf16或者fp16,所以需要转化,一般来说,使用默认的设置,不手动设置参数为其他类型就没问题,建议不要自定义lora参数层的数据类型是fp32、或者int4、int8。可以查看

ModulesToSaveWrapper没有数据类型转换的代码,因此自定义参数类型会造成运行失败

相关推荐
卷心菜小温10 小时前
【BUG】P-tuningv2微调ChatGLM2-6B时所踩的坑
python·深度学习·语言模型·nlp·bug
爱喝白开水a13 小时前
关于大模型在企业生产环境中的独立部署问题
人工智能·深度学习·llm·大语言模型·ai大模型·计算机技术·本地部署大模型
Langchain14 小时前
不可错过!CMU最新《生成式人工智能大模型》课程:从文本、图像到多模态大模型
人工智能·自然语言处理·langchain·大模型·llm·大语言模型·多模态大模型
龙的爹233315 小时前
论文翻译 | Generated Knowledge Prompting for Commonsense Reasoning
人工智能·gpt·机器学习·语言模型·自然语言处理·nlp·prompt
龙的爹233315 小时前
论文翻译 | Model-tuning Via Prompts Makes NLP Models Adversarially Robust
人工智能·gpt·语言模型·自然语言处理·nlp·prompt
幽影相随15 小时前
构建llama.cpp并在linux上使用gpu
llm·llama.cpp
AAI机器之心16 小时前
LLM大模型:开源RAG框架汇总
人工智能·chatgpt·开源·大模型·llm·大语言模型·rag
网安-搬运工1 天前
RAG再总结之如何使大模型更好使用外部数据:四个不同层级及查询-文档对齐策略
人工智能·自然语言处理·大模型·llm·大语言模型·ai大模型·rag
大模型八哥1 天前
大模型扫盲系列——大模型实用技术介绍(上)
人工智能·程序人生·ai·大模型·llm·llama·ai大模型
我爱学Python!2 天前
基于 LangChain 的自动化测试用例的生成与执行
人工智能·自然语言处理·langchain·自动化·llm·测试用例·大语言模型