征程 6 | 工具链 QAT ObserverBase 源码解析

1. 概述

ObserverBase 是 horizon_plugin_pytorch 量化框架中所有 Observer 的抽象基类。它定义了量化校准器的统一接口和核心功能,为各种量化策略(MinMax、MSE、KL 等)提供了基础架构。

2. ABCMeta 深度解析

2.1 Python 元类机制

在 Python 中,​类也是对象 ​,类是由元类(metaclass) 创建的:

默认情况下,所有类都由元类创建。当指定 metaclass=ABCMeta 时,类的创建过程由 ABCMeta 控制。

示例如下:

Plain 复制代码
from abc import ABCMeta, abstractmethod

class ObserverBase(torch.nn.Module, metaclass=ABCMeta):
    
    @abstractmethod
    def forward(self, x):
        pass

2.2 @abstractmethod 装饰器

Plain 复制代码
def abstractmethod(funcobj):
    """标记方法为抽象方法"""
    funcobj.__isabstractmethod__ = True  # 仅设置标志位
    return funcobj

2.3 ObserverBase 中的应用

Plain 复制代码
# 基类定义抽象方法
class ObserverBase(torch.nn.Module, metaclass=ABCMeta):
    @abstractmethod
    def forward(self, x):
        pass
# ObserverBase.__abstractmethods__ = frozenset({'forward'})

# 子类实现
class MinMaxObserver(ObserverBase):
    def forward(self, x_orig):
        return x_orig
# MinMaxObserver.__abstractmethods__ = frozenset() → 可实例化

3. ObserverBase 完整源码

