征程 6EM 常见 QConfig 配置解读与示例

一、引言

在工具链用户手册《量化感知训练(QAT)-开发指南-QConfig 详解》章节专门介绍了在 J6EM 上 qconfig 是怎么回事,从经历看,大家可能会存在看了依旧不懂,或懂了不知道怎么配置的情况,特别是一些 OE 包中示例没有的配置,例如固定某节点 scale、配置 linear weight int16 等操作。

qconfig 控制了模型所有节点的量化类型,例如是采用 int8 还是 int16 量化,是固定校准阶段的 scale 去 qat 还是不固定 scale 去 qat。

提供的模板可分为三类:基础模板、敏感度模板、自定义模板。本文将常见配置通过示例方式进行呈现。

二、基础模板

基础模板中 calibration / qat / qat_fixed_act_scale 区别在于使用的 observer 类型和 scale 更新逻辑,分别用于校准,不固定 activation scaleqat 训练,固定 activation scale qat 训练。

default 模板 ( default_calibration_qconfig_setter / default_qat_qconfig_setter / default_qat_fixed_act_qconfig_setter ) 会做三件事:

  • 首先,将可以设置的高精度输出都设置上,对于不支持高精度的输出将给出提示;
  • 然后,从 grid sample 算子的 grid 输入向前搜索,直到出现第一个 gemm 类算子或者 QuantStub,将中间的所有算子都设置为 int16。根据经验这里的 grid 一般表达范围较宽,int8 有较大可能不满足精度需求;
  • 最后,将其余算子设置为 int8。

int16 模板 ( qat_8bit_weight_16bit_act_qconfig_setter / qat_8bit_weight_16bit_fixed_act_qconfig_setter / calibration_8bit_weight_16bit_act_qconfig_setter ) 会做两件事:

  • 首先,将可以设置的高精度输出都设置上,对于不支持高精度的输出将给出提示;
  • 其次,将其余算子设置为 int16。
Plain 复制代码
from horizon_plugin_pytorch.quantization.qconfig_template import (
    default_calibration_qconfig_setter,
    default_qat_qconfig_setter,
    default_qat_fixed_act_qconfig_setter,
    qat_8bit_weight_16bit_act_qconfig_setter,
    qat_8bit_weight_16bit_fixed_act_qconfig_setter,
    calibration_8bit_weight_16bit_act_qconfig_setter,
)
qat_or_calib_model = prepare(
    float_model,
    example_inputs=example_inputs,  # 用来感知图结构
    qconfig_setter=(

        default_qat_qconfig_setter,    # 根据需要配置setter模板
    ),
)

三、敏感度模板

敏感度模板有三个:

Plain 复制代码
sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter
sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter
sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter

三者的区别和基础模板中三者的区别类似,也是分别用于校准,不固定 activation scale qat 训练,固定 activation scale qat 训练。

敏感度模板的第一个输入是精度 debug 工具产生的敏感度结果,第二个参数可以指定 ratio 或 topk,敏感度模板会根据配置,将量化敏感度最高的 topk 个算子设置为 int16。搭配固定模板,可以实现混合精度调优。

若模型有多个输出,每个输出都会产生一个敏感度表,您可以设置多个敏感度模版。示例如下:

Plain 复制代码
from horizon_plugin_pytorch.quantization.qconfig_template import (
    default_calibration_qconfig_setter,
    sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter,
    sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
    sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter,
)

# 这两个pt文件是通过debug工具得到的
table1 = torch.load("output_0-0_L1_sensitive_ops.pt")
table2 = torch.load("output_0-1_L1_sensitive_ops.pt")

calibration_model = prepare(
    float_model,
    example_inputs=example_input,
    qconfig_setter=(
        sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table1, ratio=0.2),
        sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table2, ratio=0.2),
        default_calibration_qconfig_setter,
    ),
)

四、自定义模板

自定义模板为 ModuleNameQconfigSetter,需要传入模块名和对应自定义的 qconfig,一般用于设置 fixed scale、配置 linear weight int16 等特殊需求,可以和固定模板,敏感度模板搭配使用。示例如下:

Plain 复制代码
from horizon_plugin_pytorch.quantization.qconfig_template import (
    calibration_8bit_weight_16bit_act_qconfig_setter,
    ModuleNameQconfigSetter,
)
from horizon_plugin_pytorch.quantization.qconfig import (
    get_qconfig,
    MSEObserver,
    MinMaxObserver,
    FixedScaleObserver,
    QConfig,
)
from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize

# 手动设置某个算子的输出scale
op_name_output_fix_scale_qconfig = QConfig(
    output=FakeQuantize.with_args(
        observer=FixedScaleObserver,
        dtype=qint16,
        scale=0.0625,
    )
)

