【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏

【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏


目录

文章目录


BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏

摘要

本文介绍了BitDistiller,这是一个通过结合量化感知训练(QAT)和知识蒸馏(KD)来提升超低精度(亚4比特)大型语言模型(LLMs)性能的框架。BitDistiller首先采用定制的非对称量化和裁剪技术来尽可能保持量化权重的保真度,然后提出了一种新颖的基于置信度的Kullback-Leibler散度(CAKLD)目标,用于自蒸馏,以实现更快的收敛和更优的模型性能。实验评估表明,BitDistiller在3比特和2比特配置下,无论是在通用语言理解还是复杂推理基准测试中,都显著超越了现有方法。值得注意的是,BitDistiller更具成本效益,需要更少的数据和训练资源。

研究背景

随着大型语言模型(LLMs)规模的扩大,自然语言处理领域取得了令人印象深刻的进展。然而,这种模型规模的扩大在部署上带来了显著的挑战,尤其是在资源受限的设备上,因为它们需要大量的内存和计算能力。权重量化作为一种流行的策略,通过减少模型大小来提高LLMs的效率和可访问性,同时最小化性能损失。尽管4比特量化已被广泛采用,提供了显著的压缩比和保留LLM能力之间的平衡,但亚4比特量化会显著降低模型权重的保真度,尤其是在小型模型或需要复杂推理的任务中,导致模型性能恶化。

问题与挑战

在极端低比特QAT中实现高性能的两个基本挑战是:如何在量化过程中最大限度地保持权重保真度,以及如何在训练中有效学习低比特表示。

如何解决

BitDistiller通过以下方式解决上述挑战:

  1. 非对称量化和裁剪:BitDistiller采用了定制的非对称量化和裁剪策略,以保持全精度模型的能力,特别是在超低比特水平上。
  2. 自蒸馏:BitDistiller利用全精度模型作为教师,低比特模型作为学生,通过自蒸馏方法进行有效的低比特表示学习。
  3. CAKLD目标:BitDistiller创新性地提出了一种基于置信度的Kullback-Leibler散度(CAKLD)目标,优化知识传递效率,实现更快的收敛和增强的模型性能。

创新点

  • 非对称量化和裁剪:BitDistiller针对不同比特级别的量化采用了不同的量化策略,如NF格式和INT格式,以及非对称裁剪,以提高量化权重的表示保真度。
  • CAKLD目标:BitDistiller提出了一种新颖的CAKLD目标,它根据全精度模型对训练数据的置信度自动权衡模式寻求和模式覆盖行为。
  • 自蒸馏框架 :BitDistiller将QAT与知识蒸馏相结合,使用全精度模型作为教师来指导低比特学生模型,这是一种简单而有效的自蒸馏方法。

算法模型

BitDistiller的框架包括以下几个关键步骤:

  1. 非对称量化和裁剪:在QAT初始化阶段,BitDistiller对权重进行非对称裁剪,以减少量化误差。
  2. 自蒸馏:在训练过程中,全精度模型生成数据,低比特模型学习这些数据,通过CAKLD目标进行优化。
  3. CAKLD目标 :CAKLD目标结合了反向KL散度和正向KL散度,根据全精度模型的置信度自动调整模式寻求和模式覆盖行为。

实验效果

实验评估表明,BitDistiller在3比特和2比特配置下的性能显著优于现有的PTQ和QAT方法。以下是一些重要的数据和结论:

  • 语言建模任务:在WikiText-2的困惑度(PPL)和MMLU(5-shot)准确性方面,BitDistiller超越了竞争对手。
  • 推理任务:在HumanEval和GSM8K等推理基准测试中,BitDistiller在3比特和2比特量化中均展现出优越性能。
  • 成本效益 :BitDistiller需要的训练数据和资源更少,更具成本效益。




代码

https://github.com/DD-DuDa/BitDistiller.git

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from tqdm import tqdm
import gc
# import bitsandbytes as bnb
import torch.nn as nn
from functools import partial
# import bitsandbytes.functional as bnbF

class Round(Function):
    @staticmethod
    def forward(self, input):
        sign = torch.sign(input)
        output = sign * torch.floor(torch.abs(input) + 0.5)
        return output

    @staticmethod
    def backward(self, grad_output):
        grad_input = grad_output.clone()
        return grad_input

# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8,
                           zero_point=True, q_group_size=-1,
                           inplace=False,
                           get_scale_zp=False
                           ):
    org_w_shape = w.shape
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)
    elif q_group_size == -1:
        w = w.reshape(-1, w.shape[-1])
    assert w.dim() == 2
    if zero_point:
        max_val = w.amax(dim=1, keepdim=True)
        min_val = w.amin(dim=1, keepdim=True)
        max_int = 2 ** n_bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
    else:  # we actually never used this
        assert min_val is None
        max_val = w.abs().amax(dim=1, keepdim=True)
        max_val = max_val.clamp(min=1e-5)
        max_int = 2 ** (n_bit - 1) - 1
        min_int = - 2 ** (n_bit - 1)
        scales = max_val / max_int
        zeros = 0

    assert torch.isnan(scales).sum() == 0
    assert torch.isnan(w).sum() == 0

    if inplace:
        ((w.div_(scales).round_().add_(zeros)).clamp_(
            min_int, max_int).sub_(zeros)).mul_(scales)
    else:
        w = (torch.clamp(torch.round(w / scales) +
                         zeros, min_int, max_int) - zeros) * scales
    assert torch.isnan(w).sum() == 0

    w = w.reshape(org_w_shape)

    if get_scale_zp:
        return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
    else:
        return w



@torch.no_grad()
def real_quantize_model_weight(
    model, w_bit, q_config,
    init_only=False
):
    from .qmodule import WQLinear
    from .pre_quant import get_blocks, get_named_linears, set_op_by_name
    assert q_config["zero_point"], "We only support zero_point quantization now."
    
    layers = get_blocks(model)
    for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")):
        layer = layers[i]
        named_linears = get_named_linears(layer)
        # scale_activations(layer)

        for name, module in named_linears.items():
            if init_only:
                q_linear = WQLinear.from_linear(
                    module, w_bit, q_config['q_group_size'], True)
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            else:
                module.cuda()
                module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
                # scales = scales.t().contiguous()
                # zeros = zeros.t().contiguous()
                q_linear = WQLinear.from_linear(
                    module, w_bit, q_config['q_group_size'], False, scales, zeros)
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
                
    torch.cuda.empty_cache()
    gc.collect()




def pseudo_quantize_n2f3_tensor(w, q_group_size=-1):
    quantizer = SteN2F3Quantizer(q_group_size=q_group_size)
    w = quantizer(w)
    return w


class SteInt3AsymQuantizer(nn.Module):
    def __init__(self, q_group_size=128):
        super().__init__()
        self.q_group_size = q_group_size
        self.bit = 3
    def forward(self, x):
        org_w_shape = x.shape

        if self.q_group_size > 0:
            assert org_w_shape[-1] % self.q_group_size == 0
            x = x.reshape(-1, self.q_group_size)
        elif self.q_group_size == -1:
            assert org_w_shape[-1] % self.q_group_size == 0
            x = x.reshape(-1, x.shape[-1])
        assert x.dim() == 2

        max_val = x.amax(dim=1, keepdim=True)
        min_val = x.amin(dim=1, keepdim=True)
        max_int = 2 ** self.bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(x).sum() == 0

        x = (torch.clamp(Round.apply(x / scales) +
                         zeros, min_int, max_int) - zeros) * scales
        assert torch.isnan(x).sum() == 0

        x = x.reshape(org_w_shape)

        return x

class SteInt2AsymQuantizer(nn.Module):
    def __init__(self, q_group_size=64):
        super().__init__()
        self.q_group_size = q_group_size
        self.bit = 2
    def forward(self, x):
        org_w_shape = x.shape

        if self.q_group_size > 0:
            assert org_w_shape[-1] % self.q_group_size == 0
            x = x.reshape(-1, self.q_group_size)
        assert x.dim() == 2

        max_val = x.amax(dim=1, keepdim=True)
        min_val = x.amin(dim=1, keepdim=True)
        max_int = 2 ** self.bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(x).sum() == 0

        x = (torch.clamp(Round.apply(x / scales) +
                         zeros, min_int, max_int) - zeros) * scales
        assert torch.isnan(x).sum() == 0

        x = x.reshape(org_w_shape)

        return x