Plain 复制代码
class ObserverBase(torch.nn.Module, metaclass=ABCMeta):
    r"""Base observer Module.

    Any observer implementation should derive from this class.

    Concrete observers should follow the same API. In forward, they will update
    the statistics of the observed Tensor. And they should provide a
    `calculate_qparams` function that computes the quantization parameters
    given the collected statistics.

    Args:
        averaging_constant: Averaging constant for min/max.
        ch_axis: Channel axis.
        dtype: Quantized data type.
        qscheme: Quantization scheme to be used.
        quant_min: Min quantization value. Will follow dtype if unspecified.
        quant_max: Max quantization value. Will follow dtype if unspecified.
        is_sync_quantize: If sync statistics when training with multiple
            devices.
        factory_kwargs: kwargs which are passed to factory functions for
            min_val and max_val.
    """

    _version = 3

    eps: torch.Tensor
    min_val: torch.Tensor
    max_val: torch.Tensor
    is_sync_quantize: Optional[bool] = True

    @typechecked
    def __init__(
        self,
        averaging_constant: float = 0.01,
        ch_axis: int = -1,
        dtype: Union[torch.dtype, QuantDType] = qint8,
        qscheme: torch.qscheme = torch.per_tensor_symmetric,
        quant_min: int = None,
        quant_max: int = None,
        is_sync_quantize: Optional[bool] = None,
        factory_kwargs: Dict = None,
        compute_scale_strategy=ComputeScaleStrategy.STATISTIC,
    ):
        super(ObserverBase, self).__init__()

        if qscheme == torch.per_channel_symmetric:
            assert (
                ch_axis >= 0
            ), "ch_axis should be non-negative when using per_channel_symmetric qcsheme"
        else:
            assert (
                ch_axis < 0
            ), "ch_axis should be negative when using per_tensor_symmetric qcsheme"
        dtype = get_horizon_quant_dtype(dtype)
        assert qscheme in (
            torch.per_tensor_symmetric,
            torch.per_channel_symmetric,
        ), (
            "only support per_tensor_symmetric and per_channel_symmetric "
            "qscheme"
        )

        self.averaging_constant = averaging_constant
        self.ch_axis = ch_axis
        self.dtype = dtype
        self.qscheme = qscheme

        self._set_quant_min_max(self.dtype, quant_min, quant_max)

        if is_sync_quantize is not None:
            self.is_sync_quantize = is_sync_quantize

        self.compute_scale_strategy = compute_scale_strategy

        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
        self.register_buffer(
            "eps",
            torch.tensor([torch.finfo(torch.float32).eps], **factory_kwargs),
        )
        self.register_buffer("min_val", torch.tensor([], **factory_kwargs))
        self.register_buffer("max_val", torch.tensor([], **factory_kwargs))

    def _set_quant_min_max(
        self,
        dtype,
        quant_min=None,
        quant_max=None,
    ):
        if (quant_min is not None) and (quant_max is not None):
            assert quant_min < quant_max, (
                "qmin must be strictly less than qmax for user-specified "
                "quantization range."
            )
            assert (
                quant_min <= 0 <= quant_max
            ), "Used-specified quantization range must include 0."
            assert qinfo(dtype).min <= quant_min, "quant_min out of bound"
            assert quant_max <= qinfo(dtype).max, "quant_max out of bound"
            self.quant_min, self.quant_max = quant_min, quant_max
        else:
            self.quant_min, self.quant_max = (
                qinfo(self.dtype).min,
                qinfo(self.dtype).max,
            )

    def reset_dtype(self, dtype):
        dtype = get_horizon_quant_dtype(dtype)
        if dtype == self.dtype:
            return
        self.dtype = dtype
        self._set_quant_min_max(self.dtype)

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        # buffers has been renamed from min/max_vals to min/max_val
        buffer_name_mapping = {"min_vals": "min_val", "max_vals": "max_val"}
        for old_name in buffer_name_mapping:
            k = prefix + old_name
            if k in state_dict:
                v = state_dict.pop(k)
                state_dict[prefix + buffer_name_mapping[old_name]] = v

        eps_key = prefix + "eps"
        if eps_key not in state_dict:
            # eps was moved to a buffer in version 2
            eps = torch.tensor([torch.finfo(torch.float32).eps])
            state_dict[eps_key] = eps

        local_state = ["min_val", "max_val"]
        for name in local_state:
            key = prefix + name
            if key in state_dict:
                # if ndim=0, make it ndim=1
                state_dict[key] = state_dict[key].reshape(-1)

                val = state_dict[key]

                # Custom handling to allow loading min_val or max_val
                # of size N into uninitialized buffers of size 0. The
                # buffers are resized here, and the values are copied in
                # the default state_dict loading code of the parent.
                if name == "min_val" and hasattr(self, "min_val"):
                    self.min_val.resize_(val.shape)
                elif hasattr(self, "max_val"):
                    self.max_val.resize_(val.shape)

        super(ObserverBase, self)._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

    def _load_from_state_dict_script(
        self,
        state_dict: Union[Dict[str, torch.Tensor], Dict[str, torch.Tensor]],
        prefix: str,
        local_metadata: Dict[str, torch.Tensor],
        strict: bool,
        missing_keys: List[str],
        unexpected_keys: List[str],
        error_msgs: List[str],
    ):
        self._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

    def sync_minmax(self, min_val, max_val):
        if dist.is_initialized() and min_val.is_cuda:
            dist.all_reduce(min_val, op=dist.ReduceOp.MIN)
            dist.all_reduce(max_val, op=dist.ReduceOp.MAX)

    def calculate_qparams(self):
        r"""Calculate the quantization parameters.

        Returns:
            scales: Scales tensor of shape (#channels,)
            zero_points: Zero points tensor of shape (#channels,)
        """
        if self.min_val.numel() == 0 or self.max_val.numel() == 0:
            warnings.warn(
                "Must run observer before calling calculate_qparams. "
                "Returning default scale and zero point. "
                "This is an expected behavior if you use KLObserver "
                "and set 1 < update_interval <= total steps. ",
            )
            return torch.tensor(
                [1.0], device=self.min_val.device
            ), torch.tensor([0], device=self.min_val.device)

        scale = _compute_scale_symmetric(
            self.min_val,
            self.max_val,
            self.quant_min,
            self.quant_max,
            self.eps,
            self.compute_scale_strategy,
        )

        return scale, None

    def repr_msgs(self):
        msges = []
        # only print minmax value for per tensor
        if hasattr(self, "min_val") and self.min_val.numel() == 1:
            msges.append("min_val={}".format(self.min_val.item()))
        if hasattr(self, "max_val") and self.max_val.numel() == 1:
            msges.append("max_val={}".format(self.max_val.item()))
        return msges

    def extra_repr(self):
        return ",".join(self.repr_msgs())

    @abstractmethod
    def forward(self, x):
        pass

    with_args = classmethod(_with_args)

