awq_marlin和gptq_marlin量化算法简要介绍

量化算法的本质

量化算法的本质在于快速实现Y=XW计算,其中X,W往往都是FP16数据类型,在大模型推理过程中,输入X往往被称之为activation,而权重W被称之为weight,对于一个M,K,N的矩阵乘法(即X的形状为M,K,W形状为K,N,Y形状为M,N)的矩阵来说,最简单的实现方式就是调用cublas仓库,这里需要重点注意的是X和W的排列方式,尤其是W的排列方式,比如说下面这段代码

cpp 复制代码
cublasGemmEx(
        handle,
        CUBLAS_OP_T, 
        CUBLAS_OP_N, 
        N,           
        M,           
        K,           
        &alpha,//alpha = 1.0f
        W, CUDA_R_16F, ldb,//ldb = K
        X, CUDA_R_16F, lda,//lda = K
        &beta,//beta = 0.0f
        Y, CUDA_R_32F, ldo,//ldo = N
        CUBLAS_COMPUTE_32F,
        CUBLAS_GEMM_DEFAULT);

处理的就是Y=XW的计算过程,但是其中X是行主元数据(即X形状是M,K,步长stride是K,1),而W是列主元数据(即W形状是K,N,步长stride是1,K,这个可以通过W = torch.randn(N,K).t()这种方式得到)。

上面的这个矩阵乘法大家耳熟能详,但是这个矩阵乘法在大模型推理训练过程中会带来一些问题,比如说最直接的就是显存占用情况,假设我们只考虑Y=XW这个计算,如果W是一个数据量为8B(80亿参数)的矩阵,如果W的每个元素都是FP16,那么W需要占用显存(GB)= 参数量×数据比特数/(8×1024×1024×1024)=14.9,也就是说,仅仅考虑存储这个权重W就需要占用14.9GB显存,如果我们能够换一种思路,比如说把权重的数据类型换成INT8,此时显存马上可以降低一半变成7.45GB,如果进一步把权重数据类型变成INT4,那么显存继续降低变成3.73GB,也就是说,对于一个稍微普通的带显卡的笔记本,就可以实现这个推理过程了。

在输入X,也就是activation数据类型为F16的情况下,W数据类型为INT8,此时的量化称之为W8A16量化,如果W数据类型为INT4,此时量化称之为W4A16量化。

比如说W8A8量化,指的就是输入X数据类型为INT8,权重矩阵数据类型也是INT8。

矩阵的量化算法

这里我们先介绍一下矩阵的量化算法,即:一个形状为K,N,数据类型为FP16的权重矩阵W,是如何变成另一个形状为K,N,数据类型为INT4或者INT8的量化矩阵w_packed。下面我们以INT8来举例子说明:

per_tensor_quant_int8

这种量化最简单,先计算出全局的abs最大值global_max = max(abs(W.flatten())),此时引入一个scale = global_max /127,有了scale以后,下面使用这段伪代码

python 复制代码
for i in range(K):
    for j in range(N):
        val = W[i,j] / scale
        val = max(-127, min(127, val))
        w_packed[i,j] = val.to(torch.int8)

通过上面这种方式得到的结果,我们称之为对称量化,这种对称量化方式涉及到的参数有w_packed, scale和W,其中scale是一个长度为1的数据类型为F32的tensor,与之对应的还有一个非对称量化,非对称量化会多一个zero参数,但是实际大模型量化过程中用的最多的就是对称量化。

per_channel_quant_int8