class SteN2F3Quantizer(nn.Module):
    def __init__(self, q_group_size=128):
        super().__init__()
        self.q_group_size = q_group_size
    
    def forward(self, x):
        org_w_shape = x.shape

        # reshape to groupsize
        if self.q_group_size > 0:
            assert org_w_shape[-1] % self.q_group_size == 0
            qx = x.reshape(-1, self.q_group_size)
        elif self.q_group_size == -1:
            qx = x.reshape(-1, x.shape[-1])
        assert qx.dim() == 2

        # Get the Min Max
        max_val = qx.amax(dim=1, keepdim=True)
        min_val = qx.amin(dim=1, keepdim=True)

        
        scale_pos = torch.abs(max_val)
        scale_neg = torch.abs(min_val)

        dev = qx.device
        x_pos = torch.zeros_like(qx)
        x_neg = torch.zeros_like(qx)
        x_pos = torch.where(qx >= 0, qx, x_pos)
        x_neg = torch.where(qx < 0, qx, x_neg)
        q_pos = x_pos / scale_pos
        q_neg = x_neg / scale_neg

        q_pos, q_neg = self.round_pass(q_pos, q_neg, dev)

        qx = q_pos * scale_pos + q_neg * scale_neg

        qx = qx.reshape(org_w_shape)

        return qx
    
    def round_n2f3(self, q_pos, q_neg, dev):
        q_pos = torch.where(q_pos >= 0.8114928305149078,                                        torch.tensor(1.0).to(dev), q_pos)
        q_pos = torch.where((q_pos < 0.8114928305149078)    & (q_pos >= 0.5024898052215576),    torch.tensor(0.6229856610298157).to(dev), q_pos)
        q_pos = torch.where((q_pos < 0.5024898052215576)    & (q_pos >= 0.2826657369732857),    torch.tensor(0.3819939494132996).to(dev), q_pos)
        q_pos = torch.where((q_pos < 0.2826657369732857)    & (q_pos >= 0.0916687622666359),    torch.tensor(0.1833375245332718).to(dev), q_pos)
        q_pos = torch.where(q_pos < 0.0916687622666359,                                        torch.tensor(0).to(dev), q_pos)

        q_neg = torch.where(q_neg >= -0.1234657019376755,                                     torch.tensor(0).to(dev), q_neg)
        q_neg = torch.where((q_neg < -0.1234657019376755)   & (q_neg >= -0.39097706973552704),   torch.tensor(-0.2469314038753510).to(dev), q_neg)
        q_neg = torch.where((q_neg < -0.39097706973552704)   & (q_neg >= -0.7675113677978516),   torch.tensor(-0.5350227355957031).to(dev), q_neg)
        q_neg = torch.where(q_neg < -0.7675113677978516,                                        torch.tensor(-1.0).to(dev), q_neg)

        return q_pos, q_neg

    def round_pass(self, q_pos, q_neg, dev):
        y_grad_pos, y_grad_neg = q_pos, q_neg
        y_pos, y_neg = self.round_n2f3(q_pos, q_neg, dev)
        
        return (y_pos - y_grad_pos).detach() + y_grad_pos, (y_neg - y_grad_neg).detach() + y_grad_neg

推荐阅读指数:✭✭✭✭✩

推荐理由

  • 创新性:BitDistiller通过结合QAT和KD,在亚4比特量化领域提供了一种新的解决方案,具有显著的性能提升。
  • 实用性:BitDistiller不仅在理论上具有创新性,而且在实际应用中也显示出了成本效益,这对于资源受限的设备尤为重要。
  • 广泛适用性:BitDistiller在多种语言和推理任务中都展现出了优越的性能,表明其方法的广泛适用性。

后记

如果您对我的博客内容感兴趣,欢迎三连击(点赞、收藏、关注和评论 ),我将持续为您带来计算机人工智能前沿技术(尤其是AI相关的大语言模型,深度学习和计算机视觉相关方向)最新学术论文及工程实践方面的内容分享,助力您更快更准更系统地了解 AI前沿技术

相关推荐
武子康11 分钟前
大数据-207 数据挖掘 机器学习理论 - 多重共线性 矩阵满秩 线性回归算法
大数据·人工智能·算法·决策树·机器学习·矩阵·数据挖掘
Topstip1 小时前
在 Google Chrome 上查找并安装 SearchGPT 扩展
前端·人工智能·chrome·gpt·ai·chatgpt
python_知世1 小时前
AI时代:成为产品经理的核心路径
人工智能·深度学习·程序人生·自然语言处理·产品经理·计算机技术·大模型应用
这个男人是小帅1 小时前
【GCN】 代码详解 (1) 如何运行【pytorch】可运行版本
人工智能·pytorch·python·深度学习·分类
cv-daily1 小时前
优化模型训练过程中的显存使用率、GPU使用率
人工智能
Jason-河山2 小时前
人工智能技术:未来生活的“魔法师”
人工智能·生活
电子手信2 小时前
教育机构如何利用知识中台进行数字教学
大数据·人工智能·自然语言处理·自动化
深度学习实战训练营2 小时前
HyperGAT模型复现微博文本情绪多分类
人工智能·分类·数据挖掘
pzx_0012 小时前
【深度学习】梯度累加和直接用大的batchsize有什么区别
pytorch·深度学习
阿_旭2 小时前
实战| 使用深度学习分割和计算水体和农田面积【Pytorch附源码】
人工智能·pytorch·深度学习·ai·目标分割