4. 核心属性详解

4.1 量化配置属性

Plain 复制代码
# 基础量化参数
self.averaging_constant: float    # 移动平均系数
self.ch_axis: int                  # 通道轴 (per_channel量化时使用)
self.dtype: QuantDType            # 量化数据类型 (qint8, qint4等)
self.qscheme: torch.qscheme       # 量化方案 (per_tensor/per_channel)
self.quant_min: int               # 量化最小值
self.quant_max: int               # 量化最大值
self.is_sync_quantize: bool       # 多卡同步统计量
self.compute_scale_strategy       # scale计算策略 (STATISTIC/POT/FP16等)

4.2 统计量缓冲区

Plain 复制代码
self.register_buffer("eps", torch.tensor([torch.finfo(torch.float32).eps]))
self.register_buffer("min_val", torch.tensor([]))
self.register_buffer("max_val", torch.tensor([]))

使用 register_buffer 注册的原因:

  • 不参与梯度计算:统计量不是模型参数
  • 随模型迁移设备:model.cuda() 时自动迁移
  • 可保存到 state_dict:校准结果可持久化

5. 核心方法详解

5.1 init - 初始化

参数说明:

参数 默认值 说明
averaging_constant 0.01 移动平均系数,值越大当前 batch 权重越高
ch_axis -1 通道轴,负数表示 per_tensor,非负表示 per_channel
dtype qint8 量化数据类型
qscheme per_tensor_symmetric 量化方案
quant_min/max None 自定义量化范围,None 时根据 dtype 自动设置
is_sync_quantize TRUE 多卡训练时是否同步统计量

关键校验逻辑:

Plain 复制代码
# per_channel 必须指定有效的 ch_axis
if qscheme == torch.per_channel_symmetric:
    assert ch_axis >= 0, "ch_axis should be non-negative"
else:
    assert ch_axis < 0, "ch_axis should be negative for per_tensor"

# 仅支持对称量化
assert qscheme in (
    torch.per_tensor_symmetric,
    torch.per_channel_symmetric,
)

5.2 forward - 更新统计信息(抽象方法)

设计意图:

  • 子类必须实现此方法(由 ABCMeta 强制)
  • 在校准阶段,每个 forward pass 收集激活值的统计信息
  • 返回原始输入(不修改数据流)

典型实现模式:

Plain 复制代码
def forward(self, x_orig):
    # 1. 计算当前 batch 的统计量
    min_val_cur, max_val_cur = compute_statistics(x_orig)
    
    # 2. 多卡同步(可选)
    if self.is_sync_quantize:
        self.sync_minmax(min_val_cur, max_val_cur)
    
    # 3. 更新累计统计量(移动平均)
    self.min_val = update_statistics(self.min_val, min_val_cur)
    self.max_val = update_statistics(self.max_val, max_val_cur)
    
    return x_orig  # 原样返回,不干扰前向传播

5.3 calculate_qparams - 计算量化参数

核心计算逻辑(_compute_scale_symmetric):

