征程 6 工具链 BEVPoolV2 算子使用教程 【2】-BEVPoolV2 QAT 链路实现示例

1.引言

在上一篇帖子中,我们已详尽阐述了 BEVPoolV2 相较于 BEVPoolV1 的改进之处,并对 BEVPoolV2 实现的代码进行了解析。想必大家对 BEVPoolV2 算子的功能及实现已有了一定程度的理解,此篇帖子将展示 征程 6 工具链 BEVPoolV2 单算子 QAT 链路的实现范例,以进一步增进用户对 BEVPoolV2 算子使用的认知。

2.QAT 代码实现

征程 6 工具链对齐 mmdet3d 的实现,目前已经支持了 BEVPoolV2 算子,QAT 链路中的核心函数如下:

Plain 复制代码
def bev_pool_v2(
    depth: Tensor,
    feat: Tensor,
    ranks_depth: Tensor,
    ranks_feat: Tensor,
    ranks_bev: Tensor,
    interval_starts: Tensor,
    interval_lengths: Tensor,
    bev_feat_shape,
):
    """BEVPoolv2 implementation for Lift-Splat-Shoot view transformation.
    This impl is same as following 
except
 the layout of inout feature:
    https://github.com/HuangJunJie2017/BEVDet/blob/dev3.0/mmdet3d/ops/bev_pool_v2/bev_pool.py
    Args:
        depth (Tensor[b, n, d, h, w]): Input depth.
        feat (Tensor[b, n, c, h, w]): Input features.
        ranks_depth (Tensor[n_points]): Depth index of points.
        ranks_feat (Tensor[n_points]): Feat index of points.
        ranks_bev (Tensor[n_points]): Output index of points.
        interval_starts (Tensor[n_pillars]): Starting position in ranks_xxx for each pooled point.  # noqa: E501
        interval_lengths (Tensor[n_pillars]): How many points in each pooled point.  # noqa: E501
        bev_feat_shape: Output shape in [b, z_out, h_out, w_out, c] or
            [z_out, h_out, w_out] or [h_out, w_out] format.
            When z_out is not given, its value will be 1 by default.
    Returns:
        Tensor[b, c, z_out, h_out, w_out]: Output features.
    """
    if len(bev_feat_shape) not in (2, 3, 5):
        raise ValueError("Illegal bev_feat_shape length")
    if len(bev_feat_shape) < 5:
        bev_feat_shape = tuple(bev_feat_shape)
        if len(bev_feat_shape) == 2:
            bev_feat_shape = (1,) + bev_feat_shape
        b = feat.size(0)
        c = feat.size(2)
        bev_feat_shape = (b,) + tuple(bev_feat_shape) + (c,)
    if has_torch_function((depth, feat)):
        return handle_torch_function(
            bev_pool_v2,
            (depth, feat),
            depth,
            feat,
            ranks_depth,
            ranks_feat,
            ranks_bev,
            interval_starts,
            interval_lengths,
            bev_feat_shape, )
    x = torch.ops.horizon.bev_pool_v2(
        depth,
        feat,
        ranks_depth,
        ranks_feat,
        ranks_bev,
        interval_starts,
        interval_lengths,
        bev_feat_shape,
    )
    return x

docker 中代码路径: /usr/local/lib/python3.10/dist-packages/horizon_plugin_pytorch/nn/bev_pool_v2.py

详细说明 BEVPoolV2 算子在整个 QAT 链路使用流程

下面我们将以一个简单的单算子示例来详细说明 BEVPoolV2 算子在整个 QAT 链路使用流程。

首先,我们需要了解 QAT 链路的基本概念和工作原理,读者可以自行去学习 征程 6 工具链用户手册的快速上手章节。接下来,我们将详细介绍 BEVPoolV2 算子在 QAT 链路中的使用流程,涉及模型搭建、QAT 模型改造、模型导出与编译等。

