2.6、微调算法

前言

在 PPQ 中我们目前提供两种不同的算法帮助你微调网络

这些算法将使用 calibration dataset 中的数据,对网络权重展开重训练

    1. 经过训练的网络不保证中间结果与原来能够对齐,在进行误差分析时你需要注意这一点
    1. 在训练中使用 with ENABLE_CUDA_KERNEL(): 子句将显著加速训练过程
    1. 训练过程的缓存数据将被贮存在 gpu 上,这可能导致你显存溢出,你可以修改参数将缓存设备改为 cpu

code

from typing import Iterable

import torch
import torchvision

from ppq import (QuantizationSettingFactory, TargetPlatform,
                 graphwise_error_analyse)
from ppq.api import QuantizationSettingFactory, quantize_torch_model
from ppq.api.interface import ENABLE_CUDA_KERNEL
from ppq.executor.torch import TorchExecutor

# ------------------------------------------------------------
# 在 PPQ 中我们目前提供两种不同的算法帮助你微调网络
# 这些算法将使用 calibration dataset 中的数据,对网络权重展开重训练
# 1. 经过训练的网络不保证中间结果与原来能够对齐,在进行误差分析时你需要注意这一点
# 2. 在训练中使用 with ENABLE_CUDA_KERNEL(): 子句将显著加速训练过程
# 3. 训练过程的缓存数据将被贮存在 gpu 上,这可能导致你显存溢出,你可以修改参数将缓存设备改为 cpu
# ------------------------------------------------------------

BATCHSIZE   = 32
INPUT_SHAPE = [BATCHSIZE, 3, 224, 224]
DEVICE      = 'cuda'
PLATFORM    = TargetPlatform.PPL_CUDA_INT8

def load_calibration_dataset() -> Iterable:
    # ------------------------------------------------------------
    # 让我们从创建 calibration 数据开始做起, PPQ 需要你送入 32 ~ 1024 个样本数据作为校准数据集
    # 它们应该尽可能服从真实样本的分布,量化过程如同训练过程一样存在可能的过拟合问题
    # 你应当保证校准数据是经过正确预处理的、有代表性的数据,否则量化将会失败;校准数据不需要标签;数据集不能乱序
    # ------------------------------------------------------------
    return [torch.rand(size=INPUT_SHAPE) for _ in range(32)]
CALIBRATION = load_calibration_dataset()

def collate_fn(batch: torch.Tensor) -> torch.Tensor:
    return batch.to(DEVICE)

# ------------------------------------------------------------
# 我们使用 mobilenet v2 作为一个样例模型
# PPQ 将会使用 torch.onnx.export 函数 把 pytorch 的模型转换为 onnx 模型
# 对于复杂的 pytorch 模型而言,你或许需要自己完成 pytorch 模型到 onnx 的转换过程
# ------------------------------------------------------------
model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True)
model = model.to(DEVICE)

# ------------------------------------------------------------
# PPQ 提供基于 LSQ 的网络微调过程,这是推荐的做法
# 你将使用 Quant Setting 来调用微调过程,并调整微调参数
# ------------------------------------------------------------
QSetting = QuantizationSettingFactory.default_setting()
QSetting.lsq_optimization                            = True
QSetting.lsq_optimization_setting.block_size         = 4
QSetting.lsq_optimization_setting.lr                 = 1e-5
QSetting.lsq_optimization_setting.gamma              = 0
QSetting.lsq_optimization_setting.is_scale_trainable = True
QSetting.lsq_optimization_setting.collecting_device  = 'cuda'

# ------------------------------------------------------------
# 如果你使用 ENABLE_CUDA_KERNEL 方法
# PPQ 将会尝试编译自定义的高性能量化算子,这一过程需要编译环境的支持
# 如果你在编译过程中发生错误,你可以删除此处对于 ENABLE_CUDA_KERNEL 方法的调用
# 这将显著降低 PPQ 的运算速度;但即使你无法编译这些算子,你仍然可以使用 pytorch 的 gpu 算子完成量化
# ------------------------------------------------------------
with ENABLE_CUDA_KERNEL():
    quantized = quantize_torch_model(
        model=model, calib_dataloader=CALIBRATION,
        calib_steps=32, input_shape=INPUT_SHAPE,
        setting=QSetting, collate_fn=collate_fn, platform=PLATFORM,
        onnx_export_file='./model.onnx', device=DEVICE, verbose=0)

    # ------------------------------------------------------------
    # 当我们完成训练后,我们将调用 graphwise_error_analyse 方法分析网络误差
    # 经过训练的中间层误差可能很大,但这不是我们所关心的 ------ 训练方法只优化最终输出的误差
    # 一个量化良好的网络,最后输出层的误差不应大于 10%
    # ------------------------------------------------------------
    graphwise_error_analyse(
        graph=quantized, 
        running_device=DEVICE, 
        dataloader=CALIBRATION,
        collate_fn=collate_fn)