Plain 复制代码
def _compute_scale_symmetric(min_val, max_val, quant_min, quant_max, eps, strategy):
    # 对称量化公式:scale = max(|min|, |max|) / (quant_range / 2)
    scale = (
        torch.max(-min_val, max_val)
        .clamp_min(0)
        .div(float(quant_max - quant_min) / 2)
        .clamp_min(eps)
    )
    
    # 可选的 scale 约束策略
    if strategy == ComputeScaleStrategy.KPOT:    # K-POT (可训练POT)
        scale = k_pot_scale(scale)
    elif strategy == ComputeScaleStrategy.POT:   # Power-of-Two
        scale = 2 ** torch.ceil(torch.log2(scale))
    elif strategy == ComputeScaleStrategy.FP16:  # FP16 精度
        scale = _get_fp16_scale(scale)
    
    return scale

5.4 sync_minmax - 多卡同步

Plain 复制代码
def sync_minmax(self, min_val, max_val):
    if dist.is_initialized() and min_val.is_cuda:
        dist.all_reduce(min_val, op=dist.ReduceOp.MIN)
        dist.all_reduce(max_val, op=dist.ReduceOp.MAX)

原理:

  • 使用 all_reduce 聚合多卡的统计量
  • MIN 操作取所有卡的最小值
  • MAX 操作取所有卡的最大值
  • 确保多卡训练时校准结果一致

5.5 _load_from_state_dict - 状态加载

关键功能:

  • 版本兼容(处理旧版名称 min_vals → min_val)
  • 动态调整 buffer 大小
  • 支持从校准模型加载参数到 QAT 模型

6. 类继承体系

Plain 复制代码
ObserverBase (抽象基类)
    │
    ├── MinMaxObserver      # 移动平均 min/max 统计
    │       │
    │       └── ClipObserver # 带截断的 min/max 统计
    │
    ├── FixedScaleObserver   # 固定 scale(不统计)
    │
    ├── PercentileObserver   # 百分位统计
    │
    ├── MSEObserver          # 最小化 MSE 搜索最优 scale
    │
    ├── KLObserver           # KL 散度校准
    │
    ├── MixObserver          # 混合多种方法
    │
    └── HistogramObserver    # 直方图统计(支持多种度量)

7. 设计亮点

  1. 统一接口:所有 Observer 遵循相同的 API,便于替换和扩展
  2. 抽象基类约束:通过 ABCMeta 强制子类实现 forward 方法
  3. 状态持久化:统计量作为 buffer 保存,支持校准结果复用
  4. 分布式支持:内置多卡同步机制
  5. 版本兼容:_load_from_state_dict 处理历史版本兼容
  6. 灵活配置:支持多种量化方案、数据类型、scale 策略

8.与 PyTorch 原生 Observer 的对比

特性 PyTorch ObserverBase Horizon ObserverBase
量化方案 支持非对称量化 仅支持对称量化
scale 约束 POT/FP16/KPOT 策略
分布式同步 需自行实现 内置 sync_minmax
数据类型 标准 torch.dtype 扩展 QuantDType (qint4 等)
版本管理 _version 字段支持迁移
相关推荐
地平线开发者2 小时前
【地平线 征程 6 工具链进阶教程】QAT 训练常见问题和排查
算法
地平线开发者2 小时前
征程 6 | 直方图量化配置与校准实例
算法
地平线开发者2 小时前
征程 6E/M Matrix 开发评板使用系列(一):开箱与点亮
算法·自动驾驶
Jerry3 小时前
LeetCode 59. 螺旋矩阵 II
算法
可编程芯片开发3 小时前
基于FOC控制器的BLDC无刷直流电机控制系统matlab编程与仿真
算法
aaaameliaaa3 小时前
进制练习题【找出只出现一次的数字、交换两个变量(不创建临时变量)、统计二进制中1的个数、打印整数二进制的奇数位和偶数位、求两个数二进制中不同位的个数】
c语言·数据结构·笔记·算法
QiLinkOS5 小时前
第三视觉理解徐玉生与他的商业活动(28)
大数据·c++·人工智能·算法·开源协议
wabs6666 小时前
关于动态规划【力扣1143.最长公共子序列的思考】
算法·leetcode·动态规划