征程 6 | 直方图量化配置与校准实例

本文基于 horizon_plugin_pytorch 量化工具链,详细介绍多输出网络的量化配置策略、HistogramObserver 使用、混合精度设置及校准流程。

1. 模型结构设计

1.1 多输出网络示例

Python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub

# -------------------------
# Backbone
# -------------------------
class Backbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        return x

# -------------------------
# Head A: 分类头
# -------------------------
class HeadA(nn.Module):
    def __init__(self, in_channels, num_classes=10):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(in_channels, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# -------------------------
# Head B: 回归头
# -------------------------
class HeadB(nn.Module):
    def __init__(self, in_channels, output_dim=4):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(in_channels * 8 * 8, output_dim)

    def forward(self, x):
        x = F.adaptive_avg_pool2d(x, (8, 8))
        x = self.flatten(x)
        x = self.fc(x)
        return x

# -------------------------
# 总网络(双输出)
# -------------------------
class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = Backbone()
        self.headA = HeadA(in_channels=64, num_classes=10)
        self.headB = HeadB(in_channels=64, output_dim=4)

        # 量化入口与出口
        self.quant = QuantStub()
        self.dequant_A = DeQuantStub()
        self.dequant_B = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)          # 量化入口
        feat = self.backbone(x)
        outA = self.headA(feat)
        outB = self.headB(feat)
        return self.dequant_A(outA), self.dequant_B(outB)  # 双输出反量化

1.2 关键设计说明

组件 作用
QuantStub 量化入口,将 FP32 输入转换为量化域
DeQuantStub 量化出口,将量化值还原为 FP32 输出
多 DeQuantStub 多输出网络需为每个输出配置独立的反量化节点

2. QConfig 配置策略

2.1 Observer 选择建议

推荐使用 HistogramObserver​,原因如下:

特性 MinMaxObserver HistogramObserver
统计方式 仅记录 min/max 构建完整直方图
分布感知 是(完整分布)
多方法支持 是(mse/percentile/kl 等)
离群值处理 敏感 自动处理

核心优势:HistogramObserver 将收集与计算分离,在不改变网络结构/权重/校准数据,一次校准后可通过 reset_scale 切换不同计算方法,无需重新跑校准。

2.2 配置示例:激活 HistogramObserver + 权重 MinMaxObserver

Plain 复制代码
from horizon_plugin_pytorch.quantization.qconfig import QConfig
from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize
from horizon_plugin_pytorch.quantization.observer_v2 import MinMaxObserver, HistogramObserver
from horizon_plugin_pytorch.quantization import qint8, get_qconfig

## ========= 方法1========
## 可以查看 MinMaxObserver 和 HistogramObserver 的 __init__ 方法了解有哪些可以设置的参数
qconfig = QConfig(
    weight=FakeQuantize.with_args(
        observer=MinMaxObserver,
        averaging_constant=0.01,              # 滑动平均系数
        dtype=qint8,
        qscheme=torch.per_channel_symmetric, # 权重使用per-channel量化,只有 weight 支持 per channel 量化
        ch_axis=0,    # Conv权重shape为 (out_channels, in_channels, H, W),第0维是输出通道                       
    ),
    output=FakeQuantize.with_args(
        observer=HistogramObserver,          # 激活使用HistogramObserver
        dtype=qint8,
        qscheme=torch.per_tensor_symmetric,  # 激活使用per-tensor量化
        ch_axis=-1, # per-tensor模式下,值无实际意义,只要为负即可
    ),
)

## ========= 方法2========
qconfig = get_qconfig(observer=HistogramObserver)

适用场景​:常规量化任务,权重分布稳定用 MinMaxObserver 即可,激活分布复杂用 HistogramObserver 精细处理。

注意​:根据敏感节点配置 int16 时,无需重新校准!权重 channel min/max 会被记录下来、直方图信息也会被记录下来。

3. QconfigSetter 与 Template 配置

3.1 完整 QconfigSetter 配置