本示例只为演示流程,未涉及到浮点训练和量化训练等流程。

3.输入准备

在进行 演示 QAT 链路之前,我们首先进行输入数据构建,这里要特别注意的是, **BEVPoolV2 算子的性能和输入索引强相关,建议构建模型的时候使用真实输入。**后面会结合代码进行说明。

4.示例代码

本示例代码基本遵循以下图中的 QAT 链路流程:

Plain 复制代码
import copy
import torch
import torch.nn as nn
import numpy as np
from horizon_plugin_pytorch.nn.bev_pool_v2 import BevPoolV2
from horizon_plugin_pytorch.quantization.hbdk4 import export
from torch.quantization import DeQuantStub
from horizon_plugin_pytorch.quantization import (
    QuantStub,
    set_fake_quantize,
    FakeQuantState,
)
from horizon_plugin_pytorch.quantization.qconfig_template import default_calibration_qconfig_setter
from horizon_plugin_pytorch.quantization.prepare import prepare, PrepareMethod
from horizon_plugin_pytorch.march import March, set_march
from hbdk4.compiler import convert, compile, save
def load_input(b, d, h_out, w_out, c):
    #load 真实输入
    #b:batch
    #d:depth数
    #h_out, w_out:输出特征图大小
    #c:通道数
    depth = torch.Tensor(np.load("real_inputs/depth.npy"))
    feat = torch.Tensor(np.load("real_inputs/feat.npy"))
    ranks_depth = torch.Tensor(np.load("real_inputs/new_ranks_depth.npy")).type(torch.int32) #
    ranks_feat = torch.Tensor(np.load("real_inputs/new_ranks_feat.npy")).type(torch.int32)
    ranks_bev = torch.Tensor(np.load("real_inputs/new_ranks_bev.npy")).type(torch.int32)
    interval_starts = torch.Tensor(np.load("real_inputs/new_interval_starts.npy")).type(torch.int32)
    interval_lengths = torch.Tensor(np.load("real_inputs/new_interval_lengths.npy")).type(torch.int32)
    bev_feat_shape = (b, d, h_out, w_out, c)
    return depth, feat, ranks_depth, ranks_feat, ranks_bev, interval_starts, interval_lengths, bev_feat_shape
#step1;构建复现浮点模型
class SimpleBEVModel(nn.Module):
    def 
__init__
(self,bev_feat_shape):
        super(SimpleBEVModel, self).
__init__
()
        self.bev_feat_shape = bev_feat_shape
        self.bev_pool = BevPoolV2(self.bev_feat_shape)
        self.quant1 = QuantStub()
        self.quant2 = QuantStub()
        self.dequant = DeQuantStub()
        _, _, self.ranks_depth, self.ranks_feat, self.ranks_bev, self.interval_starts, self.interval_lengths, _ = load_input(1, 1,640, 128, 64)
    def forward(self, data):
        depth = data["depth"]
        feat = data["feat"]
        #step2:改造模型
        #在输入/输出分别插入QuantStub和DeQuantStub
        depth = self.quant1(depth)
        feat = self.quant2(feat)
        #调用BevPoolV2算子
        bev_feat = self.bev_pool(depth, feat, self.ranks_depth, self.ranks_feat, self.ranks_bev, self.interval_starts, self.interval_lengths)
        print("output shape:",bev_feat.shape)
        bev_feat = self.dequant(bev_feat)
        return bev_feat
if 
name
 == '
