本文基于 horizon_plugin_pytorch 量化工具链,详细介绍多输出网络的量化配置策略、HistogramObserver 使用、混合精度设置及校准流程。
1. 模型结构设计
1.1 多输出网络示例
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub
# -------------------------
# Backbone
# -------------------------
class Backbone(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU(inplace=True)
def forward(self, x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
return x
# -------------------------
# Head A: 分类头
# -------------------------
class HeadA(nn.Module):
def __init__(self, in_channels, num_classes=10):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc1 = nn.Linear(in_channels, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# -------------------------
# Head B: 回归头
# -------------------------
class HeadB(nn.Module):
def __init__(self, in_channels, output_dim=4):
super().__init__()
self.flatten = nn.Flatten()
self.fc = nn.Linear(in_channels * 8 * 8, output_dim)
def forward(self, x):
x = F.adaptive_avg_pool2d(x, (8, 8))
x = self.flatten(x)
x = self.fc(x)
return x
# -------------------------
# 总网络(双输出)
# -------------------------
class MyNet(nn.Module):
def __init__(self):
super().__init__()
self.backbone = Backbone()
self.headA = HeadA(in_channels=64, num_classes=10)
self.headB = HeadB(in_channels=64, output_dim=4)
# 量化入口与出口
self.quant = QuantStub()
self.dequant_A = DeQuantStub()
self.dequant_B = DeQuantStub()
def forward(self, x):
x = self.quant(x) # 量化入口
feat = self.backbone(x)
outA = self.headA(feat)
outB = self.headB(feat)
return self.dequant_A(outA), self.dequant_B(outB) # 双输出反量化
1.2 关键设计说明
| 组件 | 作用 |
|---|---|
| QuantStub | 量化入口,将 FP32 输入转换为量化域 |
| DeQuantStub | 量化出口,将量化值还原为 FP32 输出 |
| 多 DeQuantStub | 多输出网络需为每个输出配置独立的反量化节点 |
2. QConfig 配置策略
2.1 Observer 选择建议
推荐使用 HistogramObserver,原因如下:
| 特性 | MinMaxObserver | HistogramObserver |
|---|---|---|
| 统计方式 | 仅记录 min/max | 构建完整直方图 |
| 分布感知 | 否 | 是(完整分布) |
| 多方法支持 | 否 | 是(mse/percentile/kl 等) |
| 离群值处理 | 敏感 | 自动处理 |
核心优势:HistogramObserver 将收集与计算分离,在不改变网络结构/权重/校准数据,一次校准后可通过 reset_scale 切换不同计算方法,无需重新跑校准。
2.2 配置示例:激活 HistogramObserver + 权重 MinMaxObserver
Plain
from horizon_plugin_pytorch.quantization.qconfig import QConfig
from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize
from horizon_plugin_pytorch.quantization.observer_v2 import MinMaxObserver, HistogramObserver
from horizon_plugin_pytorch.quantization import qint8, get_qconfig
## ========= 方法1========
## 可以查看 MinMaxObserver 和 HistogramObserver 的 __init__ 方法了解有哪些可以设置的参数
qconfig = QConfig(
weight=FakeQuantize.with_args(
observer=MinMaxObserver,
averaging_constant=0.01, # 滑动平均系数
dtype=qint8,
qscheme=torch.per_channel_symmetric, # 权重使用per-channel量化,只有 weight 支持 per channel 量化
ch_axis=0, # Conv权重shape为 (out_channels, in_channels, H, W),第0维是输出通道
),
output=FakeQuantize.with_args(
observer=HistogramObserver, # 激活使用HistogramObserver
dtype=qint8,
qscheme=torch.per_tensor_symmetric, # 激活使用per-tensor量化
ch_axis=-1, # per-tensor模式下,值无实际意义,只要为负即可
),
)
## ========= 方法2========
qconfig = get_qconfig(observer=HistogramObserver)
适用场景:常规量化任务,权重分布稳定用 MinMaxObserver 即可,激活分布复杂用 HistogramObserver 精细处理。
注意:根据敏感节点配置 int16 时,无需重新校准!权重 channel min/max 会被记录下来、直方图信息也会被记录下来。
3. QconfigSetter 与 Template 配置
3.1 完整 QconfigSetter 配置
Python
from horizon_plugin_pytorch.quantization import QconfigSetter, qint8, qint16
from horizon_plugin_pytorch.quantization.qconfig_setter import (
ModuleNameTemplate,
ConvDtypeTemplate,
MatmulDtypeTemplate,
SensitivityTemplate,
)
# 敏感节点列表(需要int16的层)
out_weight_int16 = ["backbone.conv1",
"backbone.conv2",]
sensitive_list = []
for name in out_weight_int16:
sensitive_list.append((name, "output"))
sensitive_list.append((name, "weight"))
qconfig_setter = QconfigSetter(
reference_qconfig=qconfig, # 基准QConfig
templates=[
# 1. 默认dtype配置
ModuleNameTemplate({"": qint8}),
# 2. Conv层专用配置
ConvDtypeTemplate(input_dtype=qint8, weight_dtype=qint8),
# 3. Matmul层专用配置
MatmulDtypeTemplate(input_dtypes=qint8),
# 4. 敏感节点配置(int16)
SensitivityTemplate(sensitive_list, 1.0),
],
)
3.2 Template 执行顺序与优先级
Python
优先级:后配置的Template优先级更高
执行顺序:
ModuleNameTemplate -> ConvDtypeTemplate -> MatmulDtypeTemplate -> SensitivityTemplate
↓ ↓ ↓ ↓
全局默认 Conv覆盖 Matmul覆盖 敏感节点覆盖
3.3 SensitivityTemplate 参数说明
Plain
SensitivityTemplate(sensitive_list, ratio)
# ratio参数:
# - 1.0 (float): 表示100%,列表中所有节点都配置为int16
# - 1 (int): 表示top1,仅配置最敏感的1个节点
# - 0.5 (float): 表示50%,配置前50%敏感度的节点
4. 校准流程详解
4.1 标准校准流程
Plain
from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState
from horizon_plugin_pytorch.march import set_march, March
from horizon_plugin_pytorch.quantization.hbdk4 import export
# 1. 设置芯片架构
set_march(March.NASH_M)
# 2. 准备校准模型
calib_net = prepare(model, (input_tensor,), qconfig_setter=qconfig_setter)
calib_net.eval()
# 3. 切换到校准模式
set_fake_quantize(calib_net, FakeQuantState.CALIBRATION)
# 4. 执行校准(统计量化参数)
with torch.no_grad():
calib_net(input_tensor)
# 5. 切换到验证模式
set_fake_quantize(calib_net, FakeQuantState.VALIDATION)
# 6. 导出量化模型
qat_bc = export(calib_net, input_tensor)
4.2 FakeQuantState 状态说明
| 状态 | 作用 | 适用阶段 |
|---|---|---|
| CALIBRATION | 统计激活分布,不进行伪量化 | 校准阶段 |
| VALIDATION | 启用伪量化,模拟量化推理 | 验证/导出阶段 |
| QAT | 启用伪量化,支持梯度回传 | QAT 训练阶段 |
5. HistogramObserver 高级用法
5.1 校准后切换计算方法
HistogramObserver 的核心优势:校准后可切换不同计算方法,无需重新跑校准数据。
重要提示:通过 reset_scale 重新计算 scale 后,scale 会保存在 state_dict 中。如果需要保存量化模型,请在 scale 更新后重新保存 state_dict:
PlainHistogramObserver.reset_scale(calib_net, method="percentile", dtype=qint16) torch.save(calib_net.state_dict(), "calib_model.pth") # 重新保存
Plain
from horizon_plugin_pytorch.quantization.observer_v2 import HistogramObserver
# 校准完成后,切换为mse方法
HistogramObserver.reset_scale(
calib_net,
method="mse",
dtype=qint8,
)
# 或切换为percentile方法(适合长尾分布)
# percentile 越小,截断越激进
HistogramObserver.reset_scale(
calib_net,
method="percentile",
method_kwargs={"percentile": 0.999999},
dtype=qint8,
)
# int16层使用完整范围(不截断)
# int16精度足够高,无需通过截断换取精度
HistogramObserver.reset_scale(
calib_net,
method="percentile",
method_kwargs={"percentile": 1.0}, # percentile越小截断越激进
dtype=qint16,
)
5.2 支持的计算方法
| 方法 | 说明 | 适用场景 |
|---|---|---|
| mse | 最小化量化误差 | 正态分布,默认推荐 |
| percentile | 百分位截断 | 长尾分布、存在离群值 |
| kl | 最小化 KL 散度 | 分布差异敏感场景 |
5.3 方法对比实验
Plain
# 对比不同方法的效果
methods = ['mse', 'percentile', 'kl']
for method in methods:
HistogramObserver.reset_scale(calib_net, method)
# 评估精度
acc = evaluate(calib_net)
print(f"Method: {method}, Accuracy: {acc}")
6. 混合精度配置
重要提示:reset_dtype 会修改 dtype/scale,修改后需要重新保存 state_dict:
Plain# reset_dtype 后重新保存 for name, mod in calib_net.named_modules(): if "headB" in name: ... torch.save(calib_net.state_dict(), "calib_model.pth")
6.1 整个模块配置 int16
Plain
from horizon_plugin_pytorch.quantization import qint16
# 将headB所有层配置为双int16(权重+激活)
for name, mod in calib_net.named_modules():
if "headB" in name:
weight_fake_quant = getattr(mod, "weight_fake_quant", None)
if weight_fake_quant is not None:
weight_fake_quant.reset_dtype(qint16)
activation_post_process = getattr(mod, "activation_post_process", None)
if activation_post_process is not None:
activation_post_process.reset_dtype(qint16)
6.2 单个节点配置 int16
Plain
# 针对敏感节点单独配置int16
calib_net.headA.fc1.weight_fake_quant.reset_dtype(qint16)
calib_net.headA._generated_relu_0.activation_post_process.reset_dtype(qint16)
# _generated_* 层是由prepare自动生成的量化节点
# 如: _generated_relu_0, _generated_adaptive_avg_pool2d_0 等
7. 完整可运行示例
7.1 示例:激活 HistogramObserver + 权重 MinMaxObserver
Plain
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub
from horizon_plugin_pytorch.quantization import (
set_fake_quantize, FakeQuantState, prepare,
QconfigSetter, qint8, qint16
)
from horizon_plugin_pytorch.quantization.qconfig import QConfig
from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize
from horizon_plugin_pytorch.quantization.observer_v2 import MinMaxObserver, HistogramObserver
from horizon_plugin_pytorch.quantization.qconfig_setter import (
ModuleNameTemplate, ConvDtypeTemplate,
MatmulDtypeTemplate, SensitivityTemplate
)
from horizon_plugin_pytorch.march import set_march, March
from horizon_plugin_pytorch.quantization.hbdk4 import export
# ============ 模型定义 ============
class Backbone(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU(inplace=True)
def forward(self, x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
return x
class HeadA(nn.Module):
def __init__(self, in_channels, num_classes=10):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc1 = nn.Linear(in_channels, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
class HeadB(nn.Module):
def __init__(self, in_channels, output_dim=4):
super().__init__()
self.flatten = nn.Flatten()
self.fc = nn.Linear(in_channels * 8 * 8, output_dim)
def forward(self, x):
x = F.adaptive_avg_pool2d(x, (8, 8))
x = self.flatten(x)
x = self.fc(x)
return x
class MyNet(nn.Module):
def __init__(self):
super().__init__()
self.backbone = Backbone()
self.headA = HeadA(in_channels=64, num_classes=10)
self.headB = HeadB(in_channels=64, output_dim=4)
self.quant = QuantStub()
self.dequant_A = DeQuantStub()
self.dequant_B = DeQuantStub()
def forward(self, x):
x = self.quant(x)
feat = self.backbone(x)
outA = self.headA(feat)
outB = self.headB(feat)
return self.dequant_A(outA), self.dequant_B(outB)
# ============ 主流程 ============
if __name__ == "__main__":
# 1. 创建模型和输入
model = MyNet()
input_tensor = torch.randn(1, 3, 32, 32) * 1000
# 2. 设置芯片架构
set_march(March.NASH_M)
# 3. 配置敏感节点(backbone使用int16)
out_weight_int16 = ["backbone.conv1",
"backbone.conv2",]
sensitive_list = []
for name in out_weight_int16:
sensitive_list.append((name, "output"))
sensitive_list.append((name, "weight"))
# 4. 配置QConfig:激活HistogramObserver + 权重MinMaxObserver
qconfig = QConfig(
weight=FakeQuantize.with_args(
observer=MinMaxObserver,
averaging_constant=0.01,
dtype=qint8,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
),
output=FakeQuantize.with_args(
observer=HistogramObserver,
dtype=qint8,
qscheme=torch.per_tensor_symmetric,
ch_axis=-1,
),
)
# 5. 配置QconfigSetter
qconfig_setter = QconfigSetter(
reference_qconfig=qconfig,
templates=[
ModuleNameTemplate({"": qint8}),
ConvDtypeTemplate(input_dtype=qint8, weight_dtype=qint8),
MatmulDtypeTemplate(input_dtypes=qint8),
SensitivityTemplate(sensitive_list, 1.0)
],
)
# 6. 准备校准模型
calib_net = prepare(model, (input_tensor,), qconfig_setter=qconfig_setter)
calib_net.eval()
# 7. 执行校准
set_fake_quantize(calib_net, FakeQuantState.CALIBRATION)
with torch.no_grad():
calib_net(input_tensor)
# 8. 可选:切换HistogramObserver的计算方法(默认是mse)
# 如需切换为percentile方法(适合长尾分布):
# HistogramObserver.reset_scale(
# calib_net,
# method="percentile",
# method_kwargs={"percentile": 0.999999},
# dtype=qint8,
# )
# 9. 混合精度配置:headB全部使用int16
for name, mod in calib_net.named_modules():
if "headB" in name:
if hasattr(mod, "weight_fake_quant") and mod.weight_fake_quant is not None:
mod.weight_fake_quant.reset_dtype(qint16)
if hasattr(mod, "activation_post_process") and mod.activation_post_process is not None:
mod.activation_post_process.reset_dtype(qint16)
# reset_dtype 后重新 reset_scale 以获得最优缩放
HistogramObserver.reset_scale(
calib_net,
method="mse",
dtype=qint16,
)
# 10. 切换验证模式并导出
calib_net.eval()
set_fake_quantize(calib_net, FakeQuantState.VALIDATION)
qat_bc = export(calib_net, input_tensor)
print("示例:量化模型导出成功!")
7.2. 常见问题与解决方案
Q1: AttributeError: 'NoneType' object has no attribute 'reset_dtype'
原因:部分模块(如 ReLU、BatchNorm)的 activation_post_process 为 None。
解决:
Plain
if hasattr(mod, "activation_post_process") and mod.activation_post_process is not None:
mod.activation_post_process.reset_dtype(qint16)