模型优化之剪枝

文章目录

什么是神经网络剪枝

神经网络剪枝

  • 在训练期间删除连接
  • 密集张量将变得稀疏(用零填充)
  • 可以通过结构化块( n m nm nm)或( 11 11 11)删除连接

剪枝的好处

  • 减少过拟合
  • 稀疏性优势
  • 文件中有大量的0,如果有适当的稀疏张量表示方法,模型二进制文件尺寸减小。
  • 模型更小,可以减少内存带宽消耗量。
  • 对于特定模式的稀疏模型,可以开发优化算子,实现加速推理。

不同粒度的剪枝

什么时候做剪枝?

one-shot pruning : 一次性修剪,包括三个步骤训练模型、剪枝、再训练

剪枝:通常根据某种标准(如权重的大小、梯度的大小等)一次性去除大量权重。

再训练:剪枝后,模型通常需要进行一定数量的额外训练(称为fine-tuning或再训练)来恢复剪枝过程中可能损失的性能。

iterative pruning: 迭代式训练,特点如下:

初始训练:首先,对未剪枝的完整模型进行训练,直到达到满意的性能水平。

剪枝:然后,根据某种剪枝策略(例如基于权重的大小或敏感度)剪除模型的部分组件(如权重、神经元或通道)。

再训练:剪枝后,重新训练模型以恢复因剪枝而丢失的性能。

迭代:重复剪枝和再训练的过程,直到达到所需的剪枝率或性能标准。

automated gradual pruning: 自动化渐进剪枝,特点如下:

剪枝策略:采用一种预定义的剪枝策略,例如基于权重阈值、敏感度分析等,该策略在整个剪枝过程中保持一致。

渐进剪枝:在整个训练过程中逐渐增加剪枝率,通常从较低的剪枝率开始,逐步增加到目标剪枝率。

无需再训练:在整个剪枝过程中,模型持续被训练,而不是在剪枝后重新训练。

自动化:整个过程高度自动化,可以减少人为干预的需求

剪枝的分类

结构化剪枝(Structured Pruning)和非结构化剪枝(Unstructured Pruning)是两种常见的神经网络剪枝方法,它们的主要区别在于剪枝后网络结构的变化以及剪枝操作的粒度。

非结构化剪枝

不改变网络结构或者参数数量,把连接上的参数置0即为剪枝。

基于某种度量(如权重的绝对值大小)对所有权重进行排序,然后根据预先设定的剪枝比例(例如去除50%的最小权重)来决定哪些权重被设置为零。这种剪枝方法不会考虑权重在模型中的位置或结构,只关注权重本身的价值。示例代码:

python 复制代码
# 导入剪枝函数
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# 计算两轮之后完成剪枝时对应的迭代次数end_step
batch_size = 128
epochs = 2
validation_split = 0.1  # 10% of training set will be used for validation set.

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# 定义剪枝模型参数,开始模型从50%稀疏度(权重为0的参数数量百分比),到80%稀疏度
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                             final_sparsity=0.80,
                                                             begin_step=0,
                                                             end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

# 当使用函数`prune_low_magnitude`包装了一下模型后,需要重新编译一下
model_for_pruning.compile(optimizer='adam',
                          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                          metrics=['accuracy'])

model_for_pruning.summary()

logdir = "./logs/mnist_pruning"

callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(train_images, train_labels,
                      batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                      callbacks=callbacks)
# --------------------------------------------------
# 评估模型,对比剪枝前后模型的准确率变化
# 经过剪枝,这里有一个小的准确率下降,和没有进行剪枝相比的话
# --------------------------------------------------

_, model_for_pruning_accuracy = model_for_pruning.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Pruned test accuracy:', model_for_pruning_accuracy)

结构化剪枝

结构化剪枝改变了网络结构,即网络层输出元素个数,比如卷积核的减少会影响特征图数量。

在下面的例子中是基于选择的模型层做剪枝,所以需要指出哪些层去做结构化剪枝。比如剪枝第二个卷积层和第一个全连接层,剪枝策略为pruning_params_2_by_4,表示该层剪枝比例为2 / 4,即该层保留一半(2/4)的权重,而将另一半设为零。

注意:第一个卷积层不能被结构化剪枝。要是结构化剪枝的话,应该至少大于一个input channels(本例所用图片为单通道灰度图),所以我们对第一个卷积层使用随机剪枝。

python 复制代码
model = keras.Sequential([
    prune_low_magnitude(
        keras.layers.Conv2D(
            32, 5, padding='same', activation='relu',
            input_shape=(28, 28, 1),
            name="pruning_sparsity_0_5"),
        **pruning_params_sparsity_0_5),
    keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),
    prune_low_magnitude(
        keras.layers.Conv2D(
            64, 5, padding='same',
            name="structural_pruning"),
        **pruning_params_2_by_4),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),
    keras.layers.Flatten(),
    prune_low_magnitude(
        keras.layers.Dense(
            1024, activation='relu',
            name="structural_pruning_dense"),
        **pruning_params_2_by_4),
    keras.layers.Dropout(0.4),
    keras.layers.Dense(10)
])

哪些层的参数更容易被剪掉

因为卷积层(conv)中的参数相比全连接层(fc)来说参数量少,所以卷积层参数的压缩比没有全连接层参数的压缩比大。换句话说,就是卷积层参数更加敏感,剪掉对准确率影响相对更大。越靠后的卷积层或卷积层之后的那些全连接层往往参数越容易被剪掉。

剪枝效果

  • 一般50%-70%左右的稀疏性,准确率降低幅度并不大
  • 剪枝是独立于量化技巧,通常与量化配合效果不错
  • 可以通过微调尝试不同的参数组合
相关推荐
axxy20009 分钟前
leetcode之hot100---240搜索二维矩阵II(C++)
数据结构·算法
苏言の狗11 分钟前
Pytorch中关于Tensor的操作
人工智能·pytorch·python·深度学习·机器学习
黑客Ash21 分钟前
安全算法基础(一)
算法·安全
bastgia1 小时前
Tokenformer: 下一代Transformer架构
人工智能·机器学习·llm
AI莫大猫1 小时前
(6)YOLOv4算法基本原理以及和YOLOv3 的差异
算法·yolo
taoyong0011 小时前
代码随想录算法训练营第十一天-239.滑动窗口最大值
c++·算法
Uu_05kkq1 小时前
【C语言1】C语言常见概念(总结复习篇)——库函数、ASCII码、转义字符
c语言·数据结构·算法
清梦20203 小时前
经典问题---跳跃游戏II(贪心算法)
算法·游戏·贪心算法
paixiaoxin3 小时前
CV-OCR经典论文解读|An Empirical Study of Scaling Law for OCR/OCR 缩放定律的实证研究
人工智能·深度学习·机器学习·生成对抗网络·计算机视觉·ocr·.net
Dream_Snowar3 小时前
速通Python 第四节——函数
开发语言·python·算法