有了上面关于per_tensor_quant_int8量化的介绍,此时理解per_channel_quant_int8量化就简单多了,对称的per_channel_quant_int8量化和上面的区别在于scale的形状变成了K,1,也就是说,原来需要计算整个矩阵abs(W)的全局最大值,现在需要针对每一行abs(Wi,:计算最大值,python代码实现可以参考

python 复制代码
def per_channel_quant_int8_torch(x, symmetric):
    if symmetric:
        x = x.float()
        absmax = x.abs().max(dim=-1).values
        absmax = absmax.clamp_min(1e-10).unsqueeze(-1)
        scale_x = absmax / 127
        x_q = x.mul(127 / absmax)
        x_q = torch.round(x_q).to(torch.int8)

        return x_q, scale_x, None
    else:
        w = x.float()
        w_min = w.min(dim=-1, keepdim=True)[0]
        w_max = w.max(dim=-1, keepdim=True)[0]

        w_scale = (w_max - w_min) / 255.0
        w_scale = torch.clamp(w_scale, min=1e-8)

        w_zero = -w_min / w_scale - 128.0

        w_q = torch.round(w / w_scale + w_zero)

        w_q = torch.clamp(w_q, -128, 127)

        w_packed = w_q.to(torch.int8)

        return w_packed, w_scale, w_zero

与之类似的还有per_channel_quant_fp8,per_tensor_quant_fp8,本质上没有区别,只不过最后量化结果数据类型不一致而已,至于per_channel_quant_int8和per_tensor_quant_int8的CUDA代码实现也非常简单,可以参考添加链接描述添加链接描述

真正的难点在于矩阵乘法。

量化模型的矩阵乘法

在量化大模型推理过程中,一般会提前提供已经量化好的权重以及对应的scale,也就是说,现在需要实现的计算过程拥有下面几个参数:

X:形状为M,K,往往是行主元,即步长stride=K,1,数据类型为FP16或者FP32

W:形状为K,N,如果是行主元,那么步长stride=N,1,如果是列主元,步长stride=1,K,数据类型可能为INT4或者是INT8

scale:当形状为M,1对应的是per_channel_quant,当形状为1,的时候,对应的是per_tensor_quant,数据类型为FP32

zero:这是optional参数,数据类型和形状往往和scale保持一致(但是awq_marlin_gemm,gptq_marlin_gemm这些量化模型可能会很不一样),如果zeros存在,那么对应的就是非对称量化,如果zeros不存在,对应的就是对称量化。

bias:这是一个optional参数,数据类型为FP32,形状往往为N,1

W8A8矩阵乘法

我们以W8A8对称量化算法来举例说明量化矩阵乘法,计算过程涉及的参数就变成了:

X:形状为M,K,往往是行主元,即步长stride=K,1,数据类型为INT8

W:形状为K,N,这里我们考虑列主元,步长stride=1,K,数据类型为INT8

x_scale:形状为M,1,数据类型为FP32

w_scale:形状为N,1,数据类型为FP32

bias:这是一个optional参数,数据类型为FP32,形状为N,1

需要实现的计算大概就是

Y = (x_scale * X)@ (w_scale * W) + bias,

在具体的实现过程中,有两种方案,

方案1:先调用cublas计算y_packed = x_packed@w_packed,由于此时x_packed, w_packed数据类型都是INT8的,使用cublas计算速度会特别快,这个时候相当于说,我们需要在CUDA层面额外引入一份显存来存储这个临时数据y_packed,cublas的计算流程可以参考下面这段代码

cpp 复制代码
    const int32_t alpha_I = 1;
    const int32_t beta_I = 0;

    cublasGemmEx(
        handle,
        CUBLAS_OP_T, 
        CUBLAS_OP_N,
        N,           
        M,           
        K,           
        &alpha_I,
        b, CUDA_R_8I, ldb,//ldb = K
        a, CUDA_R_8I, lda,//lda = K
        &beta_I,
        y_packed, CUDA_R_32I, ldo,//ldo=N
        CUBLAS_COMPUTE_32I,
        CUBLAS_GEMM_DEFAULT);

特别注意,上面的这段代码里面要求x_packed是行主元的形状为M,K的指针,而w_packed是列主元的,形状为K,N的指针,其中x_packed可以通过torch.randn(M,K)直接生成,而w_packed可以通过torch.randn(N,K).t()直接生成。有了y_packed以后,剩下的就是做后处理,根据x_scale和w_scale以及y_packed把结果还原出来,这部分比较简单。

可以看出,这个方案1其实需要实现两个kernel,第一个kernel调用cublas,第二个kernel进行后处理,这个方案非常直接简单,但是在性能上不占优势,CUDA代码里面,我们希望一个算子往往只占用一个kernel,这种做法肯定会比直接使用torch.matmul计算FP16的X@W要慢。

方案2:直接调用cutlass来计算整个过程,这个说起来很简单,但是实现起来非常复杂,这个的原始代码参考添加链接描述

GPTQ MARLIN矩阵乘法

Gptq marlin矩阵乘法计算的也是Y=XW,其中X,W往往都是FP16数据类型,在大模型推理过程中,输入X往往被称之为activation,而权重W被称之为weight,对于一个M,K,N的矩阵乘法(即X的形状为M,K,W形状为K,N,Y形状为M,N)的矩阵来说。这个算法的核心目的包括:

1:把浮点权重(W)量化压缩到 4bit/8bit(体积缩小 4~8 倍);

2:把量化后的权重重排成 MARLIN 专用格式(适配 GPU 硬件执行单元);

3:保证推理速度接近浮点、精度几乎无损。

模块1:将浮点数据类型的W量化得到量化权重w_q, w_s, w_z以及根据(w_q,w_s,w_z)反量化得到的w_ref

首先根据下面的逻辑做数据重排

python 复制代码
# 代码逻辑
w = w.reshape((-1, group_size, size_n))  # [K/group, group, N]
w = w.permute(1, 0, 2)                   # [group, K/group, N]
w = w.reshape((group_size, -1))          # [group, K/group * N]

然后获取对应group的最大值,最小值以及绝对最大值。

python 复制代码
max_val = torch.max(w, 0, keepdim=True).values
min_val = torch.min(w, 0, keepdim=True).values
abs_val = torch.max(abs(max_val), abs(min_val)

如果需要设置零点,那么就计算对应的w_s和w_z,参考

python 复制代码
max_q_val = quant_type.max()
min_q_val = quant_type.min()
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
maybe_w_zp = (
    torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
)

这里提到的quant_type就是对应的量化类型,如果是int8量化,那么对应的quant_type就是int8,此时对应的上下界max_q_val,min_q_val就是127,-127,用数学公式表达就是

如果不需要设置零点,那么只计算w_s,对应的代码参考

python 复制代码
w_s = torch.max(
    abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
    abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
)

用数学公式表达就是

对应的量化权重计算参考

python 复制代码
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
w_q = torch.clamp(w_q, min_q_val, max_q_val)

用数学公式表达就是

最后根据刚才计算的w_q,w_s,w_z得到一个反量化的w_ref=(w_q - w_z)w_s,值得说明的是,此时这个w_ref和最原始的w大概率不等价。

GPTQ 把权重压成 INT4/INT8,但直接存成 K,N 矩阵,GPU 跑不快。这是因为

1:GPU TensorCore(MMA)一次喜欢读 16×16 小块;

2:而且要连续内存、特定顺序才能用向量加载(LDG.128);

3:原生矩阵是 "行主序",不满足硬件读取模式。

Marlin perm 的本质: 把 INT4/INT8 权重重新切成 16×16 瓦片 → 打乱瓦片内部元素顺序 → 拼成 GPU 最喜欢的内存布局。重点:perm 不是随机乱排,是硬编码的、为了 TensorCore 读得快的固定重排。

这个marlin重排的过程很复杂,本人也不太能看懂。

awq_marlin_gemm和gptq_marlin_gemm的主要区别在于awq_marlin不支持zeros,

awq_marlin的实现源代码来自添加链接描述

gptq_marlin的实现源代码来自添加链接描述

本人针对awq,gptq的矩阵乘法做了一个识别简化,可以参考添加链接描述添加链接描述

相关推荐
buhuizhiyuci43 分钟前
【算法篇】动态规划——斐波那契数列模型
算法·动态规划
棱镜研途1 小时前
学习笔记丨模式识别与机器学习5大核心赛道解析(IC-IPPR 2026)
人工智能·神经网络·算法·机器学习·模式识别·学术会议·智能计算
SuperHeroWu71 小时前
【算法】逻辑回归虽然名字中有“回归“,但通常用于二分类任务。如何理解学习?
算法·回归·逻辑回归·二分类任务
gCode Teacher 格码致知1 小时前
Python教学:十六进制编码的显示方法-由Deepseek产生
开发语言·python·算法
05候补工程师1 小时前
【408数据结构】核心考点:图(Graph)精炼笔记与算法直觉
数据结构·经验分享·笔记·考研·算法·图论
靠沿1 小时前
【动态规划算法】专题三——简单多状态dp问题
算法·动态规划
吃好睡好便好1 小时前
矩阵秩的计算
人工智能·学习·线性代数·算法·机器学习·matlab·矩阵
计算机安禾1 小时前
【算法分析与设计】第35篇:后缀数据结构:后缀树与后缀数组的构造
大数据·人工智能·算法·机器学习·剪枝
计算机安禾1 小时前
【算法分析与设计】第38篇:最近点对与分治在几何中的应用
java·服务器·网络·数据库·算法