Python 复制代码
from horizon_plugin_pytorch.quantization import QconfigSetter, qint8, qint16
from horizon_plugin_pytorch.quantization.qconfig_setter import (
    ModuleNameTemplate,
    ConvDtypeTemplate,
    MatmulDtypeTemplate,
    SensitivityTemplate,
)

# 敏感节点列表(需要int16的层)
out_weight_int16 = ["backbone.conv1",
                     "backbone.conv2",]

sensitive_list = []
for name in out_weight_int16:
    sensitive_list.append((name, "output"))
    sensitive_list.append((name, "weight"))

qconfig_setter = QconfigSetter(
    reference_qconfig=qconfig,      # 基准QConfig
    templates=[
        # 1. 默认dtype配置
        ModuleNameTemplate({"": qint8}),

        # 2. Conv层专用配置
        ConvDtypeTemplate(input_dtype=qint8, weight_dtype=qint8),

        # 3. Matmul层专用配置
        MatmulDtypeTemplate(input_dtypes=qint8),

        # 4. 敏感节点配置(int16)
        SensitivityTemplate(sensitive_list, 1.0),
    ],
)

3.2 Template 执行顺序与优先级

Python 复制代码
优先级:后配置的Template优先级更高

执行顺序:
ModuleNameTemplate -> ConvDtypeTemplate -> MatmulDtypeTemplate -> SensitivityTemplate
     ↓                    ↓                     ↓                      ↓
  全局默认            Conv覆盖              Matmul覆盖             敏感节点覆盖

3.3 SensitivityTemplate 参数说明

Plain 复制代码
SensitivityTemplate(sensitive_list, ratio)

# ratio参数:
#   - 1.0 (float): 表示100%,列表中所有节点都配置为int16
#   - 1   (int):   表示top1,仅配置最敏感的1个节点
#   - 0.5 (float): 表示50%,配置前50%敏感度的节点

4. 校准流程详解

4.1 标准校准流程

Plain 复制代码
from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState
from horizon_plugin_pytorch.march import set_march, March
from horizon_plugin_pytorch.quantization.hbdk4 import export

# 1. 设置芯片架构
set_march(March.NASH_M)

# 2. 准备校准模型
calib_net = prepare(model, (input_tensor,), qconfig_setter=qconfig_setter)
calib_net.eval()

# 3. 切换到校准模式
set_fake_quantize(calib_net, FakeQuantState.CALIBRATION)

# 4. 执行校准(统计量化参数)
with torch.no_grad():
    calib_net(input_tensor)

# 5. 切换到验证模式
set_fake_quantize(calib_net, FakeQuantState.VALIDATION)

# 6. 导出量化模型
qat_bc = export(calib_net, input_tensor)

4.2 FakeQuantState 状态说明

状态 作用 适用阶段
CALIBRATION 统计激活分布,不进行伪量化 校准阶段
VALIDATION 启用伪量化,模拟量化推理 验证/导出阶段
QAT 启用伪量化,支持梯度回传 QAT 训练阶段

5. HistogramObserver 高级用法

5.1 校准后切换计算方法

HistogramObserver 的核心优势:​校准后可切换不同计算方法,无需重新跑校准数据​。

重要提示:通过 reset_scale 重新计算 scale 后,scale 会保存在 state_dict 中。如果需要保存量化模型,请在 scale 更新后重新保存 state_dict:

Plain 复制代码
HistogramObserver.reset_scale(calib_net, method="percentile", dtype=qint16)
torch.save(calib_net.state_dict(), "calib_model.pth")  # 重新保存
Plain 复制代码
from horizon_plugin_pytorch.quantization.observer_v2 import HistogramObserver

# 校准完成后,切换为mse方法
HistogramObserver.reset_scale(
    calib_net,
    method="mse",
    dtype=qint8,
)

# 或切换为percentile方法(适合长尾分布)
# percentile 越小,截断越激进
HistogramObserver.reset_scale(
    calib_net,
    method="percentile",
    method_kwargs={"percentile": 0.999999},
    dtype=qint8,
)

# int16层使用完整范围(不截断)
# int16精度足够高,无需通过截断换取精度
HistogramObserver.reset_scale(
    calib_net,
    method="percentile",
    method_kwargs={"percentile": 1.0},  # percentile越小截断越激进
    dtype=qint16,
)