# 设置某个算子weight与输出activation的量化类型
# 校准时用MSEObserver,qat时用MinMaxObserver
# 没有weight的算子,配置了weight_dtype也不会起作用
calib_weight_act_both_int16_qconfig = get_qconfig(
    observer=MSEObserver,
    weight_dtype=qint16,
    out_dtype=qint16,
)

calib_weight_act_both_int8_qconfig = get_qconfig(
    observer=MSEObserver,
    weight_dtype=qint8,
    out_dtype=qint8,
)

qat_weight_act_both_int16_qconfig = get_qconfig(
    observer=MinMaxObserver,
    weight_dtype=qint16,
    out_dtype=qint16,
    fix_scale=True,    # 是否固定scale
)

放在一块简单示例如下:

Plain 复制代码
from horizon_plugin_pytorch.quantization.qconfig_template import (
    default_qat_qconfig_setter,
    sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
    ModuleNameQconfigSetter,
)

table = torch.load("output_0-0_dataindex_1_sensitive_ops.pt")

# 自动替换生成的算子只能通过 ModuleNameQconfigSetter 配置自定义 qconfig。
module_name_to_qconfig = {
    "_generated_add_0": op_name_output_fix_scale_qconfig ,
}

qat_model = prepare(
    float_model,
    example_inputs=example_input,
    qconfig_setter=(
        ModuleNameQconfigSetter(module_name_to_qconfig),
        sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2),
        default_qat_qconfig_setter,
    ),
)

五、可运行的示例

  • 将网络中 linear2 的 weight 配置为 int16 量化、输入配置为 int8 量化、输出配置为 int16 量化,其他算子激活使用 int16 量化,weight 使用 int8 量化。
Plain 复制代码
import torch
from horizon_plugin_pytorch import set_march, March
set_march(March.NASH_M)
from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState
from horizon_plugin_pytorch.quantization import QuantStub
from horizon_plugin_pytorch.quantization.hbdk4 import export
from horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetter
from horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserver
from horizon_plugin_pytorch.dtype import qint8, qint16
from torch.quantization import DeQuantStub
import torch.nn as nn


# 定义网络结构
class SmallModel(nn.Module):
    def __init__(self):
        super(SmallModel, self).__init__()
        # 第一个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 256]
        self.linear1 = nn.Linear(256, 256)
        self.layernorm = nn.LayerNorm(256)  # 对最后一维进行归一化
        self.relu = nn.ReLU()
        # 第二个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 60]
        self.linear2 = nn.Linear(256, 60)
        # 第三个 Linear: 输入 [2, 100, 60] -> 输出 [2, 100, 60]
        self.linear3 = nn.Linear(60, 60)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        # 第一个 Linear
        x = self.linear1(x)  # [2, 100, 256]
        x = self.layernorm(x)  # [2, 100, 256]
        x = self.relu(x)  # [2, 100, 256]
        # 第二个 Linear
        x = self.linear2(x)  # [2, 100, 60]
        # 第三个 Linear
        x = self.linear3(x)
        x = self.dequant(x)
        return x

example_input = torch.randn(2, 100, 256)
model = SmallModel()

# 前向传播
output = model(example_input)
print("输出形状:", output.shape)

# A global march indicating the target hardware version must be setted before prepare qat.
set_march(March.NASH_M)

calib_weight_act_both_int16_qconfig = get_qconfig(
    observer=MSEObserver,
    weight_dtype=qint16,
    out_dtype=qint16,
)

# layernorm没有weight,配置了weight_dtype也不会起作用
calib_weight_act_both_int8_qconfig = get_qconfig(
    observer=MSEObserver,
    weight_dtype=qint8,
    out_dtype=qint8,
)

qat_weight_act_both_int16_qconfig = get_qconfig(
    observer=MinMaxObserver,
    weight_dtype=qint16,
    out_dtype=qint16,
    fix_scale=True,
)
# 节点名称,可以从model_check_result.txt中获取,也可以从敏感度文件中获取
module_name_to_qconfig = {
    "layernorm": calib_weight_act_both_int8_qconfig,
    "linear2": calib_weight_act_both_int16_qconfig,   
}

calib_model = prepare(model.eval(), example_input,
                      qconfig_setter=(
                          ModuleNameQconfigSetter(module_name_to_qconfig),
                          calibration_8bit_weight_16bit_act_qconfig_setter,
                          ),
                      )

calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
calib_model(example_input)

calib_model.eval()                            
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
calib_out = calib_model(example_input)

qat_bc = export(calib_model, example_input)
  • 配置 add 单算子输入和输出均使用固定 scale
