征程 6 | 自定义查表算子实现量化部署

一、简介

模型部署时,常常面临模型精度与执行效率之间的权衡。查表(Look-Up Table, LUT)是一种在推理过程中用空间换时间的技巧,尤其适用于非线性映射函数的加速。本文将介绍如何在地平线平台上通过 SegmentLUT 实现一个自定义查表算子,并对它进行量化与导出。

本文以实现一个 DecInt 模块为例,将输入的 12bit 图像数据(值域为 [0, 4095]),通过自定义查找表进行非线性解码,模拟某种数值映射,然后进行归一化处理,最终输出一个 [-1, 1] 范围的结果图像。

查表算子常用于:实现数据非线性幅度调整、进行颜色空间的转换或者对比度的增强等场景。

二、SegmentLUT 模块

SegmentLUT 是地平线提供的用于查表操作的模块,具备以下特性:

  • 支持将输入索引映射到查找表指定输出;
  • 支持对 LUT 函数进行量化;
  • 可导出为 HBIR,用于 BPU 编译执行。
Plain 复制代码
from horizon_plugin_pytorch.nn import SegmentLUT
class SegmentLUT(Module):
    """
    Simulate any elementwise function by:
        Segment Look Up Table for int16 input.
        Look Up Table for int8 input.

    Args:
        simulated_func (Callable): Simulated function.
        is_centrosymmetric (bool): Whether F(x)=-F(-x). Deprecated.
            Use symmetric_mode instead.
        dividing_points (Optional[Union[List[Real], Tuple[Real]]]):
            Manually set the max input value of each segment.
            Defaults to None.
        input_range (Optional[Union[List[Real], Tuple[Real]]]):
            Manually set the valid input range.
        auto_divide_strategy (str):
            Strategy used to generate dividing points when
            dividing_points is None, only support 'evenly' and 'curvature'.
        inverse_func (Callable): The inverse function of the simulated function
            used to compute the input range in int-infer stage.
            !!!Note: Can only be used in monotonically decreasing function!!!
            Otherwise, the result of int-infer may be unexpected.
            Default to None
        gradients (Optional[Union[List[Real], Tuple[Real]]]):
            Manually set the gradient of linear interval on both side.
        symmetric_mode (str): It can be "ASYM", "YSYM" or "CSYM".
            Defaults to "ASYM".
            "ASYM": No symmetry.
            "YSYM": F(x)=F(-x).
            "CSYM": F(x)=-F(-x).
    """

三、构建查找表函数

定义一个将 [0, 4095] 映射到 [0, 16777215] 的函数。这里使用 numpy.interp 插值得到完整 4096 项的查表数组:

Plain 复制代码
def get_dec_func():
    # 样本点
    # 代表的是位置索引
    x = [0, 938, 1851, 2396, 3251, 4095]
    # 对应的是与x位置相对应的目标值
    y = [0, 3132, 105740, 387380, 3818601, 16777215]
    
    # 插值构造查表表项(共 4096 项)
    lookup_table = np.interp(np.arange(4096), x, y).astype(np.int32)
    to_dec = torch.from_numpy(lookup_table).to(torch.float32)

    def to_dec_func(x):
        indices = x.to(torch.int64)
        dec = torch.take(to_dec.to(x.device), indices)
        return dec
    return to_dec_func

通过这种方式,任意输入 [0, 4095] 的整数值,都可以在查表中获得对应的输出结果,实现近似非线性函数。

在这里,重点解释下

Plain 复制代码
lookup_table = np.interp(np.arange(4096), x, y).astype(np.int32)

借助 np.interp 函数来完成线性插值工作,进而生成查找表

  1. np.arange(4096)会生成从 0 到 4095 的整数序列,这些整数代表了查找表的索引。
  2. np.interp(...)会在已知的点(x, y)之间进行线性插值,从而计算出每个索引所对应的插值结果。
  3. .astype(np.int32)会把插值得到的浮点数结果转换为 32 位整数,这样做是为了符合查找表对整数的使用需求。

四、示例运行

示例很简单,主体包括:

  • 为了方便部署,插入量化、反量化节点
  • 调用 SegmentLUT 查表
  • 数据归一化
Plain 复制代码
class DecInt(nn.Module):
    def __init__(self):
        super(DecInt, self).__init__()
        to_dec_func = get_dec_func()
        self.to_dec = SegmentLUT(to_dec_func, is_centrosymmetric=True)
        self.sub = FloatFunctional()
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, img):
        img = self.quant(img)
        img = self.to_dec(img)

        # 找到每张图像的最大值与最小值
        max_vals = img.view(img.size(0), -1).max(dim=1)[0].view(-1, 1, 1, 1)
        min_vals = img.view(img.size(0), -1).min(dim=1)[0].view(-1, 1, 1, 1)
        # 归一化到 [0, 1]
        diff = (max_vals - min_vals) + 1e-6
        img = img - min_vals
        img = torch.clamp(img / diff, 0, 1.0)
        # 乘以 2,再减去 1,将其变换到 [-1, 1] 区间
        img = img * 2
        img = self.sub.sub(img, 1)
        return self.dequant(img)

