征程 6 | 自定义查表算子实现量化部署

一、简介

模型部署时,常常面临模型精度与执行效率之间的权衡。查表(Look-Up Table, LUT)是一种在推理过程中用空间换时间的技巧,尤其适用于非线性映射函数的加速。本文将介绍如何在地平线平台上通过 SegmentLUT 实现一个自定义查表算子,并对它进行量化与导出。

本文以实现一个 DecInt 模块为例,将输入的 12bit 图像数据(值域为 [0, 4095]),通过自定义查找表进行非线性解码,模拟某种数值映射,然后进行归一化处理,最终输出一个 [-1, 1] 范围的结果图像。

查表算子常用于:实现数据非线性幅度调整、进行颜色空间的转换或者对比度的增强等场景。

二、SegmentLUT 模块

SegmentLUT 是地平线提供的用于查表操作的模块,具备以下特性:

  • 支持将输入索引映射到查找表指定输出;
  • 支持对 LUT 函数进行量化;
  • 可导出为 HBIR,用于 BPU 编译执行。
Plain 复制代码
from horizon_plugin_pytorch.nn import SegmentLUT
class SegmentLUT(Module):
    """
    Simulate any elementwise function by:
        Segment Look Up Table for int16 input.
        Look Up Table for int8 input.

    Args:
        simulated_func (Callable): Simulated function.
        is_centrosymmetric (bool): Whether F(x)=-F(-x). Deprecated.
            Use symmetric_mode instead.
        dividing_points (Optional[Union[List[Real], Tuple[Real]]]):
            Manually set the max input value of each segment.
            Defaults to None.
        input_range (Optional[Union[List[Real], Tuple[Real]]]):
            Manually set the valid input range.
        auto_divide_strategy (str):
            Strategy used to generate dividing points when
            dividing_points is None, only support 'evenly' and 'curvature'.
        inverse_func (Callable): The inverse function of the simulated function
            used to compute the input range in int-infer stage.
            !!!Note: Can only be used in monotonically decreasing function!!!
            Otherwise, the result of int-infer may be unexpected.
            Default to None
        gradients (Optional[Union[List[Real], Tuple[Real]]]):
            Manually set the gradient of linear interval on both side.
        symmetric_mode (str): It can be "ASYM", "YSYM" or "CSYM".
            Defaults to "ASYM".
            "ASYM": No symmetry.
            "YSYM": F(x)=F(-x).
            "CSYM": F(x)=-F(-x).
    """

三、构建查找表函数

定义一个将 [0, 4095] 映射到 [0, 16777215] 的函数。这里使用 numpy.interp 插值得到完整 4096 项的查表数组:

Plain 复制代码
def get_dec_func():
    # 样本点
    # 代表的是位置索引
    x = [0, 938, 1851, 2396, 3251, 4095]
    # 对应的是与x位置相对应的目标值
    y = [0, 3132, 105740, 387380, 3818601, 16777215]
    
    # 插值构造查表表项(共 4096 项)
    lookup_table = np.interp(np.arange(4096), x, y).astype(np.int32)
    to_dec = torch.from_numpy(lookup_table).to(torch.float32)

    def to_dec_func(x):
        indices = x.to(torch.int64)
        dec = torch.take(to_dec.to(x.device), indices)
        return dec
    return to_dec_func

通过这种方式,任意输入 [0, 4095] 的整数值,都可以在查表中获得对应的输出结果,实现近似非线性函数。

在这里,重点解释下

Plain 复制代码
lookup_table = np.interp(np.arange(4096), x, y).astype(np.int32)

借助 np.interp 函数来完成线性插值工作,进而生成查找表

  1. np.arange(4096)会生成从 0 到 4095 的整数序列,这些整数代表了查找表的索引。
  2. np.interp(...)会在已知的点(x, y)之间进行线性插值,从而计算出每个索引所对应的插值结果。
  3. .astype(np.int32)会把插值得到的浮点数结果转换为 32 位整数,这样做是为了符合查找表对整数的使用需求。

四、示例运行

示例很简单,主体包括:

  • 为了方便部署,插入量化、反量化节点
  • 调用 SegmentLUT 查表
  • 数据归一化
Plain 复制代码
class DecInt(nn.Module):
    def __init__(self):
        super(DecInt, self).__init__()
        to_dec_func = get_dec_func()
        self.to_dec = SegmentLUT(to_dec_func, is_centrosymmetric=True)
        self.sub = FloatFunctional()
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, img):
        img = self.quant(img)
        img = self.to_dec(img)

        # 找到每张图像的最大值与最小值
        max_vals = img.view(img.size(0), -1).max(dim=1)[0].view(-1, 1, 1, 1)
        min_vals = img.view(img.size(0), -1).min(dim=1)[0].view(-1, 1, 1, 1)
        # 归一化到 [0, 1]
        diff = (max_vals - min_vals) + 1e-6
        img = img - min_vals
        img = torch.clamp(img / diff, 0, 1.0)
        # 乘以 2,再减去 1,将其变换到 [-1, 1] 区间
        img = img * 2
        img = self.sub.sub(img, 1)
        return self.dequant(img)

全部示例代码:

Plain 复制代码
from horizon_plugin_pytorch.nn import SegmentLUT
from horizon_plugin_pytorch.nn.quantized import FloatFunctional
from horizon_plugin_pytorch.quantization import QuantStub, qint16
from torch.quantization import DeQuantStub
import numpy as np
import torch
import copy
import torch.nn as nn
from horizon_plugin_pytorch.quantization import QConfig, FakeQuantize, observer_v2
from horizon_plugin_pytorch import set_march, March
set_march(March.NASH_M)
from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState, FixedScaleObserver, FakeCast
from horizon_plugin_pytorch.quantization import QuantStub
from horizon_plugin_pytorch.quantization.hbdk4 import export
from horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetter
from horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserver
from horizon_plugin_pytorch.dtype import qint8, qint16
from torch.quantization import DeQuantStub
import torch.nn as nn
from horizon_plugin_pytorch.quantization import hbdk4 as hb4
from hbdk4.compiler import convert, save, hbm_perf, visualize, compile

def get_dec_func():
    # LUT 映射表:输入值 12bit范围
    x = [0, 938, 1851, 2396, 3251, 4095]
    # 对应的输出值
    y = [0, 3132, 105740, 387380, 3818601, 16777215]
    lookup_table = np.interp(np.arange(4096), list(x), list(y)).astype(np.int32)
    print(len(lookup_table))
    exit()
    to_dec = torch.from_numpy(lookup_table).to(torch.float32)
    def to_dec_func(x):
        indices = x.to(torch.int64)
        dec = torch.take(to_dec.to(x.device), indices)
        return dec
    return to_dec_func

class DecInt(nn.Module):
    def __init__(self):
        super(DecInt, self).__init__()
        to_dec_func = get_dec_func()
        self.to_dec = SegmentLUT(to_dec_func, is_centrosymmetric=True)
        self.sub = FloatFunctional()
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, img):
        img = self.quant(img)
        img = self.to_dec(img)

        # 找到每张图像的最大值与最小值
        max_vals = img.view(img.size(0), -1).max(dim=1)[0].view(-1, 1, 1, 1)
        min_vals = img.view(img.size(0), -1).min(dim=1)[0].view(-1, 1, 1, 1)
        # 归一化到 [0, 1]
        diff = (max_vals - min_vals) + 1e-6
        img = img - min_vals
        img = torch.clamp(img / diff, 0, 1.0)
        # 乘以 2,再减去 1,将其变换到 [-1, 1] 区间
        img = img * 2
        img = self.sub.sub(img, 1)
        return self.dequant(img)

# 构造模拟输入数据:batch_size=1,单通道图像 4x4,值范围是0~4095(12bit)
input_data = torch.randint(low=0, high=4096, size=(1, 1, 4, 4), dtype=torch.float32)

# 初始化模型
model = DecInt()
model.eval()

# 执行推理
with torch.no_grad():
    output = model(input_data)

print("输入图像:", input_data)
print("输出:", output)

# A global march indicating the target hardware version must be setted before prepare qat.
set_march(March.NASH_M)

calib_model = prepare(model.eval(), input_data, 
                      qconfig_setter=(
                          calibration_8bit_weight_16bit_act_qconfig_setter,
                          ),
                      )

calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
calib_model(input_data)

calib_model.eval()                            
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
calib_out = calib_model(input_data)
print("calib输出数据:", calib_out)

qat_bc = export(calib_model, input_data)
# save(qat_bc, "qat.bc")
# visualize(qat_bc, "qat.onnx")
hb_quantized_model = convert(qat_bc, March.NASH_M)
# save(hb_quantized_model,"quantized.bc")
# 可视化
visualize(hb_quantized_model, "quantized.onnx")

验证运行在 BPU 上

相关推荐
j_xxx404_34 分钟前
数据结构:栈和队列力扣算法题
c语言·数据结构·算法·leetcode·链表
南莺莺38 分钟前
假设一个算术表达式中包含圆括号、方括号和花括号3种类型的括号,编写一个算法来判别,表达式中的括号是否配对,以字符“\0“作为算术表达式的结束符
c语言·数据结构·算法·
THMAIL1 小时前
深度学习从入门到精通 - 神经网络核心原理:从生物神经元到数学模型蜕变
人工智能·python·深度学习·神经网络·算法·机器学习·逻辑回归
野犬寒鸦1 小时前
力扣hot100:旋转图像(48)(详细图解以及核心思路剖析)
java·数据结构·后端·算法·leetcode
墨染点香1 小时前
LeetCode 刷题【61. 旋转链表】
算法·leetcode·职场和发展
一枝小雨1 小时前
【OJ】C++ vector类OJ题
数据结构·c++·算法·leetcode·oj题
Tisfy2 小时前
LeetCode 3516.找到最近的人:计算绝对值大小
数学·算法·leetcode·题解
自信的小螺丝钉2 小时前
Leetcode 206. 反转链表 迭代/递归
算法·leetcode·链表
博大世界2 小时前
解剖智驾“大脑”:一文读懂自动驾驶系统软件架构
人工智能·机器学习·自动驾驶
黑色的山岗在沉睡2 小时前
LeetCode 189. 轮转数组
java·算法·leetcode