__main__
':

    b, d, h_out, w_out, c=1,1,640, 128, 64
    depth, feat, ranks_depth, ranks_feat, ranks_bev, interval_starts, interval_lengths, bev_feat_shape = load_input(
        b, d, h_out, w_out, c
    )
    print(f"Depth shape: {depth.shape} {depth.dtype}")
    print(f"Feat shape: {feat.shape} {feat.dtype}")
    print(f"Ranks depth shape: {ranks_depth.shape} {ranks_depth.dtype}")
    print(f"Ranks feat shape: {ranks_feat.shape} {ranks_feat.dtype}")
    print(f"Ranks bev shape: {ranks_bev.shape} {ranks_bev.dtype}")
    print(f"Interval starts shape: {interval_starts.shape} {interval_starts.dtype}")
    print(f"Interval lengths shape: {interval_lengths.shape} {interval_lengths.dtype}")
    print(f"BEV feat shape: {bev_feat_shape}")

    model = SimpleBEVModel(bev_feat_shape)
    example_inputs = dict(
        depth=depth,
        feat=feat,
    )
    import logging
    logging.basicConfig(filename='error.log', level=logging.ERROR)
    try:
        res_float = model(example_inputs)
        pass
    except Exception as e:
        logging.error("An error occurred: %s", e, exc_info=True)
    #配置march
    set_march(March.NASH_M)
    #step3:将浮点模型 prepare为伪量化模型
    calib_model = prepare(
        copy.deepcopy(model),
        example_inputs=(example_inputs,),
        qconfig_setter=(
            default_calibration_qconfig_setter,
        ),
        method=PrepareMethod.JIT_STRIP,
    )
    calib_model.eval()
    set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
    res_calib = calib_model(example_inputs)
    #step4:export出 qat.bc
    qat_hbir = export(
        calib_model,
        example_inputs,
        name="bevpool",
    )
    save(qat_hbir,"bevpoolv2_qat.bc")
    #step5:将qat.bc convert为 quantized.bc
    quanti_hbir = convert(qat_hbir, "nash-e")
    save(quanti_hbir, "bevpoolv2_quantized.bc")
    compile(
        quanti_hbir,
        path="bevpoolv2.hbm",
        march='nash-e',
        opt=2,
        jobs=64,
        balance=100,
        progress_bar=True,
    )

运行此示例代码后,目录下会有 3 个文件生成:

  • bevpoolv2_qat.bc: 单算子伪量化 bc

  • bevpoolv2_quantized.bc:单算子定点 bc

  • bevpoolv2.hbm:上板部署的 hbm

5.模型可视化

获取以上模型后,可视化查看输入输出属性是否符合预期。

可视化方式可以使用 hb_model_info 命令行工具或者 visualize 接口来可视化 bc/hbm 模型。

bevpoolv2_qat.bc可视化:

bevpoolv2_quantized.bc可视化:

相关推荐
紫雾凌寒2 小时前
计算机视觉应用|自动驾驶的感知革命:多传感器融合架构的技术演进与落地实践
人工智能·机器学习·计算机视觉·架构·自动驾驶·多传感器融合·waymo
安忘2 小时前
LeetCode 热题 -189. 轮转数组
算法·leetcode·职场和发展
Y1nhl2 小时前
力扣hot100_二叉树(4)_python版本
开发语言·pytorch·python·算法·leetcode·机器学习
曼诺尔雷迪亚兹3 小时前
2025年四川烟草工业计算机岗位备考详细内容
数据结构·数据库·计算机网络·算法
蜡笔小新..4 小时前
某些网站访问很卡 or 力扣网站经常进不去(2025/3/10)
算法·leetcode·职场和发展
IT猿手4 小时前
2025最新群智能优化算法:基于RRT的优化器(RRT-based Optimizer,RRTO)求解23个经典函数测试集,MATLAB
开发语言·人工智能·算法·机器学习·matlab
刘大猫264 小时前
五、MyBatis的增删改查模板(参数形式包括:String、对象、集合、数组、Map)
人工智能·算法·智能合约
修己xj5 小时前
算法系列之深度/广度优先搜索解决水桶分水的最优解及全部解
算法
_GR5 小时前
2019年蓝桥杯第十届C&C++大学B组真题及代码
c语言·数据结构·c++·算法·蓝桥杯
დ旧言~5 小时前
贪心算法三
算法·leetcode·贪心算法·动态规划·推荐算法