5.2 支持的计算方法

方法 说明 适用场景
mse 最小化量化误差 正态分布,默认推荐
percentile 百分位截断 长尾分布、存在离群值
kl 最小化 KL 散度 分布差异敏感场景

5.3 方法对比实验

Plain 复制代码
# 对比不同方法的效果
methods = ['mse', 'percentile', 'kl']
for method in methods:
    HistogramObserver.reset_scale(calib_net, method)
    # 评估精度
    acc = evaluate(calib_net)
    print(f"Method: {method}, Accuracy: {acc}")

6. 混合精度配置

重要提示:reset_dtype 会修改 dtype/scale,修改后需要重新保存 state_dict:

Plain 复制代码
# reset_dtype 后重新保存
for name, mod in calib_net.named_modules():
    if "headB" in name:
        ...
torch.save(calib_net.state_dict(), "calib_model.pth")

6.1 整个模块配置 int16

Plain 复制代码
from horizon_plugin_pytorch.quantization import qint16

# 将headB所有层配置为双int16(权重+激活)
for name, mod in calib_net.named_modules():
    if "headB" in name:
        weight_fake_quant = getattr(mod, "weight_fake_quant", None)
        if weight_fake_quant is not None:
            weight_fake_quant.reset_dtype(qint16)

        activation_post_process = getattr(mod, "activation_post_process", None)
        if activation_post_process is not None:
            activation_post_process.reset_dtype(qint16)

6.2 单个节点配置 int16

Plain 复制代码
# 针对敏感节点单独配置int16
calib_net.headA.fc1.weight_fake_quant.reset_dtype(qint16)
calib_net.headA._generated_relu_0.activation_post_process.reset_dtype(qint16)
# _generated_* 层是由prepare自动生成的量化节点
# 如: _generated_relu_0, _generated_adaptive_avg_pool2d_0 等

7. 完整可运行示例

7.1 示例:激活 HistogramObserver + 权重 MinMaxObserver

Plain 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub

from horizon_plugin_pytorch.quantization import (
    set_fake_quantize, FakeQuantState, prepare,
    QconfigSetter, qint8, qint16
)
from horizon_plugin_pytorch.quantization.qconfig import QConfig
from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize
from horizon_plugin_pytorch.quantization.observer_v2 import MinMaxObserver, HistogramObserver
from horizon_plugin_pytorch.quantization.qconfig_setter import (
    ModuleNameTemplate, ConvDtypeTemplate,
    MatmulDtypeTemplate, SensitivityTemplate
)
from horizon_plugin_pytorch.march import set_march, March
from horizon_plugin_pytorch.quantization.hbdk4 import export


# ============ 模型定义 ============
class Backbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        return x