Plain 复制代码
import torch
from horizon_plugin_pytorch import set_march, March
set_march(March.NASH_E)
from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState
from horizon_plugin_pytorch.quantization import QuantStub
from horizon_plugin_pytorch.quantization.hbdk4 import export
from horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetter
from horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserver, FixedScaleObserver, QConfig
from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize
from horizon_plugin_pytorch.dtype import qint8, qint16
from torch.quantization import DeQuantStub
import torch.nn as nn


class AddNet(nn.Module):
    def __init__(self):
        super(AddNet, self).__init__()
        self.quant_x = QuantStub()
        self.quant_y = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x, y):
        x = self.quant_x(x)
        y = self.quant_y(y)
        z = torch.add(x, y)
        z = self.dequant(z)
        return z

# 创建模型
model = AddNet()

# 生成两个相同形状的输入张量
torch.manual_seed(42)
x = torch.randn(1, 1, 2, 6)
y = torch.randn(1, 2, 2, 6)
example_input = (x,y)

# 前向传播
output = model(example_input[0], example_input[1])
print("float输出数据:", output)
print("输入形状:", example_input[0].shape)
print("输出形状:", output.shape)

# A global march indicating the target hardware version must be setted before prepare qat.
set_march(March.NASH_E)

add_input_fix_scale_qconfig = QConfig(
    output=FakeQuantize.with_args(
        observer=FixedScaleObserver,
        dtype=qint16,
        scale=0.03125,
    )
)
add_output_fix_scale_qconfig = QConfig(
    output=FakeQuantize.with_args(
        observer=FixedScaleObserver,
        dtype=qint16,
        scale=0.0625,
    )
)

# 节点名称,可以从model_check_result.txt中获取,也可以从敏感度文件中获取
module_name_to_qconfig = {
    "quant_x": add_input_fix_scale_qconfig,

    "quant_y": add_input_fix_scale_qconfig,

    "_generated_add_0": add_output_fix_scale_qconfig,
}

calib_model = prepare(model.eval(), example_input,
                      qconfig_setter=(
                          ModuleNameQconfigSetter(module_name_to_qconfig),
                          calibration_8bit_weight_16bit_act_qconfig_setter,
                          ),
                      )

calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
calib_model(example_input[0], example_input[1])

calib_model.eval()                            
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
calib_out = calib_model(example_input[0], example_input[1])
print("calib输出数据:", calib_out)

qat_bc = export(calib_model, example_input)

六、冻结部分网络结构 qat 的配置

补充常见冻结网络结构,去进行 qat 的做法

Plain 复制代码
from horizon_plugin_pytorch.quantization import (
    QuantStub,
    prepare,
    set_fake_quantize,
    FakeQuantState,
)
#prepare QAT模型
qat_model = prepare(
    model,
    example_inputs=xxx,
    qconfig_setter=(
        xxx,
    )
)
#加载calib权重
qat_model.load_state_dict(torch.load("calib-checkpoint.ckpt"))
#QAT训练
qat_model.train()
#固定backbone部分的权重,requires_grad不影响drop bn的行为,需要与eval联合用
for param in qat_model.backbone.parameters():
    param.requires_grad = False
#固定backbone部分的scale,eval只影响drop bn的行为,如果发生了backward仍然会改变权重,需要与requires_grad联合使用
qat_model.backbone.eval()
set_fake_quantize(qat_model.backbone, FakeQuantState.VALIDATION)
#配置head的FakeQuant为QAT状态
set_fake_quantize(qat_model.head, FakeQuantState.QAT)
相关推荐
STY_fish_20121 小时前
手拆STL
java·c++·算法
小纭在努力1 小时前
【算法设计与分析】实验——改写二分搜索算法,众数问题(算法分析:主要算法思路),有重复元素的排列问题,整数因子分解问题(算法实现:过程,分析,小结)
数据结构·python·学习·算法·算法设计与分析·实验报告·实验
芜湖xin2 小时前
【题解-洛谷】B4278 [蓝桥杯青少年组国赛 2023] 简单算术题
算法·
理智的灰太狼2 小时前
题目 3298: 蓝桥杯2024年第十五届决赛真题-兔子集结
算法·职场和发展·蓝桥杯
kingmax542120085 小时前
【洛谷P9303题解】AC- [CCC 2023 J5] CCC Word Hunt
数据结构·c++·算法·广度优先
白熊1886 小时前
【机器学习基础】机器学习入门核心算法:XGBoost 和 LightGBM
人工智能·算法·机器学习
bai_lan_ya6 小时前
数据结构-排序-排序的七种算法(2)
数据结构·算法·排序算法
全域智图7 小时前
元胞自动机(Cellular Automata, CA)
人工智能·算法·机器学习
珂朵莉MM8 小时前
2022 RoboCom 世界机器人开发者大赛-本科组(省赛)解题报告 | 珂学家
人工智能·算法·职场和发展·深度优先·图论
独家回忆3648 小时前
每日算法-250601
数据结构·算法