全部示例代码:

Plain 复制代码
from horizon_plugin_pytorch.nn import SegmentLUT
from horizon_plugin_pytorch.nn.quantized import FloatFunctional
from horizon_plugin_pytorch.quantization import QuantStub, qint16
from torch.quantization import DeQuantStub
import numpy as np
import torch
import copy
import torch.nn as nn
from horizon_plugin_pytorch.quantization import QConfig, FakeQuantize, observer_v2
from horizon_plugin_pytorch import set_march, March
set_march(March.NASH_M)
from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState, FixedScaleObserver, FakeCast
from horizon_plugin_pytorch.quantization import QuantStub
from horizon_plugin_pytorch.quantization.hbdk4 import export
from horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetter
from horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserver
from horizon_plugin_pytorch.dtype import qint8, qint16
from torch.quantization import DeQuantStub
import torch.nn as nn
from horizon_plugin_pytorch.quantization import hbdk4 as hb4
from hbdk4.compiler import convert, save, hbm_perf, visualize, compile

def get_dec_func():
    # LUT 映射表:输入值 12bit范围
    x = [0, 938, 1851, 2396, 3251, 4095]
    # 对应的输出值
    y = [0, 3132, 105740, 387380, 3818601, 16777215]
    lookup_table = np.interp(np.arange(4096), list(x), list(y)).astype(np.int32)
    print(len(lookup_table))
    exit()
    to_dec = torch.from_numpy(lookup_table).to(torch.float32)
    def to_dec_func(x):
        indices = x.to(torch.int64)
        dec = torch.take(to_dec.to(x.device), indices)
        return dec
    return to_dec_func

class DecInt(nn.Module):
    def __init__(self):
        super(DecInt, self).__init__()
        to_dec_func = get_dec_func()
        self.to_dec = SegmentLUT(to_dec_func, is_centrosymmetric=True)
        self.sub = FloatFunctional()
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, img):
        img = self.quant(img)
        img = self.to_dec(img)

        # 找到每张图像的最大值与最小值
        max_vals = img.view(img.size(0), -1).max(dim=1)[0].view(-1, 1, 1, 1)
        min_vals = img.view(img.size(0), -1).min(dim=1)[0].view(-1, 1, 1, 1)
        # 归一化到 [0, 1]
        diff = (max_vals - min_vals) + 1e-6
        img = img - min_vals
        img = torch.clamp(img / diff, 0, 1.0)
        # 乘以 2,再减去 1,将其变换到 [-1, 1] 区间
        img = img * 2
        img = self.sub.sub(img, 1)
        return self.dequant(img)

# 构造模拟输入数据:batch_size=1,单通道图像 4x4,值范围是0~4095(12bit)
input_data = torch.randint(low=0, high=4096, size=(1, 1, 4, 4), dtype=torch.float32)

# 初始化模型
model = DecInt()
model.eval()

# 执行推理
with torch.no_grad():
    output = model(input_data)

print("输入图像:", input_data)
print("输出:", output)

# A global march indicating the target hardware version must be setted before prepare qat.
set_march(March.NASH_M)

calib_model = prepare(model.eval(), input_data, 
                      qconfig_setter=(
                          calibration_8bit_weight_16bit_act_qconfig_setter,
                          ),
                      )

calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
calib_model(input_data)

calib_model.eval()                            
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
calib_out = calib_model(input_data)
print("calib输出数据:", calib_out)

qat_bc = export(calib_model, input_data)
# save(qat_bc, "qat.bc")
# visualize(qat_bc, "qat.onnx")
hb_quantized_model = convert(qat_bc, March.NASH_M)
# save(hb_quantized_model,"quantized.bc")
# 可视化
visualize(hb_quantized_model, "quantized.onnx")

验证运行在 BPU 上

相关推荐
一株月见草哇1 小时前
Matlab(4)
人工智能·算法·matlab
hans汉斯1 小时前
基于深度学习的苹果品质智能检测算法研究
人工智能·深度学习·算法
火车叨位去19491 小时前
力扣top100(day01-05)--矩阵
算法·leetcode·矩阵
mit6.8241 小时前
[Robotics_py] 机器人运动模型 | `update`函数 | 微积分&矩阵
人工智能·python·算法
火车叨位去19494 小时前
力扣top100(day02-05)--二叉树 02
算法·leetcode·职场和发展
James. 常德 student4 小时前
leetcode-hot-100 (图论)
算法·leetcode·图论
郝学胜-神的一滴4 小时前
C++中的`auto`与`std::any`:功能、区别与选择建议
开发语言·c++·程序人生·算法
坫鹏加油站5 小时前
【Python】支持向量机SVM
python·算法·机器学习·支持向量机
向日葵的小屋6 小时前
csp知识基础——贪心算法
算法·贪心算法