class HeadA(nn.Module):
    def __init__(self, in_channels, num_classes=10):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(in_channels, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class HeadB(nn.Module):
    def __init__(self, in_channels, output_dim=4):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(in_channels * 8 * 8, output_dim)

    def forward(self, x):
        x = F.adaptive_avg_pool2d(x, (8, 8))
        x = self.flatten(x)
        x = self.fc(x)
        return x


class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = Backbone()
        self.headA = HeadA(in_channels=64, num_classes=10)
        self.headB = HeadB(in_channels=64, output_dim=4)
        self.quant = QuantStub()
        self.dequant_A = DeQuantStub()
        self.dequant_B = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        feat = self.backbone(x)
        outA = self.headA(feat)
        outB = self.headB(feat)
        return self.dequant_A(outA), self.dequant_B(outB)


# ============ 主流程 ============
if __name__ == "__main__":
    # 1. 创建模型和输入
    model = MyNet()
    input_tensor = torch.randn(1, 3, 32, 32) * 1000

    # 2. 设置芯片架构
    set_march(March.NASH_M)

    # 3. 配置敏感节点(backbone使用int16)
    out_weight_int16 = ["backbone.conv1",
                        "backbone.conv2",]

    sensitive_list = []
    for name in out_weight_int16:
        sensitive_list.append((name, "output"))
        sensitive_list.append((name, "weight"))

    # 4. 配置QConfig:激活HistogramObserver + 权重MinMaxObserver
    qconfig = QConfig(
        weight=FakeQuantize.with_args(
            observer=MinMaxObserver,
            averaging_constant=0.01,
            dtype=qint8,
            qscheme=torch.per_channel_symmetric,
            ch_axis=0,
        ),
        output=FakeQuantize.with_args(
            observer=HistogramObserver,
            dtype=qint8,
            qscheme=torch.per_tensor_symmetric,
            ch_axis=-1,
        ),
    )

    # 5. 配置QconfigSetter
    qconfig_setter = QconfigSetter(
        reference_qconfig=qconfig,
        templates=[
            ModuleNameTemplate({"": qint8}),
            ConvDtypeTemplate(input_dtype=qint8, weight_dtype=qint8),
            MatmulDtypeTemplate(input_dtypes=qint8),
            SensitivityTemplate(sensitive_list, 1.0)
        ],
    )

    # 6. 准备校准模型
    calib_net = prepare(model, (input_tensor,), qconfig_setter=qconfig_setter)
    calib_net.eval()

    # 7. 执行校准
    set_fake_quantize(calib_net, FakeQuantState.CALIBRATION)
    with torch.no_grad():
        calib_net(input_tensor)

    # 8. 可选:切换HistogramObserver的计算方法(默认是mse)
    # 如需切换为percentile方法(适合长尾分布):
    # HistogramObserver.reset_scale(
    #     calib_net,
    #     method="percentile",
    #     method_kwargs={"percentile": 0.999999},
    #     dtype=qint8,
    # )

    # 9. 混合精度配置:headB全部使用int16
    for name, mod in calib_net.named_modules():
        if "headB" in name:
            if hasattr(mod, "weight_fake_quant") and mod.weight_fake_quant is not None:
                mod.weight_fake_quant.reset_dtype(qint16)
            if hasattr(mod, "activation_post_process") and mod.activation_post_process is not None:
                mod.activation_post_process.reset_dtype(qint16)
    # reset_dtype 后重新 reset_scale 以获得最优缩放
    HistogramObserver.reset_scale(
        calib_net,
        method="mse",
        dtype=qint16,
    )

    # 10. 切换验证模式并导出
    calib_net.eval()
    set_fake_quantize(calib_net, FakeQuantState.VALIDATION)
    qat_bc = export(calib_net, input_tensor)

    print("示例:量化模型导出成功!")

7.2. 常见问题与解决方案

Q1: AttributeError: 'NoneType' object has no attribute 'reset_dtype'

原因​​:部分模块(如 ReLU、BatchNorm)的 activation_post_process 为 None。

解决​:

Plain 复制代码
if hasattr(mod, "activation_post_process") and mod.activation_post_process is not None:
    mod.activation_post_process.reset_dtype(qint16)

Q2: 敏感

相关推荐
地平线开发者2 小时前
征程 6E/M Matrix 开发评板使用系列(一):开箱与点亮
算法·自动驾驶
Jerry2 小时前
LeetCode 59. 螺旋矩阵 II
算法
可编程芯片开发2 小时前
基于FOC控制器的BLDC无刷直流电机控制系统matlab编程与仿真
算法
aaaameliaaa3 小时前
进制练习题【找出只出现一次的数字、交换两个变量(不创建临时变量)、统计二进制中1的个数、打印整数二进制的奇数位和偶数位、求两个数二进制中不同位的个数】
c语言·数据结构·笔记·算法
QiLinkOS4 小时前
第三视觉理解徐玉生与他的商业活动(28)
大数据·c++·人工智能·算法·开源协议
wabs6665 小时前
关于动态规划【力扣1143.最长公共子序列的思考】
算法·leetcode·动态规划
剑挑星河月5 小时前
54.螺旋矩阵
java·算法·leetcode·矩阵
Robot_Nav6 小时前
MPPI 局部规划器实验设计讲解
人工智能·算法·mppi
mingo_敏6 小时前
Mean-Teacher 均值教师自训练框架详解
算法·均值算法