# ------------------------------------------------------------
# 下面我们向你展示另一种 PPQ 中提供的优化方法
# 在 PPQ 0.6.5 之后,我们将这部分扩展性的方法移出了 QuantizationSetting
# 现在,扩展性方法需要手动调用
# ------------------------------------------------------------
model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True)
model = model.to(DEVICE)

QSetting = QuantizationSettingFactory.default_setting()
# ------------------------------------------------------------
# baking_parameter 将会在网络量化之后,将网络中所有参数静态量化
# 参数静态量化将会显著提高 PPQ 的运行速度,但是一旦参数被静态量化,则其将无法被修改
# 也无法参与后续的训练过程
# ------------------------------------------------------------
QSetting.quantize_parameter_setting.baking_parameter = False

with ENABLE_CUDA_KERNEL():
    quantized = quantize_torch_model(
        model=model, calib_dataloader=CALIBRATION,
        calib_steps=32, input_shape=INPUT_SHAPE,
        setting=QSetting, collate_fn=collate_fn, platform=PLATFORM,
        onnx_export_file='./model.onnx', device=DEVICE, verbose=0)

    # ------------------------------------------------------------
    # 让我们手动调用 AdaroundPass 优化过程
    # 这一过程需要训练更多步数,同时你应当注意,训练过程应该放在网络量化过程之后
    # 并且不允许使用 QSetting.quantize_parameter_setting.baking_parameter = True
    # ------------------------------------------------------------
    from ppq.quantization.optim import AdaroundPass, ParameterBakingPass
    executor = TorchExecutor(graph=quantized, device=DEVICE)
    AdaroundPass(steps=5000).optimize(
        graph=quantized, dataloader=CALIBRATION, 
        executor=executor, collate_fn=collate_fn)
    ParameterBakingPass().optimize(
        graph=quantized, dataloader=CALIBRATION, 
        executor=executor, collate_fn=collate_fn)

    graphwise_error_analyse(
        graph=quantized, 
        running_device=DEVICE, 
        dataloader=CALIBRATION,
        collate_fn=collate_fn)
  • PPQ 提供基于 LSQ 的网络微调过程,这是推荐的做法
    将使用 Quant Setting 来调用微调过程,并调整微调参数
  • 另一种 PPQ 中提供的优化方法
    在 PPQ 0.6.5 之后,我们将这部分扩展性的方法移出了 QuantizationSetting
    现在,扩展性方法需要手动调用
    baking_parameter 将会在网络量化之后,将网络中所有参数静态量化
    参数静态量化将会显著提高 PPQ 的运行速度,但是一旦参数被静态量化,则其将无法被修改,也无法参与后续的训练过程
相关推荐
LNTON羚通43 分钟前
摄像机视频分析软件下载LiteAIServer视频智能分析平台玩手机打电话检测算法技术的实现
算法·目标检测·音视频·监控·视频监控
哭泣的眼泪4082 小时前
解析粗糙度仪在工业制造及材料科学和建筑工程领域的重要性
python·算法·django·virtualenv·pygame
Microsoft Word3 小时前
c++基础语法
开发语言·c++·算法
天才在此3 小时前
汽车加油行驶问题-动态规划算法(已在洛谷AC)
算法·动态规划
莫叫石榴姐4 小时前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
茶猫_5 小时前
力扣面试题 - 25 二进制数转字符串
c语言·算法·leetcode·职场和发展
肥猪猪爸7 小时前
使用卡尔曼滤波器估计pybullet中的机器人位置
数据结构·人工智能·python·算法·机器人·卡尔曼滤波·pybullet
readmancynn7 小时前
二分基本实现
数据结构·算法
萝卜兽编程7 小时前
优先级队列
c++·算法
盼海7 小时前
排序算法(四)--快速排序
数据结构·算法·排序算法